release code
This commit is contained in:
1
Dassl.ProGrad.pytorch/configs/README.md
Normal file
1
Dassl.ProGrad.pytorch/configs/README.md
Normal file
@@ -0,0 +1 @@
|
||||
The `datasets/` folder contains dataset-specific config files which define the standard protocols (e.g., image size, data augmentation, network architecture) used by most papers. The `trainers/` folder contains method-specific config files which define optimization algorithms (e.g., optimizer, epoch) and hyperparameter settings.
|
||||
7
Dassl.ProGrad.pytorch/configs/datasets/da/cifar_stl.yaml
Normal file
7
Dassl.ProGrad.pytorch/configs/datasets/da/cifar_stl.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
INPUT:
|
||||
SIZE: (32, 32)
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
|
||||
DATASET:
|
||||
NAME: "CIFARSTL"
|
||||
12
Dassl.ProGrad.pytorch/configs/datasets/da/digit5.yaml
Normal file
12
Dassl.ProGrad.pytorch/configs/datasets/da/digit5.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
INPUT:
|
||||
SIZE: (32, 32)
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
TRANSFORMS: ["normalize"]
|
||||
|
||||
DATASET:
|
||||
NAME: "Digit5"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "cnn_digit5_m3sda"
|
||||
10
Dassl.ProGrad.pytorch/configs/datasets/da/domainnet.yaml
Normal file
10
Dassl.ProGrad.pytorch/configs/datasets/da/domainnet.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||
|
||||
DATASET:
|
||||
NAME: "DomainNet"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "resnet101"
|
||||
@@ -0,0 +1,10 @@
|
||||
INPUT:
|
||||
SIZE: (96, 96)
|
||||
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||
|
||||
DATASET:
|
||||
NAME: "miniDomainNet"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "resnet18"
|
||||
14
Dassl.ProGrad.pytorch/configs/datasets/da/office31.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/da/office31.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||
|
||||
DATASET:
|
||||
NAME: "Office31"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "resnet50"
|
||||
HEAD:
|
||||
NAME: "mlp"
|
||||
HIDDEN_LAYERS: [256]
|
||||
DROPOUT: 0.
|
||||
@@ -0,0 +1,5 @@
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
|
||||
DATASET:
|
||||
NAME: "OfficeHome"
|
||||
13
Dassl.ProGrad.pytorch/configs/datasets/da/visda17.yaml
Normal file
13
Dassl.ProGrad.pytorch/configs/datasets/da/visda17.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
TRANSFORMS: ["random_flip", "center_crop", "normalize"]
|
||||
|
||||
DATASET:
|
||||
NAME: "VisDA17"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "resnet101"
|
||||
|
||||
TEST:
|
||||
PER_CLASS_RESULT: True
|
||||
14
Dassl.ProGrad.pytorch/configs/datasets/dg/cifar100_c.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/dg/cifar100_c.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
INPUT:
|
||||
SIZE: (32, 32)
|
||||
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
|
||||
DATASET:
|
||||
NAME: "CIFAR100C"
|
||||
CIFAR_C_TYPE: "fog"
|
||||
CIFAR_C_LEVEL: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "wide_resnet_16_4"
|
||||
14
Dassl.ProGrad.pytorch/configs/datasets/dg/cifar10_c.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/dg/cifar10_c.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
INPUT:
|
||||
SIZE: (32, 32)
|
||||
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
|
||||
DATASET:
|
||||
NAME: "CIFAR10C"
|
||||
CIFAR_C_TYPE: "fog"
|
||||
CIFAR_C_LEVEL: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "wide_resnet_16_4"
|
||||
12
Dassl.ProGrad.pytorch/configs/datasets/dg/digit_single.yaml
Normal file
12
Dassl.ProGrad.pytorch/configs/datasets/dg/digit_single.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
INPUT:
|
||||
SIZE: (32, 32)
|
||||
TRANSFORMS: ["normalize"]
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
|
||||
DATASET:
|
||||
NAME: "DigitSingle"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "cnn_digitsingle"
|
||||
12
Dassl.ProGrad.pytorch/configs/datasets/dg/digits_dg.yaml
Normal file
12
Dassl.ProGrad.pytorch/configs/datasets/dg/digits_dg.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
INPUT:
|
||||
SIZE: (32, 32)
|
||||
TRANSFORMS: ["normalize"]
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
|
||||
DATASET:
|
||||
NAME: "DigitsDG"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "cnn_digitsdg"
|
||||
@@ -0,0 +1,11 @@
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||
|
||||
DATASET:
|
||||
NAME: "OfficeHomeDG"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "resnet18"
|
||||
PRETRAINED: True
|
||||
11
Dassl.ProGrad.pytorch/configs/datasets/dg/pacs.yaml
Normal file
11
Dassl.ProGrad.pytorch/configs/datasets/dg/pacs.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||
|
||||
DATASET:
|
||||
NAME: "PACS"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "resnet18"
|
||||
PRETRAINED: True
|
||||
11
Dassl.ProGrad.pytorch/configs/datasets/dg/vlcs.yaml
Normal file
11
Dassl.ProGrad.pytorch/configs/datasets/dg/vlcs.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||
|
||||
DATASET:
|
||||
NAME: "VLCS"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "resnet18"
|
||||
PRETRAINED: True
|
||||
14
Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar10.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar10.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
INPUT:
|
||||
SIZE: (32, 32)
|
||||
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
|
||||
DATASET:
|
||||
NAME: "CIFAR10"
|
||||
NUM_LABELED: 4000
|
||||
VAL_PERCENT: 0.
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "wide_resnet_28_2"
|
||||
15
Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar100.yaml
Normal file
15
Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar100.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
INPUT:
|
||||
SIZE: (32, 32)
|
||||
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
CROP_PADDING: 4
|
||||
|
||||
DATASET:
|
||||
NAME: "CIFAR100"
|
||||
NUM_LABELED: 10000
|
||||
VAL_PERCENT: 0.
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "wide_resnet_28_2"
|
||||
14
Dassl.ProGrad.pytorch/configs/datasets/ssl/stl10.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/ssl/stl10.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
INPUT:
|
||||
SIZE: (96, 96)
|
||||
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
CROP_PADDING: 4
|
||||
|
||||
DATASET:
|
||||
NAME: "STL10"
|
||||
STL10_FOLD: 0
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "wide_resnet_28_2"
|
||||
15
Dassl.ProGrad.pytorch/configs/datasets/ssl/svhn.yaml
Normal file
15
Dassl.ProGrad.pytorch/configs/datasets/ssl/svhn.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
INPUT:
|
||||
SIZE: (32, 32)
|
||||
TRANSFORMS: ["random_crop", "normalize"]
|
||||
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||
CROP_PADDING: 4
|
||||
|
||||
DATASET:
|
||||
NAME: "SVHN"
|
||||
NUM_LABELED: 1000
|
||||
VAL_PERCENT: 0.
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "wide_resnet_28_2"
|
||||
20
Dassl.ProGrad.pytorch/configs/trainers/da/dael/digit5.yaml
Normal file
20
Dassl.ProGrad.pytorch/configs/trainers/da/dael/digit5.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
SAMPLER: "RandomDomainSampler"
|
||||
BATCH_SIZE: 256
|
||||
TRAIN_U:
|
||||
SAME_AS_X: False
|
||||
BATCH_SIZE: 64
|
||||
TEST:
|
||||
BATCH_SIZE: 256
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.05
|
||||
STEPSIZE: [30]
|
||||
MAX_EPOCH: 30
|
||||
LR_SCHEDULER: "cosine"
|
||||
|
||||
TRAINER:
|
||||
DAEL:
|
||||
STRONG_TRANSFORMS: ["randaugment2", "normalize"]
|
||||
@@ -0,0 +1,20 @@
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 4
|
||||
TRAIN_X:
|
||||
SAMPLER: "RandomDomainSampler"
|
||||
BATCH_SIZE: 30
|
||||
TRAIN_U:
|
||||
SAME_AS_X: False
|
||||
BATCH_SIZE: 6
|
||||
TEST:
|
||||
BATCH_SIZE: 30
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 40
|
||||
LR_SCHEDULER: "cosine"
|
||||
|
||||
TRAINER:
|
||||
DAEL:
|
||||
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
|
||||
@@ -0,0 +1,20 @@
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 8
|
||||
TRAIN_X:
|
||||
SAMPLER: "RandomDomainSampler"
|
||||
BATCH_SIZE: 192
|
||||
TRAIN_U:
|
||||
SAME_AS_X: False
|
||||
BATCH_SIZE: 64
|
||||
TEST:
|
||||
BATCH_SIZE: 200
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.005
|
||||
MAX_EPOCH: 60
|
||||
LR_SCHEDULER: "cosine"
|
||||
|
||||
TRAINER:
|
||||
DAEL:
|
||||
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
|
||||
16
Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/digit5.yaml
Normal file
16
Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/digit5.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
SAMPLER: "RandomDomainSampler"
|
||||
BATCH_SIZE: 256
|
||||
TRAIN_U:
|
||||
SAME_AS_X: False
|
||||
BATCH_SIZE: 64
|
||||
TEST:
|
||||
BATCH_SIZE: 256
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.05
|
||||
STEPSIZE: [30]
|
||||
MAX_EPOCH: 30
|
||||
LR_SCHEDULER: "cosine"
|
||||
@@ -0,0 +1,16 @@
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 4
|
||||
TRAIN_X:
|
||||
SAMPLER: "RandomDomainSampler"
|
||||
BATCH_SIZE: 30
|
||||
TRAIN_U:
|
||||
SAME_AS_X: False
|
||||
BATCH_SIZE: 6
|
||||
TEST:
|
||||
BATCH_SIZE: 30
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 40
|
||||
LR_SCHEDULER: "cosine"
|
||||
@@ -0,0 +1,16 @@
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 8
|
||||
TRAIN_X:
|
||||
SAMPLER: "RandomDomainSampler"
|
||||
BATCH_SIZE: 192
|
||||
TRAIN_U:
|
||||
SAME_AS_X: False
|
||||
BATCH_SIZE: 64
|
||||
TEST:
|
||||
BATCH_SIZE: 200
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.005
|
||||
MAX_EPOCH: 60
|
||||
LR_SCHEDULER: "cosine"
|
||||
@@ -0,0 +1,12 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 256
|
||||
TEST:
|
||||
BATCH_SIZE: 256
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.05
|
||||
STEPSIZE: [30]
|
||||
MAX_EPOCH: 30
|
||||
LR_SCHEDULER: "cosine"
|
||||
@@ -0,0 +1,12 @@
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 8
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 128
|
||||
TEST:
|
||||
BATCH_SIZE: 128
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.005
|
||||
MAX_EPOCH: 60
|
||||
LR_SCHEDULER: "cosine"
|
||||
@@ -0,0 +1,11 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 32
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
STEPSIZE: [20]
|
||||
MAX_EPOCH: 20
|
||||
@@ -0,0 +1,15 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 32
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0001
|
||||
STEPSIZE: [2]
|
||||
MAX_EPOCH: 2
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 50
|
||||
COUNT_ITER: "train_u"
|
||||
@@ -0,0 +1,16 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
SAMPLER: "RandomDomainSampler"
|
||||
BATCH_SIZE: 120
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.05
|
||||
STEPSIZE: [20]
|
||||
MAX_EPOCH: 50
|
||||
|
||||
TRAINER:
|
||||
DAEL:
|
||||
STRONG_TRANSFORMS: ["randaugment2", "normalize"]
|
||||
@@ -0,0 +1,16 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
SAMPLER: "RandomDomainSampler"
|
||||
BATCH_SIZE: 30
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 40
|
||||
LR_SCHEDULER: "cosine"
|
||||
|
||||
TRAINER:
|
||||
DAEL:
|
||||
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
|
||||
16
Dassl.ProGrad.pytorch/configs/trainers/dg/dael/pacs.yaml
Normal file
16
Dassl.ProGrad.pytorch/configs/trainers/dg/dael/pacs.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
SAMPLER: "RandomDomainSampler"
|
||||
BATCH_SIZE: 30
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 40
|
||||
LR_SCHEDULER: "cosine"
|
||||
|
||||
TRAINER:
|
||||
DAEL:
|
||||
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
|
||||
@@ -0,0 +1,20 @@
|
||||
INPUT:
|
||||
PIXEL_MEAN: [0., 0., 0.]
|
||||
PIXEL_STD: [1., 1., 1.]
|
||||
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 128
|
||||
TEST:
|
||||
BATCH_SIZE: 128
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.05
|
||||
STEPSIZE: [20]
|
||||
MAX_EPOCH: 50
|
||||
|
||||
TRAINER:
|
||||
DDAIG:
|
||||
G_ARCH: "fcn_3x32_gctx"
|
||||
LMDA: 0.3
|
||||
@@ -0,0 +1,21 @@
|
||||
INPUT:
|
||||
PIXEL_MEAN: [0., 0., 0.]
|
||||
PIXEL_STD: [1., 1., 1.]
|
||||
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 16
|
||||
TEST:
|
||||
BATCH_SIZE: 16
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0005
|
||||
STEPSIZE: [20]
|
||||
MAX_EPOCH: 25
|
||||
|
||||
TRAINER:
|
||||
DDAIG:
|
||||
G_ARCH: "fcn_3x64_gctx"
|
||||
WARMUP: 3
|
||||
LMDA: 0.3
|
||||
21
Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/pacs.yaml
Normal file
21
Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/pacs.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
INPUT:
|
||||
PIXEL_MEAN: [0., 0., 0.]
|
||||
PIXEL_STD: [1., 1., 1.]
|
||||
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 16
|
||||
TEST:
|
||||
BATCH_SIZE: 16
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0005
|
||||
STEPSIZE: [20]
|
||||
MAX_EPOCH: 25
|
||||
|
||||
TRAINER:
|
||||
DDAIG:
|
||||
G_ARCH: "fcn_3x64_gctx"
|
||||
WARMUP: 3
|
||||
LMDA: 0.3
|
||||
@@ -0,0 +1,15 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 128
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.05
|
||||
STEPSIZE: [20]
|
||||
MAX_EPOCH: 50
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
@@ -0,0 +1,12 @@
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 8
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 128
|
||||
TEST:
|
||||
BATCH_SIZE: 128
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.005
|
||||
MAX_EPOCH: 60
|
||||
LR_SCHEDULER: "cosine"
|
||||
@@ -0,0 +1,12 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 64
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.001
|
||||
MAX_EPOCH: 50
|
||||
LR_SCHEDULER: "cosine"
|
||||
12
Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/pacs.yaml
Normal file
12
Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/pacs.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 64
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.001
|
||||
MAX_EPOCH: 50
|
||||
LR_SCHEDULER: "cosine"
|
||||
@@ -0,0 +1,23 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 64
|
||||
TRAIN_U:
|
||||
SAME_AS_X: False
|
||||
BATCH_SIZE: 448
|
||||
TEST:
|
||||
BATCH_SIZE: 500
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.05
|
||||
STEPSIZE: [4000]
|
||||
MAX_EPOCH: 4000
|
||||
LR_SCHEDULER: "cosine"
|
||||
|
||||
TRAIN:
|
||||
COUNT_ITER: "train_u"
|
||||
PRINT_FREQ: 10
|
||||
|
||||
TRAINER:
|
||||
FIXMATCH:
|
||||
STRONG_TRANSFORMS: ["random_flip", "randaugment_fixmatch", "normalize", "cutout"]
|
||||
Reference in New Issue
Block a user