release code
This commit is contained in:
4
Dassl.ProGrad.pytorch/dassl/metrics/__init__.py
Normal file
4
Dassl.ProGrad.pytorch/dassl/metrics/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .accuracy import compute_accuracy
|
||||
from .distance import (
|
||||
cosine_distance, compute_distance_matrix, euclidean_squared_distance
|
||||
)
|
||||
30
Dassl.ProGrad.pytorch/dassl/metrics/accuracy.py
Normal file
30
Dassl.ProGrad.pytorch/dassl/metrics/accuracy.py
Normal file
@@ -0,0 +1,30 @@
|
||||
def compute_accuracy(output, target, topk=(1, )):
|
||||
"""Computes the accuracy over the k top predictions for
|
||||
the specified values of k.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): prediction matrix with shape (batch_size, num_classes).
|
||||
target (torch.LongTensor): ground truth labels with shape (batch_size).
|
||||
topk (tuple, optional): accuracy at top-k will be computed. For example,
|
||||
topk=(1, 5) means accuracy at top-1 and top-5 will be computed.
|
||||
|
||||
Returns:
|
||||
list: accuracy at top-k.
|
||||
"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
if isinstance(output, (tuple, list)):
|
||||
output = output[0]
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
acc = correct_k.mul_(100.0 / batch_size)
|
||||
res.append(acc)
|
||||
|
||||
return res
|
||||
77
Dassl.ProGrad.pytorch/dassl/metrics/distance.py
Normal file
77
Dassl.ProGrad.pytorch/dassl/metrics/distance.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Source: https://github.com/KaiyangZhou/deep-person-reid
|
||||
"""
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def compute_distance_matrix(input1, input2, metric="euclidean"):
|
||||
"""A wrapper function for computing distance matrix.
|
||||
|
||||
Each input matrix has the shape (n_data, feature_dim).
|
||||
|
||||
Args:
|
||||
input1 (torch.Tensor): 2-D feature matrix.
|
||||
input2 (torch.Tensor): 2-D feature matrix.
|
||||
metric (str, optional): "euclidean" or "cosine".
|
||||
Default is "euclidean".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: distance matrix.
|
||||
"""
|
||||
# check input
|
||||
assert isinstance(input1, torch.Tensor)
|
||||
assert isinstance(input2, torch.Tensor)
|
||||
assert input1.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
|
||||
input1.dim()
|
||||
)
|
||||
assert input2.dim() == 2, "Expected 2-D tensor, but got {}-D".format(
|
||||
input2.dim()
|
||||
)
|
||||
assert input1.size(1) == input2.size(1)
|
||||
|
||||
if metric == "euclidean":
|
||||
distmat = euclidean_squared_distance(input1, input2)
|
||||
elif metric == "cosine":
|
||||
distmat = cosine_distance(input1, input2)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown distance metric: {}. "
|
||||
'Please choose either "euclidean" or "cosine"'.format(metric)
|
||||
)
|
||||
|
||||
return distmat
|
||||
|
||||
|
||||
def euclidean_squared_distance(input1, input2):
|
||||
"""Computes euclidean squared distance.
|
||||
|
||||
Args:
|
||||
input1 (torch.Tensor): 2-D feature matrix.
|
||||
input2 (torch.Tensor): 2-D feature matrix.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: distance matrix.
|
||||
"""
|
||||
m, n = input1.size(0), input2.size(0)
|
||||
mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
|
||||
mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
distmat = mat1 + mat2
|
||||
distmat.addmm_(1, -2, input1, input2.t())
|
||||
return distmat
|
||||
|
||||
|
||||
def cosine_distance(input1, input2):
|
||||
"""Computes cosine distance.
|
||||
|
||||
Args:
|
||||
input1 (torch.Tensor): 2-D feature matrix.
|
||||
input2 (torch.Tensor): 2-D feature matrix.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: distance matrix.
|
||||
"""
|
||||
input1_normed = F.normalize(input1, p=2, dim=1)
|
||||
input2_normed = F.normalize(input2, p=2, dim=1)
|
||||
distmat = 1 - torch.mm(input1_normed, input2_normed.t())
|
||||
return distmat
|
||||
Reference in New Issue
Block a user