92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class MaximumMeanDiscrepancy(nn.Module):
|
|
|
|
def __init__(self, kernel_type="rbf", normalize=False):
|
|
super().__init__()
|
|
self.kernel_type = kernel_type
|
|
self.normalize = normalize
|
|
|
|
def forward(self, x, y):
|
|
# x, y: two batches of data with shape (batch, dim)
|
|
# MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')
|
|
if self.normalize:
|
|
x = F.normalize(x, dim=1)
|
|
y = F.normalize(y, dim=1)
|
|
if self.kernel_type == "linear":
|
|
return self.linear_mmd(x, y)
|
|
elif self.kernel_type == "poly":
|
|
return self.poly_mmd(x, y)
|
|
elif self.kernel_type == "rbf":
|
|
return self.rbf_mmd(x, y)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def linear_mmd(self, x, y):
|
|
# k(x, y) = x^T y
|
|
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
|
|
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
|
|
k_xy = torch.mm(x, y.t())
|
|
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
|
|
|
def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
|
|
# k(x, y) = (alpha * x^T y + c)^d
|
|
k_xx = self.remove_self_distance(torch.mm(x, x.t()))
|
|
k_xx = (alpha*k_xx + c).pow(d)
|
|
k_yy = self.remove_self_distance(torch.mm(y, y.t()))
|
|
k_yy = (alpha*k_yy + c).pow(d)
|
|
k_xy = torch.mm(x, y.t())
|
|
k_xy = (alpha*k_xy + c).pow(d)
|
|
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
|
|
|
def rbf_mmd(self, x, y):
|
|
# k_xx
|
|
d_xx = self.euclidean_squared_distance(x, x)
|
|
d_xx = self.remove_self_distance(d_xx)
|
|
k_xx = self.rbf_kernel_mixture(d_xx)
|
|
# k_yy
|
|
d_yy = self.euclidean_squared_distance(y, y)
|
|
d_yy = self.remove_self_distance(d_yy)
|
|
k_yy = self.rbf_kernel_mixture(d_yy)
|
|
# k_xy
|
|
d_xy = self.euclidean_squared_distance(x, y)
|
|
k_xy = self.rbf_kernel_mixture(d_xy)
|
|
return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
|
|
|
|
@staticmethod
|
|
def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
|
|
K = 0
|
|
for sigma in sigmas:
|
|
gamma = 1.0 / (2.0 * sigma**2)
|
|
K += torch.exp(-gamma * exponent)
|
|
return K
|
|
|
|
@staticmethod
|
|
def remove_self_distance(distmat):
|
|
tmp_list = []
|
|
for i, row in enumerate(distmat):
|
|
row1 = torch.cat([row[:i], row[i + 1:]])
|
|
tmp_list.append(row1)
|
|
return torch.stack(tmp_list)
|
|
|
|
@staticmethod
|
|
def euclidean_squared_distance(x, y):
|
|
m, n = x.size(0), y.size(0)
|
|
distmat = (
|
|
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +
|
|
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
|
)
|
|
# distmat.addmm_(1, -2, x, y.t())
|
|
distmat.addmm_(x, y.t(), beta=1, alpha=-2)
|
|
return distmat
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mmd = MaximumMeanDiscrepancy(kernel_type="rbf")
|
|
input1, input2 = torch.rand(3, 100), torch.rand(3, 100)
|
|
d = mmd(input1, input2)
|
|
print(d.item())
|