148 lines
4.6 KiB
Python
148 lines
4.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class OptimalTransport(nn.Module):
|
|
|
|
@staticmethod
|
|
def distance(batch1, batch2, dist_metric="cosine"):
|
|
if dist_metric == "cosine":
|
|
batch1 = F.normalize(batch1, p=2, dim=1)
|
|
batch2 = F.normalize(batch2, p=2, dim=1)
|
|
dist_mat = 1 - torch.mm(batch1, batch2.t())
|
|
elif dist_metric == "euclidean":
|
|
m, n = batch1.size(0), batch2.size(0)
|
|
dist_mat = (
|
|
torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
|
|
torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
|
)
|
|
dist_mat.addmm_(
|
|
1, -2, batch1, batch2.t()
|
|
) # squared euclidean distance
|
|
elif dist_metric == "fast_euclidean":
|
|
batch1 = batch1.unsqueeze(-2)
|
|
batch2 = batch2.unsqueeze(-3)
|
|
dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
|
|
else:
|
|
raise ValueError(
|
|
"Unknown cost function: {}. Expected to "
|
|
"be one of [cosine | euclidean]".format(dist_metric)
|
|
)
|
|
return dist_mat
|
|
|
|
|
|
class SinkhornDivergence(OptimalTransport):
|
|
thre = 1e-3
|
|
|
|
def __init__(
|
|
self,
|
|
dist_metric="cosine",
|
|
eps=0.01,
|
|
max_iter=5,
|
|
bp_to_sinkhorn=False
|
|
):
|
|
super().__init__()
|
|
self.dist_metric = dist_metric
|
|
self.eps = eps
|
|
self.max_iter = max_iter
|
|
self.bp_to_sinkhorn = bp_to_sinkhorn
|
|
|
|
def forward(self, x, y):
|
|
# x, y: two batches of data with shape (batch, dim)
|
|
W_xy = self.transport_cost(x, y)
|
|
W_xx = self.transport_cost(x, x)
|
|
W_yy = self.transport_cost(y, y)
|
|
return 2*W_xy - W_xx - W_yy
|
|
|
|
def transport_cost(self, x, y, return_pi=False):
|
|
C = self.distance(x, y, dist_metric=self.dist_metric)
|
|
pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
|
|
if not self.bp_to_sinkhorn:
|
|
pi = pi.detach()
|
|
cost = torch.sum(pi * C)
|
|
if return_pi:
|
|
return cost, pi
|
|
return cost
|
|
|
|
@staticmethod
|
|
def sinkhorn_iterate(C, eps, max_iter, thre):
|
|
nx, ny = C.shape
|
|
mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
|
|
nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
|
|
u = torch.zeros_like(mu)
|
|
v = torch.zeros_like(nu)
|
|
|
|
def M(_C, _u, _v):
|
|
"""Modified cost for logarithmic updates.
|
|
Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
|
|
"""
|
|
return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps
|
|
|
|
real_iter = 0 # check if algorithm terminates before max_iter
|
|
# Sinkhorn iterations
|
|
for i in range(max_iter):
|
|
u0 = u
|
|
u = eps * (
|
|
torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
|
|
) + u
|
|
v = (
|
|
eps * (
|
|
torch.log(nu + 1e-8) -
|
|
torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
|
|
) + v
|
|
)
|
|
err = (u - u0).abs().sum()
|
|
real_iter += 1
|
|
if err.item() < thre:
|
|
break
|
|
# Transport plan pi = diag(a)*K*diag(b)
|
|
return torch.exp(M(C, u, v))
|
|
|
|
|
|
class MinibatchEnergyDistance(SinkhornDivergence):
|
|
|
|
def __init__(
|
|
self,
|
|
dist_metric="cosine",
|
|
eps=0.01,
|
|
max_iter=5,
|
|
bp_to_sinkhorn=False
|
|
):
|
|
super().__init__(
|
|
dist_metric=dist_metric,
|
|
eps=eps,
|
|
max_iter=max_iter,
|
|
bp_to_sinkhorn=bp_to_sinkhorn,
|
|
)
|
|
|
|
def forward(self, x, y):
|
|
x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
|
|
y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
|
|
cost = 0
|
|
cost += self.transport_cost(x1, y1)
|
|
cost += self.transport_cost(x1, y2)
|
|
cost += self.transport_cost(x2, y1)
|
|
cost += self.transport_cost(x2, y2)
|
|
cost -= 2 * self.transport_cost(x1, x2)
|
|
cost -= 2 * self.transport_cost(y1, y2)
|
|
return cost
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# example: https://dfdazac.github.io/sinkhorn.html
|
|
import numpy as np
|
|
|
|
n_points = 5
|
|
a = np.array([[i, 0] for i in range(n_points)])
|
|
b = np.array([[i, 1] for i in range(n_points)])
|
|
x = torch.tensor(a, dtype=torch.float)
|
|
y = torch.tensor(b, dtype=torch.float)
|
|
sinkhorn = SinkhornDivergence(
|
|
dist_metric="euclidean", eps=0.01, max_iter=5
|
|
)
|
|
dist, pi = sinkhorn.transport_cost(x, y, True)
|
|
import pdb
|
|
|
|
pdb.set_trace()
|