35 lines
845 B
Python
35 lines
845 B
Python
import torch.nn as nn
|
|
from torch.autograd import Function
|
|
|
|
|
|
class _ReverseGrad(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, grad_scaling):
|
|
ctx.grad_scaling = grad_scaling
|
|
return input.view_as(input)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
grad_scaling = ctx.grad_scaling
|
|
return -grad_scaling * grad_output, None
|
|
|
|
|
|
reverse_grad = _ReverseGrad.apply
|
|
|
|
|
|
class ReverseGrad(nn.Module):
|
|
"""Gradient reversal layer.
|
|
|
|
It acts as an identity layer in the forward,
|
|
but reverses the sign of the gradient in
|
|
the backward.
|
|
"""
|
|
|
|
def forward(self, x, grad_scaling=1.0):
|
|
assert (grad_scaling >=
|
|
0), "grad_scaling must be non-negative, " "but got {}".format(
|
|
grad_scaling
|
|
)
|
|
return reverse_grad(x, grad_scaling)
|