release code
This commit is contained in:
211
.gitignore
vendored
Normal file
211
.gitignore
vendored
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[codz]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
#uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
#poetry.toml
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||||
|
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||||
|
#pdm.lock
|
||||||
|
#pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# pixi
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||||
|
#pixi.lock
|
||||||
|
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||||
|
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||||
|
.pixi
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.envrc
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# Abstra
|
||||||
|
# Abstra is an AI-powered process automation framework.
|
||||||
|
# Ignore directories containing user credentials, local state, and settings.
|
||||||
|
# Learn more at https://abstra.io/docs
|
||||||
|
.abstra/
|
||||||
|
|
||||||
|
# Visual Studio Code
|
||||||
|
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||||
|
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||||
|
# you could uncomment the following to ignore the entire vscode folder
|
||||||
|
# .vscode/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
||||||
|
# Cursor
|
||||||
|
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
||||||
|
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
||||||
|
# refer to https://docs.cursor.com/context/ignore-files
|
||||||
|
.cursorignore
|
||||||
|
.cursorindexingignore
|
||||||
|
|
||||||
|
# Marimo
|
||||||
|
marimo/_static/
|
||||||
|
marimo/_lsp/
|
||||||
|
__marimo__/
|
||||||
|
|
||||||
|
|
||||||
|
# Custom
|
||||||
|
**/output*
|
||||||
22
Dassl.ProGrad.pytorch/.flake8
Normal file
22
Dassl.ProGrad.pytorch/.flake8
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
[flake8]
|
||||||
|
ignore =
|
||||||
|
# At least two spaces before inline comment
|
||||||
|
E261,
|
||||||
|
# Line lengths are recommended to be no greater than 79 characters
|
||||||
|
E501,
|
||||||
|
# Missing whitespace around arithmetic operator
|
||||||
|
E226,
|
||||||
|
# Blank line contains whitespace
|
||||||
|
W293,
|
||||||
|
# Do not use bare 'except'
|
||||||
|
E722,
|
||||||
|
# Line break after binary operator
|
||||||
|
W504,
|
||||||
|
# Too many leading '#' for block comment
|
||||||
|
E266,
|
||||||
|
# line break before binary operator
|
||||||
|
W503,
|
||||||
|
# continuation line over-indented for hanging indent
|
||||||
|
E126
|
||||||
|
max-line-length = 79
|
||||||
|
exclude = __init__.py, build
|
||||||
140
Dassl.ProGrad.pytorch/.gitignore
vendored
Normal file
140
Dassl.ProGrad.pytorch/.gitignore
vendored
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# OS X
|
||||||
|
.DS_Store
|
||||||
|
.Spotlight-V100
|
||||||
|
.Trashes
|
||||||
|
._*
|
||||||
|
|
||||||
|
# This project
|
||||||
|
output/
|
||||||
|
debug.sh
|
||||||
|
debug.py
|
||||||
10
Dassl.ProGrad.pytorch/.isort.cfg
Normal file
10
Dassl.ProGrad.pytorch/.isort.cfg
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
[isort]
|
||||||
|
line_length=79
|
||||||
|
multi_line_output=6
|
||||||
|
length_sort=true
|
||||||
|
known_standard_library=numpy,setuptools
|
||||||
|
known_myself=dassl
|
||||||
|
known_third_party=matplotlib,cv2,torch,torchvision,PIL,yacs,scipy,gdown
|
||||||
|
no_lines_before=STDLIB,THIRDPARTY
|
||||||
|
sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER
|
||||||
|
default_section=FIRSTPARTY
|
||||||
7
Dassl.ProGrad.pytorch/.style.yapf
Normal file
7
Dassl.ProGrad.pytorch/.style.yapf
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[style]
|
||||||
|
BASED_ON_STYLE = pep8
|
||||||
|
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
|
||||||
|
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
|
||||||
|
DEDENT_CLOSING_BRACKETS = true
|
||||||
|
SPACES_BEFORE_COMMENT = 2
|
||||||
|
ARITHMETIC_PRECEDENCE_INDICATION = true
|
||||||
313
Dassl.ProGrad.pytorch/DATASETS.md
Normal file
313
Dassl.ProGrad.pytorch/DATASETS.md
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
# How to Install Datasets
|
||||||
|
|
||||||
|
`$DATA` denotes the location where datasets are installed, e.g.
|
||||||
|
|
||||||
|
```
|
||||||
|
$DATA/
|
||||||
|
|–– office31/
|
||||||
|
|–– office_home/
|
||||||
|
|–– visda17/
|
||||||
|
```
|
||||||
|
|
||||||
|
[Domain Adaptation](#domain-adaptation)
|
||||||
|
- [Office-31](#office-31)
|
||||||
|
- [Office-Home](#office-home)
|
||||||
|
- [VisDA17](#visda17)
|
||||||
|
- [CIFAR10-STL10](#cifar10-stl10)
|
||||||
|
- [Digit-5](#digit-5)
|
||||||
|
- [DomainNet](#domainnet)
|
||||||
|
- [miniDomainNet](#miniDomainNet)
|
||||||
|
|
||||||
|
[Domain Generalization](#domain-generalization)
|
||||||
|
- [PACS](#pacs)
|
||||||
|
- [VLCS](#vlcs)
|
||||||
|
- [Office-Home-DG](#office-home-dg)
|
||||||
|
- [Digits-DG](#digits-dg)
|
||||||
|
- [Digit-Single](#digit-single)
|
||||||
|
- [CIFAR-10-C](#cifar-10-c)
|
||||||
|
- [CIFAR-100-C](#cifar-100-c)
|
||||||
|
|
||||||
|
[Semi-Supervised Learning](#semi-supervised-learning)
|
||||||
|
- [CIFAR10/100 and SVHN](#cifar10100-and-svhn)
|
||||||
|
- [STL10](#stl10)
|
||||||
|
|
||||||
|
## Domain Adaptation
|
||||||
|
|
||||||
|
### Office-31
|
||||||
|
|
||||||
|
Download link: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/#datasets_code.
|
||||||
|
|
||||||
|
File structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
office31/
|
||||||
|
|–– amazon/
|
||||||
|
| |–– back_pack/
|
||||||
|
| |–– bike/
|
||||||
|
| |–– ...
|
||||||
|
|–– dslr/
|
||||||
|
| |–– back_pack/
|
||||||
|
| |–– bike/
|
||||||
|
| |–– ...
|
||||||
|
|–– webcam/
|
||||||
|
| |–– back_pack/
|
||||||
|
| |–– bike/
|
||||||
|
| |–– ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that within each domain folder you need to move all class folders out of the `images/` folder and then delete the `images/` folder.
|
||||||
|
|
||||||
|
### Office-Home
|
||||||
|
|
||||||
|
Download link: http://hemanthdv.org/OfficeHome-Dataset/.
|
||||||
|
|
||||||
|
File structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
office_home/
|
||||||
|
|–– art/
|
||||||
|
|–– clipart/
|
||||||
|
|–– product/
|
||||||
|
|–– real_world/
|
||||||
|
```
|
||||||
|
|
||||||
|
### VisDA17
|
||||||
|
|
||||||
|
Download link: http://ai.bu.edu/visda-2017/.
|
||||||
|
|
||||||
|
The dataset can also be downloaded using our script at `datasets/da/visda17.sh`. Run the following command in your terminal under `Dassl.pytorch/datasets/da`,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sh visda17.sh $DATA
|
||||||
|
```
|
||||||
|
|
||||||
|
Once the download is finished, the file structure will look like
|
||||||
|
|
||||||
|
```
|
||||||
|
visda17/
|
||||||
|
|–– train/
|
||||||
|
|–– test/
|
||||||
|
|–– validation/
|
||||||
|
```
|
||||||
|
|
||||||
|
### CIFAR10-STL10
|
||||||
|
|
||||||
|
Run the following command in your terminal under `Dassl.pytorch/datasets/da`,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python cifar_stl.py $DATA/cifar_stl
|
||||||
|
```
|
||||||
|
|
||||||
|
This will create a folder named `cifar_stl` under `$DATA`. The file structure will look like
|
||||||
|
|
||||||
|
```
|
||||||
|
cifar_stl/
|
||||||
|
|–– cifar/
|
||||||
|
| |–– train/
|
||||||
|
| |–– test/
|
||||||
|
|–– stl/
|
||||||
|
| |–– train/
|
||||||
|
| |–– test/
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that only 9 classes shared by both datasets are kept.
|
||||||
|
|
||||||
|
### Digit-5
|
||||||
|
|
||||||
|
Create a folder `$DATA/digit5` and download to this folder the dataset from [here](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download). This should give you
|
||||||
|
|
||||||
|
```
|
||||||
|
digit5/
|
||||||
|
|–– Digit-Five/
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, run the following command in your terminal under `Dassl.pytorch/datasets/da`,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python digit5.py $DATA/digit5
|
||||||
|
```
|
||||||
|
|
||||||
|
This will extract the data and organize the file structure as
|
||||||
|
|
||||||
|
```
|
||||||
|
digit5/
|
||||||
|
|–– Digit-Five/
|
||||||
|
|–– mnist/
|
||||||
|
|–– mnist_m/
|
||||||
|
|–– usps/
|
||||||
|
|–– svhn/
|
||||||
|
|–– syn/
|
||||||
|
```
|
||||||
|
|
||||||
|
### DomainNet
|
||||||
|
|
||||||
|
Download link: http://ai.bu.edu/M3SDA/. (Please download the cleaned version of split files)
|
||||||
|
|
||||||
|
File structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
domainnet/
|
||||||
|
|–– clipart/
|
||||||
|
|–– infograph/
|
||||||
|
|–– painting/
|
||||||
|
|–– quickdraw/
|
||||||
|
|–– real/
|
||||||
|
|–– sketch/
|
||||||
|
|–– splits/
|
||||||
|
| |–– clipart_train.txt
|
||||||
|
| |–– clipart_test.txt
|
||||||
|
| |–– ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### miniDomainNet
|
||||||
|
|
||||||
|
You need to download the DomainNet dataset first. The miniDomainNet's split files can be downloaded at this [google drive](https://drive.google.com/open?id=15rrLDCrzyi6ZY-1vJar3u7plgLe4COL7). After the zip file is extracted, you should have the folder `$DATA/domainnet/splits_mini/`.
|
||||||
|
|
||||||
|
## Domain Generalization
|
||||||
|
|
||||||
|
### PACS
|
||||||
|
|
||||||
|
Download link: [google drive](https://drive.google.com/open?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE).
|
||||||
|
|
||||||
|
File structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
pacs/
|
||||||
|
|–– images/
|
||||||
|
|–– splits/
|
||||||
|
```
|
||||||
|
|
||||||
|
You do not necessarily have to manually download this dataset. Once you run ``tools/train.py``, the code will detect if the dataset exists or not and automatically download the dataset to ``$DATA`` if missing. This also applies to VLCS, Office-Home-DG, and Digits-DG.
|
||||||
|
|
||||||
|
### VLCS
|
||||||
|
|
||||||
|
Download link: [google drive](https://drive.google.com/file/d/1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd/view?usp=sharing) (credit to https://github.com/fmcarlucci/JigenDG#vlcs)
|
||||||
|
|
||||||
|
File structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
VLCS/
|
||||||
|
|–– CALTECH/
|
||||||
|
|–– LABELME/
|
||||||
|
|–– PASCAL/
|
||||||
|
|–– SUN/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Office-Home-DG
|
||||||
|
|
||||||
|
Download link: [google drive](https://drive.google.com/open?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa).
|
||||||
|
|
||||||
|
File structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
office_home_dg/
|
||||||
|
|–– art/
|
||||||
|
|–– clipart/
|
||||||
|
|–– product/
|
||||||
|
|–– real_world/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Digits-DG
|
||||||
|
|
||||||
|
Download link: [google driv](https://drive.google.com/open?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7).
|
||||||
|
|
||||||
|
File structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
digits_dg/
|
||||||
|
|–– mnist/
|
||||||
|
|–– mnist_m/
|
||||||
|
|–– svhn/
|
||||||
|
|–– syn/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Digit-Single
|
||||||
|
Follow the steps for [Digit-5](#digit-5) to organize the dataset.
|
||||||
|
|
||||||
|
### CIFAR-10-C
|
||||||
|
|
||||||
|
First download the CIFAR-10-C dataset from https://zenodo.org/record/2535967#.YFxHEWQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal
|
||||||
|
```bash
|
||||||
|
python cifar_c.py $DATA/CIFAR-10-C
|
||||||
|
```
|
||||||
|
where the first argument denotes the path to the (uncompressed) CIFAR-10-C dataset.
|
||||||
|
|
||||||
|
The script will extract images from the `.npy` files and save them to `cifar10_c/` created under $DATA. The file structure will look like
|
||||||
|
```
|
||||||
|
cifar10_c/
|
||||||
|
|–– brightness/
|
||||||
|
| |–– 1/ # 5 intensity levels in total
|
||||||
|
| |–– 2/
|
||||||
|
| |–– 3/
|
||||||
|
| |–– 4/
|
||||||
|
| |–– 5/
|
||||||
|
|–– ... # 19 corruption types in total
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that `cifar10_c/` only contains the test images. The training images are the normal CIFAR-10 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-10 dataset.
|
||||||
|
|
||||||
|
### CIFAR-100-C
|
||||||
|
|
||||||
|
First download the CIFAR-100-C dataset from https://zenodo.org/record/3555552#.YFxpQmQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal
|
||||||
|
```bash
|
||||||
|
python cifar_c.py $DATA/CIFAR-100-C
|
||||||
|
```
|
||||||
|
where the first argument denotes the path to the (uncompressed) CIFAR-100-C dataset.
|
||||||
|
|
||||||
|
The script will extract images from the `.npy` files and save them to `cifar100_c/` created under $DATA. The file structure will look like
|
||||||
|
```
|
||||||
|
cifar100_c/
|
||||||
|
|–– brightness/
|
||||||
|
| |–– 1/ # 5 intensity levels in total
|
||||||
|
| |–– 2/
|
||||||
|
| |–– 3/
|
||||||
|
| |–– 4/
|
||||||
|
| |–– 5/
|
||||||
|
|–– ... # 19 corruption types in total
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that `cifar100_c/` only contains the test images. The training images are the normal CIFAR-100 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-100 dataset.
|
||||||
|
|
||||||
|
## Semi-Supervised Learning
|
||||||
|
|
||||||
|
### CIFAR10/100 and SVHN
|
||||||
|
|
||||||
|
Run the following command in your terminal under `Dassl.pytorch/datasets/ssl`,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python cifar10_cifar100_svhn.py $DATA
|
||||||
|
```
|
||||||
|
|
||||||
|
This will create three folders under `$DATA`, i.e.
|
||||||
|
|
||||||
|
```
|
||||||
|
cifar10/
|
||||||
|
|–– train/
|
||||||
|
|–– test/
|
||||||
|
cifar100/
|
||||||
|
|–– train/
|
||||||
|
|–– test/
|
||||||
|
svhn/
|
||||||
|
|–– train/
|
||||||
|
|–– test/
|
||||||
|
```
|
||||||
|
|
||||||
|
### STL10
|
||||||
|
|
||||||
|
Run the following command in your terminal under `Dassl.pytorch/datasets/ssl`,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python stl10.py $DATA/stl10
|
||||||
|
```
|
||||||
|
|
||||||
|
This will create a folder named `stl10` under `$DATA` and extract the data into three folders, i.e. `train`, `test` and `unlabeled`. Then, download from http://ai.stanford.edu/~acoates/stl10/ the "Binary files" and extract it under `stl10`.
|
||||||
|
|
||||||
|
The file structure will look like
|
||||||
|
|
||||||
|
```
|
||||||
|
stl10/
|
||||||
|
|–– train/
|
||||||
|
|–– test/
|
||||||
|
|–– unlabeled/
|
||||||
|
|–– stl10_binary/
|
||||||
|
```
|
||||||
21
Dassl.ProGrad.pytorch/LICENSE
Normal file
21
Dassl.ProGrad.pytorch/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2020 Kaiyang
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
279
Dassl.ProGrad.pytorch/README.md
Normal file
279
Dassl.ProGrad.pytorch/README.md
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
# Dassl
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
Dassl is a [PyTorch](https://pytorch.org) toolbox initially developed for our project [Domain Adaptive Ensemble Learning (DAEL)](https://arxiv.org/abs/2003.07325) to support research in domain adaptation and generalization---since in DAEL we study how to unify these two problems in a single learning framework. Given that domain adaptation is closely related to semi-supervised learning---both study how to exploit unlabeled data---we also incorporate components that support research for the latter.
|
||||||
|
|
||||||
|
Why the name "Dassl"? Dassl combines the initials of domain adaptation (DA) and semi-supervised learning (SSL), which sounds natural and informative.
|
||||||
|
|
||||||
|
Dassl has a modular design and unified interfaces, allowing fast prototyping and experimentation of new DA/DG/SSL methods. With Dassl, a new method can be implemented with only a few lines of code. Don't believe? Take a look at the [engine](https://github.com/KaiyangZhou/Dassl.pytorch/tree/master/dassl/engine) folder, which contains the implementations of many existing methods (then you will come back and star this repo). :-)
|
||||||
|
|
||||||
|
Basically, Dassl is perfect for doing research in the following areas:
|
||||||
|
- Domain adaptation
|
||||||
|
- Domain generalization
|
||||||
|
- Semi-supervised learning
|
||||||
|
|
||||||
|
BUT, thanks to the neat design, Dassl can also be used as a codebase to develop any deep learning projects, like [this](https://github.com/KaiyangZhou/CoOp). :-)
|
||||||
|
|
||||||
|
A drawback of Dassl is that it doesn't (yet? hmm) support distributed multi-GPU training (Dassl uses `DataParallel` to wrap a model, which is less efficient than `DistributedDataParallel`).
|
||||||
|
|
||||||
|
We don't provide detailed documentations for Dassl, unlike another [project](https://kaiyangzhou.github.io/deep-person-reid/) of ours. This is because Dassl is developed for research purpose and as a researcher, we think it's important to be able to read source code and we highly encourage you to do so---definitely not because we are lazy. :-)
|
||||||
|
|
||||||
|
## What's new
|
||||||
|
- Mar 2022: A new domain generalization method [EFDM](https://arxiv.org/abs/2203.07740) developed by [Yabin Zhang (PolyU)](https://ybzh.github.io/) and to appear at CVPR'22 is added to this repo. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/pull/36) for more details.
|
||||||
|
- Feb 2022: In case you don't know, a class in the painting domain of DomainNet (the official splits) only has test images (no training images), which could affect performance. See section 4.a in our [paper](https://arxiv.org/abs/2003.07325) for more details.
|
||||||
|
- Oct 2021: `v0.5.0`: **Important changes** made to `transforms.py`. 1) `center_crop` becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio). 2) For training, `Resize(cfg.INPUT.SIZE)` is deactivated when `random_crop` or `random_resized_crop` is used. These changes won't make any difference to the training transforms used in existing config files, nor to the testing transforms unless the raw images are not squared (the only difference is that now the image aspect ratio is respected).
|
||||||
|
- Oct 2021: `v0.4.3`: Copy the attributes in `self.dm` (data manager) to `SimpleTrainer` and make `self.dm` optional, which means from now on, you can build data loaders from any source you like rather than being forced to use `DataManager`.
|
||||||
|
- Sep 2021: `v0.4.2`: An important update is to set `drop_last=is_train and len(data_source)>=batch_size` when constructing a data loader to avoid 0-length.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>More</summary>
|
||||||
|
|
||||||
|
- Aug 2021: `v0.4.0`: The most noteworthy update is adding the learning rate warmup scheduler. The implementation is detailed [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/optim/lr_scheduler.py#L10) and the config variables are specified [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L171).
|
||||||
|
- Jul 2021: `v0.3.4`: Adds a new function `generate_fewshot_dataset()` to the base dataset class, which allows for the generation of a few-shot learning setting. One can customize a few-shot dataset by specifying `_C.DATASET.NUM_SHOTS` and give it to `generate_fewshot_dataset()`.
|
||||||
|
- Jul 2021: `v0.3.2`: Adds `_C.INPUT.INTERPOLATION` (default: `bilinear`). Available interpolation modes are `bilinear`, `nearest`, and `bicubic`.
|
||||||
|
- Jul 2021 `v0.3.1`: Now you can use `*.register(force=True)` to replace previously registered modules.
|
||||||
|
- Jul 2021 `v0.3.0`: Allows to deploy the model with the best validation performance for final test (for the purpose of model selection). Specifically, a new config variable named `_C.TEST.FINAL_MODEL` is introduced, which takes either `"last_step"` (default) or `"best_val"`. When set to `"best_val"`, the model will be evaluated on the `val` set after each epoch and the one with the best validation performance will be saved and used for final test (see this [code](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/engine/trainer.py#L412)).
|
||||||
|
- Jul 2021 `v0.2.7`: Adds attribute `classnames` to the base dataset class. Now you can get a list of class names ordered by numeric labels by calling `trainer.dm.dataset.classnames`.
|
||||||
|
- Jun 2021 `v0.2.6`: Merges `MixStyle2` to `MixStyle`. A new variable `self.mix` is used to switch between random mixing and cross-domain mixing. Please see [this](https://github.com/KaiyangZhou/Dassl.pytorch/issues/23) for more details on the new features.
|
||||||
|
- Jun 2021 `v0.2.5`: Fixs a [bug](https://github.com/KaiyangZhou/Dassl.pytorch/commit/29881c7faee7405f80f5f674de4bbbf80d5dc77a) in the calculation of per-class recognition accuracy.
|
||||||
|
- Jun 2021 `v0.2.4`: Adds `extend_cfg(cfg)` to `train.py`. This function is particularly useful when you build your own methods on top of Dassl.pytorch and need to define some custom variables. Please see the repository [mixstyle-release](https://github.com/KaiyangZhou/mixstyle-release) or [ssdg-benchmark](https://github.com/KaiyangZhou/ssdg-benchmark) for examples.
|
||||||
|
- Jun 2021 New benchmarks for semi-supervised domain generalization at https://github.com/KaiyangZhou/ssdg-benchmark.
|
||||||
|
- Apr 2021 Do you know you can use `tools/parse_test_res.py` to read the log files and automatically calculate and print out the results including mean and standard deviation? Check the instructions in `tools/parse_test_res.py` for more details.
|
||||||
|
- Apr 2021 `v0.2.3`: A [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) layer can now be deactivated or activated by using `model.apply(deactivate_mixstyle)` or `model.apply(activate_mixstyle)` without modifying the source code. See [dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py) for the details.
|
||||||
|
- Apr 2021 `v0.2.2`: Adds `RandomClassSampler`, which samples from a certain number of classes a certain number of images to form a minibatch (the code is modified from [Torchreid](https://github.com/KaiyangZhou/deep-person-reid)).
|
||||||
|
- Apr 2021 `v0.2.1`: Slightly adjusts the ordering in `setup_cfg()` (see `tools/train.py`).
|
||||||
|
- Apr 2021 `v0.2.0`: Adds `_C.DATASET.ALL_AS_UNLABELED` (for the SSL setting) to the config variable list. When this variable is set to `True`, all labeled data will be included in the unlabeled data set.
|
||||||
|
- Apr 2021 `v0.1.9`: Adds [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf) to the benchmark datasets (see `dassl/data/datasets/dg/vlcs.py`).
|
||||||
|
- Mar 2021 `v0.1.8`: Allows `optim` and `sched` to be `None` in `register_model()`.
|
||||||
|
- Mar 2021 `v0.1.7`: Adds [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) models to [dassl/modeling/backbone/resnet.py](dassl/modeling/backbone/resnet.py). The training configs in `configs/trainers/dg/vanilla` can be used to train MixStyle models.
|
||||||
|
- Mar 2021 `v0.1.6`: Adds [CIFAR-10/100-C](https://arxiv.org/abs/1807.01697) to the benchmark datasets for evaluating a model's robustness to image corruptions.
|
||||||
|
- Mar 2021 We have just released a survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in this topic with coverage on the history, related problems, datasets, methodologies, potential directions, and so on.
|
||||||
|
- Jan 2021 Our recent work, [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) (mixing instance-level feature statistics of samples of different domains for improving domain generalization), is accepted to ICLR'21. The code is available at https://github.com/KaiyangZhou/mixstyle-release where the cross-domain image classification part is based on Dassl.pytorch.
|
||||||
|
- May 2020 `v0.1.3`: Adds the `Digit-Single` dataset for benchmarking single-source DG methods. The corresponding CNN model is [dassl/modeling/backbone/cnn_digitsingle.py](dassl/modeling/backbone/cnn_digitsingle.py) and the dataset config file is [configs/datasets/dg/digit_single.yaml](configs/datasets/dg/digit_single.yaml). See [Volpi et al. NIPS'18](https://arxiv.org/abs/1805.12018) for how to do evaluation.
|
||||||
|
- May 2020 `v0.1.2`: 1) Adds [EfficientNet](https://arxiv.org/abs/1905.11946) models (B0-B7) (credit to https://github.com/lukemelas/EfficientNet-PyTorch). To use EfficientNet, set `MODEL.BACKBONE.NAME` to `efficientnet_b{N}` where `N={0, ..., 7}`. 2) `dassl/modeling/models` is renamed to `dassl/modeling/network` (`build_model()` to `build_network()` and `MODEL_REGISTRY` to `NETWORK_RESIGTRY`).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Dassl has implemented the following methods:
|
||||||
|
|
||||||
|
- Single-source domain adaptation
|
||||||
|
- [Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19)](https://arxiv.org/abs/1904.06487) [[dassl/engine/da/mme.py](dassl/engine/da/mme.py)]
|
||||||
|
- [Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18)](https://arxiv.org/abs/1712.02560https://arxiv.org/abs/1712.02560) [[dassl/engine/da/mcd.py](dassl/engine/da/mcd.py)]
|
||||||
|
- [Self-ensembling for visual domain adaptation (ICLR'18)](https://arxiv.org/abs/1706.05208) [[dassl/engine/da/self_ensembling.py](dassl/engine/da/self_ensembling.py)]
|
||||||
|
- [Revisiting Batch Normalization For Practical Domain Adaptation (ICLR-W'17)](https://arxiv.org/abs/1603.04779) [[dassl/engine/da/adabn.py](dassl/engine/da/adabn.py)]
|
||||||
|
- [Adversarial Discriminative Domain Adaptation (CVPR'17)](https://arxiv.org/abs/1702.05464) [[dassl/engine/da/adda.py](dassl/engine/da/adda.py)]
|
||||||
|
- [Domain-Adversarial Training of Neural Networks (JMLR'16) ](https://arxiv.org/abs/1505.07818) [[dassl/engine/da/dann.py](dassl/engine/da/dann.py)]
|
||||||
|
|
||||||
|
- Multi-source domain adaptation
|
||||||
|
- [Domain Aadaptive Ensemble Learning](https://arxiv.org/abs/2003.07325) [[dassl/engine/da/dael.py](dassl/engine/da/dael.py)]
|
||||||
|
- [Moment Matching for Multi-Source Domain Adaptation (ICCV'19)](https://arxiv.org/abs/1812.01754) [[dassl/engine/da/m3sda.py](dassl/engine/da/m3sda.py)]
|
||||||
|
|
||||||
|
- Domain generalization
|
||||||
|
- [Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization (CVPR'22)](https://arxiv.org/abs/2203.07740) [[dassl/modeling/ops/efdmix.py](dassl/modeling/ops/efdmix.py)]
|
||||||
|
- [Domain Generalization with MixStyle (ICLR'21)](https://openreview.net/forum?id=6xHJ37MVxxp) [[dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py)]
|
||||||
|
- [Deep Domain-Adversarial Image Generation for Domain Generalisation (AAAI'20)](https://arxiv.org/abs/2003.06054) [[dassl/engine/dg/ddaig.py](dassl/engine/dg/ddaig.py)]
|
||||||
|
- [Generalizing Across Domains via Cross-Gradient Training (ICLR'18)](https://arxiv.org/abs/1804.10745) [[dassl/engine/dg/crossgrad.py](dassl/engine/dg/crossgrad.py)]
|
||||||
|
|
||||||
|
- Semi-supervised learning
|
||||||
|
- [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence](https://arxiv.org/abs/2001.07685) [[dassl/engine/ssl/fixmatch.py](dassl/engine/ssl/fixmatch.py)]
|
||||||
|
- [MixMatch: A Holistic Approach to Semi-Supervised Learning (NeurIPS'19)](https://arxiv.org/abs/1905.02249) [[dassl/engine/ssl/mixmatch.py](dassl/engine/ssl/mixmatch.py)]
|
||||||
|
- [Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (NeurIPS'17)](https://arxiv.org/abs/1703.01780) [[dassl/engine/ssl/mean_teacher.py](dassl/engine/ssl/mean_teacher.py)]
|
||||||
|
- [Semi-supervised Learning by Entropy Minimization (NeurIPS'04)](http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf) [[dassl/engine/ssl/entmin.py](dassl/engine/ssl/entmin.py)]
|
||||||
|
|
||||||
|
*Feel free to make a [PR](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) to add your methods here to make it easier for others to benchmark!*
|
||||||
|
|
||||||
|
Dassl supports the following datasets:
|
||||||
|
|
||||||
|
- Domain adaptation
|
||||||
|
- [Office-31](https://scalable.mpi-inf.mpg.de/files/2013/04/saenko_eccv_2010.pdf)
|
||||||
|
- [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/)
|
||||||
|
- [VisDA17](http://ai.bu.edu/visda-2017/)
|
||||||
|
- [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html)-[STL10](https://cs.stanford.edu/~acoates/stl10/)
|
||||||
|
- [Digit-5](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download)
|
||||||
|
- [DomainNet](http://ai.bu.edu/M3SDA/)
|
||||||
|
- [miniDomainNet](https://arxiv.org/abs/2003.07325)
|
||||||
|
|
||||||
|
- Domain generalization
|
||||||
|
- [PACS](https://arxiv.org/abs/1710.03077)
|
||||||
|
- [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf)
|
||||||
|
- [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/)
|
||||||
|
- [Digits-DG](https://arxiv.org/abs/2003.06054)
|
||||||
|
- [Digit-Single](https://arxiv.org/abs/1805.12018)
|
||||||
|
- [CIFAR-10-C](https://arxiv.org/abs/1807.01697)
|
||||||
|
- [CIFAR-100-C](https://arxiv.org/abs/1807.01697)
|
||||||
|
|
||||||
|
- Semi-supervised learning
|
||||||
|
- [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html.)
|
||||||
|
- [SVHN](http://ufldl.stanford.edu/housenumbers/)
|
||||||
|
- [STL10](https://cs.stanford.edu/~acoates/stl10/)
|
||||||
|
|
||||||
|
## Get started
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
Make sure [conda](https://www.anaconda.com/distribution/) is installed properly.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone this repo
|
||||||
|
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
|
||||||
|
cd Dassl.pytorch/
|
||||||
|
|
||||||
|
# Create a conda environment
|
||||||
|
conda create -n dassl python=3.7
|
||||||
|
|
||||||
|
# Activate the environment
|
||||||
|
conda activate dassl
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Install torch (version >= 1.7.1) and torchvision
|
||||||
|
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
|
||||||
|
|
||||||
|
# Install this library (no need to re-build if the source code is modified)
|
||||||
|
python setup.py develop
|
||||||
|
```
|
||||||
|
|
||||||
|
Follow the instructions in [DATASETS.md](./DATASETS.md) to preprocess the datasets.
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
The main interface is implemented in `tools/train.py`, which basically does
|
||||||
|
|
||||||
|
1. initialize the config with `cfg = setup_cfg(args)` where `args` contains the command-line input (see `tools/train.py` for the list of input arguments);
|
||||||
|
2. instantiate a `trainer` with `build_trainer(cfg)` which loads the dataset and builds a deep neural network model;
|
||||||
|
3. call `trainer.train()` for training and evaluating the model.
|
||||||
|
|
||||||
|
Below we provide an example for training a source-only baseline on the popular domain adaptation dataset, Office-31,
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
|
||||||
|
--root $DATA \
|
||||||
|
--trainer SourceOnly \
|
||||||
|
--source-domains amazon \
|
||||||
|
--target-domains webcam \
|
||||||
|
--dataset-config-file configs/datasets/da/office31.yaml \
|
||||||
|
--config-file configs/trainers/da/source_only/office31.yaml \
|
||||||
|
--output-dir output/source_only_office31
|
||||||
|
```
|
||||||
|
|
||||||
|
`$DATA` denotes the location where datasets are installed. `--dataset-config-file` loads the common setting for the dataset (Office-31 in this case) such as image size and model architecture. `--config-file` loads the algorithm-specific setting such as hyper-parameters and optimization parameters.
|
||||||
|
|
||||||
|
To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to `--source-domains`. For instance, to train a source-only baseline on miniDomainNet, one can do
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
|
||||||
|
--root $DATA \
|
||||||
|
--trainer SourceOnly \
|
||||||
|
--source-domains clipart painting real \
|
||||||
|
--target-domains sketch \
|
||||||
|
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
|
||||||
|
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
|
||||||
|
--output-dir output/source_only_minidn
|
||||||
|
```
|
||||||
|
|
||||||
|
After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.
|
||||||
|
|
||||||
|
To print out the results saved in the log file (so you do not need to exhaustively go through all log files and calculate the mean/std by yourself), you can use `tools/parse_test_res.py`. The instruction can be found in the code.
|
||||||
|
|
||||||
|
For other trainers such as `MCD`, you can set `--trainer MCD` while keeping the config file unchanged, i.e. using the same training parameters as `SourceOnly` (in the simplest case). To modify the hyper-parameters in MCD, like `N_STEP_F` (number of steps to update the feature extractor), you can append `TRAINER.MCD.N_STEP_F 4` to the existing input arguments (otherwise the default value will be used). Alternatively, you can create a new `.yaml` config file to store your custom setting. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L176) for a complete list of algorithm-specific hyper-parameters.
|
||||||
|
|
||||||
|
### Test
|
||||||
|
Model testing can be done by using `--eval-only`, which asks the code to run `trainer.test()`. You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use `model.pth.tar-20` saved at `output/source_only_office31/model`, you can do
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
|
||||||
|
--root $DATA \
|
||||||
|
--trainer SourceOnly \
|
||||||
|
--source-domains amazon \
|
||||||
|
--target-domains webcam \
|
||||||
|
--dataset-config-file configs/datasets/da/office31.yaml \
|
||||||
|
--config-file configs/trainers/da/source_only/office31.yaml \
|
||||||
|
--output-dir output/source_only_office31_test \
|
||||||
|
--eval-only \
|
||||||
|
--model-dir output/source_only_office31 \
|
||||||
|
--load-epoch 20
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that `--model-dir` takes as input the directory path which was specified in `--output-dir` in the training stage.
|
||||||
|
|
||||||
|
### Write a new trainer
|
||||||
|
A good practice is to go through `dassl/engine/trainer.py` to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass `TrainerXU`. For domain generalization, the new class can subclass `TrainerX`. In particular, `TrainerXU` and `TrainerX` mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the `forward_backward()` method, which performs loss computation and model update. See `dassl/enigne/da/source_only.py` for example.
|
||||||
|
|
||||||
|
### Add a new backbone/head/network
|
||||||
|
`backbone` corresponds to a convolutional neural network model which performs feature extraction. `head` (which is an optional module) is mounted on top of `backbone` for further processing, which can be, for example, a MLP. `backbone` and `head` are basic building blocks for constructing a `SimpleNet()` (see `dassl/engine/trainer.py`) which serves as the primary model for a task. `network` contains custom neural network models, such as an image generator.
|
||||||
|
|
||||||
|
To add a new module, namely a backbone/head/network, you need to first register the module using the corresponding `registry`, i.e. `BACKBONE_REGISTRY` for `backbone`, `HEAD_REGISTRY` for `head` and `NETWORK_RESIGTRY` for `network`. Note that for a new `backbone`, we require the model to subclass `Backbone` as defined in `dassl/modeling/backbone/backbone.py` and specify the `self._out_features` attribute.
|
||||||
|
|
||||||
|
We provide an example below for how to add a new `backbone`.
|
||||||
|
```python
|
||||||
|
from dassl.modeling import Backbone, BACKBONE_REGISTRY
|
||||||
|
|
||||||
|
class MyBackbone(Backbone):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Create layers
|
||||||
|
self.conv = ...
|
||||||
|
|
||||||
|
self._out_features = 2048
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Extract and return features
|
||||||
|
|
||||||
|
@BACKBONE_REGISTRY.register()
|
||||||
|
def my_backbone(**kwargs):
|
||||||
|
return MyBackbone()
|
||||||
|
```
|
||||||
|
Then, you can set `MODEL.BACKBONE.NAME` to `my_backbone` to use your own architecture. For more details, please refer to the source code in `dassl/modeling`.
|
||||||
|
|
||||||
|
### Add a dataset
|
||||||
|
An example code structure is shown below. Make sure you subclass `DatasetBase` and register the dataset with `@DATASET_REGISTRY.register()`. All you need is to load `train_x`, `train_u` (optional), `val` (optional) and `test`, among which `train_u` and `val` could be `None` or simply ignored. Each of these variables contains a list of `Datum` objects. A `Datum` object (implemented [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/data/datasets/base_dataset.py#L12)) contains information for a single image, like `impath` (string) and `label` (int).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class NewDataset(DatasetBase):
|
||||||
|
|
||||||
|
dataset_dir = ''
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
|
||||||
|
train_x = ...
|
||||||
|
train_u = ... # optional, can be None
|
||||||
|
val = ... # optional, can be None
|
||||||
|
test = ...
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
|
||||||
|
```
|
||||||
|
|
||||||
|
We suggest you take a look at the datasets code in some projects like [this](https://github.com/KaiyangZhou/CoOp), which is built on top of Dassl.
|
||||||
|
|
||||||
|
## Relevant Research
|
||||||
|
|
||||||
|
We would like to share here our research relevant to Dassl.
|
||||||
|
|
||||||
|
- [Domain Adaptive Ensemble Learning](https://arxiv.org/abs/2003.07325), TIP, 2021.
|
||||||
|
- [MixStyle Neural Networks for Domain Generalization and Adaptation](https://arxiv.org/abs/2107.02053), arxiv preprint, 2021.
|
||||||
|
- [Semi-Supervised Domain Generalization with Stochastic StyleMatch](https://arxiv.org/abs/2106.00592), arxiv preprint, 2021.
|
||||||
|
- [Domain Generalization in Vision: A Survey](https://arxiv.org/abs/2103.02503), arxiv preprint, 2021.
|
||||||
|
- [Domain Generalization with MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp), in ICLR 2021.
|
||||||
|
- [Learning to Generate Novel Domains for Domain Generalization](https://arxiv.org/abs/2007.03304), in ECCV 2020.
|
||||||
|
- [Deep Domain-Adversarial Image Generation for Domain Generalisation](https://arxiv.org/abs/2003.06054), in AAAI 2020.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you find this code useful to your research, please give credit to the following paper
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{zhou2020domain,
|
||||||
|
title={Domain Adaptive Ensemble Learning},
|
||||||
|
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
|
||||||
|
journal={IEEE Transactions on Image Processing (TIP)},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
|
```
|
||||||
1
Dassl.ProGrad.pytorch/configs/README.md
Normal file
1
Dassl.ProGrad.pytorch/configs/README.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
The `datasets/` folder contains dataset-specific config files which define the standard protocols (e.g., image size, data augmentation, network architecture) used by most papers. The `trainers/` folder contains method-specific config files which define optimization algorithms (e.g., optimizer, epoch) and hyperparameter settings.
|
||||||
7
Dassl.ProGrad.pytorch/configs/datasets/da/cifar_stl.yaml
Normal file
7
Dassl.ProGrad.pytorch/configs/datasets/da/cifar_stl.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (32, 32)
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "CIFARSTL"
|
||||||
12
Dassl.ProGrad.pytorch/configs/datasets/da/digit5.yaml
Normal file
12
Dassl.ProGrad.pytorch/configs/datasets/da/digit5.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (32, 32)
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
TRANSFORMS: ["normalize"]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "Digit5"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "cnn_digit5_m3sda"
|
||||||
10
Dassl.ProGrad.pytorch/configs/datasets/da/domainnet.yaml
Normal file
10
Dassl.ProGrad.pytorch/configs/datasets/da/domainnet.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (224, 224)
|
||||||
|
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "DomainNet"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "resnet101"
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (96, 96)
|
||||||
|
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "miniDomainNet"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "resnet18"
|
||||||
14
Dassl.ProGrad.pytorch/configs/datasets/da/office31.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/da/office31.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (224, 224)
|
||||||
|
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "Office31"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "resnet50"
|
||||||
|
HEAD:
|
||||||
|
NAME: "mlp"
|
||||||
|
HIDDEN_LAYERS: [256]
|
||||||
|
DROPOUT: 0.
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (224, 224)
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "OfficeHome"
|
||||||
13
Dassl.ProGrad.pytorch/configs/datasets/da/visda17.yaml
Normal file
13
Dassl.ProGrad.pytorch/configs/datasets/da/visda17.yaml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (224, 224)
|
||||||
|
TRANSFORMS: ["random_flip", "center_crop", "normalize"]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "VisDA17"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "resnet101"
|
||||||
|
|
||||||
|
TEST:
|
||||||
|
PER_CLASS_RESULT: True
|
||||||
14
Dassl.ProGrad.pytorch/configs/datasets/dg/cifar100_c.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/dg/cifar100_c.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (32, 32)
|
||||||
|
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "CIFAR100C"
|
||||||
|
CIFAR_C_TYPE: "fog"
|
||||||
|
CIFAR_C_LEVEL: 5
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "wide_resnet_16_4"
|
||||||
14
Dassl.ProGrad.pytorch/configs/datasets/dg/cifar10_c.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/dg/cifar10_c.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (32, 32)
|
||||||
|
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "CIFAR10C"
|
||||||
|
CIFAR_C_TYPE: "fog"
|
||||||
|
CIFAR_C_LEVEL: 5
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "wide_resnet_16_4"
|
||||||
12
Dassl.ProGrad.pytorch/configs/datasets/dg/digit_single.yaml
Normal file
12
Dassl.ProGrad.pytorch/configs/datasets/dg/digit_single.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (32, 32)
|
||||||
|
TRANSFORMS: ["normalize"]
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "DigitSingle"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "cnn_digitsingle"
|
||||||
12
Dassl.ProGrad.pytorch/configs/datasets/dg/digits_dg.yaml
Normal file
12
Dassl.ProGrad.pytorch/configs/datasets/dg/digits_dg.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (32, 32)
|
||||||
|
TRANSFORMS: ["normalize"]
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "DigitsDG"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "cnn_digitsdg"
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (224, 224)
|
||||||
|
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "OfficeHomeDG"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "resnet18"
|
||||||
|
PRETRAINED: True
|
||||||
11
Dassl.ProGrad.pytorch/configs/datasets/dg/pacs.yaml
Normal file
11
Dassl.ProGrad.pytorch/configs/datasets/dg/pacs.yaml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (224, 224)
|
||||||
|
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "PACS"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "resnet18"
|
||||||
|
PRETRAINED: True
|
||||||
11
Dassl.ProGrad.pytorch/configs/datasets/dg/vlcs.yaml
Normal file
11
Dassl.ProGrad.pytorch/configs/datasets/dg/vlcs.yaml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (224, 224)
|
||||||
|
TRANSFORMS: ["random_flip", "random_translation", "normalize"]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "VLCS"
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "resnet18"
|
||||||
|
PRETRAINED: True
|
||||||
14
Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar10.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar10.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (32, 32)
|
||||||
|
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "CIFAR10"
|
||||||
|
NUM_LABELED: 4000
|
||||||
|
VAL_PERCENT: 0.
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "wide_resnet_28_2"
|
||||||
15
Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar100.yaml
Normal file
15
Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar100.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (32, 32)
|
||||||
|
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
CROP_PADDING: 4
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "CIFAR100"
|
||||||
|
NUM_LABELED: 10000
|
||||||
|
VAL_PERCENT: 0.
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "wide_resnet_28_2"
|
||||||
14
Dassl.ProGrad.pytorch/configs/datasets/ssl/stl10.yaml
Normal file
14
Dassl.ProGrad.pytorch/configs/datasets/ssl/stl10.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (96, 96)
|
||||||
|
TRANSFORMS: ["random_flip", "random_crop", "normalize"]
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
CROP_PADDING: 4
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "STL10"
|
||||||
|
STL10_FOLD: 0
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "wide_resnet_28_2"
|
||||||
15
Dassl.ProGrad.pytorch/configs/datasets/ssl/svhn.yaml
Normal file
15
Dassl.ProGrad.pytorch/configs/datasets/ssl/svhn.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
INPUT:
|
||||||
|
SIZE: (32, 32)
|
||||||
|
TRANSFORMS: ["random_crop", "normalize"]
|
||||||
|
PIXEL_MEAN: [0.5, 0.5, 0.5]
|
||||||
|
PIXEL_STD: [0.5, 0.5, 0.5]
|
||||||
|
CROP_PADDING: 4
|
||||||
|
|
||||||
|
DATASET:
|
||||||
|
NAME: "SVHN"
|
||||||
|
NUM_LABELED: 1000
|
||||||
|
VAL_PERCENT: 0.
|
||||||
|
|
||||||
|
MODEL:
|
||||||
|
BACKBONE:
|
||||||
|
NAME: "wide_resnet_28_2"
|
||||||
20
Dassl.ProGrad.pytorch/configs/trainers/da/dael/digit5.yaml
Normal file
20
Dassl.ProGrad.pytorch/configs/trainers/da/dael/digit5.yaml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
SAMPLER: "RandomDomainSampler"
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
TRAIN_U:
|
||||||
|
SAME_AS_X: False
|
||||||
|
BATCH_SIZE: 64
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.05
|
||||||
|
STEPSIZE: [30]
|
||||||
|
MAX_EPOCH: 30
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
DAEL:
|
||||||
|
STRONG_TRANSFORMS: ["randaugment2", "normalize"]
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
DATALOADER:
|
||||||
|
NUM_WORKERS: 4
|
||||||
|
TRAIN_X:
|
||||||
|
SAMPLER: "RandomDomainSampler"
|
||||||
|
BATCH_SIZE: 30
|
||||||
|
TRAIN_U:
|
||||||
|
SAME_AS_X: False
|
||||||
|
BATCH_SIZE: 6
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 30
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.002
|
||||||
|
MAX_EPOCH: 40
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
DAEL:
|
||||||
|
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
DATALOADER:
|
||||||
|
NUM_WORKERS: 8
|
||||||
|
TRAIN_X:
|
||||||
|
SAMPLER: "RandomDomainSampler"
|
||||||
|
BATCH_SIZE: 192
|
||||||
|
TRAIN_U:
|
||||||
|
SAME_AS_X: False
|
||||||
|
BATCH_SIZE: 64
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 200
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.005
|
||||||
|
MAX_EPOCH: 60
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
DAEL:
|
||||||
|
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
|
||||||
16
Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/digit5.yaml
Normal file
16
Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/digit5.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
SAMPLER: "RandomDomainSampler"
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
TRAIN_U:
|
||||||
|
SAME_AS_X: False
|
||||||
|
BATCH_SIZE: 64
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.05
|
||||||
|
STEPSIZE: [30]
|
||||||
|
MAX_EPOCH: 30
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
DATALOADER:
|
||||||
|
NUM_WORKERS: 4
|
||||||
|
TRAIN_X:
|
||||||
|
SAMPLER: "RandomDomainSampler"
|
||||||
|
BATCH_SIZE: 30
|
||||||
|
TRAIN_U:
|
||||||
|
SAME_AS_X: False
|
||||||
|
BATCH_SIZE: 6
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 30
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.002
|
||||||
|
MAX_EPOCH: 40
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
DATALOADER:
|
||||||
|
NUM_WORKERS: 8
|
||||||
|
TRAIN_X:
|
||||||
|
SAMPLER: "RandomDomainSampler"
|
||||||
|
BATCH_SIZE: 192
|
||||||
|
TRAIN_U:
|
||||||
|
SAME_AS_X: False
|
||||||
|
BATCH_SIZE: 64
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 200
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.005
|
||||||
|
MAX_EPOCH: 60
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 256
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.05
|
||||||
|
STEPSIZE: [30]
|
||||||
|
MAX_EPOCH: 30
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
DATALOADER:
|
||||||
|
NUM_WORKERS: 8
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 128
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 128
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.005
|
||||||
|
MAX_EPOCH: 60
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 32
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 32
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.002
|
||||||
|
STEPSIZE: [20]
|
||||||
|
MAX_EPOCH: 20
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 32
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 32
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.0001
|
||||||
|
STEPSIZE: [2]
|
||||||
|
MAX_EPOCH: 2
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
PRINT_FREQ: 50
|
||||||
|
COUNT_ITER: "train_u"
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
SAMPLER: "RandomDomainSampler"
|
||||||
|
BATCH_SIZE: 120
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 100
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.05
|
||||||
|
STEPSIZE: [20]
|
||||||
|
MAX_EPOCH: 50
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
DAEL:
|
||||||
|
STRONG_TRANSFORMS: ["randaugment2", "normalize"]
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
SAMPLER: "RandomDomainSampler"
|
||||||
|
BATCH_SIZE: 30
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 100
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.002
|
||||||
|
MAX_EPOCH: 40
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
DAEL:
|
||||||
|
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
|
||||||
16
Dassl.ProGrad.pytorch/configs/trainers/dg/dael/pacs.yaml
Normal file
16
Dassl.ProGrad.pytorch/configs/trainers/dg/dael/pacs.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
SAMPLER: "RandomDomainSampler"
|
||||||
|
BATCH_SIZE: 30
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 100
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.002
|
||||||
|
MAX_EPOCH: 40
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
DAEL:
|
||||||
|
STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
INPUT:
|
||||||
|
PIXEL_MEAN: [0., 0., 0.]
|
||||||
|
PIXEL_STD: [1., 1., 1.]
|
||||||
|
|
||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 128
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 128
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.05
|
||||||
|
STEPSIZE: [20]
|
||||||
|
MAX_EPOCH: 50
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
DDAIG:
|
||||||
|
G_ARCH: "fcn_3x32_gctx"
|
||||||
|
LMDA: 0.3
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
INPUT:
|
||||||
|
PIXEL_MEAN: [0., 0., 0.]
|
||||||
|
PIXEL_STD: [1., 1., 1.]
|
||||||
|
|
||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 16
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 16
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.0005
|
||||||
|
STEPSIZE: [20]
|
||||||
|
MAX_EPOCH: 25
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
DDAIG:
|
||||||
|
G_ARCH: "fcn_3x64_gctx"
|
||||||
|
WARMUP: 3
|
||||||
|
LMDA: 0.3
|
||||||
21
Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/pacs.yaml
Normal file
21
Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/pacs.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
INPUT:
|
||||||
|
PIXEL_MEAN: [0., 0., 0.]
|
||||||
|
PIXEL_STD: [1., 1., 1.]
|
||||||
|
|
||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 16
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 16
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.0005
|
||||||
|
STEPSIZE: [20]
|
||||||
|
MAX_EPOCH: 25
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
DDAIG:
|
||||||
|
G_ARCH: "fcn_3x64_gctx"
|
||||||
|
WARMUP: 3
|
||||||
|
LMDA: 0.3
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 128
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 100
|
||||||
|
NUM_WORKERS: 8
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.05
|
||||||
|
STEPSIZE: [20]
|
||||||
|
MAX_EPOCH: 50
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
PRINT_FREQ: 20
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
DATALOADER:
|
||||||
|
NUM_WORKERS: 8
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 128
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 128
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.005
|
||||||
|
MAX_EPOCH: 60
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 64
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 100
|
||||||
|
NUM_WORKERS: 8
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.001
|
||||||
|
MAX_EPOCH: 50
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
12
Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/pacs.yaml
Normal file
12
Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/pacs.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 64
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 100
|
||||||
|
NUM_WORKERS: 8
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.001
|
||||||
|
MAX_EPOCH: 50
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
DATALOADER:
|
||||||
|
TRAIN_X:
|
||||||
|
BATCH_SIZE: 64
|
||||||
|
TRAIN_U:
|
||||||
|
SAME_AS_X: False
|
||||||
|
BATCH_SIZE: 448
|
||||||
|
TEST:
|
||||||
|
BATCH_SIZE: 500
|
||||||
|
|
||||||
|
OPTIM:
|
||||||
|
NAME: "sgd"
|
||||||
|
LR: 0.05
|
||||||
|
STEPSIZE: [4000]
|
||||||
|
MAX_EPOCH: 4000
|
||||||
|
LR_SCHEDULER: "cosine"
|
||||||
|
|
||||||
|
TRAIN:
|
||||||
|
COUNT_ITER: "train_u"
|
||||||
|
PRINT_FREQ: 10
|
||||||
|
|
||||||
|
TRAINER:
|
||||||
|
FIXMATCH:
|
||||||
|
STRONG_TRANSFORMS: ["random_flip", "randaugment_fixmatch", "normalize", "cutout"]
|
||||||
18
Dassl.ProGrad.pytorch/dassl/__init__.py
Normal file
18
Dassl.ProGrad.pytorch/dassl/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
Dassl
|
||||||
|
------
|
||||||
|
PyTorch toolbox for domain adaptation and semi-supervised learning.
|
||||||
|
|
||||||
|
URL: https://github.com/KaiyangZhou/Dassl.pytorch
|
||||||
|
|
||||||
|
@article{zhou2020domain,
|
||||||
|
title={Domain Adaptive Ensemble Learning},
|
||||||
|
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
|
||||||
|
journal={arXiv preprint arXiv:2003.07325},
|
||||||
|
year={2020}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.5.0"
|
||||||
|
__author__ = "Kaiyang Zhou"
|
||||||
|
__homepage__ = "https://kaiyangzhou.github.io/"
|
||||||
5
Dassl.ProGrad.pytorch/dassl/config/__init__.py
Normal file
5
Dassl.ProGrad.pytorch/dassl/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .defaults import _C as cfg_default
|
||||||
|
|
||||||
|
|
||||||
|
def get_cfg_default():
|
||||||
|
return cfg_default.clone()
|
||||||
275
Dassl.ProGrad.pytorch/dassl/config/defaults.py
Normal file
275
Dassl.ProGrad.pytorch/dassl/config/defaults.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Config definition
|
||||||
|
###########################
|
||||||
|
|
||||||
|
_C = CN()
|
||||||
|
|
||||||
|
_C.VERSION = 1
|
||||||
|
|
||||||
|
# Directory to save the output files (like log.txt and model weights)
|
||||||
|
_C.OUTPUT_DIR = "./output"
|
||||||
|
# Path to a directory where the files were saved previously
|
||||||
|
_C.RESUME = ""
|
||||||
|
# Set seed to negative value to randomize everything
|
||||||
|
# Set seed to positive value to use a fixed seed
|
||||||
|
_C.SEED = -1
|
||||||
|
_C.USE_CUDA = True
|
||||||
|
# Print detailed information
|
||||||
|
# E.g. trainer, dataset, and backbone
|
||||||
|
_C.VERBOSE = True
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Input
|
||||||
|
###########################
|
||||||
|
_C.INPUT = CN()
|
||||||
|
_C.INPUT.SIZE = (224, 224)
|
||||||
|
# Mode of interpolation in resize functions
|
||||||
|
_C.INPUT.INTERPOLATION = "bilinear"
|
||||||
|
# For available choices please refer to transforms.py
|
||||||
|
_C.INPUT.TRANSFORMS = ()
|
||||||
|
# If True, tfm_train and tfm_test will be None
|
||||||
|
_C.INPUT.NO_TRANSFORM = False
|
||||||
|
# Default mean and std come from ImageNet
|
||||||
|
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
|
||||||
|
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
|
||||||
|
# Padding for random crop
|
||||||
|
_C.INPUT.CROP_PADDING = 4
|
||||||
|
# Cutout
|
||||||
|
_C.INPUT.CUTOUT_N = 1
|
||||||
|
_C.INPUT.CUTOUT_LEN = 16
|
||||||
|
# Gaussian noise
|
||||||
|
_C.INPUT.GN_MEAN = 0.0
|
||||||
|
_C.INPUT.GN_STD = 0.15
|
||||||
|
# RandomAugment
|
||||||
|
_C.INPUT.RANDAUGMENT_N = 2
|
||||||
|
_C.INPUT.RANDAUGMENT_M = 10
|
||||||
|
# ColorJitter (brightness, contrast, saturation, hue)
|
||||||
|
_C.INPUT.COLORJITTER_B = 0.4
|
||||||
|
_C.INPUT.COLORJITTER_C = 0.4
|
||||||
|
_C.INPUT.COLORJITTER_S = 0.4
|
||||||
|
_C.INPUT.COLORJITTER_H = 0.1
|
||||||
|
# Random gray scale's probability
|
||||||
|
_C.INPUT.RGS_P = 0.2
|
||||||
|
# Gaussian blur
|
||||||
|
_C.INPUT.GB_P = 0.5 # propability of applying this operation
|
||||||
|
_C.INPUT.GB_K = 21 # kernel size (should be an odd number)
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Dataset
|
||||||
|
###########################
|
||||||
|
_C.DATASET = CN()
|
||||||
|
# Directory where datasets are stored
|
||||||
|
_C.DATASET.ROOT = ""
|
||||||
|
_C.DATASET.NAME = ""
|
||||||
|
# List of names of source domains
|
||||||
|
_C.DATASET.SOURCE_DOMAINS = ()
|
||||||
|
# List of names of target domains
|
||||||
|
_C.DATASET.TARGET_DOMAINS = ()
|
||||||
|
# Number of labeled instances in total
|
||||||
|
# Useful for the semi-supervised learning
|
||||||
|
_C.DATASET.NUM_LABELED = -1
|
||||||
|
# Number of images per class
|
||||||
|
_C.DATASET.NUM_SHOTS = -1
|
||||||
|
# Percentage of validation data (only used for SSL datasets)
|
||||||
|
# Set to 0 if do not want to use val data
|
||||||
|
# Using val data for hyperparameter tuning was done in Oliver et al. 2018
|
||||||
|
_C.DATASET.VAL_PERCENT = 0.1
|
||||||
|
# Fold index for STL-10 dataset (normal range is 0 - 9)
|
||||||
|
# Negative number means None
|
||||||
|
_C.DATASET.STL10_FOLD = -1
|
||||||
|
# CIFAR-10/100-C's corruption type and intensity level
|
||||||
|
_C.DATASET.CIFAR_C_TYPE = ""
|
||||||
|
_C.DATASET.CIFAR_C_LEVEL = 1
|
||||||
|
# Use all data in the unlabeled data set (e.g. FixMatch)
|
||||||
|
_C.DATASET.ALL_AS_UNLABELED = False
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Dataloader
|
||||||
|
###########################
|
||||||
|
_C.DATALOADER = CN()
|
||||||
|
_C.DATALOADER.NUM_WORKERS = 4
|
||||||
|
# Apply transformations to an image K times (during training)
|
||||||
|
_C.DATALOADER.K_TRANSFORMS = 1
|
||||||
|
# img0 denotes image tensor without augmentation
|
||||||
|
# Useful for consistency learning
|
||||||
|
_C.DATALOADER.RETURN_IMG0 = False
|
||||||
|
# Setting for the train_x data-loader
|
||||||
|
_C.DATALOADER.TRAIN_X = CN()
|
||||||
|
_C.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler"
|
||||||
|
_C.DATALOADER.TRAIN_X.BATCH_SIZE = 32
|
||||||
|
# Parameter for RandomDomainSampler
|
||||||
|
# 0 or -1 means sampling from all domains
|
||||||
|
_C.DATALOADER.TRAIN_X.N_DOMAIN = 0
|
||||||
|
# Parameter of RandomClassSampler
|
||||||
|
# Number of instances per class
|
||||||
|
_C.DATALOADER.TRAIN_X.N_INS = 16
|
||||||
|
|
||||||
|
# Setting for the train_u data-loader
|
||||||
|
_C.DATALOADER.TRAIN_U = CN()
|
||||||
|
# Set to false if you want to have unique
|
||||||
|
# data loader params for train_u
|
||||||
|
_C.DATALOADER.TRAIN_U.SAME_AS_X = True
|
||||||
|
_C.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler"
|
||||||
|
_C.DATALOADER.TRAIN_U.BATCH_SIZE = 32
|
||||||
|
_C.DATALOADER.TRAIN_U.N_DOMAIN = 0
|
||||||
|
_C.DATALOADER.TRAIN_U.N_INS = 16
|
||||||
|
|
||||||
|
# Setting for the test data-loader
|
||||||
|
_C.DATALOADER.TEST = CN()
|
||||||
|
_C.DATALOADER.TEST.SAMPLER = "SequentialSampler"
|
||||||
|
_C.DATALOADER.TEST.BATCH_SIZE = 32
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Model
|
||||||
|
###########################
|
||||||
|
_C.MODEL = CN()
|
||||||
|
# Path to model weights (for initialization)
|
||||||
|
_C.MODEL.INIT_WEIGHTS = ""
|
||||||
|
_C.MODEL.BACKBONE = CN()
|
||||||
|
_C.MODEL.BACKBONE.NAME = ""
|
||||||
|
_C.MODEL.BACKBONE.PRETRAINED = True
|
||||||
|
# Definition of embedding layers
|
||||||
|
_C.MODEL.HEAD = CN()
|
||||||
|
# If none, do not construct embedding layers, the
|
||||||
|
# backbone's output will be passed to the classifier
|
||||||
|
_C.MODEL.HEAD.NAME = ""
|
||||||
|
# Structure of hidden layers (a list), e.g. [512, 512]
|
||||||
|
# If undefined, no embedding layer will be constructed
|
||||||
|
_C.MODEL.HEAD.HIDDEN_LAYERS = ()
|
||||||
|
_C.MODEL.HEAD.ACTIVATION = "relu"
|
||||||
|
_C.MODEL.HEAD.BN = True
|
||||||
|
_C.MODEL.HEAD.DROPOUT = 0.0
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Optimization
|
||||||
|
###########################
|
||||||
|
_C.OPTIM = CN()
|
||||||
|
_C.OPTIM.NAME = "adam"
|
||||||
|
_C.OPTIM.LR = 0.0003
|
||||||
|
_C.OPTIM.WEIGHT_DECAY = 5e-4
|
||||||
|
_C.OPTIM.MOMENTUM = 0.9
|
||||||
|
_C.OPTIM.SGD_DAMPNING = 0
|
||||||
|
_C.OPTIM.SGD_NESTEROV = False
|
||||||
|
_C.OPTIM.RMSPROP_ALPHA = 0.99
|
||||||
|
_C.OPTIM.ADAM_BETA1 = 0.9
|
||||||
|
_C.OPTIM.ADAM_BETA2 = 0.999
|
||||||
|
# STAGED_LR allows different layers to have
|
||||||
|
# different lr, e.g. pre-trained base layers
|
||||||
|
# can be assigned a smaller lr than the new
|
||||||
|
# classification layer
|
||||||
|
_C.OPTIM.STAGED_LR = False
|
||||||
|
_C.OPTIM.NEW_LAYERS = ()
|
||||||
|
_C.OPTIM.BASE_LR_MULT = 0.1
|
||||||
|
# Learning rate scheduler
|
||||||
|
_C.OPTIM.LR_SCHEDULER = "single_step"
|
||||||
|
# -1 or 0 means the stepsize is equal to max_epoch
|
||||||
|
_C.OPTIM.STEPSIZE = (-1, )
|
||||||
|
_C.OPTIM.GAMMA = 0.1
|
||||||
|
_C.OPTIM.MAX_EPOCH = 10
|
||||||
|
# Set WARMUP_EPOCH larger than 0 to activate warmup training
|
||||||
|
_C.OPTIM.WARMUP_EPOCH = -1
|
||||||
|
# Either linear or constant
|
||||||
|
_C.OPTIM.WARMUP_TYPE = "linear"
|
||||||
|
# Constant learning rate when type=constant
|
||||||
|
_C.OPTIM.WARMUP_CONS_LR = 1e-5
|
||||||
|
# Minimum learning rate when type=linear
|
||||||
|
_C.OPTIM.WARMUP_MIN_LR = 1e-5
|
||||||
|
# Recount epoch for the next scheduler (last_epoch=-1)
|
||||||
|
# Otherwise last_epoch=warmup_epoch
|
||||||
|
_C.OPTIM.WARMUP_RECOUNT = True
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Train
|
||||||
|
###########################
|
||||||
|
_C.TRAIN = CN()
|
||||||
|
# How often (epoch) to save model during training
|
||||||
|
# Set to 0 or negative value to only save the last one
|
||||||
|
_C.TRAIN.CHECKPOINT_FREQ = 0
|
||||||
|
# How often (batch) to print training information
|
||||||
|
_C.TRAIN.PRINT_FREQ = 10
|
||||||
|
# Use 'train_x', 'train_u' or 'smaller_one' to count
|
||||||
|
# the number of iterations in an epoch (for DA and SSL)
|
||||||
|
_C.TRAIN.COUNT_ITER = "train_x"
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Test
|
||||||
|
###########################
|
||||||
|
_C.TEST = CN()
|
||||||
|
_C.TEST.EVALUATOR = "Classification"
|
||||||
|
_C.TEST.PER_CLASS_RESULT = False
|
||||||
|
# Compute confusion matrix, which will be saved
|
||||||
|
# to $OUTPUT_DIR/cmat.pt
|
||||||
|
_C.TEST.COMPUTE_CMAT = False
|
||||||
|
# If NO_TEST=True, no testing will be conducted
|
||||||
|
_C.TEST.NO_TEST = False
|
||||||
|
# Use test or val set for FINAL evaluation
|
||||||
|
_C.TEST.SPLIT = "test"
|
||||||
|
# Which model to test after training
|
||||||
|
# Either last_step or best_val
|
||||||
|
_C.TEST.FINAL_MODEL = "last_step"
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Trainer specifics
|
||||||
|
###########################
|
||||||
|
_C.TRAINER = CN()
|
||||||
|
_C.TRAINER.NAME = ""
|
||||||
|
|
||||||
|
# MCD
|
||||||
|
_C.TRAINER.MCD = CN()
|
||||||
|
_C.TRAINER.MCD.N_STEP_F = 4 # number of steps to train F
|
||||||
|
# MME
|
||||||
|
_C.TRAINER.MME = CN()
|
||||||
|
_C.TRAINER.MME.LMDA = 0.1 # weight for the entropy loss
|
||||||
|
# SelfEnsembling
|
||||||
|
_C.TRAINER.SE = CN()
|
||||||
|
_C.TRAINER.SE.EMA_ALPHA = 0.999
|
||||||
|
_C.TRAINER.SE.CONF_THRE = 0.95
|
||||||
|
_C.TRAINER.SE.RAMPUP = 300
|
||||||
|
|
||||||
|
# M3SDA
|
||||||
|
_C.TRAINER.M3SDA = CN()
|
||||||
|
_C.TRAINER.M3SDA.LMDA = 0.5 # weight for the moment distance loss
|
||||||
|
_C.TRAINER.M3SDA.N_STEP_F = 4 # follow MCD
|
||||||
|
# DAEL
|
||||||
|
_C.TRAINER.DAEL = CN()
|
||||||
|
_C.TRAINER.DAEL.WEIGHT_U = 0.5 # weight on the unlabeled loss
|
||||||
|
_C.TRAINER.DAEL.CONF_THRE = 0.95 # confidence threshold
|
||||||
|
_C.TRAINER.DAEL.STRONG_TRANSFORMS = ()
|
||||||
|
|
||||||
|
# CrossGrad
|
||||||
|
_C.TRAINER.CG = CN()
|
||||||
|
_C.TRAINER.CG.EPS_F = 1.0 # scaling parameter for D's gradients
|
||||||
|
_C.TRAINER.CG.EPS_D = 1.0 # scaling parameter for F's gradients
|
||||||
|
_C.TRAINER.CG.ALPHA_F = 0.5 # balancing weight for the label net's loss
|
||||||
|
_C.TRAINER.CG.ALPHA_D = 0.5 # balancing weight for the domain net's loss
|
||||||
|
# DDAIG
|
||||||
|
_C.TRAINER.DDAIG = CN()
|
||||||
|
_C.TRAINER.DDAIG.G_ARCH = "" # generator's architecture
|
||||||
|
_C.TRAINER.DDAIG.LMDA = 0.3 # perturbation weight
|
||||||
|
_C.TRAINER.DDAIG.CLAMP = False # clamp perturbation values
|
||||||
|
_C.TRAINER.DDAIG.CLAMP_MIN = -1.0
|
||||||
|
_C.TRAINER.DDAIG.CLAMP_MAX = 1.0
|
||||||
|
_C.TRAINER.DDAIG.WARMUP = 0
|
||||||
|
_C.TRAINER.DDAIG.ALPHA = 0.5 # balancing weight for the losses
|
||||||
|
|
||||||
|
# EntMin
|
||||||
|
_C.TRAINER.ENTMIN = CN()
|
||||||
|
_C.TRAINER.ENTMIN.LMDA = 1e-3 # weight on the entropy loss
|
||||||
|
# Mean Teacher
|
||||||
|
_C.TRAINER.MEANTEA = CN()
|
||||||
|
_C.TRAINER.MEANTEA.WEIGHT_U = 1.0 # weight on the unlabeled loss
|
||||||
|
_C.TRAINER.MEANTEA.EMA_ALPHA = 0.999
|
||||||
|
_C.TRAINER.MEANTEA.RAMPUP = 5 # epochs used to ramp up the loss_u weight
|
||||||
|
# MixMatch
|
||||||
|
_C.TRAINER.MIXMATCH = CN()
|
||||||
|
_C.TRAINER.MIXMATCH.WEIGHT_U = 100.0 # weight on the unlabeled loss
|
||||||
|
_C.TRAINER.MIXMATCH.TEMP = 2.0 # temperature for sharpening the probability
|
||||||
|
_C.TRAINER.MIXMATCH.MIXUP_BETA = 0.75
|
||||||
|
_C.TRAINER.MIXMATCH.RAMPUP = 20000 # steps used to ramp up the loss_u weight
|
||||||
|
# FixMatch
|
||||||
|
_C.TRAINER.FIXMATCH = CN()
|
||||||
|
_C.TRAINER.FIXMATCH.WEIGHT_U = 1.0 # weight on the unlabeled loss
|
||||||
|
_C.TRAINER.FIXMATCH.CONF_THRE = 0.95 # confidence threshold
|
||||||
|
_C.TRAINER.FIXMATCH.STRONG_TRANSFORMS = ()
|
||||||
1
Dassl.ProGrad.pytorch/dassl/data/__init__.py
Normal file
1
Dassl.ProGrad.pytorch/dassl/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .data_manager import DataManager, DatasetWrapper
|
||||||
264
Dassl.ProGrad.pytorch/dassl/data/data_manager.py
Normal file
264
Dassl.ProGrad.pytorch/dassl/data/data_manager.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset as TorchDataset
|
||||||
|
|
||||||
|
from dassl.utils import read_image
|
||||||
|
|
||||||
|
from .datasets import build_dataset
|
||||||
|
from .samplers import build_sampler
|
||||||
|
from .transforms import build_transform
|
||||||
|
|
||||||
|
INTERPOLATION_MODES = {
|
||||||
|
"bilinear": Image.BILINEAR,
|
||||||
|
"bicubic": Image.BICUBIC,
|
||||||
|
"nearest": Image.NEAREST,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_data_loader(
|
||||||
|
cfg,
|
||||||
|
sampler_type="SequentialSampler",
|
||||||
|
data_source=None,
|
||||||
|
batch_size=64,
|
||||||
|
n_domain=0,
|
||||||
|
n_ins=2,
|
||||||
|
tfm=None,
|
||||||
|
is_train=True,
|
||||||
|
dataset_wrapper=None,
|
||||||
|
):
|
||||||
|
# Build sampler
|
||||||
|
sampler = build_sampler(
|
||||||
|
sampler_type,
|
||||||
|
cfg=cfg,
|
||||||
|
data_source=data_source,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_domain=n_domain,
|
||||||
|
n_ins=n_ins,
|
||||||
|
)
|
||||||
|
|
||||||
|
if dataset_wrapper is None:
|
||||||
|
dataset_wrapper = DatasetWrapper
|
||||||
|
|
||||||
|
# Build data loader
|
||||||
|
data_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
||||||
|
drop_last=is_train and len(data_source) >= batch_size,
|
||||||
|
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
|
||||||
|
)
|
||||||
|
assert len(data_loader) > 0
|
||||||
|
|
||||||
|
return data_loader
|
||||||
|
|
||||||
|
|
||||||
|
class DataManager:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
custom_tfm_train=None,
|
||||||
|
custom_tfm_test=None,
|
||||||
|
dataset_wrapper=None
|
||||||
|
):
|
||||||
|
# Load dataset
|
||||||
|
dataset = build_dataset(cfg)
|
||||||
|
# Build transform
|
||||||
|
if custom_tfm_train is None:
|
||||||
|
tfm_train = build_transform(cfg, is_train=True)
|
||||||
|
else:
|
||||||
|
print("* Using custom transform for training")
|
||||||
|
tfm_train = custom_tfm_train
|
||||||
|
|
||||||
|
if custom_tfm_test is None:
|
||||||
|
tfm_test = build_transform(cfg, is_train=False)
|
||||||
|
else:
|
||||||
|
print("* Using custom transform for testing")
|
||||||
|
tfm_test = custom_tfm_test
|
||||||
|
|
||||||
|
# Build train_loader_x
|
||||||
|
train_loader_x = build_data_loader(
|
||||||
|
cfg,
|
||||||
|
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||||
|
data_source=dataset.train_x,
|
||||||
|
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
|
||||||
|
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||||
|
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||||
|
tfm=tfm_train,
|
||||||
|
is_train=True,
|
||||||
|
dataset_wrapper=dataset_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build train_loader_u
|
||||||
|
train_loader_u = None
|
||||||
|
if dataset.train_u:
|
||||||
|
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
|
||||||
|
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
|
||||||
|
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
|
||||||
|
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
|
||||||
|
|
||||||
|
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
|
||||||
|
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
|
||||||
|
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||||
|
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||||
|
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
|
||||||
|
|
||||||
|
train_loader_u = build_data_loader(
|
||||||
|
cfg,
|
||||||
|
sampler_type=sampler_type_,
|
||||||
|
data_source=dataset.train_u,
|
||||||
|
batch_size=batch_size_,
|
||||||
|
n_domain=n_domain_,
|
||||||
|
n_ins=n_ins_,
|
||||||
|
tfm=tfm_train,
|
||||||
|
is_train=True,
|
||||||
|
dataset_wrapper=dataset_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build val_loader
|
||||||
|
val_loader = None
|
||||||
|
if dataset.val:
|
||||||
|
val_loader = build_data_loader(
|
||||||
|
cfg,
|
||||||
|
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||||
|
data_source=dataset.val,
|
||||||
|
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||||
|
tfm=tfm_test,
|
||||||
|
is_train=False,
|
||||||
|
dataset_wrapper=dataset_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build test_loader
|
||||||
|
test_loader = build_data_loader(
|
||||||
|
cfg,
|
||||||
|
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||||
|
data_source=dataset.test,
|
||||||
|
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||||
|
tfm=tfm_test,
|
||||||
|
is_train=False,
|
||||||
|
dataset_wrapper=dataset_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attributes
|
||||||
|
self._num_classes = dataset.num_classes
|
||||||
|
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
|
||||||
|
self._lab2cname = dataset.lab2cname
|
||||||
|
|
||||||
|
# Dataset and data-loaders
|
||||||
|
self.dataset = dataset
|
||||||
|
self.train_loader_x = train_loader_x
|
||||||
|
self.train_loader_u = train_loader_u
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.test_loader = test_loader
|
||||||
|
|
||||||
|
if cfg.VERBOSE:
|
||||||
|
self.show_dataset_summary(cfg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return self._num_classes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_source_domains(self):
|
||||||
|
return self._num_source_domains
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lab2cname(self):
|
||||||
|
return self._lab2cname
|
||||||
|
|
||||||
|
def show_dataset_summary(self, cfg):
|
||||||
|
print("***** Dataset statistics *****")
|
||||||
|
|
||||||
|
print(" Dataset: {}".format(cfg.DATASET.NAME))
|
||||||
|
|
||||||
|
if cfg.DATASET.SOURCE_DOMAINS:
|
||||||
|
print(" Source domains: {}".format(cfg.DATASET.SOURCE_DOMAINS))
|
||||||
|
if cfg.DATASET.TARGET_DOMAINS:
|
||||||
|
print(" Target domains: {}".format(cfg.DATASET.TARGET_DOMAINS))
|
||||||
|
|
||||||
|
print(" # classes: {:,}".format(self.num_classes))
|
||||||
|
|
||||||
|
print(" # train_x: {:,}".format(len(self.dataset.train_x)))
|
||||||
|
|
||||||
|
if self.dataset.train_u:
|
||||||
|
print(" # train_u: {:,}".format(len(self.dataset.train_u)))
|
||||||
|
|
||||||
|
if self.dataset.val:
|
||||||
|
print(" # val: {:,}".format(len(self.dataset.val)))
|
||||||
|
|
||||||
|
print(" # test: {:,}".format(len(self.dataset.test)))
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetWrapper(TorchDataset):
|
||||||
|
|
||||||
|
def __init__(self, cfg, data_source, transform=None, is_train=False):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.data_source = data_source
|
||||||
|
self.transform = transform # accept list (tuple) as input
|
||||||
|
self.is_train = is_train
|
||||||
|
# Augmenting an image K>1 times is only allowed during training
|
||||||
|
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
||||||
|
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
||||||
|
|
||||||
|
if self.k_tfm > 1 and transform is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot augment the image {} times "
|
||||||
|
"because transform is None".format(self.k_tfm)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build transform that doesn't apply any data augmentation
|
||||||
|
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||||
|
to_tensor = []
|
||||||
|
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||||
|
to_tensor += [T.ToTensor()]
|
||||||
|
if "normalize" in cfg.INPUT.TRANSFORMS:
|
||||||
|
normalize = T.Normalize(
|
||||||
|
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
||||||
|
)
|
||||||
|
to_tensor += [normalize]
|
||||||
|
self.to_tensor = T.Compose(to_tensor)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data_source)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = self.data_source[idx]
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"label": item.label,
|
||||||
|
"domain": item.domain,
|
||||||
|
"impath": item.impath
|
||||||
|
}
|
||||||
|
|
||||||
|
img0 = read_image(item.impath)
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
if isinstance(self.transform, (list, tuple)):
|
||||||
|
for i, tfm in enumerate(self.transform):
|
||||||
|
img = self._transform_image(tfm, img0)
|
||||||
|
keyname = "img"
|
||||||
|
if (i + 1) > 1:
|
||||||
|
keyname += str(i + 1)
|
||||||
|
output[keyname] = img
|
||||||
|
else:
|
||||||
|
img = self._transform_image(self.transform, img0)
|
||||||
|
output["img"] = img
|
||||||
|
|
||||||
|
if self.return_img0:
|
||||||
|
output["img0"] = self.to_tensor(img0)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _transform_image(self, tfm, img0):
|
||||||
|
img_list = []
|
||||||
|
|
||||||
|
for k in range(self.k_tfm):
|
||||||
|
img_list.append(tfm(img0))
|
||||||
|
|
||||||
|
img = img_list
|
||||||
|
if len(img) == 1:
|
||||||
|
img = img[0]
|
||||||
|
|
||||||
|
return img
|
||||||
6
Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py
Normal file
6
Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .build import DATASET_REGISTRY, build_dataset # isort:skip
|
||||||
|
from .base_dataset import Datum, DatasetBase # isort:skip
|
||||||
|
|
||||||
|
from .da import *
|
||||||
|
from .dg import *
|
||||||
|
from .ssl import *
|
||||||
225
Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py
Normal file
225
Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import os.path as osp
|
||||||
|
import tarfile
|
||||||
|
import zipfile
|
||||||
|
from collections import defaultdict
|
||||||
|
import gdown
|
||||||
|
|
||||||
|
from dassl.utils import check_isfile
|
||||||
|
|
||||||
|
|
||||||
|
class Datum:
|
||||||
|
"""Data instance which defines the basic attributes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
impath (str): image path.
|
||||||
|
label (int): class label.
|
||||||
|
domain (int): domain label.
|
||||||
|
classname (str): class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, impath="", label=0, domain=0, classname=""):
|
||||||
|
assert isinstance(impath, str)
|
||||||
|
assert check_isfile(impath)
|
||||||
|
|
||||||
|
self._impath = impath
|
||||||
|
self._label = label
|
||||||
|
self._domain = domain
|
||||||
|
self._classname = classname
|
||||||
|
|
||||||
|
@property
|
||||||
|
def impath(self):
|
||||||
|
return self._impath
|
||||||
|
|
||||||
|
@property
|
||||||
|
def label(self):
|
||||||
|
return self._label
|
||||||
|
|
||||||
|
@property
|
||||||
|
def domain(self):
|
||||||
|
return self._domain
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classname(self):
|
||||||
|
return self._classname
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetBase:
|
||||||
|
"""A unified dataset class for
|
||||||
|
1) domain adaptation
|
||||||
|
2) domain generalization
|
||||||
|
3) semi-supervised learning
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "" # the directory where the dataset is stored
|
||||||
|
domains = [] # string names of all domains
|
||||||
|
|
||||||
|
def __init__(self, train_x=None, train_u=None, val=None, test=None):
|
||||||
|
self._train_x = train_x # labeled training data
|
||||||
|
self._train_u = train_u # unlabeled training data (optional)
|
||||||
|
self._val = val # validation data (optional)
|
||||||
|
self._test = test # test data
|
||||||
|
|
||||||
|
self._num_classes = self.get_num_classes(train_x)
|
||||||
|
self._lab2cname, self._classnames = self.get_lab2cname(train_x)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_x(self):
|
||||||
|
return self._train_x
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_u(self):
|
||||||
|
return self._train_u
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val(self):
|
||||||
|
return self._val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test(self):
|
||||||
|
return self._test
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lab2cname(self):
|
||||||
|
return self._lab2cname
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classnames(self):
|
||||||
|
return self._classnames
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return self._num_classes
|
||||||
|
|
||||||
|
def get_num_classes(self, data_source):
|
||||||
|
"""Count number of classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source (list): a list of Datum objects.
|
||||||
|
"""
|
||||||
|
label_set = set()
|
||||||
|
for item in data_source:
|
||||||
|
label_set.add(item.label)
|
||||||
|
return max(label_set) + 1
|
||||||
|
|
||||||
|
def get_lab2cname(self, data_source):
|
||||||
|
"""Get a label-to-classname mapping (dict).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source (list): a list of Datum objects.
|
||||||
|
"""
|
||||||
|
container = set()
|
||||||
|
for item in data_source:
|
||||||
|
container.add((item.label, item.classname))
|
||||||
|
mapping = {label: classname for label, classname in container}
|
||||||
|
labels = list(mapping.keys())
|
||||||
|
labels.sort()
|
||||||
|
classnames = [mapping[label] for label in labels]
|
||||||
|
return mapping, classnames
|
||||||
|
|
||||||
|
def check_input_domains(self, source_domains, target_domains):
|
||||||
|
self.is_input_domain_valid(source_domains)
|
||||||
|
self.is_input_domain_valid(target_domains)
|
||||||
|
|
||||||
|
def is_input_domain_valid(self, input_domains):
|
||||||
|
for domain in input_domains:
|
||||||
|
if domain not in self.domains:
|
||||||
|
raise ValueError(
|
||||||
|
"Input domain must belong to {}, "
|
||||||
|
"but got [{}]".format(self.domains, domain)
|
||||||
|
)
|
||||||
|
|
||||||
|
def download_data(self, url, dst, from_gdrive=True):
|
||||||
|
if not osp.exists(osp.dirname(dst)):
|
||||||
|
os.makedirs(osp.dirname(dst))
|
||||||
|
|
||||||
|
if from_gdrive:
|
||||||
|
gdown.download(url, dst, quiet=False)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
print("Extracting file ...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
tar = tarfile.open(dst)
|
||||||
|
tar.extractall(path=osp.dirname(dst))
|
||||||
|
tar.close()
|
||||||
|
except:
|
||||||
|
zip_ref = zipfile.ZipFile(dst, "r")
|
||||||
|
zip_ref.extractall(osp.dirname(dst))
|
||||||
|
zip_ref.close()
|
||||||
|
|
||||||
|
print("File extracted to {}".format(osp.dirname(dst)))
|
||||||
|
|
||||||
|
def generate_fewshot_dataset(
|
||||||
|
self, *data_sources, num_shots=-1, repeat=False
|
||||||
|
):
|
||||||
|
"""Generate a few-shot dataset (typically for the training set).
|
||||||
|
|
||||||
|
This function is useful when one wants to evaluate a model
|
||||||
|
in a few-shot learning setting where each class only contains
|
||||||
|
a few number of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_sources: each individual is a list containing Datum objects.
|
||||||
|
num_shots (int): number of instances per class to sample.
|
||||||
|
repeat (bool): repeat images if needed (default: False).
|
||||||
|
"""
|
||||||
|
if num_shots < 1:
|
||||||
|
if len(data_sources) == 1:
|
||||||
|
return data_sources[0]
|
||||||
|
return data_sources
|
||||||
|
|
||||||
|
print(f"Creating a {num_shots}-shot dataset")
|
||||||
|
|
||||||
|
output = []
|
||||||
|
|
||||||
|
for data_source in data_sources:
|
||||||
|
tracker = self.split_dataset_by_label(data_source)
|
||||||
|
dataset = []
|
||||||
|
|
||||||
|
for label, items in tracker.items():
|
||||||
|
if len(items) >= num_shots:
|
||||||
|
sampled_items = random.sample(items, num_shots)
|
||||||
|
else:
|
||||||
|
if repeat:
|
||||||
|
sampled_items = random.choices(items, k=num_shots)
|
||||||
|
else:
|
||||||
|
sampled_items = items
|
||||||
|
dataset.extend(sampled_items)
|
||||||
|
|
||||||
|
output.append(dataset)
|
||||||
|
|
||||||
|
if len(output) == 1:
|
||||||
|
return output[0]
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def split_dataset_by_label(self, data_source):
|
||||||
|
"""Split a dataset, i.e. a list of Datum objects,
|
||||||
|
into class-specific groups stored in a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source (list): a list of Datum objects.
|
||||||
|
"""
|
||||||
|
output = defaultdict(list)
|
||||||
|
|
||||||
|
for item in data_source:
|
||||||
|
output[item.label].append(item)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def split_dataset_by_domain(self, data_source):
|
||||||
|
"""Split a dataset, i.e. a list of Datum objects,
|
||||||
|
into domain-specific groups stored in a dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source (list): a list of Datum objects.
|
||||||
|
"""
|
||||||
|
output = defaultdict(list)
|
||||||
|
|
||||||
|
for item in data_source:
|
||||||
|
output[item.domain].append(item)
|
||||||
|
|
||||||
|
return output
|
||||||
11
Dassl.ProGrad.pytorch/dassl/data/datasets/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/data/datasets/build.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from dassl.utils import Registry, check_availability
|
||||||
|
|
||||||
|
DATASET_REGISTRY = Registry("DATASET")
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(cfg):
|
||||||
|
avai_datasets = DATASET_REGISTRY.registered_names()
|
||||||
|
check_availability(cfg.DATASET.NAME, avai_datasets)
|
||||||
|
if cfg.VERBOSE:
|
||||||
|
print("Loading dataset: {}".format(cfg.DATASET.NAME))
|
||||||
|
return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg)
|
||||||
7
Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
Normal file
7
Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from .digit5 import Digit5
|
||||||
|
from .visda17 import VisDA17
|
||||||
|
from .cifarstl import CIFARSTL
|
||||||
|
from .office31 import Office31
|
||||||
|
from .domainnet import DomainNet
|
||||||
|
from .office_home import OfficeHome
|
||||||
|
from .mini_domainnet import miniDomainNet
|
||||||
68
Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
Normal file
68
Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class CIFARSTL(DatasetBase):
|
||||||
|
"""CIFAR-10 and STL-10.
|
||||||
|
|
||||||
|
CIFAR-10:
|
||||||
|
- 60,000 32x32 colour images.
|
||||||
|
- 10 classes, with 6,000 images per class.
|
||||||
|
- 50,000 training images and 10,000 test images.
|
||||||
|
- URL: https://www.cs.toronto.edu/~kriz/cifar.html.
|
||||||
|
|
||||||
|
STL-10:
|
||||||
|
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
|
||||||
|
monkey, ship, truck.
|
||||||
|
- Images are 96x96 pixels, color.
|
||||||
|
- 500 training images (10 pre-defined folds), 800 test images
|
||||||
|
per class.
|
||||||
|
- URL: https://cs.stanford.edu/~acoates/stl10/.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Krizhevsky. Learning Multiple Layers of Features
|
||||||
|
from Tiny Images. Tech report.
|
||||||
|
- Coates et al. An Analysis of Single Layer Networks in
|
||||||
|
Unsupervised Feature Learning. AISTATS 2011.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "cifar_stl"
|
||||||
|
domains = ["cifar", "stl"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||||
|
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||||
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, input_domains, split="train"):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
data_dir = osp.join(self.dataset_dir, dname, split)
|
||||||
|
class_names = listdir_nohidden(data_dir)
|
||||||
|
|
||||||
|
for class_name in class_names:
|
||||||
|
class_dir = osp.join(data_dir, class_name)
|
||||||
|
imnames = listdir_nohidden(class_dir)
|
||||||
|
label = int(class_name.split("_")[0])
|
||||||
|
|
||||||
|
for imname in imnames:
|
||||||
|
impath = osp.join(class_dir, imname)
|
||||||
|
item = Datum(impath=impath, label=label, domain=domain)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
124
Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import random
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
# Folder names for train and test sets
|
||||||
|
MNIST = {"train": "train_images", "test": "test_images"}
|
||||||
|
MNIST_M = {"train": "train_images", "test": "test_images"}
|
||||||
|
SVHN = {"train": "train_images", "test": "test_images"}
|
||||||
|
SYN = {"train": "train_images", "test": "test_images"}
|
||||||
|
USPS = {"train": "train_images", "test": "test_images"}
|
||||||
|
|
||||||
|
|
||||||
|
def read_image_list(im_dir, n_max=None, n_repeat=None):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for imname in listdir_nohidden(im_dir):
|
||||||
|
imname_noext = osp.splitext(imname)[0]
|
||||||
|
label = int(imname_noext.split("_")[1])
|
||||||
|
impath = osp.join(im_dir, imname)
|
||||||
|
items.append((impath, label))
|
||||||
|
|
||||||
|
if n_max is not None:
|
||||||
|
items = random.sample(items, n_max)
|
||||||
|
|
||||||
|
if n_repeat is not None:
|
||||||
|
items *= n_repeat
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def load_mnist(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, MNIST[split])
|
||||||
|
n_max = 25000 if split == "train" else 9000
|
||||||
|
return read_image_list(data_dir, n_max=n_max)
|
||||||
|
|
||||||
|
|
||||||
|
def load_mnist_m(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, MNIST_M[split])
|
||||||
|
n_max = 25000 if split == "train" else 9000
|
||||||
|
return read_image_list(data_dir, n_max=n_max)
|
||||||
|
|
||||||
|
|
||||||
|
def load_svhn(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, SVHN[split])
|
||||||
|
n_max = 25000 if split == "train" else 9000
|
||||||
|
return read_image_list(data_dir, n_max=n_max)
|
||||||
|
|
||||||
|
|
||||||
|
def load_syn(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, SYN[split])
|
||||||
|
n_max = 25000 if split == "train" else 9000
|
||||||
|
return read_image_list(data_dir, n_max=n_max)
|
||||||
|
|
||||||
|
|
||||||
|
def load_usps(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, USPS[split])
|
||||||
|
n_repeat = 3 if split == "train" else None
|
||||||
|
return read_image_list(data_dir, n_repeat=n_repeat)
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class Digit5(DatasetBase):
|
||||||
|
"""Five digit datasets.
|
||||||
|
|
||||||
|
It contains:
|
||||||
|
- MNIST: hand-written digits.
|
||||||
|
- MNIST-M: variant of MNIST with blended background.
|
||||||
|
- SVHN: street view house number.
|
||||||
|
- SYN: synthetic digits.
|
||||||
|
- USPS: hand-written digits, slightly different from MNIST.
|
||||||
|
|
||||||
|
For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from
|
||||||
|
the training set and 9,000 images from the test set. For USPS which has only
|
||||||
|
9,298 images in total, we use the entire dataset but replicate its training
|
||||||
|
set for 3 times so as to match the training set size of other domains.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Lecun et al. Gradient-based learning applied to document
|
||||||
|
recognition. IEEE 1998.
|
||||||
|
- Ganin et al. Domain-adversarial training of neural networks.
|
||||||
|
JMLR 2016.
|
||||||
|
- Netzer et al. Reading digits in natural images with unsupervised
|
||||||
|
feature learning. NIPS-W 2011.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "digit5"
|
||||||
|
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||||
|
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||||
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, input_domains, split="train"):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
func = "load_" + dname
|
||||||
|
domain_dir = osp.join(self.dataset_dir, dname)
|
||||||
|
items_d = eval(func)(domain_dir, split=split)
|
||||||
|
|
||||||
|
for impath, label in items_d:
|
||||||
|
item = Datum(
|
||||||
|
impath=impath,
|
||||||
|
label=label,
|
||||||
|
domain=domain,
|
||||||
|
classname=str(label)
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
69
Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
Normal file
69
Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class DomainNet(DatasetBase):
|
||||||
|
"""DomainNet.
|
||||||
|
|
||||||
|
Statistics:
|
||||||
|
- 6 distinct domains: Clipart, Infograph, Painting, Quickdraw,
|
||||||
|
Real, Sketch.
|
||||||
|
- Around 0.6M images.
|
||||||
|
- 345 categories.
|
||||||
|
- URL: http://ai.bu.edu/M3SDA/.
|
||||||
|
|
||||||
|
Special note: the t-shirt class (327) is missing in painting_train.txt.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Peng et al. Moment Matching for Multi-Source Domain
|
||||||
|
Adaptation. ICCV 2019.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "domainnet"
|
||||||
|
domains = [
|
||||||
|
"clipart", "infograph", "painting", "quickdraw", "real", "sketch"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
self.split_dir = osp.join(self.dataset_dir, "splits")
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||||
|
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||||
|
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
|
||||||
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, input_domains, split="train"):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
filename = dname + "_" + split + ".txt"
|
||||||
|
split_file = osp.join(self.split_dir, filename)
|
||||||
|
|
||||||
|
with open(split_file, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
impath, label = line.split(" ")
|
||||||
|
classname = impath.split("/")[1]
|
||||||
|
impath = osp.join(self.dataset_dir, impath)
|
||||||
|
label = int(label)
|
||||||
|
item = Datum(
|
||||||
|
impath=impath,
|
||||||
|
label=label,
|
||||||
|
domain=domain,
|
||||||
|
classname=classname
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class miniDomainNet(DatasetBase):
|
||||||
|
"""A subset of DomainNet.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Peng et al. Moment Matching for Multi-Source Domain
|
||||||
|
Adaptation. ICCV 2019.
|
||||||
|
- Zhou et al. Domain Adaptive Ensemble Learning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "domainnet"
|
||||||
|
domains = ["clipart", "painting", "real", "sketch"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
self.split_dir = osp.join(self.dataset_dir, "splits_mini")
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||||
|
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
|
||||||
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, input_domains, split="train"):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
filename = dname + "_" + split + ".txt"
|
||||||
|
split_file = osp.join(self.split_dir, filename)
|
||||||
|
|
||||||
|
with open(split_file, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
impath, label = line.split(" ")
|
||||||
|
classname = impath.split("/")[1]
|
||||||
|
impath = osp.join(self.dataset_dir, impath)
|
||||||
|
label = int(label)
|
||||||
|
item = Datum(
|
||||||
|
impath=impath,
|
||||||
|
label=label,
|
||||||
|
domain=domain,
|
||||||
|
classname=classname
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
Normal file
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class Office31(DatasetBase):
|
||||||
|
"""Office-31.
|
||||||
|
|
||||||
|
Statistics:
|
||||||
|
- 4,110 images.
|
||||||
|
- 31 classes related to office objects.
|
||||||
|
- 3 domains: Amazon, Webcam, Dslr.
|
||||||
|
- URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Saenko et al. Adapting visual category models to
|
||||||
|
new domains. ECCV 2010.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "office31"
|
||||||
|
domains = ["amazon", "webcam", "dslr"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
|
||||||
|
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||||
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, input_domains):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
domain_dir = osp.join(self.dataset_dir, dname)
|
||||||
|
class_names = listdir_nohidden(domain_dir)
|
||||||
|
class_names.sort()
|
||||||
|
|
||||||
|
for label, class_name in enumerate(class_names):
|
||||||
|
class_path = osp.join(domain_dir, class_name)
|
||||||
|
imnames = listdir_nohidden(class_path)
|
||||||
|
|
||||||
|
for imname in imnames:
|
||||||
|
impath = osp.join(class_path, imname)
|
||||||
|
item = Datum(
|
||||||
|
impath=impath,
|
||||||
|
label=label,
|
||||||
|
domain=domain,
|
||||||
|
classname=class_name
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
Normal file
63
Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class OfficeHome(DatasetBase):
|
||||||
|
"""Office-Home.
|
||||||
|
|
||||||
|
Statistics:
|
||||||
|
- Around 15,500 images.
|
||||||
|
- 65 classes related to office and home objects.
|
||||||
|
- 4 domains: Art, Clipart, Product, Real World.
|
||||||
|
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Venkateswara et al. Deep Hashing Network for Unsupervised
|
||||||
|
Domain Adaptation. CVPR 2017.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "office_home"
|
||||||
|
domains = ["art", "clipart", "product", "real_world"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
|
||||||
|
train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||||
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS)
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, input_domains):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
domain_dir = osp.join(self.dataset_dir, dname)
|
||||||
|
class_names = listdir_nohidden(domain_dir)
|
||||||
|
class_names.sort()
|
||||||
|
|
||||||
|
for label, class_name in enumerate(class_names):
|
||||||
|
class_path = osp.join(domain_dir, class_name)
|
||||||
|
imnames = listdir_nohidden(class_path)
|
||||||
|
|
||||||
|
for imname in imnames:
|
||||||
|
impath = osp.join(class_path, imname)
|
||||||
|
item = Datum(
|
||||||
|
impath=impath,
|
||||||
|
label=label,
|
||||||
|
domain=domain,
|
||||||
|
classname=class_name.lower(),
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
61
Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
Normal file
61
Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class VisDA17(DatasetBase):
|
||||||
|
"""VisDA17.
|
||||||
|
|
||||||
|
Focusing on simulation-to-reality domain shift.
|
||||||
|
|
||||||
|
URL: http://ai.bu.edu/visda-2017/.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Peng et al. VisDA: The Visual Domain Adaptation
|
||||||
|
Challenge. ArXiv 2017.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "visda17"
|
||||||
|
domains = ["synthetic", "real"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train_x = self._read_data("synthetic")
|
||||||
|
train_u = self._read_data("real")
|
||||||
|
test = self._read_data("real")
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, dname):
|
||||||
|
filedir = "train" if dname == "synthetic" else "validation"
|
||||||
|
image_list = osp.join(self.dataset_dir, filedir, "image_list.txt")
|
||||||
|
items = []
|
||||||
|
# There is only one source domain
|
||||||
|
domain = 0
|
||||||
|
|
||||||
|
with open(image_list, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
impath, label = line.split(" ")
|
||||||
|
classname = impath.split("/")[0]
|
||||||
|
impath = osp.join(self.dataset_dir, filedir, impath)
|
||||||
|
label = int(label)
|
||||||
|
item = Datum(
|
||||||
|
impath=impath,
|
||||||
|
label=label,
|
||||||
|
domain=domain,
|
||||||
|
classname=classname
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
6
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py
Normal file
6
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .pacs import PACS
|
||||||
|
from .vlcs import VLCS
|
||||||
|
from .cifar_c import CIFAR10C, CIFAR100C
|
||||||
|
from .digits_dg import DigitsDG
|
||||||
|
from .digit_single import DigitSingle
|
||||||
|
from .office_home_dg import OfficeHomeDG
|
||||||
123
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py
Normal file
123
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
AVAI_C_TYPES = [
|
||||||
|
"brightness",
|
||||||
|
"contrast",
|
||||||
|
"defocus_blur",
|
||||||
|
"elastic_transform",
|
||||||
|
"fog",
|
||||||
|
"frost",
|
||||||
|
"gaussian_blur",
|
||||||
|
"gaussian_noise",
|
||||||
|
"glass_blur",
|
||||||
|
"impulse_noise",
|
||||||
|
"jpeg_compression",
|
||||||
|
"motion_blur",
|
||||||
|
"pixelate",
|
||||||
|
"saturate",
|
||||||
|
"shot_noise",
|
||||||
|
"snow",
|
||||||
|
"spatter",
|
||||||
|
"speckle_noise",
|
||||||
|
"zoom_blur",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class CIFAR10C(DatasetBase):
|
||||||
|
"""CIFAR-10 -> CIFAR-10-C.
|
||||||
|
|
||||||
|
Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o
|
||||||
|
|
||||||
|
Statistics:
|
||||||
|
- 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10
|
||||||
|
- 10 categories
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Hendrycks et al. Benchmarking neural network robustness
|
||||||
|
to common corruptions and perturbations. ICLR 2019.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = ""
|
||||||
|
domains = ["cifar10", "cifar10_c"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = root
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
source_domain = cfg.DATASET.SOURCE_DOMAINS[0]
|
||||||
|
target_domain = cfg.DATASET.TARGET_DOMAINS[0]
|
||||||
|
assert source_domain == self.domains[0]
|
||||||
|
assert target_domain == self.domains[1]
|
||||||
|
|
||||||
|
c_type = cfg.DATASET.CIFAR_C_TYPE
|
||||||
|
c_level = cfg.DATASET.CIFAR_C_LEVEL
|
||||||
|
|
||||||
|
if not c_type:
|
||||||
|
raise ValueError(
|
||||||
|
"Please specify DATASET.CIFAR_C_TYPE in the config file"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
c_type in AVAI_C_TYPES
|
||||||
|
), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"'
|
||||||
|
assert 1 <= c_level <= 5
|
||||||
|
|
||||||
|
train_dir = osp.join(self.dataset_dir, source_domain, "train")
|
||||||
|
test_dir = osp.join(
|
||||||
|
self.dataset_dir, target_domain, c_type, str(c_level)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not osp.exists(test_dir):
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
train = self._read_data(train_dir)
|
||||||
|
test = self._read_data(test_dir)
|
||||||
|
|
||||||
|
super().__init__(train_x=train, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, data_dir):
|
||||||
|
class_names = listdir_nohidden(data_dir)
|
||||||
|
class_names.sort()
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for label, class_name in enumerate(class_names):
|
||||||
|
class_dir = osp.join(data_dir, class_name)
|
||||||
|
imnames = listdir_nohidden(class_dir)
|
||||||
|
|
||||||
|
for imname in imnames:
|
||||||
|
impath = osp.join(class_dir, imname)
|
||||||
|
item = Datum(impath=impath, label=label, domain=0)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class CIFAR100C(CIFAR10C):
|
||||||
|
"""CIFAR-100 -> CIFAR-100-C.
|
||||||
|
|
||||||
|
Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o
|
||||||
|
|
||||||
|
Statistics:
|
||||||
|
- 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100
|
||||||
|
- 10 categories
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Hendrycks et al. Benchmarking neural network robustness
|
||||||
|
to common corruptions and perturbations. ICLR 2019.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = ""
|
||||||
|
domains = ["cifar100", "cifar100_c"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
124
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py
Normal file
124
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
# Folder names for train and test sets
|
||||||
|
MNIST = {"train": "train_images", "test": "test_images"}
|
||||||
|
MNIST_M = {"train": "train_images", "test": "test_images"}
|
||||||
|
SVHN = {"train": "train_images", "test": "test_images"}
|
||||||
|
SYN = {"train": "train_images", "test": "test_images"}
|
||||||
|
USPS = {"train": "train_images", "test": "test_images"}
|
||||||
|
|
||||||
|
|
||||||
|
def read_image_list(im_dir, n_max=None, n_repeat=None):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for imname in listdir_nohidden(im_dir):
|
||||||
|
imname_noext = osp.splitext(imname)[0]
|
||||||
|
label = int(imname_noext.split("_")[1])
|
||||||
|
impath = osp.join(im_dir, imname)
|
||||||
|
items.append((impath, label))
|
||||||
|
|
||||||
|
if n_max is not None:
|
||||||
|
# Note that the sampling process is NOT random,
|
||||||
|
# which follows that in Volpi et al. NIPS'18.
|
||||||
|
items = items[:n_max]
|
||||||
|
|
||||||
|
if n_repeat is not None:
|
||||||
|
items *= n_repeat
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def load_mnist(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, MNIST[split])
|
||||||
|
n_max = 10000 if split == "train" else None
|
||||||
|
return read_image_list(data_dir, n_max=n_max)
|
||||||
|
|
||||||
|
|
||||||
|
def load_mnist_m(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, MNIST_M[split])
|
||||||
|
n_max = 10000 if split == "train" else None
|
||||||
|
return read_image_list(data_dir, n_max=n_max)
|
||||||
|
|
||||||
|
|
||||||
|
def load_svhn(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, SVHN[split])
|
||||||
|
n_max = 10000 if split == "train" else None
|
||||||
|
return read_image_list(data_dir, n_max=n_max)
|
||||||
|
|
||||||
|
|
||||||
|
def load_syn(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, SYN[split])
|
||||||
|
n_max = 10000 if split == "train" else None
|
||||||
|
return read_image_list(data_dir, n_max=n_max)
|
||||||
|
|
||||||
|
|
||||||
|
def load_usps(dataset_dir, split="train"):
|
||||||
|
data_dir = osp.join(dataset_dir, USPS[split])
|
||||||
|
return read_image_list(data_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class DigitSingle(DatasetBase):
|
||||||
|
"""Digit recognition datasets for single-source domain generalization.
|
||||||
|
|
||||||
|
There are five digit datasets:
|
||||||
|
- MNIST: hand-written digits.
|
||||||
|
- MNIST-M: variant of MNIST with blended background.
|
||||||
|
- SVHN: street view house number.
|
||||||
|
- SYN: synthetic digits.
|
||||||
|
- USPS: hand-written digits, slightly different from MNIST.
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
Volpi et al. train a model using 10,000 images from MNIST and
|
||||||
|
evaluate the model on the test split of the other four datasets. However,
|
||||||
|
the code does not restrict you to only use MNIST as the source dataset.
|
||||||
|
Instead, you can use any dataset as the source. But note that only 10,000
|
||||||
|
images will be sampled from the source dataset for training.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Lecun et al. Gradient-based learning applied to document
|
||||||
|
recognition. IEEE 1998.
|
||||||
|
- Ganin et al. Domain-adversarial training of neural networks.
|
||||||
|
JMLR 2016.
|
||||||
|
- Netzer et al. Reading digits in natural images with unsupervised
|
||||||
|
feature learning. NIPS-W 2011.
|
||||||
|
- Volpi et al. Generalizing to Unseen Domains via Adversarial Data
|
||||||
|
Augmentation. NIPS 2018.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Reuse the digit-5 folder instead of creating a new folder
|
||||||
|
dataset_dir = "digit5"
|
||||||
|
domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
|
||||||
|
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
|
||||||
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")
|
||||||
|
|
||||||
|
super().__init__(train_x=train, val=val, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, input_domains, split="train"):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
func = "load_" + dname
|
||||||
|
domain_dir = osp.join(self.dataset_dir, dname)
|
||||||
|
items_d = eval(func)(domain_dir, split=split)
|
||||||
|
|
||||||
|
for impath, label in items_d:
|
||||||
|
item = Datum(impath=impath, label=label, domain=domain)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
97
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py
Normal file
97
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import glob
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class DigitsDG(DatasetBase):
|
||||||
|
"""Digits-DG.
|
||||||
|
|
||||||
|
It contains 4 digit datasets:
|
||||||
|
- MNIST: hand-written digits.
|
||||||
|
- MNIST-M: variant of MNIST with blended background.
|
||||||
|
- SVHN: street view house number.
|
||||||
|
- SYN: synthetic digits.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Lecun et al. Gradient-based learning applied to document
|
||||||
|
recognition. IEEE 1998.
|
||||||
|
- Ganin et al. Domain-adversarial training of neural networks.
|
||||||
|
JMLR 2016.
|
||||||
|
- Netzer et al. Reading digits in natural images with unsupervised
|
||||||
|
feature learning. NIPS-W 2011.
|
||||||
|
- Zhou et al. Deep Domain-Adversarial Image Generation for Domain
|
||||||
|
Generalisation. AAAI 2020.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "digits_dg"
|
||||||
|
domains = ["mnist", "mnist_m", "svhn", "syn"]
|
||||||
|
data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7"
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
|
||||||
|
if not osp.exists(self.dataset_dir):
|
||||||
|
dst = osp.join(root, "digits_dg.zip")
|
||||||
|
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train = self.read_data(
|
||||||
|
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
|
||||||
|
)
|
||||||
|
val = self.read_data(
|
||||||
|
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
|
||||||
|
)
|
||||||
|
test = self.read_data(
|
||||||
|
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(train_x=train, val=val, test=test)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_data(dataset_dir, input_domains, split):
|
||||||
|
|
||||||
|
def _load_data_from_directory(directory):
|
||||||
|
folders = listdir_nohidden(directory)
|
||||||
|
folders.sort()
|
||||||
|
items_ = []
|
||||||
|
|
||||||
|
for label, folder in enumerate(folders):
|
||||||
|
impaths = glob.glob(osp.join(directory, folder, "*.jpg"))
|
||||||
|
|
||||||
|
for impath in impaths:
|
||||||
|
items_.append((impath, label))
|
||||||
|
|
||||||
|
return items_
|
||||||
|
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
if split == "all":
|
||||||
|
train_dir = osp.join(dataset_dir, dname, "train")
|
||||||
|
impath_label_list = _load_data_from_directory(train_dir)
|
||||||
|
val_dir = osp.join(dataset_dir, dname, "val")
|
||||||
|
impath_label_list += _load_data_from_directory(val_dir)
|
||||||
|
else:
|
||||||
|
split_dir = osp.join(dataset_dir, dname, split)
|
||||||
|
impath_label_list = _load_data_from_directory(split_dir)
|
||||||
|
|
||||||
|
for impath, label in impath_label_list:
|
||||||
|
class_name = impath.split("/")[-2].lower()
|
||||||
|
item = Datum(
|
||||||
|
impath=impath,
|
||||||
|
label=label,
|
||||||
|
domain=domain,
|
||||||
|
classname=class_name
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from .digits_dg import DigitsDG
|
||||||
|
from ..base_dataset import DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class OfficeHomeDG(DatasetBase):
|
||||||
|
"""Office-Home.
|
||||||
|
|
||||||
|
Statistics:
|
||||||
|
- Around 15,500 images.
|
||||||
|
- 65 classes related to office and home objects.
|
||||||
|
- 4 domains: Art, Clipart, Product, Real World.
|
||||||
|
- URL: http://hemanthdv.org/OfficeHome-Dataset/.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Venkateswara et al. Deep Hashing Network for Unsupervised
|
||||||
|
Domain Adaptation. CVPR 2017.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "office_home_dg"
|
||||||
|
domains = ["art", "clipart", "product", "real_world"]
|
||||||
|
data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
|
||||||
|
if not osp.exists(self.dataset_dir):
|
||||||
|
dst = osp.join(root, "office_home_dg.zip")
|
||||||
|
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train = DigitsDG.read_data(
|
||||||
|
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
|
||||||
|
)
|
||||||
|
val = DigitsDG.read_data(
|
||||||
|
self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
|
||||||
|
)
|
||||||
|
test = DigitsDG.read_data(
|
||||||
|
self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(train_x=train, val=val, test=test)
|
||||||
94
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py
Normal file
94
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class PACS(DatasetBase):
|
||||||
|
"""PACS.
|
||||||
|
|
||||||
|
Statistics:
|
||||||
|
- 4 domains: Photo (1,670), Art (2,048), Cartoon
|
||||||
|
(2,344), Sketch (3,929).
|
||||||
|
- 7 categories: dog, elephant, giraffe, guitar, horse,
|
||||||
|
house and person.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Li et al. Deeper, broader and artier domain generalization.
|
||||||
|
ICCV 2017.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "pacs"
|
||||||
|
domains = ["art_painting", "cartoon", "photo", "sketch"]
|
||||||
|
data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE"
|
||||||
|
# the following images contain errors and should be ignored
|
||||||
|
_error_paths = ["sketch/dog/n02103406_4068-1.png"]
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
self.image_dir = osp.join(self.dataset_dir, "images")
|
||||||
|
self.split_dir = osp.join(self.dataset_dir, "splits")
|
||||||
|
|
||||||
|
if not osp.exists(self.dataset_dir):
|
||||||
|
dst = osp.join(root, "pacs.zip")
|
||||||
|
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
|
||||||
|
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
|
||||||
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all")
|
||||||
|
|
||||||
|
super().__init__(train_x=train, val=val, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, input_domains, split):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
if split == "all":
|
||||||
|
file_train = osp.join(
|
||||||
|
self.split_dir, dname + "_train_kfold.txt"
|
||||||
|
)
|
||||||
|
impath_label_list = self._read_split_pacs(file_train)
|
||||||
|
file_val = osp.join(
|
||||||
|
self.split_dir, dname + "_crossval_kfold.txt"
|
||||||
|
)
|
||||||
|
impath_label_list += self._read_split_pacs(file_val)
|
||||||
|
else:
|
||||||
|
file = osp.join(
|
||||||
|
self.split_dir, dname + "_" + split + "_kfold.txt"
|
||||||
|
)
|
||||||
|
impath_label_list = self._read_split_pacs(file)
|
||||||
|
|
||||||
|
for impath, label in impath_label_list:
|
||||||
|
classname = impath.split("/")[-2]
|
||||||
|
item = Datum(
|
||||||
|
impath=impath,
|
||||||
|
label=label,
|
||||||
|
domain=domain,
|
||||||
|
classname=classname
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
def _read_split_pacs(self, split_file):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
with open(split_file, "r") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
impath, label = line.split(" ")
|
||||||
|
if impath in self._error_paths:
|
||||||
|
continue
|
||||||
|
impath = osp.join(self.image_dir, impath)
|
||||||
|
label = int(label) - 1
|
||||||
|
items.append((impath, label))
|
||||||
|
|
||||||
|
return items
|
||||||
60
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
Normal file
60
Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import glob
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class VLCS(DatasetBase):
|
||||||
|
"""VLCS.
|
||||||
|
|
||||||
|
Statistics:
|
||||||
|
- 4 domains: CALTECH, LABELME, PASCAL, SUN
|
||||||
|
- 5 categories: bird, car, chair, dog, and person.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "VLCS"
|
||||||
|
domains = ["caltech", "labelme", "pascal", "sun"]
|
||||||
|
data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd"
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
|
||||||
|
if not osp.exists(self.dataset_dir):
|
||||||
|
dst = osp.join(root, "vlcs.zip")
|
||||||
|
self.download_data(self.data_url, dst, from_gdrive=True)
|
||||||
|
|
||||||
|
self.check_input_domains(
|
||||||
|
cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
|
||||||
|
)
|
||||||
|
|
||||||
|
train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
|
||||||
|
val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
|
||||||
|
test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")
|
||||||
|
|
||||||
|
super().__init__(train_x=train, val=val, test=test)
|
||||||
|
|
||||||
|
def _read_data(self, input_domains, split):
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for domain, dname in enumerate(input_domains):
|
||||||
|
dname = dname.upper()
|
||||||
|
path = osp.join(self.dataset_dir, dname, split)
|
||||||
|
folders = listdir_nohidden(path)
|
||||||
|
folders.sort()
|
||||||
|
|
||||||
|
for label, folder in enumerate(folders):
|
||||||
|
impaths = glob.glob(osp.join(path, folder, "*.jpg"))
|
||||||
|
|
||||||
|
for impath in impaths:
|
||||||
|
item = Datum(impath=impath, label=label, domain=domain)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .svhn import SVHN
|
||||||
|
from .cifar import CIFAR10, CIFAR100
|
||||||
|
from .stl10 import STL10
|
||||||
108
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py
Normal file
108
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
import math
|
||||||
|
import random
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class CIFAR10(DatasetBase):
|
||||||
|
"""CIFAR10 for SSL.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Krizhevsky. Learning Multiple Layers of Features
|
||||||
|
from Tiny Images. Tech report.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "cifar10"
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
train_dir = osp.join(self.dataset_dir, "train")
|
||||||
|
test_dir = osp.join(self.dataset_dir, "test")
|
||||||
|
|
||||||
|
assert cfg.DATASET.NUM_LABELED > 0
|
||||||
|
|
||||||
|
train_x, train_u, val = self._read_data_train(
|
||||||
|
train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT
|
||||||
|
)
|
||||||
|
test = self._read_data_test(test_dir)
|
||||||
|
|
||||||
|
if cfg.DATASET.ALL_AS_UNLABELED:
|
||||||
|
train_u = train_u + train_x
|
||||||
|
|
||||||
|
if len(val) == 0:
|
||||||
|
val = None
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
|
||||||
|
|
||||||
|
def _read_data_train(self, data_dir, num_labeled, val_percent):
|
||||||
|
class_names = listdir_nohidden(data_dir)
|
||||||
|
class_names.sort()
|
||||||
|
num_labeled_per_class = num_labeled / len(class_names)
|
||||||
|
items_x, items_u, items_v = [], [], []
|
||||||
|
|
||||||
|
for label, class_name in enumerate(class_names):
|
||||||
|
class_dir = osp.join(data_dir, class_name)
|
||||||
|
imnames = listdir_nohidden(class_dir)
|
||||||
|
|
||||||
|
# Split into train and val following Oliver et al. 2018
|
||||||
|
# Set cfg.DATASET.VAL_PERCENT to 0 to not use val data
|
||||||
|
num_val = math.floor(len(imnames) * val_percent)
|
||||||
|
imnames_train = imnames[num_val:]
|
||||||
|
imnames_val = imnames[:num_val]
|
||||||
|
|
||||||
|
# Note we do shuffle after split
|
||||||
|
random.shuffle(imnames_train)
|
||||||
|
|
||||||
|
for i, imname in enumerate(imnames_train):
|
||||||
|
impath = osp.join(class_dir, imname)
|
||||||
|
item = Datum(impath=impath, label=label)
|
||||||
|
|
||||||
|
if (i + 1) <= num_labeled_per_class:
|
||||||
|
items_x.append(item)
|
||||||
|
|
||||||
|
else:
|
||||||
|
items_u.append(item)
|
||||||
|
|
||||||
|
for imname in imnames_val:
|
||||||
|
impath = osp.join(class_dir, imname)
|
||||||
|
item = Datum(impath=impath, label=label)
|
||||||
|
items_v.append(item)
|
||||||
|
|
||||||
|
return items_x, items_u, items_v
|
||||||
|
|
||||||
|
def _read_data_test(self, data_dir):
|
||||||
|
class_names = listdir_nohidden(data_dir)
|
||||||
|
class_names.sort()
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for label, class_name in enumerate(class_names):
|
||||||
|
class_dir = osp.join(data_dir, class_name)
|
||||||
|
imnames = listdir_nohidden(class_dir)
|
||||||
|
|
||||||
|
for imname in imnames:
|
||||||
|
impath = osp.join(class_dir, imname)
|
||||||
|
item = Datum(impath=impath, label=label)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class CIFAR100(CIFAR10):
|
||||||
|
"""CIFAR100 for SSL.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Krizhevsky. Learning Multiple Layers of Features
|
||||||
|
from Tiny Images. Tech report.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "cifar100"
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
87
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py
Normal file
87
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import numpy as np
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from dassl.utils import listdir_nohidden
|
||||||
|
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
from ..base_dataset import Datum, DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class STL10(DatasetBase):
|
||||||
|
"""STL-10 dataset.
|
||||||
|
|
||||||
|
Description:
|
||||||
|
- 10 classes: airplane, bird, car, cat, deer, dog, horse,
|
||||||
|
monkey, ship, truck.
|
||||||
|
- Images are 96x96 pixels, color.
|
||||||
|
- 500 training images per class, 800 test images per class.
|
||||||
|
- 100,000 unlabeled images for unsupervised learning.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Coates et al. An Analysis of Single Layer Networks in
|
||||||
|
Unsupervised Feature Learning. AISTATS 2011.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "stl10"
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
|
||||||
|
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||||
|
train_dir = osp.join(self.dataset_dir, "train")
|
||||||
|
test_dir = osp.join(self.dataset_dir, "test")
|
||||||
|
unlabeled_dir = osp.join(self.dataset_dir, "unlabeled")
|
||||||
|
fold_file = osp.join(
|
||||||
|
self.dataset_dir, "stl10_binary", "fold_indices.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only use the first five splits
|
||||||
|
assert 0 <= cfg.DATASET.STL10_FOLD <= 4
|
||||||
|
|
||||||
|
train_x = self._read_data_train(
|
||||||
|
train_dir, cfg.DATASET.STL10_FOLD, fold_file
|
||||||
|
)
|
||||||
|
train_u = self._read_data_all(unlabeled_dir)
|
||||||
|
test = self._read_data_all(test_dir)
|
||||||
|
|
||||||
|
if cfg.DATASET.ALL_AS_UNLABELED:
|
||||||
|
train_u = train_u + train_x
|
||||||
|
|
||||||
|
super().__init__(train_x=train_x, train_u=train_u, test=test)
|
||||||
|
|
||||||
|
def _read_data_train(self, data_dir, fold, fold_file):
|
||||||
|
imnames = listdir_nohidden(data_dir)
|
||||||
|
imnames.sort()
|
||||||
|
items = []
|
||||||
|
|
||||||
|
list_idx = list(range(len(imnames)))
|
||||||
|
if fold >= 0:
|
||||||
|
with open(fold_file, "r") as f:
|
||||||
|
str_idx = f.read().splitlines()[fold]
|
||||||
|
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ")
|
||||||
|
|
||||||
|
for i in list_idx:
|
||||||
|
imname = imnames[i]
|
||||||
|
impath = osp.join(data_dir, imname)
|
||||||
|
label = osp.splitext(imname)[0].split("_")[1]
|
||||||
|
label = int(label)
|
||||||
|
item = Datum(impath=impath, label=label)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
def _read_data_all(self, data_dir):
|
||||||
|
imnames = listdir_nohidden(data_dir)
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for imname in imnames:
|
||||||
|
impath = osp.join(data_dir, imname)
|
||||||
|
label = osp.splitext(imname)[0].split("_")[1]
|
||||||
|
if label == "none":
|
||||||
|
label = -1
|
||||||
|
else:
|
||||||
|
label = int(label)
|
||||||
|
item = Datum(impath=impath, label=label)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
17
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py
Normal file
17
Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from .cifar import CIFAR10
|
||||||
|
from ..build import DATASET_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
|
@DATASET_REGISTRY.register()
|
||||||
|
class SVHN(CIFAR10):
|
||||||
|
"""SVHN for SSL.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Netzer et al. Reading Digits in Natural Images with
|
||||||
|
Unsupervised Feature Learning. NIPS-W 2011.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_dir = "svhn"
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
205
Dassl.ProGrad.pytorch/dassl/data/samplers.py
Normal file
205
Dassl.ProGrad.pytorch/dassl/data/samplers.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler
|
||||||
|
|
||||||
|
|
||||||
|
class RandomDomainSampler(Sampler):
|
||||||
|
"""Randomly samples N domains each with K images
|
||||||
|
to form a minibatch of size N*K.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source (list): list of Datums.
|
||||||
|
batch_size (int): batch size.
|
||||||
|
n_domain (int): number of domains to sample in a minibatch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_source, batch_size, n_domain):
|
||||||
|
self.data_source = data_source
|
||||||
|
|
||||||
|
# Keep track of image indices for each domain
|
||||||
|
self.domain_dict = defaultdict(list)
|
||||||
|
for i, item in enumerate(data_source):
|
||||||
|
self.domain_dict[item.domain].append(i)
|
||||||
|
self.domains = list(self.domain_dict.keys())
|
||||||
|
|
||||||
|
# Make sure each domain has equal number of images
|
||||||
|
if n_domain is None or n_domain <= 0:
|
||||||
|
n_domain = len(self.domains)
|
||||||
|
assert batch_size % n_domain == 0
|
||||||
|
self.n_img_per_domain = batch_size // n_domain
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
# n_domain denotes number of domains sampled in a minibatch
|
||||||
|
self.n_domain = n_domain
|
||||||
|
self.length = len(list(self.__iter__()))
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
domain_dict = copy.deepcopy(self.domain_dict)
|
||||||
|
final_idxs = []
|
||||||
|
stop_sampling = False
|
||||||
|
|
||||||
|
while not stop_sampling:
|
||||||
|
selected_domains = random.sample(self.domains, self.n_domain)
|
||||||
|
|
||||||
|
for domain in selected_domains:
|
||||||
|
idxs = domain_dict[domain]
|
||||||
|
selected_idxs = random.sample(idxs, self.n_img_per_domain)
|
||||||
|
final_idxs.extend(selected_idxs)
|
||||||
|
|
||||||
|
for idx in selected_idxs:
|
||||||
|
domain_dict[domain].remove(idx)
|
||||||
|
|
||||||
|
remaining = len(domain_dict[domain])
|
||||||
|
if remaining < self.n_img_per_domain:
|
||||||
|
stop_sampling = True
|
||||||
|
|
||||||
|
return iter(final_idxs)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
|
class SeqDomainSampler(Sampler):
|
||||||
|
"""Sequential domain sampler, which randomly samples K
|
||||||
|
images from each domain to form a minibatch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source (list): list of Datums.
|
||||||
|
batch_size (int): batch size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_source, batch_size):
|
||||||
|
self.data_source = data_source
|
||||||
|
|
||||||
|
# Keep track of image indices for each domain
|
||||||
|
self.domain_dict = defaultdict(list)
|
||||||
|
for i, item in enumerate(data_source):
|
||||||
|
self.domain_dict[item.domain].append(i)
|
||||||
|
self.domains = list(self.domain_dict.keys())
|
||||||
|
self.domains.sort()
|
||||||
|
|
||||||
|
# Make sure each domain has equal number of images
|
||||||
|
n_domain = len(self.domains)
|
||||||
|
assert batch_size % n_domain == 0
|
||||||
|
self.n_img_per_domain = batch_size // n_domain
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
# n_domain denotes number of domains sampled in a minibatch
|
||||||
|
self.n_domain = n_domain
|
||||||
|
self.length = len(list(self.__iter__()))
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
domain_dict = copy.deepcopy(self.domain_dict)
|
||||||
|
final_idxs = []
|
||||||
|
stop_sampling = False
|
||||||
|
|
||||||
|
while not stop_sampling:
|
||||||
|
for domain in self.domains:
|
||||||
|
idxs = domain_dict[domain]
|
||||||
|
selected_idxs = random.sample(idxs, self.n_img_per_domain)
|
||||||
|
final_idxs.extend(selected_idxs)
|
||||||
|
|
||||||
|
for idx in selected_idxs:
|
||||||
|
domain_dict[domain].remove(idx)
|
||||||
|
|
||||||
|
remaining = len(domain_dict[domain])
|
||||||
|
if remaining < self.n_img_per_domain:
|
||||||
|
stop_sampling = True
|
||||||
|
|
||||||
|
return iter(final_idxs)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
|
class RandomClassSampler(Sampler):
|
||||||
|
"""Randomly samples N classes each with K instances to
|
||||||
|
form a minibatch of size N*K.
|
||||||
|
|
||||||
|
Modified from https://github.com/KaiyangZhou/deep-person-reid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source (list): list of Datums.
|
||||||
|
batch_size (int): batch size.
|
||||||
|
n_ins (int): number of instances per class to sample in a minibatch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_source, batch_size, n_ins):
|
||||||
|
if batch_size < n_ins:
|
||||||
|
raise ValueError(
|
||||||
|
"batch_size={} must be no less "
|
||||||
|
"than n_ins={}".format(batch_size, n_ins)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.data_source = data_source
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.n_ins = n_ins
|
||||||
|
self.ncls_per_batch = self.batch_size // self.n_ins
|
||||||
|
self.index_dic = defaultdict(list)
|
||||||
|
for index, item in enumerate(data_source):
|
||||||
|
self.index_dic[item.label].append(index)
|
||||||
|
self.labels = list(self.index_dic.keys())
|
||||||
|
assert len(self.labels) >= self.ncls_per_batch
|
||||||
|
|
||||||
|
# estimate number of images in an epoch
|
||||||
|
self.length = len(list(self.__iter__()))
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
batch_idxs_dict = defaultdict(list)
|
||||||
|
|
||||||
|
for label in self.labels:
|
||||||
|
idxs = copy.deepcopy(self.index_dic[label])
|
||||||
|
if len(idxs) < self.n_ins:
|
||||||
|
idxs = np.random.choice(idxs, size=self.n_ins, replace=True)
|
||||||
|
random.shuffle(idxs)
|
||||||
|
batch_idxs = []
|
||||||
|
for idx in idxs:
|
||||||
|
batch_idxs.append(idx)
|
||||||
|
if len(batch_idxs) == self.n_ins:
|
||||||
|
batch_idxs_dict[label].append(batch_idxs)
|
||||||
|
batch_idxs = []
|
||||||
|
|
||||||
|
avai_labels = copy.deepcopy(self.labels)
|
||||||
|
final_idxs = []
|
||||||
|
|
||||||
|
while len(avai_labels) >= self.ncls_per_batch:
|
||||||
|
selected_labels = random.sample(avai_labels, self.ncls_per_batch)
|
||||||
|
for label in selected_labels:
|
||||||
|
batch_idxs = batch_idxs_dict[label].pop(0)
|
||||||
|
final_idxs.extend(batch_idxs)
|
||||||
|
if len(batch_idxs_dict[label]) == 0:
|
||||||
|
avai_labels.remove(label)
|
||||||
|
|
||||||
|
return iter(final_idxs)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
|
def build_sampler(
|
||||||
|
sampler_type,
|
||||||
|
cfg=None,
|
||||||
|
data_source=None,
|
||||||
|
batch_size=32,
|
||||||
|
n_domain=0,
|
||||||
|
n_ins=16
|
||||||
|
):
|
||||||
|
if sampler_type == "RandomSampler":
|
||||||
|
return RandomSampler(data_source)
|
||||||
|
|
||||||
|
elif sampler_type == "SequentialSampler":
|
||||||
|
return SequentialSampler(data_source)
|
||||||
|
|
||||||
|
elif sampler_type == "RandomDomainSampler":
|
||||||
|
return RandomDomainSampler(data_source, batch_size, n_domain)
|
||||||
|
|
||||||
|
elif sampler_type == "SeqDomainSampler":
|
||||||
|
return SeqDomainSampler(data_source, batch_size)
|
||||||
|
|
||||||
|
elif sampler_type == "RandomClassSampler":
|
||||||
|
return RandomClassSampler(data_source, batch_size, n_ins)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown sampler type: {}".format(sampler_type))
|
||||||
1
Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py
Normal file
1
Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .transforms import build_transform
|
||||||
273
Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py
Normal file
273
Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
"""
|
||||||
|
Source: https://github.com/DeepVoltaire/AutoAugment
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from PIL import Image, ImageOps, ImageEnhance
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetPolicy:
|
||||||
|
"""Randomly choose one of the best 24 Sub-policies on ImageNet.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> policy = ImageNetPolicy()
|
||||||
|
>>> transformed = policy(image)
|
||||||
|
|
||||||
|
Example as a PyTorch Transform:
|
||||||
|
>>> transform=transforms.Compose([
|
||||||
|
>>> transforms.Resize(256),
|
||||||
|
>>> ImageNetPolicy(),
|
||||||
|
>>> transforms.ToTensor()])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fillcolor=(128, 128, 128)):
|
||||||
|
self.policies = [
|
||||||
|
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||||
|
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||||
|
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||||
|
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||||
|
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||||
|
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||||
|
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||||
|
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||||
|
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||||
|
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||||
|
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||||
|
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||||
|
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||||
|
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||||
|
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||||
|
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||||
|
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||||
|
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||||
|
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||||
|
return self.policies[policy_idx](img)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "AutoAugment ImageNet Policy"
|
||||||
|
|
||||||
|
|
||||||
|
class CIFAR10Policy:
|
||||||
|
"""Randomly choose one of the best 25 Sub-policies on CIFAR10.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> policy = CIFAR10Policy()
|
||||||
|
>>> transformed = policy(image)
|
||||||
|
|
||||||
|
Example as a PyTorch Transform:
|
||||||
|
>>> transform=transforms.Compose([
|
||||||
|
>>> transforms.Resize(256),
|
||||||
|
>>> CIFAR10Policy(),
|
||||||
|
>>> transforms.ToTensor()])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fillcolor=(128, 128, 128)):
|
||||||
|
self.policies = [
|
||||||
|
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||||
|
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||||
|
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||||
|
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||||
|
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||||
|
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||||
|
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||||
|
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||||
|
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||||
|
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||||
|
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||||
|
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||||
|
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||||
|
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||||
|
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||||
|
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
|
||||||
|
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||||
|
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||||
|
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||||
|
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||||
|
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||||
|
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||||
|
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||||
|
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||||
|
return self.policies[policy_idx](img)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "AutoAugment CIFAR10 Policy"
|
||||||
|
|
||||||
|
|
||||||
|
class SVHNPolicy:
|
||||||
|
"""Randomly choose one of the best 25 Sub-policies on SVHN.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> policy = SVHNPolicy()
|
||||||
|
>>> transformed = policy(image)
|
||||||
|
|
||||||
|
Example as a PyTorch Transform:
|
||||||
|
>>> transform=transforms.Compose([
|
||||||
|
>>> transforms.Resize(256),
|
||||||
|
>>> SVHNPolicy(),
|
||||||
|
>>> transforms.ToTensor()])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fillcolor=(128, 128, 128)):
|
||||||
|
self.policies = [
|
||||||
|
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||||
|
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||||
|
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||||
|
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||||
|
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||||
|
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||||
|
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||||
|
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||||
|
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||||
|
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||||
|
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||||
|
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||||
|
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||||
|
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||||
|
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||||
|
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||||
|
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||||
|
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||||
|
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||||
|
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||||
|
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||||
|
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||||
|
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||||
|
return self.policies[policy_idx](img)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "AutoAugment SVHN Policy"
|
||||||
|
|
||||||
|
|
||||||
|
class SubPolicy(object):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
p1,
|
||||||
|
operation1,
|
||||||
|
magnitude_idx1,
|
||||||
|
p2,
|
||||||
|
operation2,
|
||||||
|
magnitude_idx2,
|
||||||
|
fillcolor=(128, 128, 128),
|
||||||
|
):
|
||||||
|
ranges = {
|
||||||
|
"shearX": np.linspace(0, 0.3, 10),
|
||||||
|
"shearY": np.linspace(0, 0.3, 10),
|
||||||
|
"translateX": np.linspace(0, 150 / 331, 10),
|
||||||
|
"translateY": np.linspace(0, 150 / 331, 10),
|
||||||
|
"rotate": np.linspace(0, 30, 10),
|
||||||
|
"color": np.linspace(0.0, 0.9, 10),
|
||||||
|
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||||
|
"solarize": np.linspace(256, 0, 10),
|
||||||
|
"contrast": np.linspace(0.0, 0.9, 10),
|
||||||
|
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||||
|
"brightness": np.linspace(0.0, 0.9, 10),
|
||||||
|
"autocontrast": [0] * 10,
|
||||||
|
"equalize": [0] * 10,
|
||||||
|
"invert": [0] * 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||||
|
def rotate_with_fill(img, magnitude):
|
||||||
|
rot = img.convert("RGBA").rotate(magnitude)
|
||||||
|
return Image.composite(
|
||||||
|
rot, Image.new("RGBA", rot.size, (128, ) * 4), rot
|
||||||
|
).convert(img.mode)
|
||||||
|
|
||||||
|
func = {
|
||||||
|
"shearX":
|
||||||
|
lambda img, magnitude: img.transform(
|
||||||
|
img.size,
|
||||||
|
Image.AFFINE,
|
||||||
|
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||||
|
Image.BICUBIC,
|
||||||
|
fillcolor=fillcolor,
|
||||||
|
),
|
||||||
|
"shearY":
|
||||||
|
lambda img, magnitude: img.transform(
|
||||||
|
img.size,
|
||||||
|
Image.AFFINE,
|
||||||
|
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||||
|
Image.BICUBIC,
|
||||||
|
fillcolor=fillcolor,
|
||||||
|
),
|
||||||
|
"translateX":
|
||||||
|
lambda img, magnitude: img.transform(
|
||||||
|
img.size,
|
||||||
|
Image.AFFINE,
|
||||||
|
(
|
||||||
|
1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0,
|
||||||
|
1, 0
|
||||||
|
),
|
||||||
|
fillcolor=fillcolor,
|
||||||
|
),
|
||||||
|
"translateY":
|
||||||
|
lambda img, magnitude: img.transform(
|
||||||
|
img.size,
|
||||||
|
Image.AFFINE,
|
||||||
|
(
|
||||||
|
1, 0, 0, 0, 1, magnitude * img.size[1] * random.
|
||||||
|
choice([-1, 1])
|
||||||
|
),
|
||||||
|
fillcolor=fillcolor,
|
||||||
|
),
|
||||||
|
"rotate":
|
||||||
|
lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||||
|
"color":
|
||||||
|
lambda img, magnitude: ImageEnhance.Color(img).
|
||||||
|
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||||
|
"posterize":
|
||||||
|
lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||||
|
"solarize":
|
||||||
|
lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||||
|
"contrast":
|
||||||
|
lambda img, magnitude: ImageEnhance.Contrast(img).
|
||||||
|
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||||
|
"sharpness":
|
||||||
|
lambda img, magnitude: ImageEnhance.Sharpness(img).
|
||||||
|
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||||
|
"brightness":
|
||||||
|
lambda img, magnitude: ImageEnhance.Brightness(img).
|
||||||
|
enhance(1 + magnitude * random.choice([-1, 1])),
|
||||||
|
"autocontrast":
|
||||||
|
lambda img, magnitude: ImageOps.autocontrast(img),
|
||||||
|
"equalize":
|
||||||
|
lambda img, magnitude: ImageOps.equalize(img),
|
||||||
|
"invert":
|
||||||
|
lambda img, magnitude: ImageOps.invert(img),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.p1 = p1
|
||||||
|
self.operation1 = func[operation1]
|
||||||
|
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||||
|
self.p2 = p2
|
||||||
|
self.operation2 = func[operation2]
|
||||||
|
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if random.random() < self.p1:
|
||||||
|
img = self.operation1(img, self.magnitude1)
|
||||||
|
if random.random() < self.p2:
|
||||||
|
img = self.operation2(img, self.magnitude2)
|
||||||
|
return img
|
||||||
363
Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py
Normal file
363
Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
"""
|
||||||
|
Credit to
|
||||||
|
1) https://github.com/ildoonet/pytorch-randaugment
|
||||||
|
2) https://github.com/kakaobrain/fast-autoaugment
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
import PIL.ImageOps
|
||||||
|
import PIL.ImageDraw
|
||||||
|
import PIL.ImageEnhance
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def ShearX(img, v):
|
||||||
|
assert -0.3 <= v <= 0.3
|
||||||
|
if random.random() > 0.5:
|
||||||
|
v = -v
|
||||||
|
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def ShearY(img, v):
|
||||||
|
assert -0.3 <= v <= 0.3
|
||||||
|
if random.random() > 0.5:
|
||||||
|
v = -v
|
||||||
|
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def TranslateX(img, v):
|
||||||
|
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||||
|
assert -0.45 <= v <= 0.45
|
||||||
|
if random.random() > 0.5:
|
||||||
|
v = -v
|
||||||
|
v = v * img.size[0]
|
||||||
|
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def TranslateXabs(img, v):
|
||||||
|
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||||
|
assert 0 <= v
|
||||||
|
if random.random() > 0.5:
|
||||||
|
v = -v
|
||||||
|
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def TranslateY(img, v):
|
||||||
|
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||||
|
assert -0.45 <= v <= 0.45
|
||||||
|
if random.random() > 0.5:
|
||||||
|
v = -v
|
||||||
|
v = v * img.size[1]
|
||||||
|
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
|
||||||
|
|
||||||
|
|
||||||
|
def TranslateYabs(img, v):
|
||||||
|
# [-150, 150] => percentage: [-0.45, 0.45]
|
||||||
|
assert 0 <= v
|
||||||
|
if random.random() > 0.5:
|
||||||
|
v = -v
|
||||||
|
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
|
||||||
|
|
||||||
|
|
||||||
|
def Rotate(img, v):
|
||||||
|
assert -30 <= v <= 30
|
||||||
|
if random.random() > 0.5:
|
||||||
|
v = -v
|
||||||
|
return img.rotate(v)
|
||||||
|
|
||||||
|
|
||||||
|
def AutoContrast(img, _):
|
||||||
|
return PIL.ImageOps.autocontrast(img)
|
||||||
|
|
||||||
|
|
||||||
|
def Invert(img, _):
|
||||||
|
return PIL.ImageOps.invert(img)
|
||||||
|
|
||||||
|
|
||||||
|
def Equalize(img, _):
|
||||||
|
return PIL.ImageOps.equalize(img)
|
||||||
|
|
||||||
|
|
||||||
|
def Flip(img, _):
|
||||||
|
return PIL.ImageOps.mirror(img)
|
||||||
|
|
||||||
|
|
||||||
|
def Solarize(img, v):
|
||||||
|
assert 0 <= v <= 256
|
||||||
|
return PIL.ImageOps.solarize(img, v)
|
||||||
|
|
||||||
|
|
||||||
|
def SolarizeAdd(img, addition=0, threshold=128):
|
||||||
|
img_np = np.array(img).astype(np.int)
|
||||||
|
img_np = img_np + addition
|
||||||
|
img_np = np.clip(img_np, 0, 255)
|
||||||
|
img_np = img_np.astype(np.uint8)
|
||||||
|
img = Image.fromarray(img_np)
|
||||||
|
return PIL.ImageOps.solarize(img, threshold)
|
||||||
|
|
||||||
|
|
||||||
|
def Posterize(img, v):
|
||||||
|
assert 4 <= v <= 8
|
||||||
|
v = int(v)
|
||||||
|
return PIL.ImageOps.posterize(img, v)
|
||||||
|
|
||||||
|
|
||||||
|
def Contrast(img, v):
|
||||||
|
assert 0.0 <= v <= 2.0
|
||||||
|
return PIL.ImageEnhance.Contrast(img).enhance(v)
|
||||||
|
|
||||||
|
|
||||||
|
def Color(img, v):
|
||||||
|
assert 0.0 <= v <= 2.0
|
||||||
|
return PIL.ImageEnhance.Color(img).enhance(v)
|
||||||
|
|
||||||
|
|
||||||
|
def Brightness(img, v):
|
||||||
|
assert 0.0 <= v <= 2.0
|
||||||
|
return PIL.ImageEnhance.Brightness(img).enhance(v)
|
||||||
|
|
||||||
|
|
||||||
|
def Sharpness(img, v):
|
||||||
|
assert 0.0 <= v <= 2.0
|
||||||
|
return PIL.ImageEnhance.Sharpness(img).enhance(v)
|
||||||
|
|
||||||
|
|
||||||
|
def Cutout(img, v):
|
||||||
|
# [0, 60] => percentage: [0, 0.2]
|
||||||
|
assert 0.0 <= v <= 0.2
|
||||||
|
if v <= 0.0:
|
||||||
|
return img
|
||||||
|
|
||||||
|
v = v * img.size[0]
|
||||||
|
return CutoutAbs(img, v)
|
||||||
|
|
||||||
|
|
||||||
|
def CutoutAbs(img, v):
|
||||||
|
# [0, 60] => percentage: [0, 0.2]
|
||||||
|
# assert 0 <= v <= 20
|
||||||
|
if v < 0:
|
||||||
|
return img
|
||||||
|
w, h = img.size
|
||||||
|
x0 = np.random.uniform(w)
|
||||||
|
y0 = np.random.uniform(h)
|
||||||
|
|
||||||
|
x0 = int(max(0, x0 - v/2.0))
|
||||||
|
y0 = int(max(0, y0 - v/2.0))
|
||||||
|
x1 = min(w, x0 + v)
|
||||||
|
y1 = min(h, y0 + v)
|
||||||
|
|
||||||
|
xy = (x0, y0, x1, y1)
|
||||||
|
color = (125, 123, 114)
|
||||||
|
# color = (0, 0, 0)
|
||||||
|
img = img.copy()
|
||||||
|
PIL.ImageDraw.Draw(img).rectangle(xy, color)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def SamplePairing(imgs):
|
||||||
|
# [0, 0.4]
|
||||||
|
def f(img1, v):
|
||||||
|
i = np.random.choice(len(imgs))
|
||||||
|
img2 = PIL.Image.fromarray(imgs[i])
|
||||||
|
return PIL.Image.blend(img1, img2, v)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def Identity(img, v):
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class Lighting:
|
||||||
|
"""Lighting noise (AlexNet - style PCA - based noise)."""
|
||||||
|
|
||||||
|
def __init__(self, alphastd, eigval, eigvec):
|
||||||
|
self.alphastd = alphastd
|
||||||
|
self.eigval = torch.Tensor(eigval)
|
||||||
|
self.eigvec = torch.Tensor(eigvec)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if self.alphastd == 0:
|
||||||
|
return img
|
||||||
|
|
||||||
|
alpha = img.new().resize_(3).normal_(0, self.alphastd)
|
||||||
|
rgb = (
|
||||||
|
self.eigvec.type_as(img).clone().mul(
|
||||||
|
alpha.view(1, 3).expand(3, 3)
|
||||||
|
).mul(self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze()
|
||||||
|
)
|
||||||
|
|
||||||
|
return img.add(rgb.view(3, 1, 1).expand_as(img))
|
||||||
|
|
||||||
|
|
||||||
|
class CutoutDefault:
|
||||||
|
"""
|
||||||
|
Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, length):
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
h, w = img.size(1), img.size(2)
|
||||||
|
mask = np.ones((h, w), np.float32)
|
||||||
|
y = np.random.randint(h)
|
||||||
|
x = np.random.randint(w)
|
||||||
|
|
||||||
|
y1 = np.clip(y - self.length // 2, 0, h)
|
||||||
|
y2 = np.clip(y + self.length // 2, 0, h)
|
||||||
|
x1 = np.clip(x - self.length // 2, 0, w)
|
||||||
|
x2 = np.clip(x + self.length // 2, 0, w)
|
||||||
|
|
||||||
|
mask[y1:y2, x1:x2] = 0.0
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
mask = mask.expand_as(img)
|
||||||
|
img *= mask
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def randaugment_list():
|
||||||
|
# 16 oeprations and their ranges
|
||||||
|
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
|
||||||
|
# augs = [
|
||||||
|
# (Identity, 0., 1.0),
|
||||||
|
# (ShearX, 0., 0.3), # 0
|
||||||
|
# (ShearY, 0., 0.3), # 1
|
||||||
|
# (TranslateX, 0., 0.33), # 2
|
||||||
|
# (TranslateY, 0., 0.33), # 3
|
||||||
|
# (Rotate, 0, 30), # 4
|
||||||
|
# (AutoContrast, 0, 1), # 5
|
||||||
|
# (Invert, 0, 1), # 6
|
||||||
|
# (Equalize, 0, 1), # 7
|
||||||
|
# (Solarize, 0, 110), # 8
|
||||||
|
# (Posterize, 4, 8), # 9
|
||||||
|
# # (Contrast, 0.1, 1.9), # 10
|
||||||
|
# (Color, 0.1, 1.9), # 11
|
||||||
|
# (Brightness, 0.1, 1.9), # 12
|
||||||
|
# (Sharpness, 0.1, 1.9), # 13
|
||||||
|
# # (Cutout, 0, 0.2), # 14
|
||||||
|
# # (SamplePairing(imgs), 0, 0.4) # 15
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
|
||||||
|
augs = [
|
||||||
|
(AutoContrast, 0, 1),
|
||||||
|
(Equalize, 0, 1),
|
||||||
|
(Invert, 0, 1),
|
||||||
|
(Rotate, 0, 30),
|
||||||
|
(Posterize, 4, 8),
|
||||||
|
(Solarize, 0, 256),
|
||||||
|
(SolarizeAdd, 0, 110),
|
||||||
|
(Color, 0.1, 1.9),
|
||||||
|
(Contrast, 0.1, 1.9),
|
||||||
|
(Brightness, 0.1, 1.9),
|
||||||
|
(Sharpness, 0.1, 1.9),
|
||||||
|
(ShearX, 0.0, 0.3),
|
||||||
|
(ShearY, 0.0, 0.3),
|
||||||
|
(CutoutAbs, 0, 40),
|
||||||
|
(TranslateXabs, 0.0, 100),
|
||||||
|
(TranslateYabs, 0.0, 100),
|
||||||
|
]
|
||||||
|
|
||||||
|
return augs
|
||||||
|
|
||||||
|
|
||||||
|
def randaugment_list2():
|
||||||
|
augs = [
|
||||||
|
(AutoContrast, 0, 1),
|
||||||
|
(Brightness, 0.1, 1.9),
|
||||||
|
(Color, 0.1, 1.9),
|
||||||
|
(Contrast, 0.1, 1.9),
|
||||||
|
(Equalize, 0, 1),
|
||||||
|
(Identity, 0, 1),
|
||||||
|
(Invert, 0, 1),
|
||||||
|
(Posterize, 4, 8),
|
||||||
|
(Rotate, -30, 30),
|
||||||
|
(Sharpness, 0.1, 1.9),
|
||||||
|
(ShearX, -0.3, 0.3),
|
||||||
|
(ShearY, -0.3, 0.3),
|
||||||
|
(Solarize, 0, 256),
|
||||||
|
(TranslateX, -0.3, 0.3),
|
||||||
|
(TranslateY, -0.3, 0.3),
|
||||||
|
]
|
||||||
|
|
||||||
|
return augs
|
||||||
|
|
||||||
|
|
||||||
|
def fixmatch_list():
|
||||||
|
# https://arxiv.org/abs/2001.07685
|
||||||
|
augs = [
|
||||||
|
(AutoContrast, 0, 1),
|
||||||
|
(Brightness, 0.05, 0.95),
|
||||||
|
(Color, 0.05, 0.95),
|
||||||
|
(Contrast, 0.05, 0.95),
|
||||||
|
(Equalize, 0, 1),
|
||||||
|
(Identity, 0, 1),
|
||||||
|
(Posterize, 4, 8),
|
||||||
|
(Rotate, -30, 30),
|
||||||
|
(Sharpness, 0.05, 0.95),
|
||||||
|
(ShearX, -0.3, 0.3),
|
||||||
|
(ShearY, -0.3, 0.3),
|
||||||
|
(Solarize, 0, 256),
|
||||||
|
(TranslateX, -0.3, 0.3),
|
||||||
|
(TranslateY, -0.3, 0.3),
|
||||||
|
]
|
||||||
|
|
||||||
|
return augs
|
||||||
|
|
||||||
|
|
||||||
|
class RandAugment:
|
||||||
|
|
||||||
|
def __init__(self, n=2, m=10):
|
||||||
|
assert 0 <= m <= 30
|
||||||
|
self.n = n
|
||||||
|
self.m = m
|
||||||
|
self.augment_list = randaugment_list()
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
ops = random.choices(self.augment_list, k=self.n)
|
||||||
|
|
||||||
|
for op, minval, maxval in ops:
|
||||||
|
val = (self.m / 30) * (maxval-minval) + minval
|
||||||
|
img = op(img, val)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class RandAugment2:
|
||||||
|
|
||||||
|
def __init__(self, n=2, p=0.6):
|
||||||
|
self.n = n
|
||||||
|
self.p = p
|
||||||
|
self.augment_list = randaugment_list2()
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
ops = random.choices(self.augment_list, k=self.n)
|
||||||
|
|
||||||
|
for op, minval, maxval in ops:
|
||||||
|
if random.random() > self.p:
|
||||||
|
continue
|
||||||
|
m = random.random()
|
||||||
|
val = m * (maxval-minval) + minval
|
||||||
|
img = op(img, val)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class RandAugmentFixMatch:
|
||||||
|
|
||||||
|
def __init__(self, n=2):
|
||||||
|
self.n = n
|
||||||
|
self.augment_list = fixmatch_list()
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
ops = random.choices(self.augment_list, k=self.n)
|
||||||
|
|
||||||
|
for op, minval, maxval in ops:
|
||||||
|
m = random.random()
|
||||||
|
val = m * (maxval-minval) + minval
|
||||||
|
img = op(img, val)
|
||||||
|
|
||||||
|
return img
|
||||||
341
Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py
Normal file
341
Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import (
|
||||||
|
Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter,
|
||||||
|
RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop,
|
||||||
|
RandomHorizontalFlip
|
||||||
|
)
|
||||||
|
|
||||||
|
from .autoaugment import SVHNPolicy, CIFAR10Policy, ImageNetPolicy
|
||||||
|
from .randaugment import RandAugment, RandAugment2, RandAugmentFixMatch
|
||||||
|
|
||||||
|
AVAI_CHOICES = [
|
||||||
|
"random_flip",
|
||||||
|
"random_resized_crop",
|
||||||
|
"normalize",
|
||||||
|
"instance_norm",
|
||||||
|
"random_crop",
|
||||||
|
"random_translation",
|
||||||
|
"center_crop", # This has become a default operation for test
|
||||||
|
"cutout",
|
||||||
|
"imagenet_policy",
|
||||||
|
"cifar10_policy",
|
||||||
|
"svhn_policy",
|
||||||
|
"randaugment",
|
||||||
|
"randaugment_fixmatch",
|
||||||
|
"randaugment2",
|
||||||
|
"gaussian_noise",
|
||||||
|
"colorjitter",
|
||||||
|
"randomgrayscale",
|
||||||
|
"gaussian_blur",
|
||||||
|
]
|
||||||
|
|
||||||
|
INTERPOLATION_MODES = {
|
||||||
|
"bilinear": Image.BILINEAR,
|
||||||
|
"bicubic": Image.BICUBIC,
|
||||||
|
"nearest": Image.NEAREST,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Random2DTranslation:
|
||||||
|
"""Given an image of (height, width), we resize it to
|
||||||
|
(height*1.125, width*1.125), and then perform random cropping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height (int): target image height.
|
||||||
|
width (int): target image width.
|
||||||
|
p (float, optional): probability that this operation takes place.
|
||||||
|
Default is 0.5.
|
||||||
|
interpolation (int, optional): desired interpolation. Default is
|
||||||
|
``PIL.Image.BILINEAR``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.p = p
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if random.uniform(0, 1) > self.p:
|
||||||
|
return img.resize((self.width, self.height), self.interpolation)
|
||||||
|
|
||||||
|
new_width = int(round(self.width * 1.125))
|
||||||
|
new_height = int(round(self.height * 1.125))
|
||||||
|
resized_img = img.resize((new_width, new_height), self.interpolation)
|
||||||
|
|
||||||
|
x_maxrange = new_width - self.width
|
||||||
|
y_maxrange = new_height - self.height
|
||||||
|
x1 = int(round(random.uniform(0, x_maxrange)))
|
||||||
|
y1 = int(round(random.uniform(0, y_maxrange)))
|
||||||
|
croped_img = resized_img.crop(
|
||||||
|
(x1, y1, x1 + self.width, y1 + self.height)
|
||||||
|
)
|
||||||
|
|
||||||
|
return croped_img
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceNormalization:
|
||||||
|
"""Normalize data using per-channel mean and standard deviation.
|
||||||
|
|
||||||
|
Reference:
|
||||||
|
- Ulyanov et al. Instance normalization: The missing in- gredient
|
||||||
|
for fast stylization. ArXiv 2016.
|
||||||
|
- Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation.
|
||||||
|
ICLR 2018.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eps=1e-8):
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
C, H, W = img.shape
|
||||||
|
img_re = img.reshape(C, H * W)
|
||||||
|
mean = img_re.mean(1).view(C, 1, 1)
|
||||||
|
std = img_re.std(1).view(C, 1, 1)
|
||||||
|
return (img-mean) / (std + self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Cutout:
|
||||||
|
"""Randomly mask out one or more patches from an image.
|
||||||
|
|
||||||
|
https://github.com/uoguelph-mlrg/Cutout
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_holes (int, optional): number of patches to cut out
|
||||||
|
of each image. Default is 1.
|
||||||
|
length (int, optinal): length (in pixels) of each square
|
||||||
|
patch. Default is 16.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_holes=1, length=16):
|
||||||
|
self.n_holes = n_holes
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (Tensor): tensor image of size (C, H, W).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: image with n_holes of dimension
|
||||||
|
length x length cut out of it.
|
||||||
|
"""
|
||||||
|
h = img.size(1)
|
||||||
|
w = img.size(2)
|
||||||
|
|
||||||
|
mask = np.ones((h, w), np.float32)
|
||||||
|
|
||||||
|
for n in range(self.n_holes):
|
||||||
|
y = np.random.randint(h)
|
||||||
|
x = np.random.randint(w)
|
||||||
|
|
||||||
|
y1 = np.clip(y - self.length // 2, 0, h)
|
||||||
|
y2 = np.clip(y + self.length // 2, 0, h)
|
||||||
|
x1 = np.clip(x - self.length // 2, 0, w)
|
||||||
|
x2 = np.clip(x + self.length // 2, 0, w)
|
||||||
|
|
||||||
|
mask[y1:y2, x1:x2] = 0.0
|
||||||
|
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
mask = mask.expand_as(img)
|
||||||
|
return img * mask
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianNoise:
|
||||||
|
"""Add gaussian noise."""
|
||||||
|
|
||||||
|
def __init__(self, mean=0, std=0.15, p=0.5):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
self.p = p
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if random.uniform(0, 1) > self.p:
|
||||||
|
return img
|
||||||
|
noise = torch.randn(img.size()) * self.std + self.mean
|
||||||
|
return img + noise
|
||||||
|
|
||||||
|
|
||||||
|
def build_transform(cfg, is_train=True, choices=None):
|
||||||
|
"""Build transformation function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (CfgNode): config.
|
||||||
|
is_train (bool, optional): for training (True) or test (False).
|
||||||
|
Default is True.
|
||||||
|
choices (list, optional): list of strings which will overwrite
|
||||||
|
cfg.INPUT.TRANSFORMS if given. Default is None.
|
||||||
|
"""
|
||||||
|
if cfg.INPUT.NO_TRANSFORM:
|
||||||
|
print("Note: no transform is applied!")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if choices is None:
|
||||||
|
choices = cfg.INPUT.TRANSFORMS
|
||||||
|
|
||||||
|
for choice in choices:
|
||||||
|
assert choice in AVAI_CHOICES
|
||||||
|
|
||||||
|
target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"
|
||||||
|
|
||||||
|
normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
return _build_transform_train(cfg, choices, target_size, normalize)
|
||||||
|
else:
|
||||||
|
return _build_transform_test(cfg, choices, target_size, normalize)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_transform_train(cfg, choices, target_size, normalize):
|
||||||
|
print("Building transform_train")
|
||||||
|
tfm_train = []
|
||||||
|
|
||||||
|
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||||
|
|
||||||
|
# Make sure the image size matches the target size
|
||||||
|
conditions = []
|
||||||
|
conditions += ["random_crop" not in choices]
|
||||||
|
conditions += ["random_resized_crop" not in choices]
|
||||||
|
if all(conditions):
|
||||||
|
print(f"+ resize to {target_size}")
|
||||||
|
tfm_train += [Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||||
|
|
||||||
|
if "random_translation" in choices:
|
||||||
|
print("+ random translation")
|
||||||
|
tfm_train += [
|
||||||
|
Random2DTranslation(cfg.INPUT.SIZE[0], cfg.INPUT.SIZE[1])
|
||||||
|
]
|
||||||
|
|
||||||
|
if "random_crop" in choices:
|
||||||
|
crop_padding = cfg.INPUT.CROP_PADDING
|
||||||
|
print("+ random crop (padding = {})".format(crop_padding))
|
||||||
|
tfm_train += [RandomCrop(cfg.INPUT.SIZE, padding=crop_padding)]
|
||||||
|
|
||||||
|
if "random_resized_crop" in choices:
|
||||||
|
print(f"+ random resized crop (size={cfg.INPUT.SIZE})")
|
||||||
|
tfm_train += [
|
||||||
|
RandomResizedCrop(cfg.INPUT.SIZE, interpolation=interp_mode)
|
||||||
|
]
|
||||||
|
|
||||||
|
if "center_crop" in choices:
|
||||||
|
print(f"+ center crop (size={cfg.INPUT.SIZE})")
|
||||||
|
tfm_train += [CenterCrop(cfg.INPUT.SIZE)]
|
||||||
|
|
||||||
|
if "random_flip" in choices:
|
||||||
|
print("+ random flip")
|
||||||
|
tfm_train += [RandomHorizontalFlip()]
|
||||||
|
|
||||||
|
if "imagenet_policy" in choices:
|
||||||
|
print("+ imagenet policy")
|
||||||
|
tfm_train += [ImageNetPolicy()]
|
||||||
|
|
||||||
|
if "cifar10_policy" in choices:
|
||||||
|
print("+ cifar10 policy")
|
||||||
|
tfm_train += [CIFAR10Policy()]
|
||||||
|
|
||||||
|
if "svhn_policy" in choices:
|
||||||
|
print("+ svhn policy")
|
||||||
|
tfm_train += [SVHNPolicy()]
|
||||||
|
|
||||||
|
if "randaugment" in choices:
|
||||||
|
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||||
|
m_ = cfg.INPUT.RANDAUGMENT_M
|
||||||
|
print("+ randaugment (n={}, m={})".format(n_, m_))
|
||||||
|
tfm_train += [RandAugment(n_, m_)]
|
||||||
|
|
||||||
|
if "randaugment_fixmatch" in choices:
|
||||||
|
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||||
|
print("+ randaugment_fixmatch (n={})".format(n_))
|
||||||
|
tfm_train += [RandAugmentFixMatch(n_)]
|
||||||
|
|
||||||
|
if "randaugment2" in choices:
|
||||||
|
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||||
|
print("+ randaugment2 (n={})".format(n_))
|
||||||
|
tfm_train += [RandAugment2(n_)]
|
||||||
|
|
||||||
|
if "colorjitter" in choices:
|
||||||
|
print("+ color jitter")
|
||||||
|
tfm_train += [
|
||||||
|
ColorJitter(
|
||||||
|
brightness=cfg.INPUT.COLORJITTER_B,
|
||||||
|
contrast=cfg.INPUT.COLORJITTER_C,
|
||||||
|
saturation=cfg.INPUT.COLORJITTER_S,
|
||||||
|
hue=cfg.INPUT.COLORJITTER_H,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if "randomgrayscale" in choices:
|
||||||
|
print("+ random gray scale")
|
||||||
|
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
|
||||||
|
|
||||||
|
if "gaussian_blur" in choices:
|
||||||
|
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
|
||||||
|
tfm_train += [
|
||||||
|
RandomApply([GaussianBlur(cfg.INPUT.GB_K)], p=cfg.INPUT.GB_P)
|
||||||
|
]
|
||||||
|
|
||||||
|
print("+ to torch tensor of range [0, 1]")
|
||||||
|
tfm_train += [ToTensor()]
|
||||||
|
|
||||||
|
if "cutout" in choices:
|
||||||
|
cutout_n = cfg.INPUT.CUTOUT_N
|
||||||
|
cutout_len = cfg.INPUT.CUTOUT_LEN
|
||||||
|
print("+ cutout (n_holes={}, length={})".format(cutout_n, cutout_len))
|
||||||
|
tfm_train += [Cutout(cutout_n, cutout_len)]
|
||||||
|
|
||||||
|
if "normalize" in choices:
|
||||||
|
print(
|
||||||
|
"+ normalization (mean={}, "
|
||||||
|
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
|
||||||
|
)
|
||||||
|
tfm_train += [normalize]
|
||||||
|
|
||||||
|
if "gaussian_noise" in choices:
|
||||||
|
print(
|
||||||
|
"+ gaussian noise (mean={}, std={})".format(
|
||||||
|
cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
|
||||||
|
|
||||||
|
if "instance_norm" in choices:
|
||||||
|
print("+ instance normalization")
|
||||||
|
tfm_train += [InstanceNormalization()]
|
||||||
|
|
||||||
|
tfm_train = Compose(tfm_train)
|
||||||
|
|
||||||
|
return tfm_train
|
||||||
|
|
||||||
|
|
||||||
|
def _build_transform_test(cfg, choices, target_size, normalize):
|
||||||
|
print("Building transform_test")
|
||||||
|
tfm_test = []
|
||||||
|
|
||||||
|
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||||
|
|
||||||
|
print(f"+ resize the smaller edge to {max(cfg.INPUT.SIZE)}")
|
||||||
|
tfm_test += [Resize(max(cfg.INPUT.SIZE), interpolation=interp_mode)]
|
||||||
|
|
||||||
|
print(f"+ {target_size} center crop")
|
||||||
|
tfm_test += [CenterCrop(cfg.INPUT.SIZE)]
|
||||||
|
|
||||||
|
print("+ to torch tensor of range [0, 1]")
|
||||||
|
tfm_test += [ToTensor()]
|
||||||
|
|
||||||
|
if "normalize" in choices:
|
||||||
|
print(
|
||||||
|
"+ normalization (mean={}, "
|
||||||
|
"std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
|
||||||
|
)
|
||||||
|
tfm_test += [normalize]
|
||||||
|
|
||||||
|
if "instance_norm" in choices:
|
||||||
|
print("+ instance normalization")
|
||||||
|
tfm_test += [InstanceNormalization()]
|
||||||
|
|
||||||
|
tfm_test = Compose(tfm_test)
|
||||||
|
|
||||||
|
return tfm_test
|
||||||
6
Dassl.ProGrad.pytorch/dassl/engine/__init__.py
Normal file
6
Dassl.ProGrad.pytorch/dassl/engine/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .build import TRAINER_REGISTRY, build_trainer # isort:skip
|
||||||
|
from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet # isort:skip
|
||||||
|
|
||||||
|
from .da import *
|
||||||
|
from .dg import *
|
||||||
|
from .ssl import *
|
||||||
11
Dassl.ProGrad.pytorch/dassl/engine/build.py
Normal file
11
Dassl.ProGrad.pytorch/dassl/engine/build.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from dassl.utils import Registry, check_availability
|
||||||
|
|
||||||
|
TRAINER_REGISTRY = Registry("TRAINER")
|
||||||
|
|
||||||
|
|
||||||
|
def build_trainer(cfg):
|
||||||
|
avai_trainers = TRAINER_REGISTRY.registered_names()
|
||||||
|
check_availability(cfg.TRAINER.NAME, avai_trainers)
|
||||||
|
if cfg.VERBOSE:
|
||||||
|
print("Loading trainer: {}".format(cfg.TRAINER.NAME))
|
||||||
|
return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg)
|
||||||
9
Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py
Normal file
9
Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from .mcd import MCD
|
||||||
|
from .mme import MME
|
||||||
|
from .adda import ADDA
|
||||||
|
from .dael import DAEL
|
||||||
|
from .dann import DANN
|
||||||
|
from .adabn import AdaBN
|
||||||
|
from .m3sda import M3SDA
|
||||||
|
from .source_only import SourceOnly
|
||||||
|
from .self_ensembling import SelfEnsembling
|
||||||
38
Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
Normal file
38
Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from dassl.utils import check_isfile
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class AdaBN(TrainerXU):
|
||||||
|
"""Adaptive Batch Normalization.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/1603.04779.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.done_reset_bn_stats = False
|
||||||
|
|
||||||
|
def check_cfg(self, cfg):
|
||||||
|
assert check_isfile(
|
||||||
|
cfg.MODEL.INIT_WEIGHTS
|
||||||
|
), "The weights of source model must be provided"
|
||||||
|
|
||||||
|
def before_epoch(self):
|
||||||
|
if not self.done_reset_bn_stats:
|
||||||
|
for m in self.model.modules():
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("BatchNorm") != -1:
|
||||||
|
m.reset_running_stats()
|
||||||
|
|
||||||
|
self.done_reset_bn_stats = True
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
input_u = batch_u["img"].to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model(input_u)
|
||||||
|
|
||||||
|
return None
|
||||||
85
Dassl.ProGrad.pytorch/dassl/engine/da/adda.py
Normal file
85
Dassl.ProGrad.pytorch/dassl/engine/da/adda.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||||
|
from dassl.utils import check_isfile, count_num_param, open_specified_layers
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.modeling import build_head
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class ADDA(TrainerXU):
|
||||||
|
"""Adversarial Discriminative Domain Adaptation.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/1702.05464.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.open_layers = ["backbone"]
|
||||||
|
if isinstance(self.model.head, nn.Module):
|
||||||
|
self.open_layers.append("head")
|
||||||
|
|
||||||
|
self.source_model = copy.deepcopy(self.model)
|
||||||
|
self.source_model.eval()
|
||||||
|
for param in self.source_model.parameters():
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
self.build_critic()
|
||||||
|
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
def check_cfg(self, cfg):
|
||||||
|
assert check_isfile(
|
||||||
|
cfg.MODEL.INIT_WEIGHTS
|
||||||
|
), "The weights of source model must be provided"
|
||||||
|
|
||||||
|
def build_critic(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
|
print("Building critic network")
|
||||||
|
fdim = self.model.fdim
|
||||||
|
critic_body = build_head(
|
||||||
|
"mlp",
|
||||||
|
verbose=cfg.VERBOSE,
|
||||||
|
in_features=fdim,
|
||||||
|
hidden_layers=[fdim, fdim // 2],
|
||||||
|
activation="leaky_relu",
|
||||||
|
)
|
||||||
|
self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
|
||||||
|
print("# params: {:,}".format(count_num_param(self.critic)))
|
||||||
|
self.critic.to(self.device)
|
||||||
|
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
|
||||||
|
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
|
||||||
|
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
open_specified_layers(self.model, self.open_layers)
|
||||||
|
input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
|
||||||
|
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
|
||||||
|
|
||||||
|
_, feat_x = self.source_model(input_x, return_feature=True)
|
||||||
|
_, feat_u = self.model(input_u, return_feature=True)
|
||||||
|
|
||||||
|
logit_xd = self.critic(feat_x)
|
||||||
|
logit_ud = self.critic(feat_u.detach())
|
||||||
|
|
||||||
|
loss_critic = self.bce(logit_xd, domain_x)
|
||||||
|
loss_critic += self.bce(logit_ud, domain_u)
|
||||||
|
self.model_backward_and_update(loss_critic, "critic")
|
||||||
|
|
||||||
|
logit_ud = self.critic(feat_u)
|
||||||
|
loss_model = self.bce(logit_ud, 1 - domain_u)
|
||||||
|
self.model_backward_and_update(loss_model, "model")
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_critic": loss_critic.item(),
|
||||||
|
"loss_model": loss_model.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
210
Dassl.ProGrad.pytorch/dassl/engine/da/dael.py
Normal file
210
Dassl.ProGrad.pytorch/dassl/engine/da/dael.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from dassl.data import DataManager
|
||||||
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||||
|
from dassl.utils import count_num_param
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.metrics import compute_accuracy
|
||||||
|
from dassl.engine.trainer import SimpleNet
|
||||||
|
from dassl.data.transforms import build_transform
|
||||||
|
from dassl.modeling.ops.utils import create_onehot
|
||||||
|
|
||||||
|
|
||||||
|
class Experts(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_source, fdim, num_classes):
|
||||||
|
super().__init__()
|
||||||
|
self.linears = nn.ModuleList(
|
||||||
|
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
|
||||||
|
)
|
||||||
|
self.softmax = nn.Softmax(dim=1)
|
||||||
|
|
||||||
|
def forward(self, i, x):
|
||||||
|
x = self.linears[i](x)
|
||||||
|
x = self.softmax(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class DAEL(TrainerXU):
|
||||||
|
"""Domain Adaptive Ensemble Learning.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2003.07325.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||||
|
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||||
|
if n_domain <= 0:
|
||||||
|
n_domain = self.num_source_domains
|
||||||
|
self.split_batch = batch_size // n_domain
|
||||||
|
self.n_domain = n_domain
|
||||||
|
|
||||||
|
self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
|
||||||
|
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
|
||||||
|
|
||||||
|
def check_cfg(self, cfg):
|
||||||
|
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
|
||||||
|
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
|
||||||
|
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
|
||||||
|
|
||||||
|
def build_data_loader(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
tfm_train = build_transform(cfg, is_train=True)
|
||||||
|
custom_tfm_train = [tfm_train]
|
||||||
|
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
|
||||||
|
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||||
|
custom_tfm_train += [tfm_train_strong]
|
||||||
|
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||||
|
self.train_loader_x = dm.train_loader_x
|
||||||
|
self.train_loader_u = dm.train_loader_u
|
||||||
|
self.val_loader = dm.val_loader
|
||||||
|
self.test_loader = dm.test_loader
|
||||||
|
self.num_classes = dm.num_classes
|
||||||
|
self.num_source_domains = dm.num_source_domains
|
||||||
|
self.lab2cname = dm.lab2cname
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
|
print("Building F")
|
||||||
|
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||||
|
self.F.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.F)))
|
||||||
|
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||||
|
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||||
|
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||||
|
fdim = self.F.fdim
|
||||||
|
|
||||||
|
print("Building E")
|
||||||
|
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
|
||||||
|
self.E.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.E)))
|
||||||
|
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
|
||||||
|
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
|
||||||
|
self.register_model("E", self.E, self.optim_E, self.sched_E)
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
parsed_data = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
|
||||||
|
|
||||||
|
input_x = torch.split(input_x, self.split_batch, 0)
|
||||||
|
input_x2 = torch.split(input_x2, self.split_batch, 0)
|
||||||
|
label_x = torch.split(label_x, self.split_batch, 0)
|
||||||
|
domain_x = torch.split(domain_x, self.split_batch, 0)
|
||||||
|
domain_x = [d[0].item() for d in domain_x]
|
||||||
|
|
||||||
|
# Generate pseudo label
|
||||||
|
with torch.no_grad():
|
||||||
|
feat_u = self.F(input_u)
|
||||||
|
pred_u = []
|
||||||
|
for k in range(self.num_source_domains):
|
||||||
|
pred_uk = self.E(k, feat_u)
|
||||||
|
pred_uk = pred_uk.unsqueeze(1)
|
||||||
|
pred_u.append(pred_uk)
|
||||||
|
pred_u = torch.cat(pred_u, 1) # (B, K, C)
|
||||||
|
# Get the highest probability and index (label) for each expert
|
||||||
|
experts_max_p, experts_max_idx = pred_u.max(2) # (B, K)
|
||||||
|
# Get the most confident expert
|
||||||
|
max_expert_p, max_expert_idx = experts_max_p.max(1) # (B)
|
||||||
|
pseudo_label_u = []
|
||||||
|
for i, experts_label in zip(max_expert_idx, experts_max_idx):
|
||||||
|
pseudo_label_u.append(experts_label[i])
|
||||||
|
pseudo_label_u = torch.stack(pseudo_label_u, 0)
|
||||||
|
pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
|
||||||
|
pseudo_label_u = pseudo_label_u.to(self.device)
|
||||||
|
label_u_mask = (max_expert_p >= self.conf_thre).float()
|
||||||
|
|
||||||
|
loss_x = 0
|
||||||
|
loss_cr = 0
|
||||||
|
acc_x = 0
|
||||||
|
|
||||||
|
feat_x = [self.F(x) for x in input_x]
|
||||||
|
feat_x2 = [self.F(x) for x in input_x2]
|
||||||
|
feat_u2 = self.F(input_u2)
|
||||||
|
|
||||||
|
for feat_xi, feat_x2i, label_xi, i in zip(
|
||||||
|
feat_x, feat_x2, label_x, domain_x
|
||||||
|
):
|
||||||
|
cr_s = [j for j in domain_x if j != i]
|
||||||
|
|
||||||
|
# Learning expert
|
||||||
|
pred_xi = self.E(i, feat_xi)
|
||||||
|
loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
|
||||||
|
expert_label_xi = pred_xi.detach()
|
||||||
|
acc_x += compute_accuracy(pred_xi.detach(),
|
||||||
|
label_xi.max(1)[1])[0].item()
|
||||||
|
|
||||||
|
# Consistency regularization
|
||||||
|
cr_pred = []
|
||||||
|
for j in cr_s:
|
||||||
|
pred_j = self.E(j, feat_x2i)
|
||||||
|
pred_j = pred_j.unsqueeze(1)
|
||||||
|
cr_pred.append(pred_j)
|
||||||
|
cr_pred = torch.cat(cr_pred, 1)
|
||||||
|
cr_pred = cr_pred.mean(1)
|
||||||
|
loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()
|
||||||
|
|
||||||
|
loss_x /= self.n_domain
|
||||||
|
loss_cr /= self.n_domain
|
||||||
|
acc_x /= self.n_domain
|
||||||
|
|
||||||
|
# Unsupervised loss
|
||||||
|
pred_u = []
|
||||||
|
for k in range(self.num_source_domains):
|
||||||
|
pred_uk = self.E(k, feat_u2)
|
||||||
|
pred_uk = pred_uk.unsqueeze(1)
|
||||||
|
pred_u.append(pred_uk)
|
||||||
|
pred_u = torch.cat(pred_u, 1)
|
||||||
|
pred_u = pred_u.mean(1)
|
||||||
|
l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
|
||||||
|
loss_u = (l_u * label_u_mask).mean()
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
loss += loss_x
|
||||||
|
loss += loss_cr
|
||||||
|
loss += loss_u * self.weight_u
|
||||||
|
self.model_backward_and_update(loss)
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_x": loss_x.item(),
|
||||||
|
"acc_x": acc_x,
|
||||||
|
"loss_cr": loss_cr.item(),
|
||||||
|
"loss_u": loss_u.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def parse_batch_train(self, batch_x, batch_u):
|
||||||
|
input_x = batch_x["img"]
|
||||||
|
input_x2 = batch_x["img2"]
|
||||||
|
label_x = batch_x["label"]
|
||||||
|
domain_x = batch_x["domain"]
|
||||||
|
input_u = batch_u["img"]
|
||||||
|
input_u2 = batch_u["img2"]
|
||||||
|
|
||||||
|
label_x = create_onehot(label_x, self.num_classes)
|
||||||
|
|
||||||
|
input_x = input_x.to(self.device)
|
||||||
|
input_x2 = input_x2.to(self.device)
|
||||||
|
label_x = label_x.to(self.device)
|
||||||
|
input_u = input_u.to(self.device)
|
||||||
|
input_u2 = input_u2.to(self.device)
|
||||||
|
|
||||||
|
return input_x, input_x2, label_x, domain_x, input_u, input_u2
|
||||||
|
|
||||||
|
def model_inference(self, input):
|
||||||
|
f = self.F(input)
|
||||||
|
p = []
|
||||||
|
for k in range(self.num_source_domains):
|
||||||
|
p_k = self.E(k, f)
|
||||||
|
p_k = p_k.unsqueeze(1)
|
||||||
|
p.append(p_k)
|
||||||
|
p = torch.cat(p, 1)
|
||||||
|
p = p.mean(1)
|
||||||
|
return p
|
||||||
78
Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
Normal file
78
Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||||
|
from dassl.utils import count_num_param
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.metrics import compute_accuracy
|
||||||
|
from dassl.modeling import build_head
|
||||||
|
from dassl.modeling.ops import ReverseGrad
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class DANN(TrainerXU):
|
||||||
|
"""Domain-Adversarial Neural Networks.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/1505.07818.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.build_critic()
|
||||||
|
self.ce = nn.CrossEntropyLoss()
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
def build_critic(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
|
print("Building critic network")
|
||||||
|
fdim = self.model.fdim
|
||||||
|
critic_body = build_head(
|
||||||
|
"mlp",
|
||||||
|
verbose=cfg.VERBOSE,
|
||||||
|
in_features=fdim,
|
||||||
|
hidden_layers=[fdim, fdim],
|
||||||
|
activation="leaky_relu",
|
||||||
|
)
|
||||||
|
self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
|
||||||
|
print("# params: {:,}".format(count_num_param(self.critic)))
|
||||||
|
self.critic.to(self.device)
|
||||||
|
self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
|
||||||
|
self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
|
||||||
|
self.register_model("critic", self.critic, self.optim_c, self.sched_c)
|
||||||
|
self.revgrad = ReverseGrad()
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
|
||||||
|
domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)
|
||||||
|
|
||||||
|
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||||
|
progress = global_step / (self.max_epoch * self.num_batches)
|
||||||
|
lmda = 2 / (1 + np.exp(-10 * progress)) - 1
|
||||||
|
|
||||||
|
logit_x, feat_x = self.model(input_x, return_feature=True)
|
||||||
|
_, feat_u = self.model(input_u, return_feature=True)
|
||||||
|
|
||||||
|
loss_x = self.ce(logit_x, label_x)
|
||||||
|
|
||||||
|
feat_x = self.revgrad(feat_x, grad_scaling=lmda)
|
||||||
|
feat_u = self.revgrad(feat_u, grad_scaling=lmda)
|
||||||
|
output_xd = self.critic(feat_x)
|
||||||
|
output_ud = self.critic(feat_u)
|
||||||
|
loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)
|
||||||
|
|
||||||
|
loss = loss_x + loss_d
|
||||||
|
self.model_backward_and_update(loss)
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_x": loss_x.item(),
|
||||||
|
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||||
|
"loss_d": loss_d.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
208
Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py
Normal file
208
Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||||
|
from dassl.utils import count_num_param
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.engine.trainer import SimpleNet
|
||||||
|
|
||||||
|
|
||||||
|
class PairClassifiers(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, fdim, num_classes):
|
||||||
|
super().__init__()
|
||||||
|
self.c1 = nn.Linear(fdim, num_classes)
|
||||||
|
self.c2 = nn.Linear(fdim, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
z1 = self.c1(x)
|
||||||
|
if not self.training:
|
||||||
|
return z1
|
||||||
|
z2 = self.c2(x)
|
||||||
|
return z1, z2
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class M3SDA(TrainerXU):
|
||||||
|
"""Moment Matching for Multi-Source Domain Adaptation.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/1812.01754.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||||
|
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||||
|
if n_domain <= 0:
|
||||||
|
n_domain = self.num_source_domains
|
||||||
|
self.split_batch = batch_size // n_domain
|
||||||
|
self.n_domain = n_domain
|
||||||
|
|
||||||
|
self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F
|
||||||
|
self.lmda = cfg.TRAINER.M3SDA.LMDA
|
||||||
|
|
||||||
|
def check_cfg(self, cfg):
|
||||||
|
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
|
||||||
|
assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
|
print("Building F")
|
||||||
|
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||||
|
self.F.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.F)))
|
||||||
|
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||||
|
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||||
|
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||||
|
fdim = self.F.fdim
|
||||||
|
|
||||||
|
print("Building C")
|
||||||
|
self.C = nn.ModuleList(
|
||||||
|
[
|
||||||
|
PairClassifiers(fdim, self.num_classes)
|
||||||
|
for _ in range(self.num_source_domains)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.C.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.C)))
|
||||||
|
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
|
||||||
|
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
|
||||||
|
self.register_model("C", self.C, self.optim_C, self.sched_C)
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
parsed = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
input_x, label_x, domain_x, input_u = parsed
|
||||||
|
|
||||||
|
input_x = torch.split(input_x, self.split_batch, 0)
|
||||||
|
label_x = torch.split(label_x, self.split_batch, 0)
|
||||||
|
domain_x = torch.split(domain_x, self.split_batch, 0)
|
||||||
|
domain_x = [d[0].item() for d in domain_x]
|
||||||
|
|
||||||
|
# Step A
|
||||||
|
loss_x = 0
|
||||||
|
feat_x = []
|
||||||
|
|
||||||
|
for x, y, d in zip(input_x, label_x, domain_x):
|
||||||
|
f = self.F(x)
|
||||||
|
z1, z2 = self.C[d](f)
|
||||||
|
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
|
||||||
|
|
||||||
|
feat_x.append(f)
|
||||||
|
|
||||||
|
loss_x /= self.n_domain
|
||||||
|
|
||||||
|
feat_u = self.F(input_u)
|
||||||
|
loss_msda = self.moment_distance(feat_x, feat_u)
|
||||||
|
|
||||||
|
loss_step_A = loss_x + loss_msda * self.lmda
|
||||||
|
self.model_backward_and_update(loss_step_A)
|
||||||
|
|
||||||
|
# Step B
|
||||||
|
with torch.no_grad():
|
||||||
|
feat_u = self.F(input_u)
|
||||||
|
|
||||||
|
loss_x, loss_dis = 0, 0
|
||||||
|
|
||||||
|
for x, y, d in zip(input_x, label_x, domain_x):
|
||||||
|
with torch.no_grad():
|
||||||
|
f = self.F(x)
|
||||||
|
z1, z2 = self.C[d](f)
|
||||||
|
loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)
|
||||||
|
|
||||||
|
z1, z2 = self.C[d](feat_u)
|
||||||
|
p1 = F.softmax(z1, 1)
|
||||||
|
p2 = F.softmax(z2, 1)
|
||||||
|
loss_dis += self.discrepancy(p1, p2)
|
||||||
|
|
||||||
|
loss_x /= self.n_domain
|
||||||
|
loss_dis /= self.n_domain
|
||||||
|
|
||||||
|
loss_step_B = loss_x - loss_dis
|
||||||
|
self.model_backward_and_update(loss_step_B, "C")
|
||||||
|
|
||||||
|
# Step C
|
||||||
|
for _ in range(self.n_step_F):
|
||||||
|
feat_u = self.F(input_u)
|
||||||
|
|
||||||
|
loss_dis = 0
|
||||||
|
|
||||||
|
for d in domain_x:
|
||||||
|
z1, z2 = self.C[d](feat_u)
|
||||||
|
p1 = F.softmax(z1, 1)
|
||||||
|
p2 = F.softmax(z2, 1)
|
||||||
|
loss_dis += self.discrepancy(p1, p2)
|
||||||
|
|
||||||
|
loss_dis /= self.n_domain
|
||||||
|
loss_step_C = loss_dis
|
||||||
|
|
||||||
|
self.model_backward_and_update(loss_step_C, "F")
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_step_A": loss_step_A.item(),
|
||||||
|
"loss_step_B": loss_step_B.item(),
|
||||||
|
"loss_step_C": loss_step_C.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def moment_distance(self, x, u):
|
||||||
|
# x (list): a list of feature matrix.
|
||||||
|
# u (torch.Tensor): feature matrix.
|
||||||
|
x_mean = [xi.mean(0) for xi in x]
|
||||||
|
u_mean = u.mean(0)
|
||||||
|
dist1 = self.pairwise_distance(x_mean, u_mean)
|
||||||
|
|
||||||
|
x_var = [xi.var(0) for xi in x]
|
||||||
|
u_var = u.var(0)
|
||||||
|
dist2 = self.pairwise_distance(x_var, u_var)
|
||||||
|
|
||||||
|
return (dist1+dist2) / 2
|
||||||
|
|
||||||
|
def pairwise_distance(self, x, u):
|
||||||
|
# x (list): a list of feature vector.
|
||||||
|
# u (torch.Tensor): feature vector.
|
||||||
|
dist = 0
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for xi in x:
|
||||||
|
dist += self.euclidean(xi, u)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
for i in range(len(x) - 1):
|
||||||
|
for j in range(i + 1, len(x)):
|
||||||
|
dist += self.euclidean(x[i], x[j])
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return dist / count
|
||||||
|
|
||||||
|
def euclidean(self, input1, input2):
|
||||||
|
return ((input1 - input2)**2).sum().sqrt()
|
||||||
|
|
||||||
|
def discrepancy(self, y1, y2):
|
||||||
|
return (y1 - y2).abs().mean()
|
||||||
|
|
||||||
|
def parse_batch_train(self, batch_x, batch_u):
|
||||||
|
input_x = batch_x["img"]
|
||||||
|
label_x = batch_x["label"]
|
||||||
|
domain_x = batch_x["domain"]
|
||||||
|
input_u = batch_u["img"]
|
||||||
|
|
||||||
|
input_x = input_x.to(self.device)
|
||||||
|
label_x = label_x.to(self.device)
|
||||||
|
input_u = input_u.to(self.device)
|
||||||
|
|
||||||
|
return input_x, label_x, domain_x, input_u
|
||||||
|
|
||||||
|
def model_inference(self, input):
|
||||||
|
f = self.F(input)
|
||||||
|
p = 0
|
||||||
|
for C_i in self.C:
|
||||||
|
z = C_i(f)
|
||||||
|
p += F.softmax(z, 1)
|
||||||
|
p = p / len(self.C)
|
||||||
|
return p
|
||||||
105
Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py
Normal file
105
Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||||
|
from dassl.utils import count_num_param
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.engine.trainer import SimpleNet
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class MCD(TrainerXU):
|
||||||
|
"""Maximum Classifier Discrepancy.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/1712.02560.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.n_step_F = cfg.TRAINER.MCD.N_STEP_F
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
|
print("Building F")
|
||||||
|
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||||
|
self.F.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.F)))
|
||||||
|
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||||
|
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||||
|
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||||
|
fdim = self.F.fdim
|
||||||
|
|
||||||
|
print("Building C1")
|
||||||
|
self.C1 = nn.Linear(fdim, self.num_classes)
|
||||||
|
self.C1.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.C1)))
|
||||||
|
self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
|
||||||
|
self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
|
||||||
|
self.register_model("C1", self.C1, self.optim_C1, self.sched_C1)
|
||||||
|
|
||||||
|
print("Building C2")
|
||||||
|
self.C2 = nn.Linear(fdim, self.num_classes)
|
||||||
|
self.C2.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.C2)))
|
||||||
|
self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
|
||||||
|
self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
|
||||||
|
self.register_model("C2", self.C2, self.optim_C2, self.sched_C2)
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
parsed = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
input_x, label_x, input_u = parsed
|
||||||
|
|
||||||
|
# Step A
|
||||||
|
feat_x = self.F(input_x)
|
||||||
|
logit_x1 = self.C1(feat_x)
|
||||||
|
logit_x2 = self.C2(feat_x)
|
||||||
|
loss_x1 = F.cross_entropy(logit_x1, label_x)
|
||||||
|
loss_x2 = F.cross_entropy(logit_x2, label_x)
|
||||||
|
loss_step_A = loss_x1 + loss_x2
|
||||||
|
self.model_backward_and_update(loss_step_A)
|
||||||
|
|
||||||
|
# Step B
|
||||||
|
with torch.no_grad():
|
||||||
|
feat_x = self.F(input_x)
|
||||||
|
logit_x1 = self.C1(feat_x)
|
||||||
|
logit_x2 = self.C2(feat_x)
|
||||||
|
loss_x1 = F.cross_entropy(logit_x1, label_x)
|
||||||
|
loss_x2 = F.cross_entropy(logit_x2, label_x)
|
||||||
|
loss_x = loss_x1 + loss_x2
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
feat_u = self.F(input_u)
|
||||||
|
pred_u1 = F.softmax(self.C1(feat_u), 1)
|
||||||
|
pred_u2 = F.softmax(self.C2(feat_u), 1)
|
||||||
|
loss_dis = self.discrepancy(pred_u1, pred_u2)
|
||||||
|
|
||||||
|
loss_step_B = loss_x - loss_dis
|
||||||
|
self.model_backward_and_update(loss_step_B, ["C1", "C2"])
|
||||||
|
|
||||||
|
# Step C
|
||||||
|
for _ in range(self.n_step_F):
|
||||||
|
feat_u = self.F(input_u)
|
||||||
|
pred_u1 = F.softmax(self.C1(feat_u), 1)
|
||||||
|
pred_u2 = F.softmax(self.C2(feat_u), 1)
|
||||||
|
loss_step_C = self.discrepancy(pred_u1, pred_u2)
|
||||||
|
self.model_backward_and_update(loss_step_C, "F")
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_step_A": loss_step_A.item(),
|
||||||
|
"loss_step_B": loss_step_B.item(),
|
||||||
|
"loss_step_C": loss_step_C.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def discrepancy(self, y1, y2):
|
||||||
|
return (y1 - y2).abs().mean()
|
||||||
|
|
||||||
|
def model_inference(self, input):
|
||||||
|
feat = self.F(input)
|
||||||
|
return self.C1(feat)
|
||||||
86
Dassl.ProGrad.pytorch/dassl/engine/da/mme.py
Normal file
86
Dassl.ProGrad.pytorch/dassl/engine/da/mme.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||||
|
from dassl.utils import count_num_param
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.metrics import compute_accuracy
|
||||||
|
from dassl.modeling.ops import ReverseGrad
|
||||||
|
from dassl.engine.trainer import SimpleNet
|
||||||
|
|
||||||
|
|
||||||
|
class Prototypes(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, fdim, num_classes, temp=0.05):
|
||||||
|
super().__init__()
|
||||||
|
self.prototypes = nn.Linear(fdim, num_classes, bias=False)
|
||||||
|
self.temp = temp
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.normalize(x, p=2, dim=1)
|
||||||
|
out = self.prototypes(x)
|
||||||
|
out = out / self.temp
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class MME(TrainerXU):
|
||||||
|
"""Minimax Entropy.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/1904.06487.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.lmda = cfg.TRAINER.MME.LMDA
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
|
print("Building F")
|
||||||
|
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||||
|
self.F.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.F)))
|
||||||
|
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||||
|
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||||
|
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||||
|
|
||||||
|
print("Building C")
|
||||||
|
self.C = Prototypes(self.F.fdim, self.num_classes)
|
||||||
|
self.C.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.C)))
|
||||||
|
self.optim_C = build_optimizer(self.C, cfg.OPTIM)
|
||||||
|
self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
|
||||||
|
self.register_model("C", self.C, self.optim_C, self.sched_C)
|
||||||
|
|
||||||
|
self.revgrad = ReverseGrad()
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
|
||||||
|
feat_x = self.F(input_x)
|
||||||
|
logit_x = self.C(feat_x)
|
||||||
|
loss_x = F.cross_entropy(logit_x, label_x)
|
||||||
|
self.model_backward_and_update(loss_x)
|
||||||
|
|
||||||
|
feat_u = self.F(input_u)
|
||||||
|
feat_u = self.revgrad(feat_u)
|
||||||
|
logit_u = self.C(feat_u)
|
||||||
|
prob_u = F.softmax(logit_u, 1)
|
||||||
|
loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
|
||||||
|
self.model_backward_and_update(loss_u * self.lmda)
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_x": loss_x.item(),
|
||||||
|
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||||
|
"loss_u": loss_u.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def model_inference(self, input):
|
||||||
|
return self.C(self.F(input))
|
||||||
78
Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py
Normal file
78
Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import copy
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.metrics import compute_accuracy
|
||||||
|
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class SelfEnsembling(TrainerXU):
|
||||||
|
"""Self-ensembling for visual domain adaptation.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/1706.05208.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA
|
||||||
|
self.conf_thre = cfg.TRAINER.SE.CONF_THRE
|
||||||
|
self.rampup = cfg.TRAINER.SE.RAMPUP
|
||||||
|
|
||||||
|
self.teacher = copy.deepcopy(self.model)
|
||||||
|
self.teacher.train()
|
||||||
|
for param in self.teacher.parameters():
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
def check_cfg(self, cfg):
|
||||||
|
assert cfg.DATALOADER.K_TRANSFORMS == 2
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
global_step = self.batch_idx + self.epoch * self.num_batches
|
||||||
|
parsed = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
input_x, label_x, input_u1, input_u2 = parsed
|
||||||
|
|
||||||
|
logit_x = self.model(input_x)
|
||||||
|
loss_x = F.cross_entropy(logit_x, label_x)
|
||||||
|
|
||||||
|
prob_u = F.softmax(self.model(input_u1), 1)
|
||||||
|
t_prob_u = F.softmax(self.teacher(input_u2), 1)
|
||||||
|
loss_u = ((prob_u - t_prob_u)**2).sum(1)
|
||||||
|
|
||||||
|
if self.conf_thre:
|
||||||
|
max_prob = t_prob_u.max(1)[0]
|
||||||
|
mask = (max_prob > self.conf_thre).float()
|
||||||
|
loss_u = (loss_u * mask).mean()
|
||||||
|
else:
|
||||||
|
weight_u = sigmoid_rampup(global_step, self.rampup)
|
||||||
|
loss_u = loss_u.mean() * weight_u
|
||||||
|
|
||||||
|
loss = loss_x + loss_u
|
||||||
|
self.model_backward_and_update(loss)
|
||||||
|
|
||||||
|
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
|
||||||
|
ema_model_update(self.model, self.teacher, ema_alpha)
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_x": loss_x.item(),
|
||||||
|
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
|
||||||
|
"loss_u": loss_u.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def parse_batch_train(self, batch_x, batch_u):
|
||||||
|
input_x = batch_x["img"][0]
|
||||||
|
label_x = batch_x["label"]
|
||||||
|
input_u = batch_u["img"]
|
||||||
|
input_u1, input_u2 = input_u
|
||||||
|
|
||||||
|
input_x = input_x.to(self.device)
|
||||||
|
label_x = label_x.to(self.device)
|
||||||
|
input_u1 = input_u1.to(self.device)
|
||||||
|
input_u2 = input_u2.to(self.device)
|
||||||
|
|
||||||
|
return input_x, label_x, input_u1, input_u2
|
||||||
34
Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py
Normal file
34
Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.metrics import compute_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class SourceOnly(TrainerXU):
|
||||||
|
"""Baseline model for domain adaptation, which is
|
||||||
|
trained using source data only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
input, label = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
output = self.model(input)
|
||||||
|
loss = F.cross_entropy(output, label)
|
||||||
|
self.model_backward_and_update(loss)
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"acc": compute_accuracy(output, label)[0].item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def parse_batch_train(self, batch_x, batch_u):
|
||||||
|
input = batch_x["img"]
|
||||||
|
label = batch_x["label"]
|
||||||
|
input = input.to(self.device)
|
||||||
|
label = label.to(self.device)
|
||||||
|
return input, label
|
||||||
4
Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py
Normal file
4
Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .ddaig import DDAIG
|
||||||
|
from .daeldg import DAELDG
|
||||||
|
from .vanilla import Vanilla
|
||||||
|
from .crossgrad import CrossGrad
|
||||||
83
Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py
Normal file
83
Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||||
|
from dassl.utils import count_num_param
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||||
|
from dassl.engine.trainer import SimpleNet
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class CrossGrad(TrainerX):
|
||||||
|
"""Cross-gradient training.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/1804.10745.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.eps_f = cfg.TRAINER.CG.EPS_F
|
||||||
|
self.eps_d = cfg.TRAINER.CG.EPS_D
|
||||||
|
self.alpha_f = cfg.TRAINER.CG.ALPHA_F
|
||||||
|
self.alpha_d = cfg.TRAINER.CG.ALPHA_D
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
|
print("Building F")
|
||||||
|
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
||||||
|
self.F.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.F)))
|
||||||
|
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||||
|
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||||
|
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||||
|
|
||||||
|
print("Building D")
|
||||||
|
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
|
||||||
|
self.D.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.D)))
|
||||||
|
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
|
||||||
|
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
|
||||||
|
self.register_model("D", self.D, self.optim_D, self.sched_D)
|
||||||
|
|
||||||
|
def forward_backward(self, batch):
|
||||||
|
input, label, domain = self.parse_batch_train(batch)
|
||||||
|
|
||||||
|
input.requires_grad = True
|
||||||
|
|
||||||
|
# Compute domain perturbation
|
||||||
|
loss_d = F.cross_entropy(self.D(input), domain)
|
||||||
|
loss_d.backward()
|
||||||
|
grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)
|
||||||
|
input_d = input.data + self.eps_f * grad_d
|
||||||
|
|
||||||
|
# Compute label perturbation
|
||||||
|
input.grad.data.zero_()
|
||||||
|
loss_f = F.cross_entropy(self.F(input), label)
|
||||||
|
loss_f.backward()
|
||||||
|
grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)
|
||||||
|
input_f = input.data + self.eps_d * grad_f
|
||||||
|
|
||||||
|
input = input.detach()
|
||||||
|
|
||||||
|
# Update label net
|
||||||
|
loss_f1 = F.cross_entropy(self.F(input), label)
|
||||||
|
loss_f2 = F.cross_entropy(self.F(input_d), label)
|
||||||
|
loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2
|
||||||
|
self.model_backward_and_update(loss_f, "F")
|
||||||
|
|
||||||
|
# Update domain net
|
||||||
|
loss_d1 = F.cross_entropy(self.D(input), domain)
|
||||||
|
loss_d2 = F.cross_entropy(self.D(input_f), domain)
|
||||||
|
loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2
|
||||||
|
self.model_backward_and_update(loss_d, "D")
|
||||||
|
|
||||||
|
loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def model_inference(self, input):
|
||||||
|
return self.F(input)
|
||||||
169
Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py
Normal file
169
Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from dassl.data import DataManager
|
||||||
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||||
|
from dassl.utils import count_num_param
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||||
|
from dassl.metrics import compute_accuracy
|
||||||
|
from dassl.engine.trainer import SimpleNet
|
||||||
|
from dassl.data.transforms import build_transform
|
||||||
|
from dassl.modeling.ops.utils import create_onehot
|
||||||
|
|
||||||
|
|
||||||
|
class Experts(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_source, fdim, num_classes):
|
||||||
|
super().__init__()
|
||||||
|
self.linears = nn.ModuleList(
|
||||||
|
[nn.Linear(fdim, num_classes) for _ in range(n_source)]
|
||||||
|
)
|
||||||
|
self.softmax = nn.Softmax(dim=1)
|
||||||
|
|
||||||
|
def forward(self, i, x):
|
||||||
|
x = self.linears[i](x)
|
||||||
|
x = self.softmax(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class DAELDG(TrainerX):
|
||||||
|
"""Domain Adaptive Ensemble Learning.
|
||||||
|
|
||||||
|
DG version: only use labeled source data.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2003.07325.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||||
|
batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||||
|
if n_domain <= 0:
|
||||||
|
n_domain = self.num_source_domains
|
||||||
|
self.split_batch = batch_size // n_domain
|
||||||
|
self.n_domain = n_domain
|
||||||
|
|
||||||
|
self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE
|
||||||
|
|
||||||
|
def check_cfg(self, cfg):
|
||||||
|
assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
|
||||||
|
assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0
|
||||||
|
|
||||||
|
def build_data_loader(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
tfm_train = build_transform(cfg, is_train=True)
|
||||||
|
custom_tfm_train = [tfm_train]
|
||||||
|
choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
|
||||||
|
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||||
|
custom_tfm_train += [tfm_train_strong]
|
||||||
|
dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||||
|
self.train_loader_x = dm.train_loader_x
|
||||||
|
self.train_loader_u = dm.train_loader_u
|
||||||
|
self.val_loader = dm.val_loader
|
||||||
|
self.test_loader = dm.test_loader
|
||||||
|
self.num_classes = dm.num_classes
|
||||||
|
self.num_source_domains = dm.num_source_domains
|
||||||
|
self.lab2cname = dm.lab2cname
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
|
print("Building F")
|
||||||
|
self.F = SimpleNet(cfg, cfg.MODEL, 0)
|
||||||
|
self.F.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.F)))
|
||||||
|
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||||
|
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||||
|
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||||
|
fdim = self.F.fdim
|
||||||
|
|
||||||
|
print("Building E")
|
||||||
|
self.E = Experts(self.num_source_domains, fdim, self.num_classes)
|
||||||
|
self.E.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.E)))
|
||||||
|
self.optim_E = build_optimizer(self.E, cfg.OPTIM)
|
||||||
|
self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
|
||||||
|
self.register_model("E", self.E, self.optim_E, self.sched_E)
|
||||||
|
|
||||||
|
def forward_backward(self, batch):
|
||||||
|
parsed_data = self.parse_batch_train(batch)
|
||||||
|
input, input2, label, domain = parsed_data
|
||||||
|
|
||||||
|
input = torch.split(input, self.split_batch, 0)
|
||||||
|
input2 = torch.split(input2, self.split_batch, 0)
|
||||||
|
label = torch.split(label, self.split_batch, 0)
|
||||||
|
domain = torch.split(domain, self.split_batch, 0)
|
||||||
|
domain = [d[0].item() for d in domain]
|
||||||
|
|
||||||
|
loss_x = 0
|
||||||
|
loss_cr = 0
|
||||||
|
acc = 0
|
||||||
|
|
||||||
|
feat = [self.F(x) for x in input]
|
||||||
|
feat2 = [self.F(x) for x in input2]
|
||||||
|
|
||||||
|
for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
|
||||||
|
cr_s = [j for j in domain if j != i]
|
||||||
|
|
||||||
|
# Learning expert
|
||||||
|
pred_i = self.E(i, feat_i)
|
||||||
|
loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
|
||||||
|
expert_label_i = pred_i.detach()
|
||||||
|
acc += compute_accuracy(pred_i.detach(),
|
||||||
|
label_i.max(1)[1])[0].item()
|
||||||
|
|
||||||
|
# Consistency regularization
|
||||||
|
cr_pred = []
|
||||||
|
for j in cr_s:
|
||||||
|
pred_j = self.E(j, feat2_i)
|
||||||
|
pred_j = pred_j.unsqueeze(1)
|
||||||
|
cr_pred.append(pred_j)
|
||||||
|
cr_pred = torch.cat(cr_pred, 1)
|
||||||
|
cr_pred = cr_pred.mean(1)
|
||||||
|
loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()
|
||||||
|
|
||||||
|
loss_x /= self.n_domain
|
||||||
|
loss_cr /= self.n_domain
|
||||||
|
acc /= self.n_domain
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
loss += loss_x
|
||||||
|
loss += loss_cr
|
||||||
|
self.model_backward_and_update(loss)
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_x": loss_x.item(),
|
||||||
|
"acc": acc,
|
||||||
|
"loss_cr": loss_cr.item()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def parse_batch_train(self, batch):
|
||||||
|
input = batch["img"]
|
||||||
|
input2 = batch["img2"]
|
||||||
|
label = batch["label"]
|
||||||
|
domain = batch["domain"]
|
||||||
|
|
||||||
|
label = create_onehot(label, self.num_classes)
|
||||||
|
|
||||||
|
input = input.to(self.device)
|
||||||
|
input2 = input2.to(self.device)
|
||||||
|
label = label.to(self.device)
|
||||||
|
|
||||||
|
return input, input2, label, domain
|
||||||
|
|
||||||
|
def model_inference(self, input):
|
||||||
|
f = self.F(input)
|
||||||
|
p = []
|
||||||
|
for k in range(self.num_source_domains):
|
||||||
|
p_k = self.E(k, f)
|
||||||
|
p_k = p_k.unsqueeze(1)
|
||||||
|
p.append(p_k)
|
||||||
|
p = torch.cat(p, 1)
|
||||||
|
p = p.mean(1)
|
||||||
|
return p
|
||||||
107
Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py
Normal file
107
Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.optim import build_optimizer, build_lr_scheduler
|
||||||
|
from dassl.utils import count_num_param
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||||
|
from dassl.modeling import build_network
|
||||||
|
from dassl.engine.trainer import SimpleNet
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class DDAIG(TrainerX):
|
||||||
|
"""Deep Domain-Adversarial Image Generation.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2003.06054.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.lmda = cfg.TRAINER.DDAIG.LMDA
|
||||||
|
self.clamp = cfg.TRAINER.DDAIG.CLAMP
|
||||||
|
self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN
|
||||||
|
self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX
|
||||||
|
self.warmup = cfg.TRAINER.DDAIG.WARMUP
|
||||||
|
self.alpha = cfg.TRAINER.DDAIG.ALPHA
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
|
||||||
|
print("Building F")
|
||||||
|
self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
|
||||||
|
self.F.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.F)))
|
||||||
|
self.optim_F = build_optimizer(self.F, cfg.OPTIM)
|
||||||
|
self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
|
||||||
|
self.register_model("F", self.F, self.optim_F, self.sched_F)
|
||||||
|
|
||||||
|
print("Building D")
|
||||||
|
self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
|
||||||
|
self.D.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.D)))
|
||||||
|
self.optim_D = build_optimizer(self.D, cfg.OPTIM)
|
||||||
|
self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
|
||||||
|
self.register_model("D", self.D, self.optim_D, self.sched_D)
|
||||||
|
|
||||||
|
print("Building G")
|
||||||
|
self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)
|
||||||
|
self.G.to(self.device)
|
||||||
|
print("# params: {:,}".format(count_num_param(self.G)))
|
||||||
|
self.optim_G = build_optimizer(self.G, cfg.OPTIM)
|
||||||
|
self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)
|
||||||
|
self.register_model("G", self.G, self.optim_G, self.sched_G)
|
||||||
|
|
||||||
|
def forward_backward(self, batch):
|
||||||
|
input, label, domain = self.parse_batch_train(batch)
|
||||||
|
|
||||||
|
#############
|
||||||
|
# Update G
|
||||||
|
#############
|
||||||
|
input_p = self.G(input, lmda=self.lmda)
|
||||||
|
if self.clamp:
|
||||||
|
input_p = torch.clamp(
|
||||||
|
input_p, min=self.clamp_min, max=self.clamp_max
|
||||||
|
)
|
||||||
|
loss_g = 0
|
||||||
|
# Minimize label loss
|
||||||
|
loss_g += F.cross_entropy(self.F(input_p), label)
|
||||||
|
# Maximize domain loss
|
||||||
|
loss_g -= F.cross_entropy(self.D(input_p), domain)
|
||||||
|
self.model_backward_and_update(loss_g, "G")
|
||||||
|
|
||||||
|
# Perturb data with new G
|
||||||
|
with torch.no_grad():
|
||||||
|
input_p = self.G(input, lmda=self.lmda)
|
||||||
|
if self.clamp:
|
||||||
|
input_p = torch.clamp(
|
||||||
|
input_p, min=self.clamp_min, max=self.clamp_max
|
||||||
|
)
|
||||||
|
|
||||||
|
#############
|
||||||
|
# Update F
|
||||||
|
#############
|
||||||
|
loss_f = F.cross_entropy(self.F(input), label)
|
||||||
|
if (self.epoch + 1) > self.warmup:
|
||||||
|
loss_fp = F.cross_entropy(self.F(input_p), label)
|
||||||
|
loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp
|
||||||
|
self.model_backward_and_update(loss_f, "F")
|
||||||
|
|
||||||
|
#############
|
||||||
|
# Update D
|
||||||
|
#############
|
||||||
|
loss_d = F.cross_entropy(self.D(input), domain)
|
||||||
|
self.model_backward_and_update(loss_d, "D")
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_g": loss_g.item(),
|
||||||
|
"loss_f": loss_f.item(),
|
||||||
|
"loss_d": loss_d.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def model_inference(self, input):
|
||||||
|
return self.F(input)
|
||||||
32
Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py
Normal file
32
Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerX
|
||||||
|
from dassl.metrics import compute_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class Vanilla(TrainerX):
|
||||||
|
"""Vanilla baseline."""
|
||||||
|
|
||||||
|
def forward_backward(self, batch):
|
||||||
|
input, label = self.parse_batch_train(batch)
|
||||||
|
output = self.model(input)
|
||||||
|
loss = F.cross_entropy(output, label)
|
||||||
|
self.model_backward_and_update(loss)
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"acc": compute_accuracy(output, label)[0].item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def parse_batch_train(self, batch):
|
||||||
|
input = batch["img"]
|
||||||
|
label = batch["label"]
|
||||||
|
input = input.to(self.device)
|
||||||
|
label = label.to(self.device)
|
||||||
|
return input, label
|
||||||
5
Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py
Normal file
5
Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .entmin import EntMin
|
||||||
|
from .fixmatch import FixMatch
|
||||||
|
from .mixmatch import MixMatch
|
||||||
|
from .mean_teacher import MeanTeacher
|
||||||
|
from .sup_baseline import SupBaseline
|
||||||
41
Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
Normal file
41
Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.metrics import compute_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class EntMin(TrainerXU):
|
||||||
|
"""Entropy Minimization.
|
||||||
|
|
||||||
|
http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.lmda = cfg.TRAINER.ENTMIN.LMDA
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
|
||||||
|
output_x = self.model(input_x)
|
||||||
|
loss_x = F.cross_entropy(output_x, label_x)
|
||||||
|
|
||||||
|
output_u = F.softmax(self.model(input_u), 1)
|
||||||
|
loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()
|
||||||
|
|
||||||
|
loss = loss_x + loss_u * self.lmda
|
||||||
|
|
||||||
|
self.model_backward_and_update(loss)
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_x": loss_x.item(),
|
||||||
|
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
|
||||||
|
"loss_u": loss_u.item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
112
Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
Normal file
112
Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from dassl.data import DataManager
|
||||||
|
from dassl.engine import TRAINER_REGISTRY, TrainerXU
|
||||||
|
from dassl.metrics import compute_accuracy
|
||||||
|
from dassl.data.transforms import build_transform
|
||||||
|
|
||||||
|
|
||||||
|
@TRAINER_REGISTRY.register()
|
||||||
|
class FixMatch(TrainerXU):
|
||||||
|
"""FixMatch: Simplifying Semi-Supervised Learning with
|
||||||
|
Consistency and Confidence.
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2001.07685.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U
|
||||||
|
self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE
|
||||||
|
|
||||||
|
def check_cfg(self, cfg):
|
||||||
|
assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0
|
||||||
|
|
||||||
|
def build_data_loader(self):
|
||||||
|
cfg = self.cfg
|
||||||
|
tfm_train = build_transform(cfg, is_train=True)
|
||||||
|
custom_tfm_train = [tfm_train]
|
||||||
|
choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS
|
||||||
|
tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
|
||||||
|
custom_tfm_train += [tfm_train_strong]
|
||||||
|
self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
|
||||||
|
self.train_loader_x = self.dm.train_loader_x
|
||||||
|
self.train_loader_u = self.dm.train_loader_u
|
||||||
|
self.val_loader = self.dm.val_loader
|
||||||
|
self.test_loader = self.dm.test_loader
|
||||||
|
self.num_classes = self.dm.num_classes
|
||||||
|
|
||||||
|
def assess_y_pred_quality(self, y_pred, y_true, mask):
|
||||||
|
n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()
|
||||||
|
acc_thre = n_masked_correct / (mask.sum() + 1e-5)
|
||||||
|
acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy
|
||||||
|
keep_rate = mask.sum() / mask.numel()
|
||||||
|
output = {
|
||||||
|
"acc_thre": acc_thre,
|
||||||
|
"acc_raw": acc_raw,
|
||||||
|
"keep_rate": keep_rate
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_backward(self, batch_x, batch_u):
|
||||||
|
parsed_data = self.parse_batch_train(batch_x, batch_u)
|
||||||
|
input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data
|
||||||
|
input_u = torch.cat([input_x, input_u], 0)
|
||||||
|
input_u2 = torch.cat([input_x2, input_u2], 0)
|
||||||
|
n_x = input_x.size(0)
|
||||||
|
|
||||||
|
# Generate pseudo labels
|
||||||
|
with torch.no_grad():
|
||||||
|
output_u = F.softmax(self.model(input_u), 1)
|
||||||
|
max_prob, label_u_pred = output_u.max(1)
|
||||||
|
mask_u = (max_prob >= self.conf_thre).float()
|
||||||
|
|
||||||
|
# Evaluate pseudo labels' accuracy
|
||||||
|
y_u_pred_stats = self.assess_y_pred_quality(
|
||||||
|
label_u_pred[n_x:], label_u, mask_u[n_x:]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Supervised loss
|
||||||
|
output_x = self.model(input_x)
|
||||||
|
loss_x = F.cross_entropy(output_x, label_x)
|
||||||
|
|
||||||
|
# Unsupervised loss
|
||||||
|
output_u = self.model(input_u2)
|
||||||
|
loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none")
|
||||||
|
loss_u = (loss_u * mask_u).mean()
|
||||||
|
|
||||||
|
loss = loss_x + loss_u * self.weight_u
|
||||||
|
self.model_backward_and_update(loss)
|
||||||
|
|
||||||
|
loss_summary = {
|
||||||
|
"loss_x": loss_x.item(),
|
||||||
|
"acc_x": compute_accuracy(output_x, label_x)[0].item(),
|
||||||
|
"loss_u": loss_u.item(),
|
||||||
|
"y_u_pred_acc_raw": y_u_pred_stats["acc_raw"],
|
||||||
|
"y_u_pred_acc_thre": y_u_pred_stats["acc_thre"],
|
||||||
|
"y_u_pred_keep": y_u_pred_stats["keep_rate"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self.batch_idx + 1) == self.num_batches:
|
||||||
|
self.update_lr()
|
||||||
|
|
||||||
|
return loss_summary
|
||||||
|
|
||||||
|
def parse_batch_train(self, batch_x, batch_u):
|
||||||
|
input_x = batch_x["img"]
|
||||||
|
input_x2 = batch_x["img2"]
|
||||||
|
label_x = batch_x["label"]
|
||||||
|
input_u = batch_u["img"]
|
||||||
|
input_u2 = batch_u["img2"]
|
||||||
|
# label_u is used only for evaluating pseudo labels' accuracy
|
||||||
|
label_u = batch_u["label"]
|
||||||
|
|
||||||
|
input_x = input_x.to(self.device)
|
||||||
|
input_x2 = input_x2.to(self.device)
|
||||||
|
label_x = label_x.to(self.device)
|
||||||
|
input_u = input_u.to(self.device)
|
||||||
|
input_u2 = input_u2.to(self.device)
|
||||||
|
label_u = label_u.to(self.device)
|
||||||
|
|
||||||
|
return input_x, input_x2, label_x, input_u, input_u2, label_u
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user