import torch from torch import nn #clip文本编码器前半 class TransformerEncoder(nn.Module): def __init__(self, dtype, token_embedding, positional_embedding=None, transformer_encoder=None, ln_final=None, text_projection=None): super().__init__() self.dtype = dtype self.token_embedding = token_embedding self.positional_embedding = positional_embedding self.transformer_encoder = transformer_encoder self.ln_final = ln_final self.text_projection = text_projection if self.positional_embedding is None: assert self.transformer_encoder is None if self.transformer_encoder is None: assert self.ln_final is None if self.ln_final is None: assert self.text_projection is None def forward(self, text): x = self.token_embedding(text).type(self.dtype) # (bs, seq_len, dim) eot_indices = text.argmax(dim=-1) if self.positional_embedding is not None: x = x + self.positional_embedding.type(self.dtype) if self.transformer_encoder is not None: x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer_encoder(x) x = x.permute(1, 0, 2) # LND -> NLD if self.ln_final is not None: x = self.ln_final(x).type(self.dtype) if self.text_projection is not None: x = x[torch.arange(x.shape[0]), eot_indices] @ self.text_projection return x, eot_indices #clip文本编码器后半 class PartialTransformer(nn.Module): def __init__(self, dtype, logit_scale, vocab_size, positional_embedding=None, partial_transformer=None, ln_final=None, text_projection=None): super().__init__() self.dtype = dtype self.logit_scale = logit_scale self.vocab_size = vocab_size self.positional_embedding = positional_embedding self.partial_transformer = partial_transformer self.ln_final = ln_final self.text_projection = text_projection if self.positional_embedding is not None: assert self.partial_transformer is not None assert self.ln_final is not None assert self.text_projection is not None elif self.partial_transformer is not None: assert self.ln_final is not None assert self.text_projection is not None elif self.ln_final is not None: assert self.text_projection is not None def forward(self, x, eot_indices): if self.positional_embedding is not None: x = x + self.positional_embedding.type(self.dtype) if self.partial_transformer is not None: x = x.permute(1, 0, 2) x = self.partial_transformer(x) x = x.permute(1, 0, 2) if self.ln_final is not None: x = self.ln_final(x).type(self.dtype) x = x[torch.arange(x.shape[0]), eot_indices] @ self.text_projection return x #返回前后两个文本编码器 def get_text(clip_model, text_layer_idx=0): # contains feature_extractor (does encode_text() from prompts) and partial_model (need to reverse the dim) vocab_size = clip_model.vocab_size token_embedding = clip_model.token_embedding positional_embedding = clip_model.positional_embedding transformer = clip_model.transformer ln_final = clip_model.ln_final text_projection = clip_model.text_projection logit_scale = clip_model.logit_scale dtype = clip_model.dtype if text_layer_idx == -1: # finetune all layers feature_extractor = TransformerEncoder( dtype, token_embedding) partial_model = PartialTransformer( dtype, logit_scale, vocab_size, positional_embedding=positional_embedding, partial_transformer=transformer, ln_final=ln_final, text_projection=text_projection) elif text_layer_idx == 0: # finetune no layers feature_extractor = TransformerEncoder( dtype, token_embedding, positional_embedding=positional_embedding, transformer_encoder=transformer, ln_final=ln_final, text_projection=text_projection) partial_model = PartialTransformer(dtype, logit_scale, vocab_size) else: # finetune some layers transformer_encoder = transformer.resblocks[:-text_layer_idx] partial_transformer = transformer.resblocks[-text_layer_idx:] feature_extractor = TransformerEncoder( dtype, token_embedding, positional_embedding=positional_embedding, transformer_encoder=transformer_encoder) partial_model = PartialTransformer( dtype, logit_scale, vocab_size, positional_embedding=None, partial_transformer=partial_transformer, ln_final=ln_final, text_projection=text_projection) return feature_extractor, partial_model #RN50图像编码器分段 class PartialResNet(nn.Module): def __init__(self, conv1=None, bn1=None, conv2=None, bn2=None, conv3=None, bn3=None, layer1=None, layer2=None, layer3=None, layer4=None, attnpool=None, mode='feature_extractor'): super().__init__() assert mode in ['feature_extractor', 'partial_model'] self.conv1 = conv1 self.bn1 = bn1 self.conv2 = conv2 self.bn2 = bn2 self.conv3 = conv3 self.bn3 = bn3 self.relu = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(2) self.layer1 = layer1 self.layer2 = layer2 self.layer3 = layer3 self.layer4 = layer4 self.attnpool = attnpool self.apply_stem = self.conv3 != None if mode == 'partial_model': if self.conv1 is not None: assert self.bn1 is not None if self.bn1 is not None: assert self.conv2 is not None if self.conv2 is not None: assert self.bn2 is not None if self.bn2 is not None: assert self.conv3 is not None if self.conv3 is not None: assert self.conv1 is not None # make sure entire stem is included assert self.bn3 is not None if self.bn3 is not None: assert self.layer1 is not None if self.layer1 is not None: assert self.layer2 is not None if self.layer2 is not None: assert self.layer3 is not None if self.layer3 is not None: assert self.layer4 is not None if self.layer4 is not None: assert self.attnpool is not None elif mode == 'feature_extractor': if self.attnpool is not None: assert self.layer4 is not None if self.layer4 is not None: assert self.layer3 is not None if self.layer3 is not None: assert self.layer2 is not None if self.layer2 is not None: assert self.layer1 is not None if self.layer1 is not None: assert self.bn3 is not None if self.bn3 is not None: assert self.conv3 is not None if self.conv3 is not None: assert self.bn2 is not None if self.bn2 is not None: assert self.conv2 is not None if self.conv2 is not None: assert self.bn1 is not None if self.bn1 is not None: assert self.conv1 is not None def forward(self, x): if self.apply_stem: def stem(x): for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: x = self.relu(bn(conv(x))) x = self.avgpool(x) return x x = x.type(self.conv1.weight.dtype) x = stem(x) if self.layer1 is not None: x = self.layer1(x) if self.layer2 is not None: x = self.layer2(x) if self.layer3 is not None: x = self.layer3(x) if self.layer4 is not None: x = self.layer4(x) if self.attnpool is not None: x = self.attnpool(x) return x #返回前后两个RN50图像编码器 def get_image_resnet(model, image_layer_idx=0): # contains feature_extractor and partial_model # the 3-layer stem conv1 = model.conv1 bn1 = model.bn1 conv2 = model.conv2 bn2 = model.bn2 conv3 = model.conv3 bn3 = model.bn3 avgpool = model.avgpool relu = model.relu layer1 = model.layer1 layer2 = model.layer2 layer3 = model.layer3 layer4 = model.layer4 attnpool = model.attnpool if image_layer_idx == -1: # finetune all layers feature_extractor = PartialResNet(mode='feature_extractor') partial_model = PartialResNet(conv1=conv1, bn1=bn1, conv2=conv2, bn2=bn2, conv3=conv3, bn3=bn3, layer1=layer1, layer2=layer2, layer3=layer3, layer4=layer4, attnpool=attnpool, mode='partial_model') elif image_layer_idx == 0: # finetune no layers feature_extractor = PartialResNet(conv1=conv1, bn1=bn1, conv2=conv2, bn2=bn2, conv3=conv3, bn3=bn3, layer1=layer1, layer2=layer2, layer3=layer3, layer4=layer4, attnpool=attnpool, mode='feature_extractor') partial_model = PartialResNet(mode='partial_model') elif image_layer_idx == 1: # finetune attention pool feature_extractor = PartialResNet(conv1=conv1, bn1=bn1, conv2=conv2, bn2=bn2, conv3=conv3, bn3=bn3, layer1=layer1, layer2=layer2, layer3=layer3, layer4=layer4, mode='feature_extractor') partial_model = PartialResNet(attnpool=attnpool, mode='partial_model') elif image_layer_idx == 2: # finetune attnpool and layer4 feature_extractor = PartialResNet(conv1=conv1, bn1=bn1, conv2=conv2, bn2=bn2, conv3=conv3, bn3=bn3, layer1=layer1, layer2=layer2, layer3=layer3, mode='feature_extractor') partial_model = PartialResNet(layer4=layer4, attnpool=attnpool, mode='partial_model') else: raise ValueError("Invalid layer index") return feature_extractor, partial_model #vit16图像编码器 class PartialViT(nn.Module): def __init__(self, conv1=None, class_embedding=None, positional_embedding=None, ln_pre=None, transformer_encoder=None, ln_post=None, proj=None, mode='feature_extractor'): super().__init__() assert mode in ['feature_extractor', 'partial_model'] self.conv1 = conv1 self.class_embedding = class_embedding self.positional_embedding = positional_embedding self.ln_pre = ln_pre self.transformer_encoder = transformer_encoder self.ln_post = ln_post self.proj = proj if mode == 'partial_model': if self.conv1 is not None: assert self.ln_pre is not None if self.ln_pre is not None: assert self.transformer_encoder is not None if self.transformer_encoder is not None: assert self.ln_post is not None if self.ln_post is not None: assert self.proj is not None elif mode == 'feature_extractor': if self.proj is not None: assert self.ln_post is not None if self.ln_post is not None: assert self.transformer_encoder is not None if self.transformer_encoder is not None: assert self.ln_pre is not None if self.ln_pre is not None: assert self.conv1 is not None def forward(self, x): if self.conv1 is not None: x = self.conv1(x) x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] if self.class_embedding is not None: x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] if self.positional_embedding is not None: x = x + self.positional_embedding.to(x.dtype) if self.ln_pre is not None: x = self.ln_pre(x) if self.transformer_encoder is not None: x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer_encoder(x) x = x.permute(1, 0, 2) # LND -> NLD if self.ln_post is not None: x = self.ln_post(x[:, 0, :]) if self.proj is not None: x = x @ self.proj return x #返回前后两个vit16图像编码器 def get_image_vit(model, image_layer_idx=0): # contains feature_extractor and partial_model conv1 = model.conv1 class_embedding = model.class_embedding positional_embedding = model.positional_embedding ln_pre = model.ln_pre transformer = model.transformer ln_post = model.ln_post proj = model.proj if image_layer_idx == -1: # finetune all layers feature_extractor = PartialViT(mode='feature_extractor') partial_model = PartialViT(conv1=conv1, class_embedding=class_embedding, positional_embedding=positional_embedding, ln_pre=ln_pre, transformer_encoder=transformer, ln_post=ln_post, proj=proj, mode='partial_model') elif image_layer_idx == 0: # finetune no layers feature_extractor = PartialViT(conv1=conv1, class_embedding=class_embedding, positional_embedding=positional_embedding, ln_pre=ln_pre, transformer_encoder=transformer, ln_post=ln_post, proj=proj, mode='feature_extractor') partial_model = PartialViT(mode='partial_model') else: # finetune some layers transformer_encoder = transformer.resblocks[:-image_layer_idx] partial_transformer = transformer.resblocks[-image_layer_idx:] feature_extractor = PartialViT(conv1=conv1, class_embedding=class_embedding, positional_embedding=positional_embedding, ln_pre=ln_pre, transformer_encoder=transformer_encoder, mode='feature_extractor') partial_model = PartialViT(transformer_encoder=partial_transformer, ln_post=ln_post, proj=proj, mode='partial_model') return feature_extractor, partial_model