release code

This commit is contained in:
miunangel
2025-08-16 20:46:31 +08:00
commit 3dc26db3b9
277 changed files with 60106 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
from .entmin import EntMin
from .fixmatch import FixMatch
from .mixmatch import MixMatch
from .mean_teacher import MeanTeacher
from .sup_baseline import SupBaseline

View File

@@ -0,0 +1,41 @@
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class EntMin(TrainerXU):
"""Entropy Minimization.
http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.lmda = cfg.TRAINER.ENTMIN.LMDA
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
output_x = self.model(input_x)
loss_x = F.cross_entropy(output_x, label_x)
output_u = F.softmax(self.model(input_u), 1)
loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()
loss = loss_x + loss_u * self.lmda
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,112 @@
import torch
from torch.nn import functional as F
from dassl.data import DataManager
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.data.transforms import build_transform
@TRAINER_REGISTRY.register()
class FixMatch(TrainerXU):
"""FixMatch: Simplifying Semi-Supervised Learning with
Consistency and Confidence.
https://arxiv.org/abs/2001.07685.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U
self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE
def check_cfg(self, cfg):
assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0
def build_data_loader(self):
cfg = self.cfg
tfm_train = build_transform(cfg, is_train=True)
custom_tfm_train = [tfm_train]
choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
custom_tfm_train += [tfm_train_strong]
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
self.train_loader_x = self.dm.train_loader_x
self.train_loader_u = self.dm.train_loader_u
self.val_loader = self.dm.val_loader
self.test_loader = self.dm.test_loader
self.num_classes = self.dm.num_classes
def assess_y_pred_quality(self, y_pred, y_true, mask):
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
keep_rate = mask.sum() / mask.numel()
output = {
"acc_thre": acc_thre,
"acc_raw": acc_raw,
"keep_rate": keep_rate
}
return output
def forward_backward(self, batch_x, batch_u):
parsed_data = self.parse_batch_train(batch_x, batch_u)
input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data
input_u = torch.cat([input_x, input_u], 0)
input_u2 = torch.cat([input_x2, input_u2], 0)
n_x = input_x.size(0)
# Generate pseudo labels
with torch.no_grad():
output_u = F.softmax(self.model(input_u), 1)
max_prob, label_u_pred = output_u.max(1)
mask_u = (max_prob >= self.conf_thre).float()
# Evaluate pseudo labels' accuracy
y_u_pred_stats = self.assess_y_pred_quality(
label_u_pred[n_x:], label_u, mask_u[n_x:]
)
# Supervised loss
output_x = self.model(input_x)
loss_x = F.cross_entropy(output_x, label_x)
# Unsupervised loss
output_u = self.model(input_u2)
loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none")
loss_u = (loss_u * mask_u).mean()
loss = loss_x + loss_u * self.weight_u
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
"loss_u": loss_u.item(),
"y_u_pred_acc_raw": y_u_pred_stats["acc_raw"],
"y_u_pred_acc_thre": y_u_pred_stats["acc_thre"],
"y_u_pred_keep": y_u_pred_stats["keep_rate"],
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"]
input_x2 = batch_x["img2"]
label_x = batch_x["label"]
input_u = batch_u["img"]
input_u2 = batch_u["img2"]
# label_u is used only for evaluating pseudo labels' accuracy
label_u = batch_u["label"]
input_x = input_x.to(self.device)
input_x2 = input_x2.to(self.device)
label_x = label_x.to(self.device)
input_u = input_u.to(self.device)
input_u2 = input_u2.to(self.device)
label_u = label_u.to(self.device)
return input_x, input_x2, label_x, input_u, input_u2, label_u

View File

@@ -0,0 +1,54 @@
import copy
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
@TRAINER_REGISTRY.register()
class MeanTeacher(TrainerXU):
"""Mean teacher.
https://arxiv.org/abs/1703.01780.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U
self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA
self.rampup = cfg.TRAINER.MEANTEA.RAMPUP
self.teacher = copy.deepcopy(self.model)
self.teacher.train()
for param in self.teacher.parameters():
param.requires_grad_(False)
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
logit_x = self.model(input_x)
loss_x = F.cross_entropy(logit_x, label_x)
target_u = F.softmax(self.teacher(input_u), 1)
prob_u = F.softmax(self.model(input_u), 1)
loss_u = ((prob_u - target_u)**2).sum(1).mean()
weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
global_step = self.batch_idx + self.epoch * self.num_batches
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
ema_model_update(self.model, self.teacher, ema_alpha)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary

View File

@@ -0,0 +1,98 @@
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling.ops import mixup
from dassl.modeling.ops.utils import (
sharpen_prob, create_onehot, linear_rampup, shuffle_index
)
@TRAINER_REGISTRY.register()
class MixMatch(TrainerXU):
"""MixMatch: A Holistic Approach to Semi-Supervised Learning.
https://arxiv.org/abs/1905.02249.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U
self.temp = cfg.TRAINER.MIXMATCH.TEMP
self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA
self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP
def check_cfg(self, cfg):
assert cfg.DATALOADER.K_TRANSFORMS > 1
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
num_x = input_x.shape[0]
global_step = self.batch_idx + self.epoch * self.num_batches
weight_u = self.weight_u * linear_rampup(global_step, self.rampup)
# Generate pseudo-label for unlabeled data
with torch.no_grad():
output_u = 0
for input_ui in input_u:
output_ui = F.softmax(self.model(input_ui), 1)
output_u += output_ui
output_u /= len(input_u)
label_u = sharpen_prob(output_u, self.temp)
label_u = [label_u] * len(input_u)
label_u = torch.cat(label_u, 0)
input_u = torch.cat(input_u, 0)
# Combine and shuffle labeled and unlabeled data
input_xu = torch.cat([input_x, input_u], 0)
label_xu = torch.cat([label_x, label_u], 0)
input_xu, label_xu = shuffle_index(input_xu, label_xu)
# Mixup
input_x, label_x = mixup(
input_x,
input_xu[:num_x],
label_x,
label_xu[:num_x],
self.beta,
preserve_order=True,
)
input_u, label_u = mixup(
input_u,
input_xu[num_x:],
label_u,
label_xu[num_x:],
self.beta,
preserve_order=True,
)
# Compute losses
output_x = F.softmax(self.model(input_x), 1)
loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()
output_u = F.softmax(self.model(input_u), 1)
loss_u = ((label_u - output_u)**2).mean()
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
label_x = create_onehot(label_x, self.num_classes)
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = [input_ui.to(self.device) for input_ui in input_u]
return input_x, label_x, input_u

View File

@@ -0,0 +1,32 @@
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
@TRAINER_REGISTRY.register()
class SupBaseline(TrainerXU):
"""Supervised Baseline."""
def forward_backward(self, batch_x, batch_u):
input, label = self.parse_batch_train(batch_x, batch_u)
output = self.model(input)
loss = F.cross_entropy(output, label)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label)[0].item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input = batch_x["img"]
label = batch_x["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label