37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
|
|
def euclidean_dist(x, y):
|
|
m, n = x.size(0), y.size(0)
|
|
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
|
|
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
|
|
dist = xx + yy
|
|
dist.addmm_(1, -2, x, y.t())
|
|
dist = dist.clamp(min=1e-12).sqrt()
|
|
return dist
|
|
|
|
|
|
def euclidean_dist_pair(x):
|
|
m = x.size(0)
|
|
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, m)
|
|
dist = xx + xx.t()
|
|
dist.addmm_(1, -2, x, x.t())
|
|
dist = dist.clamp(min=1e-12).sqrt()
|
|
return dist
|
|
|
|
def euclidean_dist_np(x, y):
|
|
(rowx, colx) = x.shape
|
|
(rowy, coly) = y.shape
|
|
xy = np.dot(x, y.T)
|
|
x2 = np.repeat(np.reshape(np.sum(np.multiply(x, x), axis=1), (rowx, 1)), repeats=rowy, axis=1)
|
|
y2 = np.repeat(np.reshape(np.sum(np.multiply(y, y), axis=1), (rowy, 1)), repeats=rowx, axis=1).T
|
|
return np.sqrt(np.clip(x2 + y2 - 2. * xy, 1e-12, None))
|
|
|
|
#calculate the euclidean distance of each sample in x, return a N*N matrix, whose diag is zero
|
|
def euclidean_dist_pair_np(x):
|
|
(rowx, colx) = x.shape
|
|
xy = np.dot(x, x.T)
|
|
x2 = np.repeat(np.reshape(np.sum(np.multiply(x, x), axis=1), (rowx, 1)), repeats=rowx, axis=1)
|
|
return np.sqrt(np.clip(x2 + x2.T - 2. * xy, 1e-12, None))
|