Files
2025-08-16 21:13:50 +08:00

35 lines
845 B
Python

import torch.nn as nn
from torch.autograd import Function
class _ReverseGrad(Function):
@staticmethod
def forward(ctx, input, grad_scaling):
ctx.grad_scaling = grad_scaling
return input.view_as(input)
@staticmethod
def backward(ctx, grad_output):
grad_scaling = ctx.grad_scaling
return -grad_scaling * grad_output, None
reverse_grad = _ReverseGrad.apply
class ReverseGrad(nn.Module):
"""Gradient reversal layer.
It acts as an identity layer in the forward,
but reverses the sign of the gradient in
the backward.
"""
def forward(self, x, grad_scaling=1.0):
assert (grad_scaling >=
0), "grad_scaling must be non-negative, " "but got {}".format(
grad_scaling
)
return reverse_grad(x, grad_scaling)