Files
clip-symnets/res.py
2024-05-21 19:41:56 +08:00

60 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ModifiedBasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels):
super(ModifiedBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# 确保shortcut路径的维度匹配如果不匹配则通过1x1卷积进行调整
self.shortcut = nn.Sequential()
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x) # 添加shortcut连接
out = F.relu(out)
return out
class ModifiedResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
super(ModifiedResNet, self).__init__()
self.in_channels = 512 # 假设起始通道数为512
self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(512)
self.layer1 = self._make_layer(block, 512, layers[0])
self.layer2 = self._make_layer(block, 512, layers[1])
self.layer3 = self._make_layer(block, 512, layers[2])
self.layer4 = self._make_layer(block, 512, layers[3])
self.linear = nn.Linear(512*block.expansion, num_classes)
def _make_layer(self, block, out_channels, blocks):
layers = []
for _ in range(blocks):
layers.append(block(self.in_channels, out_channels))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
# 假设你有一个特定的方式来处理最终的特征图以得到256x512维的输出
# 例如可以使用自适应池化层调整尺寸或者直接reshape根据实际需求
out = F.adaptive_avg_pool2d(out, (x.size(0), x.size(1)))
out = out.view(out.size(0), -1)
out = self.linear(out)
return out