508 lines
21 KiB
Python
508 lines
21 KiB
Python
import random
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
import os
|
|
import math
|
|
import clip
|
|
import ipdb
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
|
|
|
|
|
|
def zeroshot_classifier(classname, templates, CLIP_Text):
|
|
with torch.no_grad():
|
|
classname = classname.replace('_', ' ')
|
|
str_prompts = [template.format(classname) for template in templates]
|
|
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
|
features, eot_indices = CLIP_Text(prompts)
|
|
return features, eot_indices
|
|
|
|
def warm_train(classnames, templates,source_train_loader, source_train_loader_batch, model,
|
|
adapter, criterion_classifier_source, criterion_classifier_target, optimizer,
|
|
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder):
|
|
random.seed(1)
|
|
batch_time = AverageMeter()
|
|
data_time = AverageMeter()
|
|
losses_classifier = AverageMeter()
|
|
losses_G = AverageMeter()
|
|
losses_T = AverageMeter()
|
|
top1_source = AverageMeter()
|
|
top1_target = AverageMeter()
|
|
CLIP_Text.eval()
|
|
CLIP_Image.eval()
|
|
Text_Encoder.eval()
|
|
Image_Encoder.eval()
|
|
logit_scale = 4.60517
|
|
logit_scale = math.exp(logit_scale)
|
|
model.eval()
|
|
adapter.train()
|
|
new_epoch_flag = False
|
|
end = time.time()
|
|
concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda()
|
|
try:
|
|
(image, label,_) = source_train_loader_batch.__next__()[1]
|
|
except StopIteration:
|
|
epoch = epoch + 1
|
|
new_epoch_flag = True
|
|
source_train_loader_batch = enumerate(source_train_loader)
|
|
(image, label,_) = source_train_loader_batch.__next__()[1]
|
|
target_target = label.cuda()
|
|
#自监督标签
|
|
label_self_supervised = label.cuda()
|
|
indices = torch.randperm(len(label))
|
|
target_source = label[indices].cuda()
|
|
# target_source = label.cuda()
|
|
input_target = image.cuda()
|
|
|
|
zeroshot_weights = []
|
|
for i in range(len(target_source)):
|
|
features, eot_indices = zeroshot_classifier(classnames[target_source[i]], templates, CLIP_Text)
|
|
class_embeddings = Text_Encoder(features, eot_indices)
|
|
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
|
class_embedding = class_embeddings.mean(dim=0)
|
|
class_embedding = class_embedding / class_embedding.norm()
|
|
class_embedding = class_embedding / class_embedding.norm(dim=-1, keepdim=True)
|
|
zeroshot_weights.append(class_embedding)
|
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
|
input_source = zeroshot_weights.T
|
|
|
|
|
|
data_time.update(time.time() - end)
|
|
|
|
target_target_temp = target_target + len(classnames)
|
|
target_source_temp = target_source + len(classnames)
|
|
target_target_temp = target_target_temp.cuda()
|
|
|
|
# clip图片编码器
|
|
with torch.no_grad():
|
|
input_target_temp = CLIP_Image(input_target)
|
|
|
|
input_target_add = Image_Encoder(input_target_temp)
|
|
|
|
# 文本直接输入全连接层
|
|
|
|
output_source = adapter(input_source) * logit_scale
|
|
# 输入编码图片
|
|
|
|
output_target = adapter(input_target_add) * logit_scale
|
|
|
|
self_zeroshot_weights = []
|
|
for i in range(len(label_self_supervised)):
|
|
features, eot_indices = zeroshot_classifier(classnames[label_self_supervised[i]], templates, CLIP_Text)
|
|
class_embeddings = Text_Encoder(features, eot_indices)
|
|
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
|
class_embedding = class_embeddings.mean(dim=0)
|
|
class_embedding = class_embedding / class_embedding.norm()
|
|
self_zeroshot_weights.append(class_embedding)
|
|
self_zeroshot_weights = torch.stack(self_zeroshot_weights, dim=1).cuda()
|
|
self_input_source = self_zeroshot_weights.T
|
|
|
|
# 自监督文本输入全连接层
|
|
self_output_source = adapter(self_input_source)
|
|
self_output_source = F.normalize(self_output_source)
|
|
# 自监督图像特征
|
|
self_output_target = output_target / logit_scale
|
|
self_output_target = F.normalize(self_output_target)
|
|
# # 构造自监督标签0-255
|
|
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
|
|
logits_per_image = logit_scale * self_output_target @ self_output_source.T
|
|
logits_per_text = logit_scale * self_output_source @ self_output_target.T
|
|
loss_self_supervised = (
|
|
F.cross_entropy(logits_per_image, self_supervised_labels) +
|
|
F.cross_entropy(logits_per_text, self_supervised_labels)
|
|
) / 2
|
|
|
|
|
|
# 有监督分类的交叉熵损失
|
|
loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source)
|
|
loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target)
|
|
|
|
# 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。
|
|
loss_domain_st_Cst_part1 = criterion(output_source,target_source)
|
|
loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp)
|
|
|
|
# 类级别混淆
|
|
loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source,
|
|
target_source_temp)
|
|
# 域级别混淆
|
|
# loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source(
|
|
# output_target)
|
|
|
|
lam = 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1
|
|
if(epoch<30):
|
|
self_lam=3
|
|
else:
|
|
self_lam=1/5
|
|
|
|
loss_confusion_target = concatenatedCELoss(output_target)
|
|
loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2
|
|
loss_G = loss_category_st_G+lam*loss_confusion_target
|
|
loss_T = loss_G + loss_classifier+self_lam*loss_self_supervised
|
|
|
|
|
|
|
|
|
|
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5))
|
|
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
|
|
|
|
losses_classifier.update(loss_classifier.item(), input_source.size(0))
|
|
losses_G.update(loss_G.item(), input_source.size(0))
|
|
losses_T.update(loss_T.item(), input_source.size(0))
|
|
top1_source.update(prec1_source[0], input_source.size(0))
|
|
top1_target.update(prec1_target[0], input_source.size(0))
|
|
|
|
optimizer.zero_grad()
|
|
# loss_classifier.backward(retain_graph=True)
|
|
# optimizer.step()
|
|
#
|
|
# optimizer.zero_grad()
|
|
# loss_G.backward()
|
|
loss_T.backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
batch_time.update(time.time() - end)
|
|
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
|
|
print('Train: [{0}/{1}]\t'
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
|
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
|
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
|
|
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
|
|
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
|
|
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
|
|
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
|
|
epoch, args.epochs, batch_time=batch_time,
|
|
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source,
|
|
top1T=top1_target))
|
|
if new_epoch_flag:
|
|
log = open(os.path.join(args.log, 'log.txt'), 'a')
|
|
log.write("\n")
|
|
log.write("Train:epoch: %d, loss@min: %4f, loss@max: %4f, Top1S acc: %3f, Top1T acc: %3f" % (
|
|
epoch, losses_classifier.avg, losses_G.avg, top1_source.avg, top1_target.avg))
|
|
log.close()
|
|
return source_train_loader_batch, epoch, new_epoch_flag
|
|
|
|
def train(classnames, templates,source_train_loader, source_train_loader_batch, model,
|
|
adapter, criterion_classifier_source, criterion_classifier_target, optimizer,
|
|
epoch, args, scheduler, criterion, CLIP_Text, Text_Encoder, CLIP_Image, Image_Encoder):
|
|
random.seed(1)
|
|
batch_time = AverageMeter()
|
|
data_time = AverageMeter()
|
|
losses_classifier = AverageMeter()
|
|
losses_G = AverageMeter()
|
|
losses_T = AverageMeter()
|
|
top1_source = AverageMeter()
|
|
top1_target = AverageMeter()
|
|
CLIP_Text.eval()
|
|
CLIP_Image.eval()
|
|
Text_Encoder.eval()
|
|
Image_Encoder.train()
|
|
logit_scale = 4.60517
|
|
logit_scale = math.exp(logit_scale)
|
|
model.eval()
|
|
adapter.train()
|
|
new_epoch_flag = False
|
|
end = time.time()
|
|
concatenatedCELoss = ConcatenatedCELoss(num_classes=len(classnames)).cuda()
|
|
try:
|
|
(image, label,_) = source_train_loader_batch.__next__()[1]
|
|
except StopIteration:
|
|
epoch = epoch + 1
|
|
new_epoch_flag = True
|
|
source_train_loader_batch = enumerate(source_train_loader)
|
|
(image, label,_) = source_train_loader_batch.__next__()[1]
|
|
target_target = label.cuda()
|
|
#自监督标签
|
|
label_self_supervised = label.cuda()
|
|
indices = torch.randperm(len(label))
|
|
target_source = label[indices].cuda()
|
|
# target_source = label.cuda()
|
|
input_target = image.cuda()
|
|
|
|
zeroshot_weights = []
|
|
for i in range(len(target_source)):
|
|
features, eot_indices = zeroshot_classifier(classnames[target_source[i]], templates, CLIP_Text)
|
|
class_embeddings = Text_Encoder(features, eot_indices)
|
|
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
|
class_embedding = class_embeddings.mean(dim=0)
|
|
class_embedding = class_embedding / class_embedding.norm()
|
|
class_embedding = class_embedding / class_embedding.norm(dim=-1, keepdim=True)
|
|
zeroshot_weights.append(class_embedding)
|
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
|
input_source = zeroshot_weights.T
|
|
|
|
|
|
data_time.update(time.time() - end)
|
|
|
|
target_target_temp = target_target + len(classnames)
|
|
target_source_temp = target_source + len(classnames)
|
|
target_target_temp = target_target_temp.cuda()
|
|
|
|
# clip图片编码器
|
|
with torch.no_grad():
|
|
input_target_temp = CLIP_Image(input_target)
|
|
|
|
input_target_add = Image_Encoder(input_target_temp)
|
|
|
|
# 文本直接输入全连接层
|
|
|
|
output_source = adapter(input_source) * logit_scale
|
|
# 输入编码图片
|
|
|
|
output_target = adapter(input_target_add) * logit_scale
|
|
|
|
self_zeroshot_weights = []
|
|
for i in range(len(label_self_supervised)):
|
|
features, eot_indices = zeroshot_classifier(classnames[label_self_supervised[i]], templates, CLIP_Text)
|
|
class_embeddings = Text_Encoder(features, eot_indices)
|
|
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
|
class_embedding = class_embeddings.mean(dim=0)
|
|
class_embedding = class_embedding / class_embedding.norm()
|
|
self_zeroshot_weights.append(class_embedding)
|
|
self_zeroshot_weights = torch.stack(self_zeroshot_weights, dim=1).cuda()
|
|
self_input_source = self_zeroshot_weights.T
|
|
|
|
# 自监督文本输入全连接层
|
|
# self_output_source = adapter(self_input_source)
|
|
# self_output_source = F.normalize(self_output_source[:,:len(classnames)])
|
|
self_output_source = F.normalize(self_input_source)
|
|
|
|
# 自监督图像特征
|
|
# self_output_target = output_target / logit_scale
|
|
# self_output_target = F.normalize(self_output_target[:,len(classnames):])
|
|
self_output_target = F.normalize(input_target_add)
|
|
# # 构造自监督标签0-255
|
|
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
|
|
logits_per_image = logit_scale * self_output_target @ self_output_source.T
|
|
logits_per_text = logit_scale * self_output_source @ self_output_target.T
|
|
loss_self_supervised_1 = (
|
|
F.cross_entropy(logits_per_image, self_supervised_labels) +
|
|
F.cross_entropy(logits_per_text, self_supervised_labels)
|
|
) / 2
|
|
# 自监督文本输入全连接层
|
|
self_output_source = adapter(self_input_source)
|
|
self_output_source = F.normalize(self_output_source[:,:len(classnames)])
|
|
# self_output_source = F.normalize(self_input_source)
|
|
|
|
# 自监督图像特征
|
|
self_output_target = output_target / logit_scale
|
|
self_output_target = F.normalize(self_output_target[:,len(classnames):])
|
|
# self_output_target = F.normalize(input_target_add)
|
|
# # 构造自监督标签0-255
|
|
self_supervised_labels = torch.arange(self_output_source.shape[0], device="cuda:0", dtype=torch.long)
|
|
logits_per_image = logit_scale * self_output_target @ self_output_source.T
|
|
logits_per_text = logit_scale * self_output_source @ self_output_target.T
|
|
loss_self_supervised_2 = (
|
|
F.cross_entropy(logits_per_image, self_supervised_labels) +
|
|
F.cross_entropy(logits_per_text, self_supervised_labels)
|
|
) / 2
|
|
|
|
loss_self_supervised=loss_self_supervised_2+loss_self_supervised_1
|
|
|
|
|
|
# 有监督分类的交叉熵损失
|
|
loss_task_s_Cs = criterion(output_source[:, :len(classnames)], target_source)
|
|
loss_task_s_Ct = criterion(output_target[:, len(classnames):], target_target)
|
|
|
|
# 对于源域数据,它希望让分类器上半部分所占的概率尽可能大,对于目标域数据,它希望让分类器下半部分所占的概率尽可能大。
|
|
loss_domain_st_Cst_part1 = criterion(output_source,target_source)
|
|
loss_domain_st_Cst_part2 = criterion(output_target, target_target_temp)
|
|
|
|
# 类级别混淆
|
|
loss_category_st_G = 0.5 * criterion(output_target, target_target) + 0.5 * criterion(output_source,
|
|
target_source_temp)
|
|
# 域级别混淆
|
|
# loss_domain_st_G = 0.5 * criterion_classifier_target(output_source) + 0.5 * criterion_classifier_source(
|
|
# output_target)
|
|
|
|
lam = 2 / (1 + math.exp(-1 * 10 * epoch / args.epochs)) - 1
|
|
# if(epoch<30):
|
|
# self_lam= 5
|
|
# else:
|
|
self_lam=0.6
|
|
|
|
loss_confusion_target = concatenatedCELoss(output_target)
|
|
loss_classifier = loss_task_s_Cs + loss_task_s_Ct + loss_domain_st_Cst_part1 + loss_domain_st_Cst_part2
|
|
loss_G = loss_category_st_G+lam*loss_confusion_target
|
|
loss_T = loss_G + loss_classifier+self_lam*loss_self_supervised
|
|
|
|
|
|
|
|
|
|
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_source, topk=(1, 5))
|
|
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
|
|
|
|
losses_classifier.update(loss_classifier.item(), input_source.size(0))
|
|
losses_G.update(loss_G.item(), input_source.size(0))
|
|
losses_T.update(loss_T.item(), input_source.size(0))
|
|
top1_source.update(prec1_source[0], input_source.size(0))
|
|
top1_target.update(prec1_target[0], input_source.size(0))
|
|
|
|
optimizer.zero_grad()
|
|
# loss_classifier.backward(retain_graph=True)
|
|
# optimizer.step()
|
|
#
|
|
# optimizer.zero_grad()
|
|
# loss_G.backward()
|
|
loss_T.backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
batch_time.update(time.time() - end)
|
|
if (epoch + 1) % args.print_freq == 0 or epoch == 0:
|
|
print('Train: [{0}/{1}]\t'
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
|
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
|
'Loss@C {loss_c.val:.4f} ({loss_c.avg:.4f})\t'
|
|
'Loss@G {loss_g.val:.4f} ({loss_g.avg:.4f})\t'
|
|
'Loss@T {loss_t.val:.4f} ({loss_t.avg:.4f})\t'
|
|
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
|
|
'top1T {top1T.val:.3f} ({top1T.avg:.3f})\t'.format(
|
|
epoch, args.epochs, batch_time=batch_time,
|
|
data_time=data_time, loss_c=losses_classifier, loss_g=losses_G, loss_t=losses_T, top1S=top1_source,
|
|
top1T=top1_target))
|
|
if new_epoch_flag:
|
|
log = open(os.path.join(args.log, 'log.txt'), 'a')
|
|
log.write("\n")
|
|
log.write("Train:epoch: %d, loss@min: %4f, loss@max: %4f, Top1S acc: %3f, Top1T acc: %3f" % (
|
|
epoch, losses_classifier.avg, losses_G.avg, top1_source.avg, top1_target.avg))
|
|
log.close()
|
|
return source_train_loader_batch, epoch, new_epoch_flag
|
|
|
|
|
|
best_target_acc = 0
|
|
best_epoch=0
|
|
|
|
def validate(classnames, templates,val_loader, model, adapter, epoch, args, zero_shots, criterion, CLIP_Text, Text_Encoder, CLIP_Image,
|
|
Image_Encoder):
|
|
global best_target_acc
|
|
global best_epoch
|
|
batch_time = AverageMeter()
|
|
losses_source = AverageMeter()
|
|
losses_target = AverageMeter()
|
|
top1_source = AverageMeter()
|
|
top1_target = AverageMeter()
|
|
zero_acc_I_acc = AverageMeter()
|
|
clip_acc_aver = AverageMeter()
|
|
Compu4_acc = AverageMeter()
|
|
# switch to evaluate mode
|
|
CLIP_Text.eval()
|
|
CLIP_Image.eval()
|
|
Text_Encoder.eval()
|
|
Image_Encoder.eval()
|
|
|
|
model.eval()
|
|
adapter.eval()
|
|
end = time.time()
|
|
logit_scale = 4.60517
|
|
logit_scale = math.exp(logit_scale)
|
|
|
|
for i, (image, label,_) in enumerate(val_loader):
|
|
image = image.cuda()
|
|
label = label.cuda()
|
|
|
|
zeroshot_weights = []
|
|
for j in range(len(label)):
|
|
features, eot_indices = zeroshot_classifier(classnames[label[j]], templates, CLIP_Text)
|
|
with torch.no_grad():
|
|
class_embeddings = Text_Encoder(features, eot_indices)
|
|
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
|
class_embedding = class_embeddings.mean(dim=0)
|
|
class_embedding = class_embedding / class_embedding.norm()
|
|
class_embedding = class_embedding / class_embedding.norm(dim=-1, keepdim=True)
|
|
zeroshot_weights.append(class_embedding)
|
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
|
input_source = zeroshot_weights
|
|
input_source = input_source.T
|
|
input_target = image.cuda()
|
|
target_target = label.cuda()
|
|
target_source = label.cuda()
|
|
|
|
# clip图片编码器
|
|
with torch.no_grad():
|
|
input_target_temp = CLIP_Image(input_target)
|
|
input_target_add = Image_Encoder(input_target_temp)
|
|
#output_source = adapter(input_source) * logit_scale
|
|
output_target = adapter(input_target_add) * logit_scale
|
|
output_source=output_target
|
|
|
|
# 3
|
|
loss_source = criterion(output_source[:, :len(classnames)], target_target)
|
|
loss_target = criterion(output_target[:, len(classnames):], target_target)
|
|
|
|
# measure accuracy and record loss
|
|
prec1_source, _ = accuracy(output_source.data[:, :len(classnames)], target_target, topk=(1, 5))
|
|
prec1_target, _ = accuracy(output_target.data[:, len(classnames):], target_target, topk=(1, 5))
|
|
|
|
losses_source.update(loss_source.item(), image.size(0))
|
|
losses_target.update(loss_target.item(), image.size(0))
|
|
|
|
top1_source.update(prec1_source[0], image.size(0))
|
|
top1_target.update(prec1_target[0], image.size(0))
|
|
# measure elapsed time
|
|
batch_time.update(time.time() - end)
|
|
end = time.time()
|
|
|
|
if i % args.print_freq == 0:
|
|
|
|
print('Test: [{0}][{1}/{2}]\t'
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
|
'LS {lossS.val:.4f} ({lossS.avg:.4f})\t'
|
|
'LT {lossT.val:.4f} ({lossT.avg:.4f})\t'
|
|
'top1S {top1S.val:.3f} ({top1S.avg:.3f})\t'
|
|
'top1T {top1T.val:.3f} ({top1T.avg:.3f})'.format(
|
|
epoch, i, len(val_loader), batch_time=batch_time, lossS=losses_source, lossT=losses_target,
|
|
top1S=top1_source, top1T=top1_target))
|
|
print(' * Top1@S {top1S.avg:.3f} Top1@T {top1T.avg:.3f}'
|
|
.format(top1S=top1_source, top1T=top1_target))
|
|
if max(top1_target.avg,top1_source.avg) > best_target_acc:
|
|
best_target_acc = max(top1_target.avg,top1_source.avg)
|
|
best_epoch=epoch
|
|
print('best_epoch', best_epoch,' * Current_best_target@T:', best_target_acc.item())
|
|
|
|
log = open(os.path.join(args.log, 'log.txt'), 'a')
|
|
log.write("\n")
|
|
log.write(" Test:epoch: %d, LS: %4f, LT: %4f, Top1S: %3f, Top1T: %3f" % \
|
|
(epoch, losses_source.avg, losses_target.avg, top1_source.avg, top1_target.avg))
|
|
log.close()
|
|
return best_target_acc.item()
|
|
|
|
|
|
class AverageMeter(object):
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
def accuracy(output, target, topk=(1,)):
|
|
"""Computes the precision@k for the specified values of k"""
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
pred = pred.t()
|
|
_2, pred2 = output.topk(1, 1, True, True)
|
|
a = target.view(1, -1)
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
# print(correct)
|
|
res = []
|
|
for k in topk:
|
|
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
|
res.append(correct_k.mul_(100.0 / batch_size))
|
|
return res
|