Files
clip-symnets/test_t_sne.py
2024-05-21 19:41:56 +08:00

127 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from opts import opts
from clip import clip
import torch
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights, gpt_clip_classifier,calculate_zeroshot_weights_GPT
import glob
import json
import os
import numpy as np
from PIL import Image
# 设置随机种子以确保可重复性
np.random.seed(0)
args = opts()
model, preprocess = clip.load(args.name)
model = model.cuda()
model.float()
CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder = configure_clip_encoders(args, model, 0, 1)
prepare_directories(args, CLIP_Text, CLIP_Image)
json_files = glob.glob('gpt_file/caltech_prompt.json')
for file_path in json_files:
# 打开并读取每个YAML文件
with open(file_path, 'r') as f:
gpt3_prompt = json.load(f)
dir = './weights/'
CLIP_Text=torch.load(dir+'CLIP_Text.pth')
CLIP_Image=torch.load(dir+'CLIP_Image.pth')
Image_Encoder=torch.load(dir+'Image_Encoder.pth')
Text_Encoder=torch.load(dir+'Text_Encoder.pth')
adapter=torch.load(dir+'adapter.pth')
image_folder = "./images"
image_folder2 = "./images_2"
image_folder3 = "./images_3"
# 读取图片并提取特征
features = []
for image_name in os.listdir(image_folder):
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
# 读取图片
image_path = os.path.join(image_folder, image_name)
image = preprocess(Image.open(image_path)).unsqueeze(0).cuda()
with torch.no_grad():
input_target_temp = CLIP_Image(image)
input_target_add = Image_Encoder(input_target_temp)
# 存储提取的特征
# features.append(input_target_add.detach().numpy())
features.append(input_target_add.detach().cpu().numpy())
features2 = []
for image_name in os.listdir(image_folder2):
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
# 读取图片
image_path = os.path.join(image_folder2, image_name)
image = preprocess(Image.open(image_path)).unsqueeze(0).cuda()
with torch.no_grad():
input_target_temp = CLIP_Image(image)
input_target_add = Image_Encoder(input_target_temp)
# 存储提取的特征
# features.append(input_target_add.detach().numpy())
features2.append(input_target_add.detach().cpu().numpy())
# 将特征列表转换为NumPy数组
features = np.vstack(features)
features2 = np.vstack(features2)
features3 = []
for image_name in os.listdir(image_folder3):
if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
# 读取图片
image_path = os.path.join(image_folder3, image_name)
image = preprocess(Image.open(image_path)).unsqueeze(0).cuda()
with torch.no_grad():
input_target_temp = CLIP_Image(image)
input_target_add = Image_Encoder(input_target_temp)
# 存储提取的特征
# features.append(input_target_add.detach().numpy())
features3.append(input_target_add.detach().cpu().numpy())
# 将特征列表转换为NumPy数组
features = np.vstack(features)
features2 = np.vstack(features2)
features3 = np.vstack(features3)
# 源域样本:一个二维特征
# source_domain_sample = np.array([[0.5, 0.5]]) # 示例特征
# 生成随机数据目标域20个样本每个样本1024维
# target_domain_samples = np.random.rand(20, 1024)
target_domain_samples = features
source_domain_samples=features2
# 由于源域只有一个二维样本我们不需要对它使用t-SNE
tsne = TSNE(n_components=2, random_state=0,perplexity=19)
source_domain_tsne_results = tsne.fit_transform(source_domain_samples)
# 直接对目标域样本应用t-SNE降维到2维
# tsne = TSNE(n_components=2, random_state=0,perplexity=7)
target_domain_tsne_results = tsne.fit_transform(target_domain_samples)
results3=tsne.fit_transform(features3)
# 可视化结果
plt.figure(figsize=(10, 6))
# plt.scatter(source_domain_tsne_results[:, 0], source_domain_tsne_results[:, 1], c='blue', label='Source Domain (Single Sample)', edgecolors='w', s=200)
plt.scatter(source_domain_tsne_results[:, 0], source_domain_tsne_results[:, 1], c='blue', label='Source Domain (Single Sample)')#, edgecolors='w')
plt.scatter(target_domain_tsne_results[:, 0], target_domain_tsne_results[:, 1], c='red', label='Target Domain')
plt.scatter(results3[:, 0], results3[:, 1], c='yellow', label='3')
plt.title('t-SNE visualization of feature representations')
plt.xlabel('Component 1')
plt.ylabel('Component 2')
plt.legend()
plt.show()