release code
This commit is contained in:
147
Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
Normal file
147
Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class OptimalTransport(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def distance(batch1, batch2, dist_metric="cosine"):
|
||||
if dist_metric == "cosine":
|
||||
batch1 = F.normalize(batch1, p=2, dim=1)
|
||||
batch2 = F.normalize(batch2, p=2, dim=1)
|
||||
dist_mat = 1 - torch.mm(batch1, batch2.t())
|
||||
elif dist_metric == "euclidean":
|
||||
m, n = batch1.size(0), batch2.size(0)
|
||||
dist_mat = (
|
||||
torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +
|
||||
torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
)
|
||||
dist_mat.addmm_(
|
||||
1, -2, batch1, batch2.t()
|
||||
) # squared euclidean distance
|
||||
elif dist_metric == "fast_euclidean":
|
||||
batch1 = batch1.unsqueeze(-2)
|
||||
batch2 = batch2.unsqueeze(-3)
|
||||
dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown cost function: {}. Expected to "
|
||||
"be one of [cosine | euclidean]".format(dist_metric)
|
||||
)
|
||||
return dist_mat
|
||||
|
||||
|
||||
class SinkhornDivergence(OptimalTransport):
|
||||
thre = 1e-3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist_metric="cosine",
|
||||
eps=0.01,
|
||||
max_iter=5,
|
||||
bp_to_sinkhorn=False
|
||||
):
|
||||
super().__init__()
|
||||
self.dist_metric = dist_metric
|
||||
self.eps = eps
|
||||
self.max_iter = max_iter
|
||||
self.bp_to_sinkhorn = bp_to_sinkhorn
|
||||
|
||||
def forward(self, x, y):
|
||||
# x, y: two batches of data with shape (batch, dim)
|
||||
W_xy = self.transport_cost(x, y)
|
||||
W_xx = self.transport_cost(x, x)
|
||||
W_yy = self.transport_cost(y, y)
|
||||
return 2*W_xy - W_xx - W_yy
|
||||
|
||||
def transport_cost(self, x, y, return_pi=False):
|
||||
C = self.distance(x, y, dist_metric=self.dist_metric)
|
||||
pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)
|
||||
if not self.bp_to_sinkhorn:
|
||||
pi = pi.detach()
|
||||
cost = torch.sum(pi * C)
|
||||
if return_pi:
|
||||
return cost, pi
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def sinkhorn_iterate(C, eps, max_iter, thre):
|
||||
nx, ny = C.shape
|
||||
mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)
|
||||
nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)
|
||||
u = torch.zeros_like(mu)
|
||||
v = torch.zeros_like(nu)
|
||||
|
||||
def M(_C, _u, _v):
|
||||
"""Modified cost for logarithmic updates.
|
||||
Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon
|
||||
"""
|
||||
return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps
|
||||
|
||||
real_iter = 0 # check if algorithm terminates before max_iter
|
||||
# Sinkhorn iterations
|
||||
for i in range(max_iter):
|
||||
u0 = u
|
||||
u = eps * (
|
||||
torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)
|
||||
) + u
|
||||
v = (
|
||||
eps * (
|
||||
torch.log(nu + 1e-8) -
|
||||
torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)
|
||||
) + v
|
||||
)
|
||||
err = (u - u0).abs().sum()
|
||||
real_iter += 1
|
||||
if err.item() < thre:
|
||||
break
|
||||
# Transport plan pi = diag(a)*K*diag(b)
|
||||
return torch.exp(M(C, u, v))
|
||||
|
||||
|
||||
class MinibatchEnergyDistance(SinkhornDivergence):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist_metric="cosine",
|
||||
eps=0.01,
|
||||
max_iter=5,
|
||||
bp_to_sinkhorn=False
|
||||
):
|
||||
super().__init__(
|
||||
dist_metric=dist_metric,
|
||||
eps=eps,
|
||||
max_iter=max_iter,
|
||||
bp_to_sinkhorn=bp_to_sinkhorn,
|
||||
)
|
||||
|
||||
def forward(self, x, y):
|
||||
x1, x2 = torch.split(x, x.size(0) // 2, dim=0)
|
||||
y1, y2 = torch.split(y, y.size(0) // 2, dim=0)
|
||||
cost = 0
|
||||
cost += self.transport_cost(x1, y1)
|
||||
cost += self.transport_cost(x1, y2)
|
||||
cost += self.transport_cost(x2, y1)
|
||||
cost += self.transport_cost(x2, y2)
|
||||
cost -= 2 * self.transport_cost(x1, x2)
|
||||
cost -= 2 * self.transport_cost(y1, y2)
|
||||
return cost
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# example: https://dfdazac.github.io/sinkhorn.html
|
||||
import numpy as np
|
||||
|
||||
n_points = 5
|
||||
a = np.array([[i, 0] for i in range(n_points)])
|
||||
b = np.array([[i, 1] for i in range(n_points)])
|
||||
x = torch.tensor(a, dtype=torch.float)
|
||||
y = torch.tensor(b, dtype=torch.float)
|
||||
sinkhorn = SinkhornDivergence(
|
||||
dist_metric="euclidean", eps=0.01, max_iter=5
|
||||
)
|
||||
dist, pi = sinkhorn.transport_cost(x, y, True)
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
Reference in New Issue
Block a user