import torch from dassl.utils import check_isfile from dassl.engine import TRAINER_REGISTRY, TrainerXU @TRAINER_REGISTRY.register() class AdaBN(TrainerXU): """Adaptive Batch Normalization. https://arxiv.org/abs/1603.04779. """ def __init__(self, cfg): super().__init__(cfg) self.done_reset_bn_stats = False def check_cfg(self, cfg): assert check_isfile( cfg.MODEL.INIT_WEIGHTS ), "The weights of source model must be provided" def before_epoch(self): if not self.done_reset_bn_stats: for m in self.model.modules(): classname = m.__class__.__name__ if classname.find("BatchNorm") != -1: m.reset_running_stats() self.done_reset_bn_stats = True def forward_backward(self, batch_x, batch_u): input_u = batch_u["img"].to(self.device) with torch.no_grad(): self.model(input_u) return None