release code
This commit is contained in:
38
Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
Normal file
38
Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
|
||||
from dassl.utils import check_isfile
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class AdaBN(TrainerXU):
|
||||
"""Adaptive Batch Normalization.
|
||||
|
||||
https://arxiv.org/abs/1603.04779.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.done_reset_bn_stats = False
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert check_isfile(
|
||||
cfg.MODEL.INIT_WEIGHTS
|
||||
), "The weights of source model must be provided"
|
||||
|
||||
def before_epoch(self):
|
||||
if not self.done_reset_bn_stats:
|
||||
for m in self.model.modules():
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("BatchNorm") != -1:
|
||||
m.reset_running_stats()
|
||||
|
||||
self.done_reset_bn_stats = True
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_u = batch_u["img"].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
self.model(input_u)
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user