release code
This commit is contained in:
34
Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
Normal file
34
Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class _ReverseGrad(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, grad_scaling):
|
||||
ctx.grad_scaling = grad_scaling
|
||||
return input.view_as(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_scaling = ctx.grad_scaling
|
||||
return -grad_scaling * grad_output, None
|
||||
|
||||
|
||||
reverse_grad = _ReverseGrad.apply
|
||||
|
||||
|
||||
class ReverseGrad(nn.Module):
|
||||
"""Gradient reversal layer.
|
||||
|
||||
It acts as an identity layer in the forward,
|
||||
but reverses the sign of the gradient in
|
||||
the backward.
|
||||
"""
|
||||
|
||||
def forward(self, x, grad_scaling=1.0):
|
||||
assert (grad_scaling >=
|
||||
0), "grad_scaling must be non-negative, " "but got {}".format(
|
||||
grad_scaling
|
||||
)
|
||||
return reverse_grad(x, grad_scaling)
|
||||
Reference in New Issue
Block a user