This commit is contained in:
2024-05-21 19:41:56 +08:00
commit ca67205608
217 changed files with 201004 additions and 0 deletions

0
engine/__init__.py Normal file
View File

424
engine/partial_model.py Normal file
View File

@@ -0,0 +1,424 @@
import torch
from torch import nn
#clip文本编码器前半
class TransformerEncoder(nn.Module):
def __init__(self, dtype,
token_embedding,
positional_embedding=None,
transformer_encoder=None,
ln_final=None,
text_projection=None):
super().__init__()
self.dtype = dtype
self.token_embedding = token_embedding
self.positional_embedding = positional_embedding
self.transformer_encoder = transformer_encoder
self.ln_final = ln_final
self.text_projection = text_projection
if self.positional_embedding is None:
assert self.transformer_encoder is None
if self.transformer_encoder is None:
assert self.ln_final is None
if self.ln_final is None:
assert self.text_projection is None
def forward(self, text):
x = self.token_embedding(text).type(self.dtype) # (bs, seq_len, dim)
eot_indices = text.argmax(dim=-1)
if self.positional_embedding is not None:
x = x + self.positional_embedding.type(self.dtype)
if self.transformer_encoder is not None:
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer_encoder(x)
x = x.permute(1, 0, 2) # LND -> NLD
if self.ln_final is not None:
x = self.ln_final(x).type(self.dtype)
if self.text_projection is not None:
x = x[torch.arange(x.shape[0]), eot_indices] @ self.text_projection
return x, eot_indices
#clip文本编码器后半
class PartialTransformer(nn.Module):
def __init__(self, dtype,
logit_scale,
vocab_size,
positional_embedding=None,
partial_transformer=None,
ln_final=None,
text_projection=None):
super().__init__()
self.dtype = dtype
self.logit_scale = logit_scale
self.vocab_size = vocab_size
self.positional_embedding = positional_embedding
self.partial_transformer = partial_transformer
self.ln_final = ln_final
self.text_projection = text_projection
if self.positional_embedding is not None:
assert self.partial_transformer is not None
assert self.ln_final is not None
assert self.text_projection is not None
elif self.partial_transformer is not None:
assert self.ln_final is not None
assert self.text_projection is not None
elif self.ln_final is not None:
assert self.text_projection is not None
def forward(self, x, eot_indices):
if self.positional_embedding is not None:
x = x + self.positional_embedding.type(self.dtype)
if self.partial_transformer is not None:
x = x.permute(1, 0, 2)
x = self.partial_transformer(x)
x = x.permute(1, 0, 2)
if self.ln_final is not None:
x = self.ln_final(x).type(self.dtype)
x = x[torch.arange(x.shape[0]), eot_indices] @ self.text_projection
return x
#返回前后两个文本编码器
def get_text(clip_model, text_layer_idx=0):
# contains feature_extractor (does encode_text() from prompts) and partial_model (need to reverse the dim)
vocab_size = clip_model.vocab_size
token_embedding = clip_model.token_embedding
positional_embedding = clip_model.positional_embedding
transformer = clip_model.transformer
ln_final = clip_model.ln_final
text_projection = clip_model.text_projection
logit_scale = clip_model.logit_scale
dtype = clip_model.dtype
if text_layer_idx == -1:
# finetune all layers
feature_extractor = TransformerEncoder(
dtype, token_embedding)
partial_model = PartialTransformer(
dtype, logit_scale, vocab_size,
positional_embedding=positional_embedding,
partial_transformer=transformer,
ln_final=ln_final, text_projection=text_projection)
elif text_layer_idx == 0:
# finetune no layers
feature_extractor = TransformerEncoder(
dtype, token_embedding,
positional_embedding=positional_embedding, transformer_encoder=transformer,
ln_final=ln_final, text_projection=text_projection)
partial_model = PartialTransformer(dtype, logit_scale, vocab_size)
else:
# finetune some layers
transformer_encoder = transformer.resblocks[:-text_layer_idx]
partial_transformer = transformer.resblocks[-text_layer_idx:]
feature_extractor = TransformerEncoder(
dtype, token_embedding,
positional_embedding=positional_embedding,
transformer_encoder=transformer_encoder)
partial_model = PartialTransformer(
dtype, logit_scale, vocab_size,
positional_embedding=None,
partial_transformer=partial_transformer,
ln_final=ln_final, text_projection=text_projection)
return feature_extractor, partial_model
#RN50图像编码器分段
class PartialResNet(nn.Module):
def __init__(self, conv1=None,
bn1=None,
conv2=None,
bn2=None,
conv3=None,
bn3=None,
layer1=None,
layer2=None,
layer3=None,
layer4=None,
attnpool=None,
mode='feature_extractor'):
super().__init__()
assert mode in ['feature_extractor', 'partial_model']
self.conv1 = conv1
self.bn1 = bn1
self.conv2 = conv2
self.bn2 = bn2
self.conv3 = conv3
self.bn3 = bn3
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
self.layer1 = layer1
self.layer2 = layer2
self.layer3 = layer3
self.layer4 = layer4
self.attnpool = attnpool
self.apply_stem = self.conv3 != None
if mode == 'partial_model':
if self.conv1 is not None:
assert self.bn1 is not None
if self.bn1 is not None:
assert self.conv2 is not None
if self.conv2 is not None:
assert self.bn2 is not None
if self.bn2 is not None:
assert self.conv3 is not None
if self.conv3 is not None:
assert self.conv1 is not None # make sure entire stem is included
assert self.bn3 is not None
if self.bn3 is not None:
assert self.layer1 is not None
if self.layer1 is not None:
assert self.layer2 is not None
if self.layer2 is not None:
assert self.layer3 is not None
if self.layer3 is not None:
assert self.layer4 is not None
if self.layer4 is not None:
assert self.attnpool is not None
elif mode == 'feature_extractor':
if self.attnpool is not None:
assert self.layer4 is not None
if self.layer4 is not None:
assert self.layer3 is not None
if self.layer3 is not None:
assert self.layer2 is not None
if self.layer2 is not None:
assert self.layer1 is not None
if self.layer1 is not None:
assert self.bn3 is not None
if self.bn3 is not None:
assert self.conv3 is not None
if self.conv3 is not None:
assert self.bn2 is not None
if self.bn2 is not None:
assert self.conv2 is not None
if self.conv2 is not None:
assert self.bn1 is not None
if self.bn1 is not None:
assert self.conv1 is not None
def forward(self, x):
if self.apply_stem:
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
if self.layer1 is not None:
x = self.layer1(x)
if self.layer2 is not None:
x = self.layer2(x)
if self.layer3 is not None:
x = self.layer3(x)
if self.layer4 is not None:
x = self.layer4(x)
if self.attnpool is not None:
x = self.attnpool(x)
return x
#返回前后两个RN50图像编码器
def get_image_resnet(model, image_layer_idx=0):
# contains feature_extractor and partial_model
# the 3-layer stem
conv1 = model.conv1
bn1 = model.bn1
conv2 = model.conv2
bn2 = model.bn2
conv3 = model.conv3
bn3 = model.bn3
avgpool = model.avgpool
relu = model.relu
layer1 = model.layer1
layer2 = model.layer2
layer3 = model.layer3
layer4 = model.layer4
attnpool = model.attnpool
if image_layer_idx == -1:
# finetune all layers
feature_extractor = PartialResNet(mode='feature_extractor')
partial_model = PartialResNet(conv1=conv1,
bn1=bn1,
conv2=conv2,
bn2=bn2,
conv3=conv3,
bn3=bn3,
layer1=layer1,
layer2=layer2,
layer3=layer3,
layer4=layer4,
attnpool=attnpool,
mode='partial_model')
elif image_layer_idx == 0:
# finetune no layers
feature_extractor = PartialResNet(conv1=conv1,
bn1=bn1,
conv2=conv2,
bn2=bn2,
conv3=conv3,
bn3=bn3,
layer1=layer1,
layer2=layer2,
layer3=layer3,
layer4=layer4,
attnpool=attnpool,
mode='feature_extractor')
partial_model = PartialResNet(mode='partial_model')
elif image_layer_idx == 1:
# finetune attention pool
feature_extractor = PartialResNet(conv1=conv1,
bn1=bn1,
conv2=conv2,
bn2=bn2,
conv3=conv3,
bn3=bn3,
layer1=layer1,
layer2=layer2,
layer3=layer3,
layer4=layer4,
mode='feature_extractor')
partial_model = PartialResNet(attnpool=attnpool,
mode='partial_model')
elif image_layer_idx == 2:
# finetune attnpool and layer4
feature_extractor = PartialResNet(conv1=conv1,
bn1=bn1,
conv2=conv2,
bn2=bn2,
conv3=conv3,
bn3=bn3,
layer1=layer1,
layer2=layer2,
layer3=layer3,
mode='feature_extractor')
partial_model = PartialResNet(layer4=layer4,
attnpool=attnpool,
mode='partial_model')
else:
raise ValueError("Invalid layer index")
return feature_extractor, partial_model
#vit16图像编码器
class PartialViT(nn.Module):
def __init__(self, conv1=None,
class_embedding=None,
positional_embedding=None,
ln_pre=None,
transformer_encoder=None,
ln_post=None,
proj=None,
mode='feature_extractor'):
super().__init__()
assert mode in ['feature_extractor', 'partial_model']
self.conv1 = conv1
self.class_embedding = class_embedding
self.positional_embedding = positional_embedding
self.ln_pre = ln_pre
self.transformer_encoder = transformer_encoder
self.ln_post = ln_post
self.proj = proj
if mode == 'partial_model':
if self.conv1 is not None:
assert self.ln_pre is not None
if self.ln_pre is not None:
assert self.transformer_encoder is not None
if self.transformer_encoder is not None:
assert self.ln_post is not None
if self.ln_post is not None:
assert self.proj is not None
elif mode == 'feature_extractor':
if self.proj is not None:
assert self.ln_post is not None
if self.ln_post is not None:
assert self.transformer_encoder is not None
if self.transformer_encoder is not None:
assert self.ln_pre is not None
if self.ln_pre is not None:
assert self.conv1 is not None
def forward(self, x):
if self.conv1 is not None:
x = self.conv1(x)
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
if self.class_embedding is not None:
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype,
device=x.device), x],
dim=1) # shape = [*, grid ** 2 + 1, width]
if self.positional_embedding is not None:
x = x + self.positional_embedding.to(x.dtype)
if self.ln_pre is not None:
x = self.ln_pre(x)
if self.transformer_encoder is not None:
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer_encoder(x)
x = x.permute(1, 0, 2) # LND -> NLD
if self.ln_post is not None:
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
#返回前后两个vit16图像编码器
def get_image_vit(model, image_layer_idx=0):
# contains feature_extractor and partial_model
conv1 = model.conv1
class_embedding = model.class_embedding
positional_embedding = model.positional_embedding
ln_pre = model.ln_pre
transformer = model.transformer
ln_post = model.ln_post
proj = model.proj
if image_layer_idx == -1:
# finetune all layers
feature_extractor = PartialViT(mode='feature_extractor')
partial_model = PartialViT(conv1=conv1,
class_embedding=class_embedding,
positional_embedding=positional_embedding,
ln_pre=ln_pre,
transformer_encoder=transformer,
ln_post=ln_post,
proj=proj,
mode='partial_model')
elif image_layer_idx == 0:
# finetune no layers
feature_extractor = PartialViT(conv1=conv1,
class_embedding=class_embedding,
positional_embedding=positional_embedding,
ln_pre=ln_pre,
transformer_encoder=transformer,
ln_post=ln_post,
proj=proj,
mode='feature_extractor')
partial_model = PartialViT(mode='partial_model')
else:
# finetune some layers
transformer_encoder = transformer.resblocks[:-image_layer_idx]
partial_transformer = transformer.resblocks[-image_layer_idx:]
feature_extractor = PartialViT(conv1=conv1,
class_embedding=class_embedding,
positional_embedding=positional_embedding,
ln_pre=ln_pre,
transformer_encoder=transformer_encoder,
mode='feature_extractor')
partial_model = PartialViT(transformer_encoder=partial_transformer,
ln_post=ln_post,
proj=proj,
mode='partial_model')
return feature_extractor, partial_model