init
This commit is contained in:
59
res.py
Normal file
59
res.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ModifiedBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(ModifiedBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
# 确保shortcut路径的维度匹配,如果不匹配则通过1x1卷积进行调整
|
||||
self.shortcut = nn.Sequential()
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.bn2(self.conv2(out))
|
||||
out += self.shortcut(x) # 添加shortcut连接
|
||||
out = F.relu(out)
|
||||
return out
|
||||
class ModifiedResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=1000):
|
||||
super(ModifiedResNet, self).__init__()
|
||||
self.in_channels = 512 # 假设起始通道数为512
|
||||
self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(512)
|
||||
self.layer1 = self._make_layer(block, 512, layers[0])
|
||||
self.layer2 = self._make_layer(block, 512, layers[1])
|
||||
self.layer3 = self._make_layer(block, 512, layers[2])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3])
|
||||
self.linear = nn.Linear(512*block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, out_channels, blocks):
|
||||
layers = []
|
||||
for _ in range(blocks):
|
||||
layers.append(block(self.in_channels, out_channels))
|
||||
self.in_channels = out_channels * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
# 假设你有一个特定的方式来处理最终的特征图,以得到256x512维的输出
|
||||
# 例如,可以使用自适应池化层调整尺寸或者直接reshape(根据实际需求)
|
||||
out = F.adaptive_avg_pool2d(out, (x.size(0), x.size(1)))
|
||||
out = out.view(out.size(0), -1)
|
||||
out = self.linear(out)
|
||||
return out
|
||||
Reference in New Issue
Block a user