Release of PromptSRC with pretrained models.
This commit is contained in:
49409
interpret_prompts/clip_words.csv
Normal file
49409
interpret_prompts/clip_words.csv
Normal file
File diff suppressed because it is too large
Load Diff
84
interpret_prompts/interpret_prompt.py
Normal file
84
interpret_prompts/interpret_prompt.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from clip.simple_tokenizer import SimpleTokenizer
|
||||
from clip import clip
|
||||
|
||||
# "ViT-B/16"
|
||||
# "RN50"
|
||||
def load_clip_to_cpu(backbone_name="ViT-B/16"):
|
||||
url = clip._MODELS[backbone_name]
|
||||
model_path = clip._download(url)
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location="cpu").eval()
|
||||
state_dict = None
|
||||
|
||||
except RuntimeError:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
model = clip.build_model(state_dict or model.state_dict())
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument("fpath", type=str, help="Path to the learned prompt")
|
||||
# parser.add_argument("topk", type=int, help="Select top-k similar words")
|
||||
# args = parser.parse_args()
|
||||
|
||||
fpath = "./compound_prompt_weights/train_base/food101/shots_16/cocoop/vit_b16_c4_ep10_batch1_ctxv1/seed1/prompt_learner/model.pth.tar-5"
|
||||
topk = 10
|
||||
|
||||
assert os.path.exists(fpath)
|
||||
|
||||
print(f"Return the top-{topk} matched words")
|
||||
|
||||
tokenizer = SimpleTokenizer()
|
||||
clip_model = load_clip_to_cpu()
|
||||
token_embedding = clip_model.token_embedding.weight
|
||||
print(f"Size of token embedding: {token_embedding.shape}")
|
||||
|
||||
prompt_learner = torch.load(fpath, map_location="cpu")["state_dict"]
|
||||
# Extract the input tokens
|
||||
ctx = prompt_learner["prompt_learner.ctx"]
|
||||
ctx = ctx.float()
|
||||
# Now extract the intermediate tokens
|
||||
intermediate_embeddings = []
|
||||
depth = 9 - 1
|
||||
for i in range(depth):
|
||||
# Now extract the prompt embeddings and store it
|
||||
query = 'prompt_learner.compound_prompts_text.' + str(i)
|
||||
temp = prompt_learner[query].float()
|
||||
intermediate_embeddings.append(temp)
|
||||
|
||||
print(f"Size of context: {ctx.shape}")
|
||||
|
||||
# Now repeat this for all layer context embeddings
|
||||
|
||||
all_layer_ctx = [ctx] + intermediate_embeddings
|
||||
|
||||
for idx, single_ctx in enumerate(all_layer_ctx):
|
||||
print("SHOWING RESULTS FOR CTX Vectors of Layer: ", idx + 1)
|
||||
ctx = single_ctx
|
||||
if ctx.dim() == 2:
|
||||
# Generic context
|
||||
distance = torch.cdist(ctx, token_embedding)
|
||||
print(f"Size of distance matrix: {distance.shape}")
|
||||
sorted_idxs = torch.argsort(distance, dim=1)
|
||||
sorted_idxs = sorted_idxs[:, :topk]
|
||||
|
||||
for m, idxs in enumerate(sorted_idxs):
|
||||
words = [tokenizer.decoder[idx.item()] for idx in idxs]
|
||||
dist = [f"{distance[m, idx].item():.4f}" for idx in idxs]
|
||||
print(f"{m+1}: {words} {dist}")
|
||||
|
||||
elif ctx.dim() == 3:
|
||||
# Class-specific context
|
||||
raise NotImplementedError
|
||||
|
||||
print("##############################")
|
||||
print("##############################")
|
||||
Reference in New Issue
Block a user