init
This commit is contained in:
0
engine/__init__.py
Normal file
0
engine/__init__.py
Normal file
424
engine/partial_model.py
Normal file
424
engine/partial_model.py
Normal file
@@ -0,0 +1,424 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
#clip文本编码器前半
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, dtype,
|
||||
token_embedding,
|
||||
positional_embedding=None,
|
||||
transformer_encoder=None,
|
||||
ln_final=None,
|
||||
text_projection=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.token_embedding = token_embedding
|
||||
self.positional_embedding = positional_embedding
|
||||
self.transformer_encoder = transformer_encoder
|
||||
self.ln_final = ln_final
|
||||
self.text_projection = text_projection
|
||||
if self.positional_embedding is None:
|
||||
assert self.transformer_encoder is None
|
||||
if self.transformer_encoder is None:
|
||||
assert self.ln_final is None
|
||||
if self.ln_final is None:
|
||||
assert self.text_projection is None
|
||||
|
||||
def forward(self, text):
|
||||
x = self.token_embedding(text).type(self.dtype) # (bs, seq_len, dim)
|
||||
eot_indices = text.argmax(dim=-1)
|
||||
if self.positional_embedding is not None:
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
|
||||
if self.transformer_encoder is not None:
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer_encoder(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
if self.ln_final is not None:
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
if self.text_projection is not None:
|
||||
x = x[torch.arange(x.shape[0]), eot_indices] @ self.text_projection
|
||||
return x, eot_indices
|
||||
|
||||
#clip文本编码器后半
|
||||
class PartialTransformer(nn.Module):
|
||||
def __init__(self, dtype,
|
||||
logit_scale,
|
||||
vocab_size,
|
||||
positional_embedding=None,
|
||||
partial_transformer=None,
|
||||
ln_final=None,
|
||||
text_projection=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.logit_scale = logit_scale
|
||||
self.vocab_size = vocab_size
|
||||
self.positional_embedding = positional_embedding
|
||||
self.partial_transformer = partial_transformer
|
||||
self.ln_final = ln_final
|
||||
self.text_projection = text_projection
|
||||
if self.positional_embedding is not None:
|
||||
assert self.partial_transformer is not None
|
||||
assert self.ln_final is not None
|
||||
assert self.text_projection is not None
|
||||
elif self.partial_transformer is not None:
|
||||
assert self.ln_final is not None
|
||||
assert self.text_projection is not None
|
||||
elif self.ln_final is not None:
|
||||
assert self.text_projection is not None
|
||||
|
||||
def forward(self, x, eot_indices):
|
||||
if self.positional_embedding is not None:
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
|
||||
if self.partial_transformer is not None:
|
||||
x = x.permute(1, 0, 2)
|
||||
x = self.partial_transformer(x)
|
||||
x = x.permute(1, 0, 2)
|
||||
|
||||
if self.ln_final is not None:
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
x = x[torch.arange(x.shape[0]), eot_indices] @ self.text_projection
|
||||
return x
|
||||
|
||||
#返回前后两个文本编码器
|
||||
def get_text(clip_model, text_layer_idx=0):
|
||||
# contains feature_extractor (does encode_text() from prompts) and partial_model (need to reverse the dim)
|
||||
vocab_size = clip_model.vocab_size
|
||||
token_embedding = clip_model.token_embedding
|
||||
positional_embedding = clip_model.positional_embedding
|
||||
transformer = clip_model.transformer
|
||||
ln_final = clip_model.ln_final
|
||||
text_projection = clip_model.text_projection
|
||||
logit_scale = clip_model.logit_scale
|
||||
dtype = clip_model.dtype
|
||||
|
||||
if text_layer_idx == -1:
|
||||
# finetune all layers
|
||||
feature_extractor = TransformerEncoder(
|
||||
dtype, token_embedding)
|
||||
partial_model = PartialTransformer(
|
||||
dtype, logit_scale, vocab_size,
|
||||
positional_embedding=positional_embedding,
|
||||
partial_transformer=transformer,
|
||||
ln_final=ln_final, text_projection=text_projection)
|
||||
elif text_layer_idx == 0:
|
||||
# finetune no layers
|
||||
feature_extractor = TransformerEncoder(
|
||||
dtype, token_embedding,
|
||||
positional_embedding=positional_embedding, transformer_encoder=transformer,
|
||||
ln_final=ln_final, text_projection=text_projection)
|
||||
partial_model = PartialTransformer(dtype, logit_scale, vocab_size)
|
||||
else:
|
||||
# finetune some layers
|
||||
transformer_encoder = transformer.resblocks[:-text_layer_idx]
|
||||
partial_transformer = transformer.resblocks[-text_layer_idx:]
|
||||
feature_extractor = TransformerEncoder(
|
||||
dtype, token_embedding,
|
||||
positional_embedding=positional_embedding,
|
||||
transformer_encoder=transformer_encoder)
|
||||
partial_model = PartialTransformer(
|
||||
dtype, logit_scale, vocab_size,
|
||||
positional_embedding=None,
|
||||
partial_transformer=partial_transformer,
|
||||
ln_final=ln_final, text_projection=text_projection)
|
||||
return feature_extractor, partial_model
|
||||
|
||||
#RN50图像编码器分段
|
||||
class PartialResNet(nn.Module):
|
||||
def __init__(self, conv1=None,
|
||||
bn1=None,
|
||||
conv2=None,
|
||||
bn2=None,
|
||||
conv3=None,
|
||||
bn3=None,
|
||||
layer1=None,
|
||||
layer2=None,
|
||||
layer3=None,
|
||||
layer4=None,
|
||||
attnpool=None,
|
||||
mode='feature_extractor'):
|
||||
super().__init__()
|
||||
assert mode in ['feature_extractor', 'partial_model']
|
||||
self.conv1 = conv1
|
||||
self.bn1 = bn1
|
||||
self.conv2 = conv2
|
||||
self.bn2 = bn2
|
||||
self.conv3 = conv3
|
||||
self.bn3 = bn3
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
self.layer1 = layer1
|
||||
self.layer2 = layer2
|
||||
self.layer3 = layer3
|
||||
self.layer4 = layer4
|
||||
self.attnpool = attnpool
|
||||
self.apply_stem = self.conv3 != None
|
||||
if mode == 'partial_model':
|
||||
if self.conv1 is not None:
|
||||
assert self.bn1 is not None
|
||||
if self.bn1 is not None:
|
||||
assert self.conv2 is not None
|
||||
if self.conv2 is not None:
|
||||
assert self.bn2 is not None
|
||||
if self.bn2 is not None:
|
||||
assert self.conv3 is not None
|
||||
if self.conv3 is not None:
|
||||
assert self.conv1 is not None # make sure entire stem is included
|
||||
assert self.bn3 is not None
|
||||
if self.bn3 is not None:
|
||||
assert self.layer1 is not None
|
||||
if self.layer1 is not None:
|
||||
assert self.layer2 is not None
|
||||
if self.layer2 is not None:
|
||||
assert self.layer3 is not None
|
||||
if self.layer3 is not None:
|
||||
assert self.layer4 is not None
|
||||
if self.layer4 is not None:
|
||||
assert self.attnpool is not None
|
||||
elif mode == 'feature_extractor':
|
||||
if self.attnpool is not None:
|
||||
assert self.layer4 is not None
|
||||
if self.layer4 is not None:
|
||||
assert self.layer3 is not None
|
||||
if self.layer3 is not None:
|
||||
assert self.layer2 is not None
|
||||
if self.layer2 is not None:
|
||||
assert self.layer1 is not None
|
||||
if self.layer1 is not None:
|
||||
assert self.bn3 is not None
|
||||
if self.bn3 is not None:
|
||||
assert self.conv3 is not None
|
||||
if self.conv3 is not None:
|
||||
assert self.bn2 is not None
|
||||
if self.bn2 is not None:
|
||||
assert self.conv2 is not None
|
||||
if self.conv2 is not None:
|
||||
assert self.bn1 is not None
|
||||
if self.bn1 is not None:
|
||||
assert self.conv1 is not None
|
||||
|
||||
def forward(self, x):
|
||||
if self.apply_stem:
|
||||
def stem(x):
|
||||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
||||
x = self.relu(bn(conv(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
x = x.type(self.conv1.weight.dtype)
|
||||
x = stem(x)
|
||||
if self.layer1 is not None:
|
||||
x = self.layer1(x)
|
||||
if self.layer2 is not None:
|
||||
x = self.layer2(x)
|
||||
if self.layer3 is not None:
|
||||
x = self.layer3(x)
|
||||
if self.layer4 is not None:
|
||||
x = self.layer4(x)
|
||||
if self.attnpool is not None:
|
||||
x = self.attnpool(x)
|
||||
return x
|
||||
|
||||
#返回前后两个RN50图像编码器
|
||||
def get_image_resnet(model, image_layer_idx=0):
|
||||
# contains feature_extractor and partial_model
|
||||
# the 3-layer stem
|
||||
conv1 = model.conv1
|
||||
bn1 = model.bn1
|
||||
conv2 = model.conv2
|
||||
bn2 = model.bn2
|
||||
conv3 = model.conv3
|
||||
bn3 = model.bn3
|
||||
avgpool = model.avgpool
|
||||
relu = model.relu
|
||||
|
||||
layer1 = model.layer1
|
||||
layer2 = model.layer2
|
||||
layer3 = model.layer3
|
||||
layer4 = model.layer4
|
||||
|
||||
attnpool = model.attnpool
|
||||
|
||||
if image_layer_idx == -1:
|
||||
# finetune all layers
|
||||
feature_extractor = PartialResNet(mode='feature_extractor')
|
||||
partial_model = PartialResNet(conv1=conv1,
|
||||
bn1=bn1,
|
||||
conv2=conv2,
|
||||
bn2=bn2,
|
||||
conv3=conv3,
|
||||
bn3=bn3,
|
||||
layer1=layer1,
|
||||
layer2=layer2,
|
||||
layer3=layer3,
|
||||
layer4=layer4,
|
||||
attnpool=attnpool,
|
||||
mode='partial_model')
|
||||
elif image_layer_idx == 0:
|
||||
# finetune no layers
|
||||
feature_extractor = PartialResNet(conv1=conv1,
|
||||
bn1=bn1,
|
||||
conv2=conv2,
|
||||
bn2=bn2,
|
||||
conv3=conv3,
|
||||
bn3=bn3,
|
||||
layer1=layer1,
|
||||
layer2=layer2,
|
||||
layer3=layer3,
|
||||
layer4=layer4,
|
||||
attnpool=attnpool,
|
||||
mode='feature_extractor')
|
||||
partial_model = PartialResNet(mode='partial_model')
|
||||
elif image_layer_idx == 1:
|
||||
# finetune attention pool
|
||||
feature_extractor = PartialResNet(conv1=conv1,
|
||||
bn1=bn1,
|
||||
conv2=conv2,
|
||||
bn2=bn2,
|
||||
conv3=conv3,
|
||||
bn3=bn3,
|
||||
layer1=layer1,
|
||||
layer2=layer2,
|
||||
layer3=layer3,
|
||||
layer4=layer4,
|
||||
mode='feature_extractor')
|
||||
partial_model = PartialResNet(attnpool=attnpool,
|
||||
mode='partial_model')
|
||||
elif image_layer_idx == 2:
|
||||
# finetune attnpool and layer4
|
||||
feature_extractor = PartialResNet(conv1=conv1,
|
||||
bn1=bn1,
|
||||
conv2=conv2,
|
||||
bn2=bn2,
|
||||
conv3=conv3,
|
||||
bn3=bn3,
|
||||
layer1=layer1,
|
||||
layer2=layer2,
|
||||
layer3=layer3,
|
||||
mode='feature_extractor')
|
||||
partial_model = PartialResNet(layer4=layer4,
|
||||
attnpool=attnpool,
|
||||
mode='partial_model')
|
||||
else:
|
||||
raise ValueError("Invalid layer index")
|
||||
|
||||
return feature_extractor, partial_model
|
||||
|
||||
#vit16图像编码器
|
||||
class PartialViT(nn.Module):
|
||||
def __init__(self, conv1=None,
|
||||
class_embedding=None,
|
||||
positional_embedding=None,
|
||||
ln_pre=None,
|
||||
transformer_encoder=None,
|
||||
ln_post=None,
|
||||
proj=None,
|
||||
mode='feature_extractor'):
|
||||
super().__init__()
|
||||
assert mode in ['feature_extractor', 'partial_model']
|
||||
self.conv1 = conv1
|
||||
self.class_embedding = class_embedding
|
||||
self.positional_embedding = positional_embedding
|
||||
self.ln_pre = ln_pre
|
||||
self.transformer_encoder = transformer_encoder
|
||||
self.ln_post = ln_post
|
||||
self.proj = proj
|
||||
if mode == 'partial_model':
|
||||
if self.conv1 is not None:
|
||||
assert self.ln_pre is not None
|
||||
if self.ln_pre is not None:
|
||||
assert self.transformer_encoder is not None
|
||||
if self.transformer_encoder is not None:
|
||||
assert self.ln_post is not None
|
||||
if self.ln_post is not None:
|
||||
assert self.proj is not None
|
||||
elif mode == 'feature_extractor':
|
||||
if self.proj is not None:
|
||||
assert self.ln_post is not None
|
||||
if self.ln_post is not None:
|
||||
assert self.transformer_encoder is not None
|
||||
if self.transformer_encoder is not None:
|
||||
assert self.ln_pre is not None
|
||||
if self.ln_pre is not None:
|
||||
assert self.conv1 is not None
|
||||
|
||||
def forward(self, x):
|
||||
if self.conv1 is not None:
|
||||
x = self.conv1(x)
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
if self.class_embedding is not None:
|
||||
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype,
|
||||
device=x.device), x],
|
||||
dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
|
||||
if self.positional_embedding is not None:
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
|
||||
if self.ln_pre is not None:
|
||||
x = self.ln_pre(x)
|
||||
|
||||
if self.transformer_encoder is not None:
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer_encoder(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
if self.ln_post is not None:
|
||||
x = self.ln_post(x[:, 0, :])
|
||||
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
#返回前后两个vit16图像编码器
|
||||
def get_image_vit(model, image_layer_idx=0):
|
||||
# contains feature_extractor and partial_model
|
||||
conv1 = model.conv1
|
||||
class_embedding = model.class_embedding
|
||||
positional_embedding = model.positional_embedding
|
||||
ln_pre = model.ln_pre
|
||||
transformer = model.transformer
|
||||
ln_post = model.ln_post
|
||||
proj = model.proj
|
||||
|
||||
if image_layer_idx == -1:
|
||||
# finetune all layers
|
||||
feature_extractor = PartialViT(mode='feature_extractor')
|
||||
partial_model = PartialViT(conv1=conv1,
|
||||
class_embedding=class_embedding,
|
||||
positional_embedding=positional_embedding,
|
||||
ln_pre=ln_pre,
|
||||
transformer_encoder=transformer,
|
||||
ln_post=ln_post,
|
||||
proj=proj,
|
||||
mode='partial_model')
|
||||
elif image_layer_idx == 0:
|
||||
# finetune no layers
|
||||
feature_extractor = PartialViT(conv1=conv1,
|
||||
class_embedding=class_embedding,
|
||||
positional_embedding=positional_embedding,
|
||||
ln_pre=ln_pre,
|
||||
transformer_encoder=transformer,
|
||||
ln_post=ln_post,
|
||||
proj=proj,
|
||||
mode='feature_extractor')
|
||||
partial_model = PartialViT(mode='partial_model')
|
||||
else:
|
||||
# finetune some layers
|
||||
transformer_encoder = transformer.resblocks[:-image_layer_idx]
|
||||
partial_transformer = transformer.resblocks[-image_layer_idx:]
|
||||
feature_extractor = PartialViT(conv1=conv1,
|
||||
class_embedding=class_embedding,
|
||||
positional_embedding=positional_embedding,
|
||||
ln_pre=ln_pre,
|
||||
transformer_encoder=transformer_encoder,
|
||||
mode='feature_extractor')
|
||||
partial_model = PartialViT(transformer_encoder=partial_transformer,
|
||||
ln_post=ln_post,
|
||||
proj=proj,
|
||||
mode='partial_model')
|
||||
return feature_extractor, partial_model
|
||||
Reference in New Issue
Block a user