37 lines
1.3 KiB
Python
37 lines
1.3 KiB
Python
import matplotlib.pyplot as plt
|
||
from matplotlib.font_manager import FontProperties
|
||
# 指定支持中文的字体(Windows系统示例,黑体)
|
||
chinese_font = FontProperties(fname='C:/Windows/Fonts/simhei.ttf')
|
||
|
||
# 数据集名称
|
||
datasets = ["Caltech", "ImageNet", "DTD", "EuroSAT", "Aircraft", "Food", "Flowers", "Pets", "Cars", "SUN397", "UCF101", "Average"]
|
||
|
||
# CLIP在每个数据集上的准确率
|
||
clip_accuracy = [86.29, 58.18, 42.32, 37.56, 17.28, 77.31, 66.14, 85.77, 55.61, 58.52, 61.46, 58.77]
|
||
|
||
# GPT-CLIP在每个数据集上的准确率
|
||
gpt_clip_accuracy = [88.84, 61.46, 50.06, 37.80, 20.61, 77.65, 64.91, 86.40, 57.00, 62.00, 62.83, 60.87]
|
||
|
||
# 创建折线图,调整线型和标记,以便更容易区分
|
||
plt.figure(figsize=(14, 8))
|
||
plt.plot(datasets, clip_accuracy, marker='o', linestyle='--', color='blue', linewidth=2, markersize=8, label='CLIP')
|
||
plt.plot(datasets, gpt_clip_accuracy, marker='s', linestyle='-.', color='red', linewidth=2, markersize=8, label='GPT-CLIP')
|
||
|
||
# 添加图例
|
||
plt.legend()
|
||
|
||
# 设置图表标题和轴标签
|
||
# plt.title('不同数据集上的准确率对比')
|
||
plt.title('不同数据集上的准确率对比', fontproperties=chinese_font, fontsize=20)
|
||
plt.xlabel('Dataset')
|
||
|
||
plt.ylabel('Accuracy')
|
||
|
||
# 优化横坐标标签显示
|
||
plt.xticks(rotation=45)
|
||
plt.savefig('./GTP-CLIP_ACC.jpg')
|
||
# 显示图表
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|