60 lines
2.5 KiB
Python
60 lines
2.5 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
class ModifiedBasicBlock(nn.Module):
|
||
expansion = 1
|
||
|
||
def __init__(self, in_channels, out_channels):
|
||
super(ModifiedBasicBlock, self).__init__()
|
||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||
|
||
# 确保shortcut路径的维度匹配,如果不匹配则通过1x1卷积进行调整
|
||
self.shortcut = nn.Sequential()
|
||
if in_channels != out_channels:
|
||
self.shortcut = nn.Sequential(
|
||
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
||
nn.BatchNorm2d(out_channels)
|
||
)
|
||
|
||
def forward(self, x):
|
||
out = F.relu(self.bn1(self.conv1(x)))
|
||
out = self.bn2(self.conv2(out))
|
||
out += self.shortcut(x) # 添加shortcut连接
|
||
out = F.relu(out)
|
||
return out
|
||
class ModifiedResNet(nn.Module):
|
||
def __init__(self, block, layers, num_classes=1000):
|
||
super(ModifiedResNet, self).__init__()
|
||
self.in_channels = 512 # 假设起始通道数为512
|
||
self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
|
||
self.bn1 = nn.BatchNorm2d(512)
|
||
self.layer1 = self._make_layer(block, 512, layers[0])
|
||
self.layer2 = self._make_layer(block, 512, layers[1])
|
||
self.layer3 = self._make_layer(block, 512, layers[2])
|
||
self.layer4 = self._make_layer(block, 512, layers[3])
|
||
self.linear = nn.Linear(512*block.expansion, num_classes)
|
||
|
||
def _make_layer(self, block, out_channels, blocks):
|
||
layers = []
|
||
for _ in range(blocks):
|
||
layers.append(block(self.in_channels, out_channels))
|
||
self.in_channels = out_channels * block.expansion
|
||
return nn.Sequential(*layers)
|
||
|
||
def forward(self, x):
|
||
out = F.relu(self.bn1(self.conv1(x)))
|
||
out = self.layer1(out)
|
||
out = self.layer2(out)
|
||
out = self.layer3(out)
|
||
out = self.layer4(out)
|
||
# 假设你有一个特定的方式来处理最终的特征图,以得到256x512维的输出
|
||
# 例如,可以使用自适应池化层调整尺寸或者直接reshape(根据实际需求)
|
||
out = F.adaptive_avg_pool2d(out, (x.size(0), x.size(1)))
|
||
out = out.view(out.size(0), -1)
|
||
out = self.linear(out)
|
||
return out
|