release code
This commit is contained in:
23
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
Normal file
23
Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
|
||||
def mixup(x1, x2, y1, y2, beta, preserve_order=False):
|
||||
"""Mixup.
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): data with shape of (b, c, h, w).
|
||||
x2 (torch.Tensor): data with shape of (b, c, h, w).
|
||||
y1 (torch.Tensor): label with shape of (b, n).
|
||||
y2 (torch.Tensor): label with shape of (b, n).
|
||||
beta (float): hyper-parameter for Beta sampling.
|
||||
preserve_order (bool): apply lmda=max(lmda, 1-lmda).
|
||||
Default is False.
|
||||
"""
|
||||
lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1])
|
||||
if preserve_order:
|
||||
lmda = torch.max(lmda, 1 - lmda)
|
||||
lmda = lmda.to(x1.device)
|
||||
xmix = x1*lmda + x2 * (1-lmda)
|
||||
lmda = lmda[:, :, 0, 0]
|
||||
ymix = y1*lmda + y2 * (1-lmda)
|
||||
return xmix, ymix
|
||||
Reference in New Issue
Block a user