import torch import torch.nn as nn import torch.nn.functional as F class ModifiedBasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels): super(ModifiedBasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) # 确保shortcut路径的维度匹配,如果不匹配则通过1x1卷积进行调整 self.shortcut = nn.Sequential() if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) # 添加shortcut连接 out = F.relu(out) return out class ModifiedResNet(nn.Module): def __init__(self, block, layers, num_classes=1000): super(ModifiedResNet, self).__init__() self.in_channels = 512 # 假设起始通道数为512 self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(512) self.layer1 = self._make_layer(block, 512, layers[0]) self.layer2 = self._make_layer(block, 512, layers[1]) self.layer3 = self._make_layer(block, 512, layers[2]) self.layer4 = self._make_layer(block, 512, layers[3]) self.linear = nn.Linear(512*block.expansion, num_classes) def _make_layer(self, block, out_channels, blocks): layers = [] for _ in range(blocks): layers.append(block(self.in_channels, out_channels)) self.in_channels = out_channels * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) # 假设你有一个特定的方式来处理最终的特征图,以得到256x512维的输出 # 例如,可以使用自适应池化层调整尺寸或者直接reshape(根据实际需求) out = F.adaptive_avg_pool2d(out, (x.size(0), x.size(1))) out = out.view(out.size(0), -1) out = self.linear(out) return out