release code
This commit is contained in:
5
Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py
Normal file
5
Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .entmin import EntMin
|
||||
from .fixmatch import FixMatch
|
||||
from .mixmatch import MixMatch
|
||||
from .mean_teacher import MeanTeacher
|
||||
from .sup_baseline import SupBaseline
|
||||
41
Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
Normal file
41
Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class EntMin(TrainerXU):
|
||||
"""Entropy Minimization.
|
||||
|
||||
http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.lmda = cfg.TRAINER.ENTMIN.LMDA
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
|
||||
output_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(output_x, label_x)
|
||||
|
||||
output_u = F.softmax(self.model(input_u), 1)
|
||||
loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()
|
||||
|
||||
loss = loss_x + loss_u * self.lmda
|
||||
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
112
Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
Normal file
112
Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.data import DataManager
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.data.transforms import build_transform
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class FixMatch(TrainerXU):
|
||||
"""FixMatch: Simplifying Semi-Supervised Learning with
|
||||
Consistency and Confidence.
|
||||
|
||||
https://arxiv.org/abs/2001.07685.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U
|
||||
self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0
|
||||
|
||||
def build_data_loader(self):
|
||||
cfg = self.cfg
|
||||
tfm_train = build_transform(cfg, is_train=True)
|
||||
custom_tfm_train = [tfm_train]
|
||||
choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS
|
||||
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||
custom_tfm_train += [tfm_train_strong]
|
||||
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||
self.train_loader_x = self.dm.train_loader_x
|
||||
self.train_loader_u = self.dm.train_loader_u
|
||||
self.val_loader = self.dm.val_loader
|
||||
self.test_loader = self.dm.test_loader
|
||||
self.num_classes = self.dm.num_classes
|
||||
|
||||
def assess_y_pred_quality(self, y_pred, y_true, mask):
|
||||
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
|
||||
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
|
||||
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
|
||||
keep_rate = mask.sum() / mask.numel()
|
||||
output = {
|
||||
"acc_thre": acc_thre,
|
||||
"acc_raw": acc_raw,
|
||||
"keep_rate": keep_rate
|
||||
}
|
||||
return output
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
parsed_data = self.parse_batch_train(batch_x, batch_u)
|
||||
input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data
|
||||
input_u = torch.cat([input_x, input_u], 0)
|
||||
input_u2 = torch.cat([input_x2, input_u2], 0)
|
||||
n_x = input_x.size(0)
|
||||
|
||||
# Generate pseudo labels
|
||||
with torch.no_grad():
|
||||
output_u = F.softmax(self.model(input_u), 1)
|
||||
max_prob, label_u_pred = output_u.max(1)
|
||||
mask_u = (max_prob >= self.conf_thre).float()
|
||||
|
||||
# Evaluate pseudo labels' accuracy
|
||||
y_u_pred_stats = self.assess_y_pred_quality(
|
||||
label_u_pred[n_x:], label_u, mask_u[n_x:]
|
||||
)
|
||||
|
||||
# Supervised loss
|
||||
output_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(output_x, label_x)
|
||||
|
||||
# Unsupervised loss
|
||||
output_u = self.model(input_u2)
|
||||
loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none")
|
||||
loss_u = (loss_u * mask_u).mean()
|
||||
|
||||
loss = loss_x + loss_u * self.weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
"y_u_pred_acc_raw": y_u_pred_stats["acc_raw"],
|
||||
"y_u_pred_acc_thre": y_u_pred_stats["acc_thre"],
|
||||
"y_u_pred_keep": y_u_pred_stats["keep_rate"],
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"]
|
||||
input_x2 = batch_x["img2"]
|
||||
label_x = batch_x["label"]
|
||||
input_u = batch_u["img"]
|
||||
input_u2 = batch_u["img2"]
|
||||
# label_u is used only for evaluating pseudo labels' accuracy
|
||||
label_u = batch_u["label"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
input_x2 = input_x2.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = input_u.to(self.device)
|
||||
input_u2 = input_u2.to(self.device)
|
||||
label_u = label_u.to(self.device)
|
||||
|
||||
return input_x, input_x2, label_x, input_u, input_u2, label_u
|
||||
54
Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py
Normal file
54
Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import copy
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MeanTeacher(TrainerXU):
|
||||
"""Mean teacher.
|
||||
|
||||
https://arxiv.org/abs/1703.01780.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U
|
||||
self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA
|
||||
self.rampup = cfg.TRAINER.MEANTEA.RAMPUP
|
||||
|
||||
self.teacher = copy.deepcopy(self.model)
|
||||
self.teacher.train()
|
||||
for param in self.teacher.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
|
||||
logit_x = self.model(input_x)
|
||||
loss_x = F.cross_entropy(logit_x, label_x)
|
||||
|
||||
target_u = F.softmax(self.teacher(input_u), 1)
|
||||
prob_u = F.softmax(self.model(input_u), 1)
|
||||
loss_u = ((prob_u - target_u)**2).sum(1).mean()
|
||||
|
||||
weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)
|
||||
loss = loss_x + loss_u*weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
|
||||
ema_model_update(self.model, self.teacher, ema_alpha)
|
||||
|
||||
loss_summary = {
|
||||
"loss_x": loss_x.item(),
|
||||
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||
"loss_u": loss_u.item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
98
Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py
Normal file
98
Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.modeling.ops import mixup
|
||||
from dassl.modeling.ops.utils import (
|
||||
sharpen_prob, create_onehot, linear_rampup, shuffle_index
|
||||
)
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class MixMatch(TrainerXU):
|
||||
"""MixMatch: A Holistic Approach to Semi-Supervised Learning.
|
||||
|
||||
https://arxiv.org/abs/1905.02249.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U
|
||||
self.temp = cfg.TRAINER.MIXMATCH.TEMP
|
||||
self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA
|
||||
self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP
|
||||
|
||||
def check_cfg(self, cfg):
|
||||
assert cfg.DATALOADER.K_TRANSFORMS > 1
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||
num_x = input_x.shape[0]
|
||||
|
||||
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||
weight_u = self.weight_u * linear_rampup(global_step, self.rampup)
|
||||
|
||||
# Generate pseudo-label for unlabeled data
|
||||
with torch.no_grad():
|
||||
output_u = 0
|
||||
for input_ui in input_u:
|
||||
output_ui = F.softmax(self.model(input_ui), 1)
|
||||
output_u += output_ui
|
||||
output_u /= len(input_u)
|
||||
label_u = sharpen_prob(output_u, self.temp)
|
||||
label_u = [label_u] * len(input_u)
|
||||
label_u = torch.cat(label_u, 0)
|
||||
input_u = torch.cat(input_u, 0)
|
||||
|
||||
# Combine and shuffle labeled and unlabeled data
|
||||
input_xu = torch.cat([input_x, input_u], 0)
|
||||
label_xu = torch.cat([label_x, label_u], 0)
|
||||
input_xu, label_xu = shuffle_index(input_xu, label_xu)
|
||||
|
||||
# Mixup
|
||||
input_x, label_x = mixup(
|
||||
input_x,
|
||||
input_xu[:num_x],
|
||||
label_x,
|
||||
label_xu[:num_x],
|
||||
self.beta,
|
||||
preserve_order=True,
|
||||
)
|
||||
|
||||
input_u, label_u = mixup(
|
||||
input_u,
|
||||
input_xu[num_x:],
|
||||
label_u,
|
||||
label_xu[num_x:],
|
||||
self.beta,
|
||||
preserve_order=True,
|
||||
)
|
||||
|
||||
# Compute losses
|
||||
output_x = F.softmax(self.model(input_x), 1)
|
||||
loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()
|
||||
|
||||
output_u = F.softmax(self.model(input_u), 1)
|
||||
loss_u = ((label_u - output_u)**2).mean()
|
||||
|
||||
loss = loss_x + loss_u*weight_u
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input_x = batch_x["img"][0]
|
||||
label_x = batch_x["label"]
|
||||
label_x = create_onehot(label_x, self.num_classes)
|
||||
input_u = batch_u["img"]
|
||||
|
||||
input_x = input_x.to(self.device)
|
||||
label_x = label_x.to(self.device)
|
||||
input_u = [input_ui.to(self.device) for input_ui in input_u]
|
||||
|
||||
return input_x, label_x, input_u
|
||||
32
Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py
Normal file
32
Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||
from dassl.metrics import compute_accuracy
|
||||
|
||||
|
||||
@TRAINER_REGISTRY.register()
|
||||
class SupBaseline(TrainerXU):
|
||||
"""Supervised Baseline."""
|
||||
|
||||
def forward_backward(self, batch_x, batch_u):
|
||||
input, label = self.parse_batch_train(batch_x, batch_u)
|
||||
output = self.model(input)
|
||||
loss = F.cross_entropy(output, label)
|
||||
self.model_backward_and_update(loss)
|
||||
|
||||
loss_summary = {
|
||||
"loss": loss.item(),
|
||||
"acc": compute_accuracy(output, label)[0].item(),
|
||||
}
|
||||
|
||||
if (self.batch_idx + 1) == self.num_batches:
|
||||
self.update_lr()
|
||||
|
||||
return loss_summary
|
||||
|
||||
def parse_batch_train(self, batch_x, batch_u):
|
||||
input = batch_x["img"]
|
||||
label = batch_x["label"]
|
||||
input = input.to(self.device)
|
||||
label = label.to(self.device)
|
||||
return input, label
|
||||
Reference in New Issue
Block a user