84 lines
2.5 KiB
Python
84 lines
2.5 KiB
Python
import os
|
|
import sys
|
|
import argparse
|
|
import torch
|
|
|
|
from clip.simple_tokenizer import SimpleTokenizer
|
|
from clip import clip
|
|
|
|
# "ViT-B/16"
|
|
# "RN50"
|
|
def load_clip_to_cpu(backbone_name="ViT-B/16"):
|
|
url = clip._MODELS[backbone_name]
|
|
model_path = clip._download(url)
|
|
|
|
try:
|
|
# loading JIT archive
|
|
model = torch.jit.load(model_path, map_location="cpu").eval()
|
|
state_dict = None
|
|
|
|
except RuntimeError:
|
|
state_dict = torch.load(model_path, map_location="cpu")
|
|
|
|
model = clip.build_model(state_dict or model.state_dict())
|
|
|
|
return model
|
|
|
|
|
|
# parser = argparse.ArgumentParser()
|
|
# parser.add_argument("fpath", type=str, help="Path to the learned prompt")
|
|
# parser.add_argument("topk", type=int, help="Select top-k similar words")
|
|
# args = parser.parse_args()
|
|
|
|
fpath = "./compound_prompt_weights/train_base/food101/shots_16/cocoop/vit_b16_c4_ep10_batch1_ctxv1/seed1/prompt_learner/model.pth.tar-5"
|
|
topk = 10
|
|
|
|
assert os.path.exists(fpath)
|
|
|
|
print(f"Return the top-{topk} matched words")
|
|
|
|
tokenizer = SimpleTokenizer()
|
|
clip_model = load_clip_to_cpu()
|
|
token_embedding = clip_model.token_embedding.weight
|
|
print(f"Size of token embedding: {token_embedding.shape}")
|
|
|
|
prompt_learner = torch.load(fpath, map_location="cpu")["state_dict"]
|
|
# Extract the input tokens
|
|
ctx = prompt_learner["prompt_learner.ctx"]
|
|
ctx = ctx.float()
|
|
# Now extract the intermediate tokens
|
|
intermediate_embeddings = []
|
|
depth = 9 - 1
|
|
for i in range(depth):
|
|
# Now extract the prompt embeddings and store it
|
|
query = 'prompt_learner.compound_prompts_text.' + str(i)
|
|
temp = prompt_learner[query].float()
|
|
intermediate_embeddings.append(temp)
|
|
|
|
print(f"Size of context: {ctx.shape}")
|
|
|
|
# Now repeat this for all layer context embeddings
|
|
|
|
all_layer_ctx = [ctx] + intermediate_embeddings
|
|
|
|
for idx, single_ctx in enumerate(all_layer_ctx):
|
|
print("SHOWING RESULTS FOR CTX Vectors of Layer: ", idx + 1)
|
|
ctx = single_ctx
|
|
if ctx.dim() == 2:
|
|
# Generic context
|
|
distance = torch.cdist(ctx, token_embedding)
|
|
print(f"Size of distance matrix: {distance.shape}")
|
|
sorted_idxs = torch.argsort(distance, dim=1)
|
|
sorted_idxs = sorted_idxs[:, :topk]
|
|
|
|
for m, idxs in enumerate(sorted_idxs):
|
|
words = [tokenizer.decoder[idx.item()] for idx in idxs]
|
|
dist = [f"{distance[m, idx].item():.4f}" for idx in idxs]
|
|
print(f"{m+1}: {words} {dist}")
|
|
|
|
elif ctx.dim() == 3:
|
|
# Class-specific context
|
|
raise NotImplementedError
|
|
|
|
print("##############################")
|
|
print("##############################") |