This commit is contained in:
2024-05-21 19:41:56 +08:00
commit ca67205608
217 changed files with 201004 additions and 0 deletions

126
test_t_sne.py Normal file
View File

@@ -0,0 +1,126 @@
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from opts import opts
from clip import clip
import torch
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights, gpt_clip_classifier,calculate_zeroshot_weights_GPT
import glob
import json
import os
import numpy as np
from PIL import Image
# 设置随机种子以确保可重复性
np.random.seed(0)
args = opts()
model, preprocess = clip.load(args.name)
model = model.cuda()
model.float()
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 1)
prepare_directories(args, CLIP_Text, CLIP_Image)
json_files = glob.glob('gpt_file/caltech_prompt.json')
for file_path in json_files:
# 打开并读取每个YAML文件
with open(file_path, 'r') as f:
gpt3_prompt = json.load(f)
dir = './weights/'
CLIP_Text=torch.load(dir+'CLIP_Text.pth')
CLIP_Image=torch.load(dir+'CLIP_Image.pth')
Image_Encoder=torch.load(dir+'Image_Encoder.pth')
Text_Encoder=torch.load(dir+'Text_Encoder.pth')
adapter=torch.load(dir+'adapter.pth')
image_folder = "./images"
image_folder2 = "./images_2"
image_folder3 = "./images_3"
# 读取图片并提取特征
features = []
for image_name in os.listdir(image_folder):
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
# 读取图片
image_path = os.path.join(image_folder, image_name)
image = preprocess(Image.open(image_path)).unsqueeze(0).cuda()
with torch.no_grad():
input_target_temp = CLIP_Image(image)
input_target_add = Image_Encoder(input_target_temp)
# 存储提取的特征
# features.append(input_target_add.detach().numpy())
features.append(input_target_add.detach().cpu().numpy())
features2 = []
for image_name in os.listdir(image_folder2):
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
# 读取图片
image_path = os.path.join(image_folder2, image_name)
image = preprocess(Image.open(image_path)).unsqueeze(0).cuda()
with torch.no_grad():
input_target_temp = CLIP_Image(image)
input_target_add = Image_Encoder(input_target_temp)
# 存储提取的特征
# features.append(input_target_add.detach().numpy())
features2.append(input_target_add.detach().cpu().numpy())
# 将特征列表转换为NumPy数组
features = np.vstack(features)
features2 = np.vstack(features2)
features3 = []
for image_name in os.listdir(image_folder3):
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
# 读取图片
image_path = os.path.join(image_folder3, image_name)
image = preprocess(Image.open(image_path)).unsqueeze(0).cuda()
with torch.no_grad():
input_target_temp = CLIP_Image(image)
input_target_add = Image_Encoder(input_target_temp)
# 存储提取的特征
# features.append(input_target_add.detach().numpy())
features3.append(input_target_add.detach().cpu().numpy())
# 将特征列表转换为NumPy数组
features = np.vstack(features)
features2 = np.vstack(features2)
features3 = np.vstack(features3)
# 源域样本:一个二维特征
# source_domain_sample = np.array([[0.5, 0.5]]) # 示例特征
# 生成随机数据目标域20个样本每个样本1024维
# target_domain_samples = np.random.rand(20, 1024)
target_domain_samples = features
source_domain_samples=features2
# 由于源域只有一个二维样本我们不需要对它使用t-SNE
tsne = TSNE(n_components=2, random_state=0,perplexity=19)
source_domain_tsne_results = tsne.fit_transform(source_domain_samples)
# 直接对目标域样本应用t-SNE降维到2维
# tsne = TSNE(n_components=2, random_state=0,perplexity=7)
target_domain_tsne_results = tsne.fit_transform(target_domain_samples)
results3=tsne.fit_transform(features3)
# 可视化结果
plt.figure(figsize=(10, 6))
# plt.scatter(source_domain_tsne_results[:, 0], source_domain_tsne_results[:, 1], c='blue', label='Source Domain (Single Sample)', edgecolors='w', s=200)
plt.scatter(source_domain_tsne_results[:, 0], source_domain_tsne_results[:, 1], c='blue', label='Source Domain (Single Sample)')#, edgecolors='w')
plt.scatter(target_domain_tsne_results[:, 0], target_domain_tsne_results[:, 1], c='red', label='Target Domain')
plt.scatter(results3[:, 0], results3[:, 1], c='yellow', label='3')
plt.title('t-SNE visualization of feature representations')
plt.xlabel('Component 1')
plt.ylabel('Component 2')
plt.legend()
plt.show()