import torch.nn as nn from torch.autograd import Function class _ReverseGrad(Function): @staticmethod def forward(ctx, input, grad_scaling): ctx.grad_scaling = grad_scaling return input.view_as(input) @staticmethod def backward(ctx, grad_output): grad_scaling = ctx.grad_scaling return -grad_scaling * grad_output, None reverse_grad = _ReverseGrad.apply class ReverseGrad(nn.Module): """Gradient reversal layer. It acts as an identity layer in the forward, but reverses the sign of the gradient in the backward. """ def forward(self, x, grad_scaling=1.0): assert (grad_scaling >= 0), "grad_scaling must be non-negative, " "but got {}".format( grad_scaling ) return reverse_grad(x, grad_scaling)