Upload to Main
This commit is contained in:
Generated
+8
@@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
+58
@@ -0,0 +1,58 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="44">
|
||||
<item index="0" class="java.lang.String" itemvalue="interrogate" />
|
||||
<item index="1" class="java.lang.String" itemvalue="pytest" />
|
||||
<item index="2" class="java.lang.String" itemvalue="cityscapesscripts" />
|
||||
<item index="3" class="java.lang.String" itemvalue="isort" />
|
||||
<item index="4" class="java.lang.String" itemvalue="xdoctest" />
|
||||
<item index="5" class="java.lang.String" itemvalue="codecov" />
|
||||
<item index="6" class="java.lang.String" itemvalue="flake8" />
|
||||
<item index="7" class="java.lang.String" itemvalue="pandas" />
|
||||
<item index="8" class="java.lang.String" itemvalue="scikit-image" />
|
||||
<item index="9" class="java.lang.String" itemvalue="scipy" />
|
||||
<item index="10" class="java.lang.String" itemvalue="scikit-learn" />
|
||||
<item index="11" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="12" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="13" class="java.lang.String" itemvalue="torchvision" />
|
||||
<item index="14" class="java.lang.String" itemvalue="sklearn" />
|
||||
<item index="15" class="java.lang.String" itemvalue="accelerate" />
|
||||
<item index="16" class="java.lang.String" itemvalue="fire" />
|
||||
<item index="17" class="java.lang.String" itemvalue="opencv-python-headless" />
|
||||
<item index="18" class="java.lang.String" itemvalue="tqdm" />
|
||||
<item index="19" class="java.lang.String" itemvalue="mat73" />
|
||||
<item index="20" class="java.lang.String" itemvalue="panda" />
|
||||
<item index="21" class="java.lang.String" itemvalue="imageio" />
|
||||
<item index="22" class="java.lang.String" itemvalue="opencv-python" />
|
||||
<item index="23" class="java.lang.String" itemvalue="h5py" />
|
||||
<item index="24" class="java.lang.String" itemvalue="matplotlib" />
|
||||
<item index="25" class="java.lang.String" itemvalue="pydensecrf" />
|
||||
<item index="26" class="java.lang.String" itemvalue="pyparsing" />
|
||||
<item index="27" class="java.lang.String" itemvalue="Markdown" />
|
||||
<item index="28" class="java.lang.String" itemvalue="Pillow" />
|
||||
<item index="29" class="java.lang.String" itemvalue="termcolor" />
|
||||
<item index="30" class="java.lang.String" itemvalue="spacy" />
|
||||
<item index="31" class="java.lang.String" itemvalue="transformers" />
|
||||
<item index="32" class="java.lang.String" itemvalue="datadings" />
|
||||
<item index="33" class="java.lang.String" itemvalue="nltk" />
|
||||
<item index="34" class="java.lang.String" itemvalue="wandb" />
|
||||
<item index="35" class="java.lang.String" itemvalue="webdataset" />
|
||||
<item index="36" class="java.lang.String" itemvalue="ipython" />
|
||||
<item index="37" class="java.lang.String" itemvalue="einops" />
|
||||
<item index="38" class="java.lang.String" itemvalue="ftfy" />
|
||||
<item index="39" class="java.lang.String" itemvalue="seaborn" />
|
||||
<item index="40" class="java.lang.String" itemvalue="tensorboard" />
|
||||
<item index="41" class="java.lang.String" itemvalue="torchattacks" />
|
||||
<item index="42" class="java.lang.String" itemvalue="ipdb" />
|
||||
<item index="43" class="java.lang.String" itemvalue="openml" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
Generated
+4
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (py37)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
Generated
+8
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/multi-research.iml" filepath="$PROJECT_DIR$/.idea/multi-research.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
Generated
+14
@@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.7 (py37)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
Generated
+25
@@ -0,0 +1,25 @@
|
||||
<component name="ProjectRunConfigurationManager">
|
||||
<configuration default="false" name="traial" type="PythonConfigurationType" factoryName="Python" singleton="false">
|
||||
<module name="multi-research" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
<env name="CUDA_VISIBLE_DEVICES" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="$USER_HOME$/anaconda3/envs/py37/bin/python" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||
<option name="IS_MODULE_SDK" value="false" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/train.py" />
|
||||
<option name="PARAMETERS" value="--root $USER_HOME$/Data_file/few_shot_data --seed 1 --trainer MaPLe --dataset-config-file configs/datasets/oxford_pets.yaml --config-file configs/trainers/MaPLe/vit_b16_t.yaml --output-dir output/DAPT DATASET.NUM_SHOTS 1 DATASET.SELECTION_RATIO 1.0" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
</component>
|
||||
@@ -0,0 +1,22 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Muhammad Uzair Khattak
|
||||
Copyright (c) 2021 Kaiyang Zhou
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,89 @@
|
||||
# DAPT [T-PAMI 2025]
|
||||
|
||||
|
||||
|
||||
> [**Decouple before Align: Visual Disentanglement Enhances Prompt Tuning**](https://arxiv.org/abs/2210.03117)<br>
|
||||
> [Fei Zhang](https://scholar.google.com/citations?hl=zh-CN&user=Omrg6UkAAAAJ), [Tianfei Zhou](https://www.tfzhou.com/), [Jiangchao Yao](https://sunarker.github.io/), [Ya Zhang](http://scholar.google.com/citations?user=pbjw9sMAAAAJ&hl=zh-CN), [Ivor W. Tsang](https://scholar.google.com.sg/citations?user=rJMOlVsAAAAJ&hl=en), [Yanfeng Wang](https://ieeexplore.ieee.org/author/37085615187)
|
||||
|
||||
|
||||
|
||||
Official implementation of the paper "[Decouple before Align: Visual Disentanglement
|
||||
Enhances Prompt Tuning](https://arxiv.org/pdf/2508.00395)".
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Highlights
|
||||
|
||||

|
||||
> **<p align="justify"> Abstract:** *Prompt tuning (PT), as an emerging resource-efficient fine-tuning paradigm, has showcased remarkable effectiveness in
|
||||
improving the task-specific transferability of vision-language models. This paper delves into a previously overlooked information
|
||||
asymmetry issue in PT, where the visual modality mostly conveys more context than the object-oriented textual modality. Correspondingly,
|
||||
coarsely aligning these two modalities could result in the biased attention, driving the model to merely focus on the context area. To
|
||||
address this, we propose DAPT, an effective PT framework based on an intuitive decouple-before-align concept. First, we propose to
|
||||
explicitly decouple the visual modality into the foreground and background representation via exploiting coarse-and-fine visual
|
||||
segmenting cues, and then both of these decoupled patterns are aligned with the original foreground texts and the hand-crafted
|
||||
background classes, thereby symmetrically strengthening the modal alignment. To further enhance the visual concentration, we propose
|
||||
a visual pull-push regularization tailored for the foreground-background patterns, directing the original visual representation towards
|
||||
unbiased attention on the region-of-interest object. We demonstrate the power of architecture-free DAPT through few-shot learning,
|
||||
base-to-novel generalization, and data-efficient learning, all of which yield superior performance across prevailing benchmarks.* </p>
|
||||
|
||||
## Main Contributions
|
||||
|
||||
1) **Multi-modal prompt learning:** Adapt CLIP using a novel prompting technique which prompts both the vision and language branch of CLIP.
|
||||
2) **Vision Decoupling:** We propose the visual disentanglement that exploits the
|
||||
visual cues of different levels to highlight the text-oriented
|
||||
object in the visual modality.
|
||||
3) **Fine-grained v.s. Coarse Visual Decoupling:** Different Masks are explored to serve effective decoupling visual signal.
|
||||
|
||||
## Results
|
||||
>SOTA performance is made, and such a method could be seamlessly integrated on other methods.
|
||||
|
||||

|
||||
|
||||
|
||||
## Installation
|
||||
For installation and other package requirements, please follow the instructions detailed in [INSTALL.md](docs/INSTALL.md).
|
||||
|
||||
## Data preparation
|
||||
Please follow the instructions at [DATASETS.md](docs/DATASETS.md) to prepare all datasets.
|
||||
|
||||
|
||||
**DAPT-S**: Then, you should download the **Segementation MASK** from [Here](https://drive.google.com/file/d/12BDM8X3ZynLVNqmkAMEvxVMk7vU9ILzv/view?usp=sharing), and put them correspondingly to each root data directory.
|
||||
|
||||
These masks are generated with [SEEM](https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once), and you can also generate masks by other tools, Just suit yourself.
|
||||
|
||||
(PS: It contains all segmentation masks for ImageNet, so it is convenient to use them for other studies)
|
||||
|
||||
|
||||
|
||||
## Training and Evaluation
|
||||
Please refer to the [RUN.md](docs/RUN.md) for detailed instructions on training, evaluating and reproducing the results using our pre-trained models. (All implementations could also refer to [MaPLe](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main))
|
||||
|
||||
|
||||
<hr />
|
||||
|
||||
## Citation
|
||||
If you use our work, please consider citing:
|
||||
```bibtex
|
||||
@ARTICLE{11106768,
|
||||
author={Zhang, Fei and Zhou, Tianfei and Yao, Jiangchao and Zhang, Ya and Tsang, Ivor W. and Wang, Yanfeng},
|
||||
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
|
||||
title={Decouple Before Align: Visual Disentanglement Enhances Prompt Tuning},
|
||||
year={2025},
|
||||
volume={47},
|
||||
number={11},
|
||||
pages={10619-10632},
|
||||
keywords={Visualization;Tuning;Semantics;Artificial intelligence;Object oriented modeling;Accuracy;Image recognition;Context modeling;Training;Technological innovation;Prompt tuning;visual disentanglement;multi-modal learning},
|
||||
doi={10.1109/TPAMI.2025.3594894}}
|
||||
```
|
||||
|
||||
## Contact
|
||||
If you have any questions, please create an issue on this repository or contact at ferenas@sjtu.edu.cn.
|
||||
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
Our code is based on [MaPLe](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main) repository. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well.
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .clip import *
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+221
@@ -0,0 +1,221 @@
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
from .model import build_model
|
||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
||||
|
||||
if torch.__version__.split(".") < ["1", "7", "1"]:
|
||||
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
||||
|
||||
|
||||
__all__ = ["available_models", "load", "tokenize"]
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
_MODELS = {
|
||||
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
||||
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
||||
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
||||
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
||||
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
||||
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, filename)
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
||||
|
||||
return download_target
|
||||
|
||||
|
||||
def _transform(n_px):
|
||||
return Compose([
|
||||
Resize(n_px, interpolation=BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
lambda image: image.convert("RGB"),
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available CLIP models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||
|
||||
device : Union[str, torch.device]
|
||||
The device to put the loaded model
|
||||
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
The CLIP model
|
||||
|
||||
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name])
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, _transform(model.visual.input_resolution)
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||
|
||||
def patch_device(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("prim::Constant"):
|
||||
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
||||
node.copyAttributes(device_node)
|
||||
|
||||
model.apply(patch_device)
|
||||
patch_device(model.encode_image)
|
||||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 on CPU
|
||||
if str(device) == "cpu":
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
||||
def patch_float(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("aten::to"):
|
||||
inputs = list(node.inputs())
|
||||
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||
if inputs[i].node()["value"] == 5:
|
||||
inputs[i].node().copyAttributes(float_node)
|
||||
|
||||
model.apply(patch_float)
|
||||
patch_float(model.encode_image)
|
||||
patch_float(model.encode_text)
|
||||
|
||||
model.float()
|
||||
|
||||
return model, _transform(model.input_resolution.item())
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
|
||||
truncate: bool
|
||||
Whether to truncate the text in case its encoding is longer than the context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
||||
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
if truncate:
|
||||
tokens = tokens[:context_length]
|
||||
tokens[-1] = eot_token
|
||||
else:
|
||||
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
+741
@@ -0,0 +1,741 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super().__init__()
|
||||
|
||||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
("-1", nn.AvgPool2d(stride)),
|
||||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||
]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.relu(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x, key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
|
||||
return x[0]
|
||||
|
||||
|
||||
class ModifiedResNet(nn.Module):
|
||||
"""
|
||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
"""
|
||||
|
||||
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(width)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
# residual layers
|
||||
self._inplanes = width # this is a *mutable* variable used during construction
|
||||
self.layer1 = self._make_layer(width, layers[0])
|
||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||
|
||||
embed_dim = width * 32 # the ResNet feature dimension
|
||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
|
||||
self._inplanes = planes * Bottleneck.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Bottleneck(self._inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
def stem(x):
|
||||
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
||||
x = self.relu(bn(conv(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
|
||||
x = x.type(self.conv1.weight.dtype)
|
||||
x = stem(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.attnpool(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAttentionBlock_IVLP(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, add_prompt=False,
|
||||
text_layer=False, i=0, design_details=None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
# Only add learnable tokens if flag is set True
|
||||
# For the first iteration i, we should not add the learnable parameters
|
||||
# as it is already been taken care of in the very start, for both text
|
||||
# and the visual branch
|
||||
self.text_layer = text_layer
|
||||
self.attn_mask = attn_mask
|
||||
if i != 0:
|
||||
self.add_prompt = add_prompt
|
||||
if self.add_prompt:
|
||||
if self.text_layer:
|
||||
self.n_ctx_text = design_details["language_ctx"] # hyperparameter
|
||||
ctx_vectors = torch.empty(self.n_ctx_text, d_model)
|
||||
else:
|
||||
self.n_ctx_visual = design_details["vision_ctx"] # hyperparameter
|
||||
ctx_vectors = torch.empty(self.n_ctx_visual, d_model)
|
||||
# Code snippet for per layer visual prompts
|
||||
nn.init.normal_(ctx_vectors, std=0.02)
|
||||
self.VPT_shallow = nn.Parameter(ctx_vectors)
|
||||
else:
|
||||
self.add_prompt = False
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Will need to append the learnable tokens for this layer here
|
||||
# Check if flag was set for this layer or not
|
||||
if self.add_prompt:
|
||||
# Also see if this is textual transformer layer or not
|
||||
if not self.text_layer:
|
||||
# Remove the outputs produced by learnable tokens of previous layer
|
||||
prefix = x[0:x.shape[0] - self.n_ctx_visual, :, :]
|
||||
# Create/configure learnable tokens of this layer
|
||||
visual_context = self.VPT_shallow.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
|
||||
# Add the learnable tokens of this layer with the input, by replacing the previous
|
||||
# layer learnable tokens
|
||||
x = torch.cat([prefix, visual_context], dim=0)
|
||||
else:
|
||||
# Appending the learnable tokens in different way
|
||||
# x -> [77, NCLS, DIM]
|
||||
# First remove the learnable tokens from previous layer
|
||||
prefix = x[:1, :, :]
|
||||
suffix = x[1 + self.n_ctx_text:, :, :]
|
||||
# Create/configure learnable tokens of this layer
|
||||
textual_context = self.VPT_shallow.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
|
||||
# Add the learnable tokens of this layer with the input, replaced by previous
|
||||
# layer learnable tokens
|
||||
x = torch.cat([prefix, textual_context, suffix], dim=0)
|
||||
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAttentionBlock_MaPLe(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details=None,
|
||||
text_layer=False, i=0):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
# For the first iteration i, we do not need to add the learnable parameters here
|
||||
# as it will be added in the beginning, for both text and the vision branch
|
||||
self.text_layer = text_layer
|
||||
self.attn_mask = attn_mask
|
||||
# This must be consistent with the config file prompt
|
||||
self.compound_prompt_nctx = design_details['maple_length']
|
||||
if i == 0:
|
||||
self.first_layer = True
|
||||
else:
|
||||
self.first_layer = False
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
|
||||
|
||||
def forward(self, inputs):
|
||||
# For the first layer, we do not need to add any duplicate, as it is already added
|
||||
# as the shallow version
|
||||
x = inputs[0]
|
||||
compound_prompts_deeper = inputs[1]
|
||||
counter = inputs[2]
|
||||
if not self.first_layer:
|
||||
if len(compound_prompts_deeper) > 0:
|
||||
# This means that deeper compound prompts are turned on
|
||||
# Here it behaves differently for text and visual side
|
||||
# Forward function is same for both
|
||||
|
||||
if not self.text_layer:
|
||||
# First check if the ith layer needs compound prompts or not
|
||||
if not (counter > len(compound_prompts_deeper) - 1):
|
||||
# Remove the outputs produced by learnable tokens of previous layer
|
||||
prefix = x[0:x.shape[0] - self.compound_prompt_nctx, :, :]
|
||||
# Create/configure learnable tokens of this layer
|
||||
visual_context = compound_prompts_deeper[counter] # extract the correct index
|
||||
visual_context = visual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
|
||||
# Add the learnable tokens of this layer with the input, by replacing previous
|
||||
# layer learnable tokens
|
||||
x = torch.cat([prefix, visual_context], dim=0)
|
||||
|
||||
# Once done, update the counter, so that the next time, it does not use same learnable tokens
|
||||
counter += 1
|
||||
else:
|
||||
# First check if the ith layer needs compound prompts or not
|
||||
if not (counter > len(compound_prompts_deeper) - 1):
|
||||
# Appending the learnable tokens in different way
|
||||
# x -> [77, NCLS, DIM]
|
||||
# First remove the learnable tokens from previous layer
|
||||
prefix = x[:1, :, :]
|
||||
suffix = x[1 + self.compound_prompt_nctx:, :, :]
|
||||
# Create/configure learnable tokens of this layer
|
||||
textual_context = compound_prompts_deeper[counter]
|
||||
textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
|
||||
# Add the learnable tokens of this layer with the input, replaced by previous
|
||||
# layer learnable tokens
|
||||
x = torch.cat([prefix, textual_context, suffix], dim=0)
|
||||
# Once done, update the counter, so that the next time, it does not use same learnable tokens
|
||||
counter += 1
|
||||
inp,attn_mask = self.attention(self.ln_1(x))
|
||||
x = x + inp
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
if self.text_layer:
|
||||
return [x, compound_prompts_deeper, counter]
|
||||
else:
|
||||
return [x, compound_prompts_deeper, counter, attn_mask] # return again as a list, so that nn.seq can work
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompts_needed=0,
|
||||
text_layer=False, design_details=None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
# Implements respective encoder blocks for a given design choice
|
||||
current_trainer = design_details['trainer']
|
||||
if current_trainer == 'IVLP' or current_trainer == 'VPT':
|
||||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock_IVLP(width, heads, attn_mask, True,
|
||||
text_layer, i,
|
||||
design_details) if prompts_needed > i
|
||||
else ResidualAttentionBlock_IVLP(width, heads, attn_mask, False,
|
||||
text_layer, i, design_details)
|
||||
for i in range(layers)])
|
||||
elif current_trainer == 'MaPLe':
|
||||
self.resblocks = nn.Sequential(
|
||||
*[ResidualAttentionBlock_MaPLe(width, heads, attn_mask, design_details, text_layer, i)
|
||||
for i in range(layers)])
|
||||
else:
|
||||
# Corresponds to default CoOp or CoCoOp
|
||||
assert current_trainer == 'CoOp' or current_trainer == 'CoCoOp'
|
||||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class Transformer_normal(nn.Module):
|
||||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompts_needed=0,
|
||||
text_layer=False, design_details=None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
# Implements respective encoder blocks for a given design choice
|
||||
# current_trainer = design_details['trainer']
|
||||
# if current_trainer == 'IVLP' or current_trainer == 'VPT':
|
||||
# self.resblocks = nn.Sequential(*[ResidualAttentionBlock_IVLP(width, heads, attn_mask, True,
|
||||
# text_layer, i,
|
||||
# design_details) if prompts_needed > i
|
||||
# else ResidualAttentionBlock_IVLP(width, heads, attn_mask, False,
|
||||
# text_layer, i, design_details)
|
||||
# for i in range(layers)])
|
||||
# elif current_trainer == 'MaPLe':
|
||||
# self.resblocks = nn.Sequential(
|
||||
# *[ResidualAttentionBlock_MaPLe(width, heads, attn_mask, design_details, text_layer, i)
|
||||
# for i in range(layers)])
|
||||
# else:
|
||||
# # Corresponds to default CoOp or CoCoOp
|
||||
# assert current_trainer == 'CoOp' or current_trainer == 'CoCoOp'
|
||||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int,
|
||||
output_dim: int, design_details):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
if design_details["vision_depth"] == 0:
|
||||
self.VPT_shallow = False
|
||||
else:
|
||||
self.VPT_shallow = True
|
||||
if self.VPT_shallow:
|
||||
# Add visual prompt tokens here
|
||||
n_ctx = design_details["vision_ctx"] # hyperparameter
|
||||
ctx_vectors = torch.empty(n_ctx, width)
|
||||
nn.init.normal_(ctx_vectors, std=0.02)
|
||||
self.VPT = nn.Parameter(ctx_vectors)
|
||||
# self.VPT.half()
|
||||
scale = width ** -0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
# hyper-parameter if need to add prompt embeddings inside to the input
|
||||
# of transformer block or not:
|
||||
self.prompt_till_layer_visual = design_details["vision_depth"]
|
||||
self.transformer = Transformer_normal(width, layers, heads, prompts_needed=self.prompt_till_layer_visual,
|
||||
design_details=design_details)
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
x = torch.cat(
|
||||
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||||
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
|
||||
# After positional embeddings, we will attach prompts with the model, remember only those
|
||||
# are trainable parameters here in whole image encoder.
|
||||
if self.VPT_shallow:
|
||||
visual_ctx = self.VPT.expand(x.shape[0], -1, -1).half()
|
||||
x = torch.cat([x, visual_ctx], dim=1)
|
||||
else:
|
||||
assert self.prompt_till_layer_visual == 0
|
||||
|
||||
# Normal code as before
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.ln_post(x[:, 0, :])
|
||||
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer_MaPLe(nn.Module):
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
|
||||
design_details):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
self.VPT_shallow = True
|
||||
scale = width ** -0.5
|
||||
self.patch_num = self.input_resolution // patch_size
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
# hyper-parameter if need to add prompt embeddings inside to the input
|
||||
# of transformer block or not:
|
||||
self.prompt_till_layer_visual = 0
|
||||
self.transformer = Transformer(width, layers, heads, design_details=design_details)
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor, shared_ctx, compound_deeper_prompts):
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, output_size]
|
||||
x = torch.cat(
|
||||
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||||
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
|
||||
# After positional embeddings, we will attach prompts with the model, remember only those
|
||||
# are trainable parameters here in whole image encoder.
|
||||
if self.VPT_shallow:
|
||||
visual_ctx = shared_ctx.expand(x.shape[0], -1, -1).half()
|
||||
x = torch.cat([x, visual_ctx], dim=1)
|
||||
else:
|
||||
assert self.prompt_till_layer_visual == 0
|
||||
|
||||
# Normal code as before
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
# Again combine the inputs, so nn.sequential can work
|
||||
outputs = self.transformer([x, compound_deeper_prompts, 0]) # third argument is counter
|
||||
x = outputs[0]
|
||||
mask = outputs[3]
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
visual_ctx = x[:,-shared_ctx.shape[0]:,:]
|
||||
x = self.ln_post(x[:, 0, :]) #only cls embedding is selected
|
||||
visual_ctx = self.ln_post(visual_ctx)
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
visual_ctx = visual_ctx @ self.proj
|
||||
return x,visual_ctx,mask
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int,
|
||||
design_details
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
trainer = design_details['trainer']
|
||||
|
||||
if isinstance(vision_layers, (tuple, list)):
|
||||
vision_heads = vision_width * 32 // 64
|
||||
self.visual = ModifiedResNet(
|
||||
layers=vision_layers,
|
||||
output_dim=embed_dim,
|
||||
heads=vision_heads,
|
||||
input_resolution=image_resolution,
|
||||
width=vision_width
|
||||
)
|
||||
|
||||
else:
|
||||
vision_heads = vision_width // 64
|
||||
if trainer == "MaPLe":
|
||||
self.visual = VisionTransformer_MaPLe(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim,
|
||||
design_details=design_details
|
||||
)
|
||||
self.visual_ori = VisionTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim,
|
||||
design_details=design_details
|
||||
)
|
||||
else:
|
||||
self.visual = VisionTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim,
|
||||
design_details=design_details
|
||||
)
|
||||
# hyper-parameter if need to add prompt embeddings inside to the input
|
||||
# of transformer block or not:
|
||||
prompt_till_layer_text = design_details['language_depth']
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask(),
|
||||
prompts_needed=prompt_till_layer_text,
|
||||
text_layer=True,
|
||||
design_details=design_details
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
def initialize_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
if isinstance(self.visual, ModifiedResNet):
|
||||
if self.visual.attnpool is not None:
|
||||
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
||||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
||||
|
||||
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
||||
for name, param in resnet_block.named_parameters():
|
||||
if name.endswith("bn3.weight"):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
attn_std = self.transformer.width ** -0.5
|
||||
fc_std = (2 * self.transformer.width) ** -0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.visual.conv1.weight.dtype
|
||||
|
||||
def encode_image(self, image):
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
||||
def encode_text(self, text):
|
||||
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, image, text):
|
||||
image_features = self.encode_image(image)
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
# normalized features
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logit_scale * text_features @ image_features.t()
|
||||
|
||||
# shape = [global_batch_size, global_batch_size]
|
||||
return logits_per_image, logits_per_text
|
||||
|
||||
|
||||
def convert_weights(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
if isinstance(l, nn.MultiheadAttention):
|
||||
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
tensor = getattr(l, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.half()
|
||||
|
||||
for name in ["text_projection", "proj"]:
|
||||
if hasattr(l, name):
|
||||
attr = getattr(l, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def build_model(state_dict: dict, design_details):
|
||||
vit = "visual.proj" in state_dict
|
||||
|
||||
if vit:
|
||||
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||
vision_layers = len(
|
||||
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in
|
||||
[1, 2, 3, 4]]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
vision_patch_size = None
|
||||
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||
image_resolution = output_width * 32
|
||||
|
||||
embed_dim = state_dict["text_projection"].shape[1]
|
||||
context_length = state_dict["positional_embedding"].shape[0]
|
||||
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
||||
|
||||
model = CLIP(
|
||||
embed_dim,
|
||||
image_resolution, vision_layers, vision_width, vision_patch_size,
|
||||
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, design_details
|
||||
)
|
||||
|
||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
convert_weights(model)
|
||||
try:
|
||||
model.load_state_dict(state_dict)
|
||||
except:
|
||||
missing_keys, _ = model.load_state_dict(state_dict, strict=False)
|
||||
print('Weights not found for some missing keys: ', missing_keys)
|
||||
return model.eval()
|
||||
@@ -0,0 +1,132 @@
|
||||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
return text
|
||||
+49409
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "Caltech101"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "DescribableTextures"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "EuroSAT"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "FGVCAircraft"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "Food101"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "ImageNet"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "ImageNetA"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "ImageNetR"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "ImageNetSketch"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "ImageNetV2"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "OxfordFlowers"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "OxfordPets"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "VOC12"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "StanfordCars"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "SUN397"
|
||||
@@ -0,0 +1,2 @@
|
||||
DATASET:
|
||||
NAME: "UCF101"
|
||||
@@ -0,0 +1,100 @@
|
||||
MODEL:
|
||||
META_ARCHITECTURE: "GeneralizedVLRCNN"
|
||||
WEIGHT: "swin_tiny_patch4_window7_224.pth"
|
||||
RPN_ONLY: True
|
||||
RPN_ARCHITECTURE: "VLDYHEAD"
|
||||
|
||||
BACKBONE:
|
||||
CONV_BODY: "SWINT-FPN-RETINANET"
|
||||
OUT_CHANNELS: 256
|
||||
FREEZE_CONV_BODY_AT: -1
|
||||
|
||||
LANGUAGE_BACKBONE:
|
||||
FREEZE: False
|
||||
MODEL_TYPE: "bert-base-uncased" # "roberta-base", "clip"
|
||||
MASK_SPECIAL: False
|
||||
|
||||
RPN:
|
||||
USE_FPN: True
|
||||
ANCHOR_SIZES: (64, 128, 256, 512, 1024)
|
||||
ANCHOR_STRIDE: (8, 16, 32, 64, 128)
|
||||
ASPECT_RATIOS: (1.0,)
|
||||
SCALES_PER_OCTAVE: 1
|
||||
|
||||
DYHEAD:
|
||||
CHANNELS: 256
|
||||
NUM_CONVS: 6
|
||||
USE_GN: True
|
||||
USE_DYRELU: True
|
||||
USE_DFCONV: True
|
||||
USE_DYFUSE: True
|
||||
TOPK: 9 # topk for selecting candidate positive samples from each level
|
||||
SCORE_AGG: "MEAN"
|
||||
LOG_SCALE: 0.0
|
||||
|
||||
FUSE_CONFIG:
|
||||
EARLY_FUSE_ON: True
|
||||
TYPE: "MHA-B"
|
||||
USE_CLASSIFICATION_LOSS: False
|
||||
USE_TOKEN_LOSS: False
|
||||
USE_CONTRASTIVE_ALIGN_LOSS: False
|
||||
CONTRASTIVE_HIDDEN_DIM: 64
|
||||
USE_DOT_PRODUCT_TOKEN_LOSS: True
|
||||
USE_FUSED_FEATURES_DOT_PRODUCT: True
|
||||
USE_LAYER_SCALE: True
|
||||
CLAMP_MIN_FOR_UNDERFLOW: True
|
||||
CLAMP_MAX_FOR_OVERFLOW: True
|
||||
CLAMP_BERTATTN_MIN_FOR_UNDERFLOW: True
|
||||
CLAMP_BERTATTN_MAX_FOR_OVERFLOW: True
|
||||
CLAMP_DOT_PRODUCT: True
|
||||
|
||||
USE_CHECKPOINT: True
|
||||
|
||||
TEST:
|
||||
DURING_TRAINING: False
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
# use for grounding model
|
||||
DATASETS:
|
||||
TRAIN: ("object365_dt_train", "mixed_train_no_coco", "flickr30k_train", )
|
||||
TEST: ("coco_2014_val", )
|
||||
DISABLE_SHUFFLE: False
|
||||
ADD_DET_PROMPT: False
|
||||
RANDOM_SAMPLE_NEG: 85
|
||||
CONTROL_PROB: (0.0, 0.0, 0.5, 0.0)
|
||||
|
||||
SEPARATION_TOKENS: ". "
|
||||
|
||||
INPUT:
|
||||
PIXEL_MEAN: [ 103.530, 116.280, 123.675 ]
|
||||
PIXEL_STD: [ 57.375, 57.120, 58.395 ]
|
||||
MIN_SIZE_TRAIN: 800
|
||||
MAX_SIZE_TRAIN: 1333
|
||||
MIN_SIZE_TEST: 800
|
||||
MAX_SIZE_TEST: 1333
|
||||
|
||||
AUGMENT:
|
||||
MULT_MIN_SIZE_TRAIN: (480,560,640,720,800)
|
||||
|
||||
DATALOADER:
|
||||
SIZE_DIVISIBILITY: 32
|
||||
|
||||
SOLVER:
|
||||
OPTIMIZER: ADAMW
|
||||
BASE_LR: 0.0001
|
||||
LANG_LR: 0.00001
|
||||
WEIGHT_DECAY: 0.0001
|
||||
STEPS: (0.67, 0.89)
|
||||
MAX_EPOCH: 30
|
||||
IMS_PER_BATCH: 64
|
||||
WARMUP_ITERS: 2000
|
||||
WARMUP_FACTOR: 0.001
|
||||
USE_AMP: True
|
||||
MODEL_EMA: 0.999
|
||||
FIND_UNUSED_PARAMETERS: False
|
||||
|
||||
CLIP_GRADIENTS:
|
||||
ENABLED: True
|
||||
CLIP_TYPE: "full_model"
|
||||
CLIP_VALUE: 1.0
|
||||
NORM_TYPE: 2.0
|
||||
@@ -0,0 +1,35 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 1
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 10
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
TRAINER:
|
||||
COCOOP:
|
||||
N_CTX: 16
|
||||
CTX_INIT: ""
|
||||
PREC: "fp16"
|
||||
@@ -0,0 +1,35 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 1
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 10
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
TRAINER:
|
||||
COCOOP:
|
||||
N_CTX: 4
|
||||
CTX_INIT: ""
|
||||
PREC: "fp16"
|
||||
@@ -0,0 +1,35 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 1
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 10
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
TRAINER:
|
||||
COCOOP:
|
||||
N_CTX: 4
|
||||
CTX_INIT: "a photo of a"
|
||||
PREC: "fp16"
|
||||
@@ -0,0 +1,35 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 1
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 10
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
TRAINER:
|
||||
COCOOP:
|
||||
N_CTX: 8
|
||||
CTX_INIT: ""
|
||||
PREC: "fp16"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 200
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "RN101"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 50
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "RN101"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 200
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "RN50"
|
||||
@@ -0,0 +1,33 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 200
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "RN50"
|
||||
|
||||
TRAINER:
|
||||
COOP:
|
||||
CTX_INIT: "a photo of a"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 100
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "RN50"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 50
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "RN50"
|
||||
@@ -0,0 +1,33 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 50
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "RN50"
|
||||
|
||||
TRAINER:
|
||||
COOP:
|
||||
CTX_INIT: "a photo of a"
|
||||
@@ -0,0 +1,17 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 200
|
||||
TEST:
|
||||
BATCH_SIZE: 200
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "RN50"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 200
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 100
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 50
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 200
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/32"
|
||||
@@ -0,0 +1,29 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 32
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.002
|
||||
MAX_EPOCH: 50
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 5
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/32"
|
||||
@@ -0,0 +1,39 @@
|
||||
# Deep independent V-L Prompting
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 4
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0035
|
||||
MAX_EPOCH: 5
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
TRAINER:
|
||||
IVLP:
|
||||
N_CTX_VISION: 2
|
||||
N_CTX_TEXT: 2
|
||||
CTX_INIT: "a photo of a"
|
||||
PREC: "fp16"
|
||||
PROMPT_DEPTH_VISION: 12
|
||||
PROMPT_DEPTH_TEXT: 12
|
||||
@@ -0,0 +1,39 @@
|
||||
# Deep language prompting
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 4
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0025
|
||||
MAX_EPOCH: 5
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
TRAINER:
|
||||
IVLP:
|
||||
N_CTX_VISION: 0
|
||||
N_CTX_TEXT: 4
|
||||
CTX_INIT: "a photo of a"
|
||||
PREC: "fp16"
|
||||
PROMPT_DEPTH_VISION: 0
|
||||
PROMPT_DEPTH_TEXT: 12
|
||||
@@ -0,0 +1,53 @@
|
||||
DATASET:
|
||||
SELECTION_BATCH_SIZE: 50
|
||||
SUBSAMPLE_CLASSES: all
|
||||
|
||||
|
||||
DATALOADER:
|
||||
RETURN_IMG0: true
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 16
|
||||
TEST:
|
||||
BATCH_SIZE: 64
|
||||
NUM_WORKERS: 2
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
# CUTOUT_N: 1
|
||||
# CUTOUT_LEN: 128
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0026 #0.0035 0.0026 for crossdata
|
||||
MAX_EPOCH: 5
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
CHECKPOINT_FREQ: 1
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
|
||||
|
||||
TEST:
|
||||
PER_CLASS_RESULT: false
|
||||
FINAL_MODEL: "best_val"
|
||||
|
||||
|
||||
|
||||
TRAINER:
|
||||
MAPLEG:
|
||||
N_CTX: 4
|
||||
CTX_INIT: "a photo of a"
|
||||
PREC: "fp16"
|
||||
PROMPT_DEPTH: 9
|
||||
@@ -0,0 +1,52 @@
|
||||
DATASET:
|
||||
SELECTION_BATCH_SIZE: 50
|
||||
SUBSAMPLE_CLASSES: base
|
||||
|
||||
|
||||
DATALOADER:
|
||||
RETURN_IMG0: true
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 1
|
||||
TEST:
|
||||
BATCH_SIZE: 256
|
||||
NUM_WORKERS: 4
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
# CUTOUT_N: 1
|
||||
# CUTOUT_LEN: 128
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0035
|
||||
MAX_EPOCH: 5
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
|
||||
|
||||
TEST:
|
||||
PER_CLASS_RESULT: false
|
||||
FINAL_MODEL: "best_val"
|
||||
|
||||
|
||||
|
||||
TRAINER:
|
||||
MAPLE:
|
||||
N_CTX: 2
|
||||
CTX_INIT: "a photo of a"
|
||||
PREC: "fp16"
|
||||
PROMPT_DEPTH: 9
|
||||
@@ -0,0 +1,41 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 4
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 4
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0035
|
||||
MAX_EPOCH: 5
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
TEST:
|
||||
FINAL_MODEL: "best_val"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
TRAINER:
|
||||
MAPLE:
|
||||
N_CTX: 2
|
||||
CTX_INIT: "a photo of a"
|
||||
PREC: "fp16"
|
||||
PROMPT_DEPTH: 9
|
||||
@@ -0,0 +1,36 @@
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 4
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0026
|
||||
MAX_EPOCH: 2
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
TRAINER:
|
||||
MAPLE:
|
||||
N_CTX: 2
|
||||
CTX_INIT: "a photo of a"
|
||||
PREC: "fp16"
|
||||
PROMPT_DEPTH: 9
|
||||
@@ -0,0 +1,53 @@
|
||||
DATASET:
|
||||
SELECTION_BATCH_SIZE: 50
|
||||
SUBSAMPLE_CLASSES: all
|
||||
|
||||
|
||||
DATALOADER:
|
||||
RETURN_IMG0: true
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 4
|
||||
TEST:
|
||||
BATCH_SIZE: 128
|
||||
NUM_WORKERS: 4
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
# CUTOUT_N: 1
|
||||
# CUTOUT_LEN: 128
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0035 #0.0035 0.0026 for crossdata
|
||||
MAX_EPOCH: 10
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
CHECKPOINT_FREQ: 1
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
|
||||
|
||||
TEST:
|
||||
PER_CLASS_RESULT: false
|
||||
FINAL_MODEL: "best_val"
|
||||
|
||||
|
||||
|
||||
TRAINER:
|
||||
MAPLE:
|
||||
N_CTX: 2
|
||||
CTX_INIT: "A photo of a"
|
||||
PREC: "fp16"
|
||||
PROMPT_DEPTH: 9
|
||||
@@ -0,0 +1,37 @@
|
||||
# Deep vision prompting
|
||||
DATALOADER:
|
||||
TRAIN_X:
|
||||
BATCH_SIZE: 4
|
||||
TEST:
|
||||
BATCH_SIZE: 100
|
||||
NUM_WORKERS: 8
|
||||
|
||||
INPUT:
|
||||
SIZE: (224, 224)
|
||||
INTERPOLATION: "bicubic"
|
||||
PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
|
||||
PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
|
||||
TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
|
||||
|
||||
OPTIM:
|
||||
NAME: "sgd"
|
||||
LR: 0.0025
|
||||
MAX_EPOCH: 5
|
||||
LR_SCHEDULER: "cosine"
|
||||
WARMUP_EPOCH: 1
|
||||
WARMUP_TYPE: "constant"
|
||||
WARMUP_CONS_LR: 1e-5
|
||||
|
||||
TRAIN:
|
||||
PRINT_FREQ: 20
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "ViT-B/16"
|
||||
|
||||
TRAINER:
|
||||
VPT:
|
||||
N_CTX_VISION: 8
|
||||
CTX_INIT: "a photo of a"
|
||||
PREC: "fp16"
|
||||
PROMPT_DEPTH_VISION: 12
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from .dtd import DescribableTextures as DTD
|
||||
import deepcore.methods as s_method
|
||||
import numpy as np
|
||||
|
||||
IGNORED = ["BACKGROUND_Google", "Faces_easy"]
|
||||
NEW_CNAMES = {
|
||||
"airplanes": "airplane",
|
||||
"Faces": "face",
|
||||
"Leopards": "leopard",
|
||||
"Motorbikes": "motorbike",
|
||||
}
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Caltech101(DatasetBase):
|
||||
|
||||
dataset_dir = "caltech-101"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
@@ -0,0 +1,481 @@
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
import os
|
||||
from dassl.utils import read_image
|
||||
|
||||
from dassl.data.datasets import build_dataset
|
||||
from dassl.data.samplers import build_sampler
|
||||
from dassl.data.transforms import INTERPOLATION_MODES, build_transform
|
||||
from .new_da import RandomResizedCropPair, build_transform_pair
|
||||
from PIL import Image
|
||||
|
||||
def build_data_loader(
|
||||
cfg,
|
||||
sampler_type="SequentialSampler",
|
||||
data_source=None,
|
||||
batch_size=64,
|
||||
n_domain=0,
|
||||
n_ins=2,
|
||||
tfm=None,
|
||||
is_train=True,
|
||||
dataset_wrapper=None,
|
||||
weight=None,
|
||||
):
|
||||
# Build sampler
|
||||
sampler = build_sampler(
|
||||
sampler_type,
|
||||
cfg=cfg,
|
||||
data_source=data_source,
|
||||
batch_size=batch_size,
|
||||
n_domain=n_domain,
|
||||
n_ins=n_ins
|
||||
)
|
||||
|
||||
if dataset_wrapper is None:
|
||||
dataset_wrapper = DatasetWrapper
|
||||
|
||||
# Build data loader
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset_wrapper(cfg, data_source,transform=tfm, is_train=is_train,weight=weight),
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
||||
drop_last=is_train and len(data_source) >= batch_size,
|
||||
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA)
|
||||
)
|
||||
assert len(data_loader) > 0
|
||||
|
||||
return data_loader
|
||||
|
||||
|
||||
|
||||
def build_data_loader_mask(
|
||||
cfg,
|
||||
dataset,
|
||||
sampler_type="SequentialSampler",
|
||||
data_source=None,
|
||||
batch_size=64,
|
||||
n_domain=0,
|
||||
n_ins=2,
|
||||
tfm=None,
|
||||
is_train=True,
|
||||
dataset_wrapper=None,
|
||||
weight=None,
|
||||
):
|
||||
# Build sampler
|
||||
sampler = build_sampler(
|
||||
sampler_type,
|
||||
cfg=cfg,
|
||||
data_source=data_source,
|
||||
batch_size=batch_size,
|
||||
n_domain=n_domain,
|
||||
n_ins=n_ins
|
||||
)
|
||||
|
||||
if dataset_wrapper is None:
|
||||
dataset_wrapper = DatasetWrapperMask
|
||||
|
||||
# Build data loader
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset_wrapper(cfg, dataset,data_source,transform=tfm, is_train=is_train,weight=weight),
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.DATALOADER.NUM_WORKERS,
|
||||
drop_last=is_train and len(data_source) >= batch_size,
|
||||
pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA)
|
||||
)
|
||||
assert len(data_loader) > 0
|
||||
|
||||
return data_loader
|
||||
|
||||
def select_dm_loader(cfg,dataset,s_ind=None,is_train=False):
|
||||
|
||||
tfm = build_transform(cfg, is_train=is_train)
|
||||
if is_train:
|
||||
dataloader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||
data_source=list(np.asarray(dataset)[s_ind]) if s_ind is not None else dataset,
|
||||
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE, #cfg.DATALOADER.TRAIN_X.BATCH_SIZE*
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm,
|
||||
is_train=is_train,
|
||||
dataset_wrapper=None,
|
||||
)
|
||||
else:
|
||||
dataloader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=list(np.asarray(dataset)[s_ind]) if s_ind is not None else dataset,
|
||||
batch_size=cfg.DATASET.SELECTION_BATCH_SIZE,
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm,
|
||||
is_train=is_train,
|
||||
dataset_wrapper=None,
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
class DataManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
dataset,
|
||||
s_ind=None,
|
||||
custom_tfm_train=None,
|
||||
custom_tfm_test=None,
|
||||
dataset_wrapper=None,
|
||||
weight=None,
|
||||
):
|
||||
# # Load dataset
|
||||
# dataset = build_dataset(cfg)
|
||||
|
||||
# Build transform
|
||||
if custom_tfm_train is None:
|
||||
###pair is for
|
||||
tfm_train_pair = build_transform_pair(cfg, is_train=True)
|
||||
tfm_train = build_transform(cfg,is_train=True)
|
||||
else:
|
||||
print("* Using custom transform for training")
|
||||
tfm_train = custom_tfm_train
|
||||
|
||||
if custom_tfm_test is None:
|
||||
tfm_test = build_transform(cfg, is_train=False)
|
||||
else:
|
||||
print("* Using custom transform for testing")
|
||||
tfm_test = custom_tfm_test
|
||||
|
||||
|
||||
# Build train_loader_x
|
||||
|
||||
train_loader_x = build_data_loader_mask(
|
||||
cfg,
|
||||
dataset,
|
||||
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||
data_source=list(np.asarray(dataset.train_x)[s_ind]) if s_ind is not None else dataset.train_x,
|
||||
batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm_train_pair,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
weight=weight
|
||||
)
|
||||
|
||||
|
||||
train_loader_xmore = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
|
||||
data_source=list(np.asarray(dataset.train_x)[s_ind]) if s_ind is not None else dataset.train_x,
|
||||
batch_size=cfg.DATASET.SELECTION_BATCH_SIZE,
|
||||
n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
|
||||
n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper,
|
||||
weight=weight
|
||||
)
|
||||
|
||||
# Build train_loader_u
|
||||
train_loader_u = None
|
||||
if dataset.train_u:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS
|
||||
|
||||
if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
|
||||
sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
|
||||
batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
|
||||
n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
|
||||
n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS
|
||||
|
||||
train_loader_u = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=sampler_type_,
|
||||
data_source=dataset.train_u,
|
||||
batch_size=batch_size_,
|
||||
n_domain=n_domain_,
|
||||
n_ins=n_ins_,
|
||||
tfm=tfm_train,
|
||||
is_train=True,
|
||||
dataset_wrapper=dataset_wrapper
|
||||
)
|
||||
|
||||
# Build val_loader
|
||||
val_loader = None
|
||||
if dataset.val:
|
||||
val_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.val,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper
|
||||
)
|
||||
|
||||
# Build test_loader
|
||||
test_loader = build_data_loader(
|
||||
cfg,
|
||||
sampler_type=cfg.DATALOADER.TEST.SAMPLER,
|
||||
data_source=dataset.test,
|
||||
batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
|
||||
tfm=tfm_test,
|
||||
is_train=False,
|
||||
dataset_wrapper=dataset_wrapper
|
||||
)
|
||||
|
||||
# Attributes
|
||||
self._num_classes = dataset.num_classes
|
||||
self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
|
||||
self._lab2cname = dataset.lab2cname
|
||||
|
||||
# Dataset and data-loaders
|
||||
self.dataset = dataset
|
||||
self.train_loader_x = train_loader_x
|
||||
self.train_loader_u = train_loader_u
|
||||
self.train_loader_xmore = train_loader_xmore
|
||||
self.val_loader = val_loader
|
||||
self.test_loader = test_loader
|
||||
|
||||
if cfg.VERBOSE:
|
||||
self.show_dataset_summary(cfg)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return self._num_classes
|
||||
|
||||
@property
|
||||
def num_source_domains(self):
|
||||
return self._num_source_domains
|
||||
|
||||
@property
|
||||
def lab2cname(self):
|
||||
return self._lab2cname
|
||||
|
||||
def show_dataset_summary(self, cfg):
|
||||
dataset_name = cfg.DATASET.NAME
|
||||
source_domains = cfg.DATASET.SOURCE_DOMAINS
|
||||
target_domains = cfg.DATASET.TARGET_DOMAINS
|
||||
|
||||
table = []
|
||||
table.append(["Dataset", dataset_name])
|
||||
if source_domains:
|
||||
table.append(["Source", source_domains])
|
||||
if target_domains:
|
||||
table.append(["Target", target_domains])
|
||||
table.append(["# classes", f"{self.num_classes:,}"])
|
||||
table.append(["# train_x", f"{len(self.dataset.train_x):,}"])
|
||||
if self.dataset.train_u:
|
||||
table.append(["# train_u", f"{len(self.dataset.train_u):,}"])
|
||||
if self.dataset.val:
|
||||
table.append(["# val", f"{len(self.dataset.val):,}"])
|
||||
table.append(["# test", f"{len(self.dataset.test):,}"])
|
||||
|
||||
print(tabulate(table))
|
||||
|
||||
|
||||
class DatasetWrapperMask(TorchDataset):
|
||||
|
||||
def __init__(self, cfg, dataset,data_source,transform=None, is_train=False,weight=None):
|
||||
self.cfg = cfg
|
||||
self.data_source = data_source
|
||||
self.transform = transform # accept list (tuple) as input
|
||||
self.is_train = is_train
|
||||
self.data_path = dataset.dataset_dir
|
||||
self.mask_path = os.path.join(dataset.dataset_dir,'mask')
|
||||
# Augmenting an image K>1 times is only allowed during training
|
||||
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
||||
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
if self.k_tfm > 1 and transform is None:
|
||||
raise ValueError(
|
||||
"Cannot augment the image {} times "
|
||||
"because transform is None".format(self.k_tfm)
|
||||
)
|
||||
|
||||
# Build transform that doesn't apply any data augmentation
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
to_tensor = []
|
||||
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||
to_tensor += [T.ToTensor()]
|
||||
if "normalize" in cfg.INPUT.TRANSFORMS:
|
||||
normalize = T.Normalize(
|
||||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
||||
)
|
||||
to_tensor += [normalize]
|
||||
self.to_tensor = T.Compose(to_tensor)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_source[idx]
|
||||
|
||||
if self.weight is None:
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath,
|
||||
"index": idx
|
||||
}
|
||||
else:
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath,
|
||||
"index": idx,
|
||||
"weight": self.weight[idx]
|
||||
}
|
||||
|
||||
# img_path = os.path.join('/'.join(item.impath.split('/')[:-1]),'mask',item.impath.split('/')[-1]) ('/').join(item.impath.split('/')[-2:])
|
||||
if self.cfg.DATASET.NAME in ['Food101','Caltech101','DescribableTextures','EuroSAT','UCF101']:
|
||||
mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[-2:])))
|
||||
elif self.cfg.DATASET.NAME in ['SUN397']:
|
||||
mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[7:])))
|
||||
elif self.cfg.DATASET.NAME in ['ImageNet']:
|
||||
mask = read_image(os.path.join(self.mask_path,('/').join(item.impath.split('/')[7:])))
|
||||
elif self.cfg.DATASET.NAME in ['VOC12']:
|
||||
mask_path = os.path.join(self.data_path,'VOCdevkit/VOC2012/SegmentationClass_All',item.impath.split('/')[-1][:-3]+'png')
|
||||
mask = read_image(mask_path)
|
||||
else:
|
||||
mask = read_image(os.path.join(self.mask_path, item.impath.split('/')[-1]))
|
||||
img0 = read_image(item.impath)
|
||||
mask = mask.resize(img0.size)
|
||||
if self.transform is not None:
|
||||
if isinstance(self.transform, (list, tuple)):
|
||||
for i, tfm in enumerate(self.transform):
|
||||
img = self._transform_image(tfm, img0,img0)
|
||||
keyname = "img"
|
||||
if (i + 1) > 1:
|
||||
keyname += str(i + 1)
|
||||
output[keyname] = img
|
||||
else:
|
||||
img,mask = self._transform_image(self.transform, img0,mask)
|
||||
output["img"] = img
|
||||
output["mask"] = mask
|
||||
else:
|
||||
output["img"] = img0
|
||||
|
||||
if self.return_img0:
|
||||
output["img0"] = self.to_tensor(img0) # without any augmentation
|
||||
|
||||
return output
|
||||
|
||||
def _transform_image(self, tfm, img0,mask):
|
||||
img_list = []
|
||||
for k in range(self.k_tfm):
|
||||
img_list.append(tfm(img0,mask))
|
||||
|
||||
img = img_list
|
||||
if len(img_list) == 1:
|
||||
img = img_list[0][0]
|
||||
mask = img_list[0][1]
|
||||
|
||||
return img,mask
|
||||
|
||||
|
||||
class DatasetWrapper(TorchDataset):
|
||||
|
||||
def __init__(self, cfg, data_source,transform=None, is_train=False,weight=None):
|
||||
self.cfg = cfg
|
||||
self.data_source = data_source
|
||||
self.transform = transform # accept list (tuple) as input
|
||||
self.is_train = is_train
|
||||
self.mask_path = ('/').join(data_source[0].impath.split('/')[:-2])+'/mask'
|
||||
# Augmenting an image K>1 times is only allowed during training
|
||||
self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
|
||||
self.return_img0 = cfg.DATALOADER.RETURN_IMG0
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
if self.k_tfm > 1 and transform is None:
|
||||
raise ValueError(
|
||||
"Cannot augment the image {} times "
|
||||
"because transform is None".format(self.k_tfm)
|
||||
)
|
||||
|
||||
# Build transform that doesn't apply any data augmentation
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
to_tensor = []
|
||||
to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
|
||||
to_tensor += [T.ToTensor()]
|
||||
if "normalize" in cfg.INPUT.TRANSFORMS:
|
||||
normalize = T.Normalize(
|
||||
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
|
||||
)
|
||||
to_tensor += [normalize]
|
||||
self.to_tensor = T.Compose(to_tensor)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_source[idx]
|
||||
|
||||
if self.weight is None:
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath,
|
||||
"index": idx
|
||||
}
|
||||
else:
|
||||
output = {
|
||||
"label": item.label,
|
||||
"domain": item.domain,
|
||||
"impath": item.impath,
|
||||
"index": idx,
|
||||
"weight": self.weight[idx]
|
||||
}
|
||||
|
||||
# img0 = read_image(item.impath)
|
||||
img0 = read_image(item.impath)
|
||||
# img0 = img0.resize(mask.size)
|
||||
# mask = read_image(item.impath.split('/')[:-1].join('/'))
|
||||
if self.transform is not None:
|
||||
if isinstance(self.transform, (list, tuple)):
|
||||
for i, tfm in enumerate(self.transform):
|
||||
img = self._transform_image(tfm, img0)
|
||||
keyname = "img"
|
||||
if (i + 1) > 1:
|
||||
keyname += str(i + 1)
|
||||
output[keyname] = img
|
||||
else:
|
||||
img = self._transform_image(self.transform, img0)
|
||||
output["img"] = img
|
||||
output['mask'] = 1
|
||||
else:
|
||||
output["img"] = img0
|
||||
|
||||
if self.return_img0:
|
||||
output["img0"] = self.to_tensor(img0) # without any augmentation
|
||||
|
||||
return output
|
||||
|
||||
def _transform_image(self, tfm, img0):
|
||||
img_list = []
|
||||
|
||||
for k in range(self.k_tfm):
|
||||
img_list.append(tfm(img0))
|
||||
|
||||
img = img_list
|
||||
if len(img) == 1:
|
||||
img = img[0]
|
||||
|
||||
return img
|
||||
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden, mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class DescribableTextures(DatasetBase):
|
||||
|
||||
dataset_dir = "dtd"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = self.read_and_split_data(self.image_dir)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
@staticmethod
|
||||
def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None):
|
||||
# The data are supposed to be organized into the following structure
|
||||
# =============
|
||||
# images/
|
||||
# dog/
|
||||
# cat/
|
||||
# horse/
|
||||
# =============
|
||||
categories = listdir_nohidden(image_dir)
|
||||
categories = [c for c in categories if c not in ignored]
|
||||
categories.sort()
|
||||
|
||||
p_tst = 1 - p_trn - p_val
|
||||
print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test")
|
||||
|
||||
def _collate(ims, y, c):
|
||||
items = []
|
||||
for im in ims:
|
||||
item = Datum(impath=im, label=y, classname=c) # is already 0-based
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
train, val, test = [], [], []
|
||||
for label, category in enumerate(categories):
|
||||
category_dir = os.path.join(image_dir, category)
|
||||
images = listdir_nohidden(category_dir)
|
||||
images = [os.path.join(category_dir, im) for im in images]
|
||||
random.shuffle(images)
|
||||
n_total = len(images)
|
||||
n_train = round(n_total * p_trn)
|
||||
n_val = round(n_total * p_val)
|
||||
n_test = n_total - n_train - n_val
|
||||
assert n_train > 0 and n_val > 0 and n_test > 0
|
||||
|
||||
if new_cnames is not None and category in new_cnames:
|
||||
category = new_cnames[category]
|
||||
|
||||
train.extend(_collate(images[:n_train], label, category))
|
||||
val.extend(_collate(images[n_train : n_train + n_val], label, category))
|
||||
test.extend(_collate(images[n_train + n_val :], label, category))
|
||||
|
||||
return train, val, test
|
||||
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from .dtd import DescribableTextures as DTD
|
||||
|
||||
NEW_CNAMES = {
|
||||
"AnnualCrop": "Annual Crop Land",
|
||||
"Forest": "Forest",
|
||||
"HerbaceousVegetation": "Herbaceous Vegetation Land",
|
||||
"Highway": "Highway or Road",
|
||||
"Industrial": "Industrial Buildings",
|
||||
"Pasture": "Pasture Land",
|
||||
"PermanentCrop": "Permanent Crop Land",
|
||||
"Residential": "Residential Buildings",
|
||||
"River": "River",
|
||||
"SeaLake": "Sea or Lake",
|
||||
}
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class EuroSAT(DatasetBase):
|
||||
|
||||
dataset_dir = "eurosat"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "2750")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def update_classname(self, dataset_old):
|
||||
dataset_new = []
|
||||
for item_old in dataset_old:
|
||||
cname_old = item_old.classname
|
||||
cname_new = NEW_CNAMES[cname_old]
|
||||
item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new)
|
||||
dataset_new.append(item_new)
|
||||
return dataset_new
|
||||
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class FGVCAircraft(DatasetBase):
|
||||
|
||||
dataset_dir = "fgvc_aircraft"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
classnames = []
|
||||
with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
classnames.append(line.strip())
|
||||
cname2lab = {c: i for i, c in enumerate(classnames)}
|
||||
|
||||
train = self.read_data(cname2lab, "images_variant_train.txt")
|
||||
val = self.read_data(cname2lab, "images_variant_val.txt")
|
||||
test = self.read_data(cname2lab, "images_variant_test.txt")
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, cname2lab, split_file):
|
||||
filepath = os.path.join(self.dataset_dir, split_file)
|
||||
items = []
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
imname = line[0] + ".jpg"
|
||||
classname = " ".join(line[1:])
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
label = cname2lab[classname]
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from .dtd import DescribableTextures as DTD
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Food101(DatasetBase):
|
||||
|
||||
dataset_dir = "food-101"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = DTD.read_and_split_data(self.image_dir)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden, mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
from random import sample
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNet(DatasetBase):
|
||||
|
||||
dataset_dir = "imagenet"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.preprocessed):
|
||||
with open(self.preprocessed, "rb") as f:
|
||||
preprocessed = pickle.load(f)
|
||||
train = preprocessed["train"]
|
||||
test = preprocessed["test"]
|
||||
else:
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = self.read_classnames(text_file)
|
||||
train = self.read_data(classnames, "train")
|
||||
# Follow standard practice to perform evaluation on the val set
|
||||
# Also used as the val set (so evaluate the last-step model)
|
||||
test = self.read_data(classnames, "val")
|
||||
|
||||
preprocessed = {"train": train, "test": test}
|
||||
with open(self.preprocessed, "wb") as f:
|
||||
pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1000:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train = data["train"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
data = {"train": train}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, test = OxfordPets.subsample_classes(train, test, subsample=subsample)
|
||||
|
||||
|
||||
super().__init__(train_x=sample(train,int(len(train)*0.8)), val=sample(test,5000), test=test)
|
||||
|
||||
@staticmethod
|
||||
def read_classnames(text_file):
|
||||
"""Return a dictionary containing
|
||||
key-value pairs of <folder name>: <class name>.
|
||||
"""
|
||||
classnames = OrderedDict()
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
folder = line[0]
|
||||
classname = " ".join(line[1:])
|
||||
classnames[folder] = classname
|
||||
return classnames
|
||||
|
||||
def read_data(self, classnames, split_dir):
|
||||
split_dir = os.path.join(self.image_dir, split_dir)
|
||||
folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders): ##sub evaluation
|
||||
imnames = listdir_nohidden(os.path.join(split_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(split_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from .imagenet import ImageNet
|
||||
|
||||
TO_BE_IGNORED = ["README.txt"]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNetA(DatasetBase):
|
||||
"""ImageNet-A(dversarial).
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir = "imagenet-adversarial"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "imagenet-a")
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = ImageNet.read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = listdir_nohidden(image_dir, sort=True)
|
||||
folders = [f for f in folders if f not in TO_BE_IGNORED]
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(image_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(image_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from .imagenet import ImageNet
|
||||
|
||||
TO_BE_IGNORED = ["README.txt"]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNetR(DatasetBase):
|
||||
"""ImageNet-R(endition).
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir = "imagenet-rendition"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "imagenet-r")
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = ImageNet.read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = listdir_nohidden(image_dir, sort=True)
|
||||
folders = [f for f in folders if f not in TO_BE_IGNORED]
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(image_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(image_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from .imagenet import ImageNet
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNetSketch(DatasetBase):
|
||||
"""ImageNet-Sketch.
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir = "imagenet-sketch"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = ImageNet.read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = listdir_nohidden(image_dir, sort=True)
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(image_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(image_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden
|
||||
|
||||
from .imagenet import ImageNet
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ImageNetV2(DatasetBase):
|
||||
"""ImageNetV2.
|
||||
|
||||
This dataset is used for testing only.
|
||||
"""
|
||||
|
||||
dataset_dir = "imagenetv2"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
image_dir = "imagenetv2-matched-frequency-format-val"
|
||||
self.image_dir = os.path.join(self.dataset_dir, image_dir)
|
||||
|
||||
text_file = os.path.join(self.dataset_dir, "classnames.txt")
|
||||
classnames = ImageNet.read_classnames(text_file)
|
||||
|
||||
data = self.read_data(classnames)
|
||||
|
||||
super().__init__(train_x=data, test=data)
|
||||
|
||||
def read_data(self, classnames):
|
||||
image_dir = self.image_dir
|
||||
folders = list(classnames.keys())
|
||||
items = []
|
||||
|
||||
for label in range(1000):
|
||||
class_dir = os.path.join(image_dir, str(label))
|
||||
imnames = listdir_nohidden(class_dir)
|
||||
folder = folders[label]
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(class_dir, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,567 @@
|
||||
import torch
|
||||
from torchvision.transforms import RandomResizedCrop,InterpolationMode
|
||||
from torchvision.transforms import functional as F
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
from torchvision.transforms import (
|
||||
Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter,
|
||||
RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop,
|
||||
RandomHorizontalFlip
|
||||
)
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from dassl.data.transforms.transforms import SVHNPolicy, CIFAR10Policy, ImageNetPolicy
|
||||
from dassl.data.transforms.transforms import RandAugment, RandAugment2, RandAugmentFixMatch
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
class RandomResizedCropPair(RandomResizedCrop):
|
||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR):
|
||||
super(RandomResizedCropPair, self).__init__(size, scale, ratio, interpolation)
|
||||
|
||||
def __call__(self, img,mask):
|
||||
i,j,h,w = self.get_params(img,self.scale,self.ratio)
|
||||
return F.resized_crop(img,i,j,h,w,self.size,self.interpolation),F.resized_crop(mask,i,j,h,w,self.size,self.interpolation)
|
||||
|
||||
|
||||
class ComposePair:
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img,mask):
|
||||
|
||||
for t in self.transforms:
|
||||
if isinstance(t,Normalize):
|
||||
img = t(img)
|
||||
elif isinstance(t,ToTensor):
|
||||
img = t(img)
|
||||
mask = torch.from_numpy(np.array(mask,dtype=np.float16)).permute(2,0,1)[:1]
|
||||
|
||||
|
||||
###design the mask split
|
||||
mask[mask==255] = 0
|
||||
mask[mask > 1] = 1
|
||||
else:
|
||||
img,mask = t(img,mask)
|
||||
|
||||
return img,mask
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += ' {0}'.format(t)
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
class RandomHorizontalFlipPair(RandomHorizontalFlip):
|
||||
def __init__(self, p=0.5):
|
||||
super().__init__(p)
|
||||
|
||||
def __call__(self, img, mask):
|
||||
if torch.rand(1) < self.p:
|
||||
return F.hflip(img),F.hflip(mask)
|
||||
return img,mask
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
AVAI_CHOICES = [
|
||||
"random_flip",
|
||||
"random_resized_crop",
|
||||
"normalize",
|
||||
"instance_norm",
|
||||
"random_crop",
|
||||
"random_translation",
|
||||
"center_crop", # This has become a default operation during testing
|
||||
"cutout",
|
||||
"imagenet_policy",
|
||||
"cifar10_policy",
|
||||
"svhn_policy",
|
||||
"randaugment",
|
||||
"randaugment_fixmatch",
|
||||
"randaugment2",
|
||||
"gaussian_noise",
|
||||
"colorjitter",
|
||||
"randomgrayscale",
|
||||
"gaussian_blur",
|
||||
|
||||
"random_flip_pair",
|
||||
"random_resized_crop_pair",
|
||||
]
|
||||
|
||||
INTERPOLATION_MODES = {
|
||||
"bilinear": InterpolationMode.BILINEAR,
|
||||
"bicubic": InterpolationMode.BICUBIC,
|
||||
"nearest": InterpolationMode.NEAREST,
|
||||
}
|
||||
|
||||
|
||||
class Random2DTranslation:
|
||||
"""Given an image of (height, width), we resize it to
|
||||
(height*1.125, width*1.125), and then perform random cropping.
|
||||
|
||||
Args:
|
||||
height (int): target image height.
|
||||
width (int): target image width.
|
||||
p (float, optional): probability that this operation takes place.
|
||||
Default is 0.5.
|
||||
interpolation (int, optional): desired interpolation. Default is
|
||||
``torchvision.transforms.functional.InterpolationMode.BILINEAR``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, height, width, p=0.5, interpolation=InterpolationMode.BILINEAR
|
||||
):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.p = p
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return F.resize(
|
||||
img=img,
|
||||
size=[self.height, self.width],
|
||||
interpolation=self.interpolation
|
||||
)
|
||||
|
||||
new_width = int(round(self.width * 1.125))
|
||||
new_height = int(round(self.height * 1.125))
|
||||
resized_img = F.resize(
|
||||
img=img,
|
||||
size=[new_height, new_width],
|
||||
interpolation=self.interpolation
|
||||
)
|
||||
x_maxrange = new_width - self.width
|
||||
y_maxrange = new_height - self.height
|
||||
x1 = int(round(random.uniform(0, x_maxrange)))
|
||||
y1 = int(round(random.uniform(0, y_maxrange)))
|
||||
croped_img = F.crop(
|
||||
img=resized_img,
|
||||
top=y1,
|
||||
left=x1,
|
||||
height=self.height,
|
||||
width=self.width
|
||||
)
|
||||
|
||||
return croped_img
|
||||
|
||||
|
||||
class InstanceNormalization:
|
||||
"""Normalize data using per-channel mean and standard deviation.
|
||||
|
||||
Reference:
|
||||
- Ulyanov et al. Instance normalization: The missing in- gredient
|
||||
for fast stylization. ArXiv 2016.
|
||||
- Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation.
|
||||
ICLR 2018.
|
||||
"""
|
||||
|
||||
def __init__(self, eps=1e-8):
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, img):
|
||||
C, H, W = img.shape
|
||||
img_re = img.reshape(C, H * W)
|
||||
mean = img_re.mean(1).view(C, 1, 1)
|
||||
std = img_re.std(1).view(C, 1, 1)
|
||||
return (img-mean) / (std + self.eps)
|
||||
|
||||
|
||||
class Cutout:
|
||||
"""Randomly mask out one or more patches from an image.
|
||||
|
||||
https://github.com/uoguelph-mlrg/Cutout
|
||||
|
||||
Args:
|
||||
n_holes (int, optional): number of patches to cut out
|
||||
of each image. Default is 1.
|
||||
length (int, optinal): length (in pixels) of each square
|
||||
patch. Default is 16.
|
||||
"""
|
||||
|
||||
def __init__(self, n_holes=1, length=16):
|
||||
self.n_holes = n_holes
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (Tensor): tensor image of size (C, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor: image with n_holes of dimension
|
||||
length x length cut out of it.
|
||||
"""
|
||||
h = img.size(1)
|
||||
w = img.size(2)
|
||||
|
||||
mask = np.ones((h, w), np.float32)
|
||||
|
||||
for n in range(self.n_holes):
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1:y2, x1:x2] = 0.0
|
||||
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
return img * mask
|
||||
|
||||
|
||||
class GaussianNoise:
|
||||
"""Add gaussian noise."""
|
||||
|
||||
def __init__(self, mean=0, std=0.15, p=0.5):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return img
|
||||
noise = torch.randn(img.size()) * self.std + self.mean
|
||||
return img + noise
|
||||
|
||||
|
||||
def build_transform(cfg, is_train=True, choices=None):
|
||||
"""Build transformation function.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): config.
|
||||
is_train (bool, optional): for training (True) or test (False).
|
||||
Default is True.
|
||||
choices (list, optional): list of strings which will overwrite
|
||||
cfg.INPUT.TRANSFORMS if given. Default is None.
|
||||
"""
|
||||
if cfg.INPUT.NO_TRANSFORM:
|
||||
print("Note: no transform is applied!")
|
||||
return None
|
||||
|
||||
if choices is None:
|
||||
choices = cfg.INPUT.TRANSFORMS
|
||||
|
||||
for choice in choices:
|
||||
assert choice in AVAI_CHOICES
|
||||
|
||||
target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"
|
||||
|
||||
normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
||||
|
||||
if is_train:
|
||||
return _build_transform_train(cfg, choices, target_size, normalize)
|
||||
else:
|
||||
return _build_transform_test(cfg, choices, target_size, normalize)
|
||||
|
||||
|
||||
def build_transform_pair(cfg, is_train=True, choices=None):
|
||||
"""Build transformation function.
|
||||
|
||||
Args:
|
||||
cfg (CfgNode): config.
|
||||
is_train (bool, optional): for training (True) or test (False).
|
||||
Default is True.
|
||||
choices (list, optional): list of strings which will overwrite
|
||||
cfg.INPUT.TRANSFORMS if given. Default is None.
|
||||
"""
|
||||
if cfg.INPUT.NO_TRANSFORM:
|
||||
print("Note: no transform is applied!")
|
||||
return None
|
||||
|
||||
if choices is None:
|
||||
choices = cfg.INPUT.TRANSFORMS
|
||||
|
||||
for choice in choices:
|
||||
assert choice in AVAI_CHOICES
|
||||
|
||||
target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"
|
||||
|
||||
normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
||||
|
||||
if is_train:
|
||||
return _build_transform_train_pair(cfg, choices, target_size, normalize)
|
||||
else:
|
||||
return _build_transform_test(cfg, choices, target_size, normalize)
|
||||
|
||||
def _build_transform_train_pair(cfg, choices, target_size, normalize):
|
||||
print("Building transform_train_pair")
|
||||
tfm_train = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
input_size = cfg.INPUT.SIZE
|
||||
|
||||
# Make sure the image size matches the target size
|
||||
conditions = []
|
||||
conditions += ["random_crop" not in choices]
|
||||
conditions += ["random_resized_crop" not in choices]
|
||||
if all(conditions):
|
||||
print(f"+ resize to {target_size}")
|
||||
tfm_train += [Resize(input_size, interpolation=interp_mode)]
|
||||
|
||||
# if "random_translation" in choices:
|
||||
# print("+ random translation")
|
||||
# tfm_train += [Random2DTranslation(input_size[0], input_size[1])]
|
||||
#
|
||||
# if "random_crop" in choices:
|
||||
# crop_padding = cfg.INPUT.CROP_PADDING
|
||||
# print(f"+ random crop (padding = {crop_padding})")
|
||||
# tfm_train += [RandomCrop(input_size, padding=crop_padding)]
|
||||
|
||||
if "random_resized_crop" in choices:
|
||||
s_ = cfg.INPUT.RRCROP_SCALE
|
||||
print(f"+ random resized crop pair (size={input_size}, scale={s_})")
|
||||
tfm_train += [
|
||||
RandomResizedCropPair(input_size, scale=s_, interpolation=interp_mode)
|
||||
]
|
||||
|
||||
if "random_flip" in choices:
|
||||
print("+ random flip pair")
|
||||
tfm_train += [RandomHorizontalFlipPair()]
|
||||
|
||||
if "imagenet_policy" in choices:
|
||||
print("+ imagenet policy")
|
||||
tfm_train += [ImageNetPolicy()]
|
||||
|
||||
if "cifar10_policy" in choices:
|
||||
print("+ cifar10 policy")
|
||||
tfm_train += [CIFAR10Policy()]
|
||||
|
||||
if "svhn_policy" in choices:
|
||||
print("+ svhn policy")
|
||||
tfm_train += [SVHNPolicy()]
|
||||
|
||||
if "randaugment" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
m_ = cfg.INPUT.RANDAUGMENT_M
|
||||
print(f"+ randaugment (n={n_}, m={m_})")
|
||||
tfm_train += [RandAugment(n_, m_)]
|
||||
|
||||
if "randaugment_fixmatch" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print(f"+ randaugment_fixmatch (n={n_})")
|
||||
tfm_train += [RandAugmentFixMatch(n_)]
|
||||
|
||||
if "randaugment2" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print(f"+ randaugment2 (n={n_})")
|
||||
tfm_train += [RandAugment2(n_)]
|
||||
|
||||
if "colorjitter" in choices:
|
||||
b_ = cfg.INPUT.COLORJITTER_B
|
||||
c_ = cfg.INPUT.COLORJITTER_C
|
||||
s_ = cfg.INPUT.COLORJITTER_S
|
||||
h_ = cfg.INPUT.COLORJITTER_H
|
||||
print(
|
||||
f"+ color jitter (brightness={b_}, "
|
||||
f"contrast={c_}, saturation={s_}, hue={h_})"
|
||||
)
|
||||
tfm_train += [
|
||||
ColorJitter(
|
||||
brightness=b_,
|
||||
contrast=c_,
|
||||
saturation=s_,
|
||||
hue=h_,
|
||||
)
|
||||
]
|
||||
|
||||
if "randomgrayscale" in choices:
|
||||
print("+ random gray scale")
|
||||
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
|
||||
|
||||
if "gaussian_blur" in choices:
|
||||
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
|
||||
gb_k, gb_p = cfg.INPUT.GB_K, cfg.INPUT.GB_P
|
||||
tfm_train += [RandomApply([GaussianBlur(gb_k)], p=gb_p)]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_train += [ToTensor()]
|
||||
|
||||
if "cutout" in choices:
|
||||
cutout_n = cfg.INPUT.CUTOUT_N
|
||||
cutout_len = cfg.INPUT.CUTOUT_LEN
|
||||
print(f"+ cutout (n_holes={cutout_n}, length={cutout_len})")
|
||||
tfm_train += [Cutout(cutout_n, cutout_len)]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
f"+ normalization (mean={cfg.INPUT.PIXEL_MEAN}, std={cfg.INPUT.PIXEL_STD})"
|
||||
)
|
||||
tfm_train += [normalize]
|
||||
|
||||
if "gaussian_noise" in choices:
|
||||
print(
|
||||
f"+ gaussian noise (mean={cfg.INPUT.GN_MEAN}, std={cfg.INPUT.GN_STD})"
|
||||
)
|
||||
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_train += [InstanceNormalization()]
|
||||
|
||||
tfm_train = ComposePair(tfm_train)
|
||||
|
||||
|
||||
return tfm_train
|
||||
|
||||
|
||||
def _build_transform_train(cfg, choices, target_size, normalize):
|
||||
print("Building transform_train")
|
||||
tfm_train = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
input_size = cfg.INPUT.SIZE
|
||||
|
||||
# Make sure the image size matches the target size
|
||||
conditions = []
|
||||
conditions += ["random_crop" not in choices]
|
||||
conditions += ["random_resized_crop" not in choices]
|
||||
if all(conditions):
|
||||
print(f"+ resize to {target_size}")
|
||||
tfm_train += [Resize(input_size, interpolation=interp_mode)]
|
||||
|
||||
if "random_translation" in choices:
|
||||
print("+ random translation")
|
||||
tfm_train += [Random2DTranslation(input_size[0], input_size[1])]
|
||||
|
||||
if "random_crop" in choices:
|
||||
crop_padding = cfg.INPUT.CROP_PADDING
|
||||
print(f"+ random crop (padding = {crop_padding})")
|
||||
tfm_train += [RandomCrop(input_size, padding=crop_padding)]
|
||||
|
||||
if "random_resized_crop" in choices:
|
||||
s_ = cfg.INPUT.RRCROP_SCALE
|
||||
print(f"+ random resized crop (size={input_size}, scale={s_})")
|
||||
tfm_train += [
|
||||
RandomResizedCrop(input_size, scale=s_, interpolation=interp_mode)
|
||||
]
|
||||
|
||||
if "random_flip" in choices:
|
||||
print("+ random flip")
|
||||
tfm_train += [RandomHorizontalFlip()]
|
||||
|
||||
if "imagenet_policy" in choices:
|
||||
print("+ imagenet policy")
|
||||
tfm_train += [ImageNetPolicy()]
|
||||
|
||||
if "cifar10_policy" in choices:
|
||||
print("+ cifar10 policy")
|
||||
tfm_train += [CIFAR10Policy()]
|
||||
|
||||
if "svhn_policy" in choices:
|
||||
print("+ svhn policy")
|
||||
tfm_train += [SVHNPolicy()]
|
||||
|
||||
if "randaugment" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
m_ = cfg.INPUT.RANDAUGMENT_M
|
||||
print(f"+ randaugment (n={n_}, m={m_})")
|
||||
tfm_train += [RandAugment(n_, m_)]
|
||||
|
||||
if "randaugment_fixmatch" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print(f"+ randaugment_fixmatch (n={n_})")
|
||||
tfm_train += [RandAugmentFixMatch(n_)]
|
||||
|
||||
if "randaugment2" in choices:
|
||||
n_ = cfg.INPUT.RANDAUGMENT_N
|
||||
print(f"+ randaugment2 (n={n_})")
|
||||
tfm_train += [RandAugment2(n_)]
|
||||
|
||||
if "colorjitter" in choices:
|
||||
b_ = cfg.INPUT.COLORJITTER_B
|
||||
c_ = cfg.INPUT.COLORJITTER_C
|
||||
s_ = cfg.INPUT.COLORJITTER_S
|
||||
h_ = cfg.INPUT.COLORJITTER_H
|
||||
print(
|
||||
f"+ color jitter (brightness={b_}, "
|
||||
f"contrast={c_}, saturation={s_}, hue={h_})"
|
||||
)
|
||||
tfm_train += [
|
||||
ColorJitter(
|
||||
brightness=b_,
|
||||
contrast=c_,
|
||||
saturation=s_,
|
||||
hue=h_,
|
||||
)
|
||||
]
|
||||
|
||||
if "randomgrayscale" in choices:
|
||||
print("+ random gray scale")
|
||||
tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]
|
||||
|
||||
if "gaussian_blur" in choices:
|
||||
print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
|
||||
gb_k, gb_p = cfg.INPUT.GB_K, cfg.INPUT.GB_P
|
||||
tfm_train += [RandomApply([GaussianBlur(gb_k)], p=gb_p)]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_train += [ToTensor()]
|
||||
|
||||
if "cutout" in choices:
|
||||
cutout_n = cfg.INPUT.CUTOUT_N
|
||||
cutout_len = cfg.INPUT.CUTOUT_LEN
|
||||
print(f"+ cutout (n_holes={cutout_n}, length={cutout_len})")
|
||||
tfm_train += [Cutout(cutout_n, cutout_len)]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
f"+ normalization (mean={cfg.INPUT.PIXEL_MEAN}, std={cfg.INPUT.PIXEL_STD})"
|
||||
)
|
||||
tfm_train += [normalize]
|
||||
|
||||
if "gaussian_noise" in choices:
|
||||
print(
|
||||
f"+ gaussian noise (mean={cfg.INPUT.GN_MEAN}, std={cfg.INPUT.GN_STD})"
|
||||
)
|
||||
tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_train += [InstanceNormalization()]
|
||||
|
||||
tfm_train = Compose(tfm_train)
|
||||
|
||||
return tfm_train
|
||||
|
||||
|
||||
def _build_transform_test(cfg, choices, target_size, normalize):
|
||||
print("Building transform_test")
|
||||
tfm_test = []
|
||||
|
||||
interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
|
||||
input_size = cfg.INPUT.SIZE
|
||||
|
||||
print(f"+ resize the smaller edge to {max(input_size)}")
|
||||
tfm_test += [Resize(max(input_size), interpolation=interp_mode)]
|
||||
|
||||
print(f"+ {target_size} center crop")
|
||||
tfm_test += [CenterCrop(input_size)]
|
||||
|
||||
print("+ to torch tensor of range [0, 1]")
|
||||
tfm_test += [ToTensor()]
|
||||
|
||||
if "normalize" in choices:
|
||||
print(
|
||||
f"+ normalization (mean={cfg.INPUT.PIXEL_MEAN}, std={cfg.INPUT.PIXEL_STD})"
|
||||
)
|
||||
tfm_test += [normalize]
|
||||
|
||||
if "instance_norm" in choices:
|
||||
print("+ instance normalization")
|
||||
tfm_test += [InstanceNormalization()]
|
||||
|
||||
tfm_test = Compose(tfm_test)
|
||||
|
||||
return tfm_test
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
from scipy.io import loadmat
|
||||
from collections import defaultdict
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import read_json, mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OxfordFlowers(DatasetBase):
|
||||
|
||||
dataset_dir = "oxford_flowers"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "jpg")
|
||||
self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat")
|
||||
self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
train, val, test = self.read_data()
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self):
|
||||
tracker = defaultdict(list)
|
||||
label_file = loadmat(self.label_file)["labels"][0]
|
||||
for i, label in enumerate(label_file):
|
||||
imname = f"image_{str(i + 1).zfill(5)}.jpg"
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
label = int(label)
|
||||
tracker[label].append(impath)
|
||||
|
||||
print("Splitting data into 50% train, 20% val, and 30% test")
|
||||
|
||||
def _collate(ims, y, c):
|
||||
items = []
|
||||
for im in ims:
|
||||
item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
lab2cname = read_json(self.lab2cname_file)
|
||||
train, val, test = [], [], []
|
||||
for label, impaths in tracker.items():
|
||||
random.shuffle(impaths)
|
||||
n_total = len(impaths)
|
||||
n_train = round(n_total * 0.5)
|
||||
n_val = round(n_total * 0.2)
|
||||
n_test = n_total - n_train - n_val
|
||||
assert n_train > 0 and n_val > 0 and n_test > 0
|
||||
cname = lab2cname[str(label)]
|
||||
train.extend(_collate(impaths[:n_train], label, cname))
|
||||
val.extend(_collate(impaths[n_train : n_train + n_val], label, cname))
|
||||
test.extend(_collate(impaths[n_train + n_val :], label, cname))
|
||||
|
||||
return train, val, test
|
||||
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
import pickle
|
||||
import math
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import read_json, write_json, mkdir_if_missing
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class OxfordPets(DatasetBase):
|
||||
|
||||
dataset_dir = "oxford_pets"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "images")
|
||||
self.anno_dir = os.path.join(self.dataset_dir, "annotations")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = self.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
trainval = self.read_data(split_file="trainval.txt")
|
||||
test = self.read_data(split_file="test.txt")
|
||||
train, val = self.split_trainval(trainval)
|
||||
self.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = self.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, split_file):
|
||||
filepath = os.path.join(self.anno_dir, split_file)
|
||||
items = []
|
||||
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
imname, label, species, _ = line.split(" ")
|
||||
breed = imname.split("_")[:-1]
|
||||
breed = "_".join(breed)
|
||||
breed = breed.lower()
|
||||
imname += ".jpg"
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
label = int(label) - 1 # convert to 0-based index
|
||||
item = Datum(impath=impath, label=label, classname=breed)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def split_trainval(trainval, p_val=0.2):
|
||||
p_trn = 1 - p_val
|
||||
print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val")
|
||||
tracker = defaultdict(list)
|
||||
for idx, item in enumerate(trainval):
|
||||
label = item.label
|
||||
tracker[label].append(idx)
|
||||
|
||||
train, val = [], []
|
||||
for label, idxs in tracker.items():
|
||||
n_val = round(len(idxs) * p_val)
|
||||
assert n_val > 0
|
||||
random.shuffle(idxs)
|
||||
for n, idx in enumerate(idxs):
|
||||
item = trainval[idx]
|
||||
if n < n_val:
|
||||
val.append(item)
|
||||
else:
|
||||
train.append(item)
|
||||
|
||||
return train, val
|
||||
|
||||
@staticmethod
|
||||
def save_split(train, val, test, filepath, path_prefix):
|
||||
def _extract(items):
|
||||
out = []
|
||||
for item in items:
|
||||
impath = item.impath
|
||||
label = item.label
|
||||
classname = item.classname
|
||||
impath = impath.replace(path_prefix, "")
|
||||
if impath.startswith("/"):
|
||||
impath = impath[1:]
|
||||
out.append((impath, label, classname))
|
||||
return out
|
||||
|
||||
train = _extract(train)
|
||||
val = _extract(val)
|
||||
test = _extract(test)
|
||||
|
||||
split = {"train": train, "val": val, "test": test}
|
||||
|
||||
write_json(split, filepath)
|
||||
print(f"Saved split to {filepath}")
|
||||
|
||||
@staticmethod
|
||||
def read_split(filepath, path_prefix):
|
||||
def _convert(items):
|
||||
out = []
|
||||
for impath, label, classname in items:
|
||||
impath = os.path.join(path_prefix, impath)
|
||||
item = Datum(impath=impath, label=int(label), classname=classname)
|
||||
out.append(item)
|
||||
return out
|
||||
|
||||
print(f"Reading split from {filepath}")
|
||||
split = read_json(filepath)
|
||||
train = _convert(split["train"])
|
||||
val = _convert(split["val"])
|
||||
test = _convert(split["test"])
|
||||
|
||||
return train, val, test
|
||||
|
||||
@staticmethod
|
||||
def subsample_classes(*args, subsample="all"):
|
||||
"""Divide classes into two groups. The first group
|
||||
represents base classes while the second group represents
|
||||
new classes.
|
||||
|
||||
Args:
|
||||
args: a list of datasets, e.g. train, val and test.
|
||||
subsample (str): what classes to subsample.
|
||||
"""
|
||||
assert subsample in ["all", "base", "new"]
|
||||
|
||||
if subsample == "all":
|
||||
return args
|
||||
|
||||
dataset = args[0]
|
||||
labels = set()
|
||||
for item in dataset:
|
||||
labels.add(item.label)
|
||||
labels = list(labels)
|
||||
labels.sort()
|
||||
n = len(labels)
|
||||
# Divide classes into two halves
|
||||
m = math.ceil(n / 2)
|
||||
|
||||
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
|
||||
if subsample == "base":
|
||||
selected = labels[:m] # take the first half
|
||||
else:
|
||||
selected = labels[m:] # take the second half
|
||||
relabeler = {y: y_new for y_new, y in enumerate(selected)}
|
||||
|
||||
output = []
|
||||
for dataset in args:
|
||||
dataset_new = []
|
||||
for item in dataset:
|
||||
if item.label not in selected:
|
||||
continue
|
||||
item_new = Datum(
|
||||
impath=item.impath,
|
||||
label=relabeler[item.label],
|
||||
classname=item.classname
|
||||
)
|
||||
dataset_new.append(item_new)
|
||||
output.append(dataset_new)
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import listdir_nohidden, mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import math
|
||||
CAT_LIST = ['aeroplane',
|
||||
'bicycle',
|
||||
'bird',
|
||||
'boat',
|
||||
'bottle',
|
||||
'bus',
|
||||
'car',
|
||||
'cat',
|
||||
'chair',
|
||||
'cow',
|
||||
'table',
|
||||
'dog',
|
||||
'horse',
|
||||
'motorbike',
|
||||
'person',
|
||||
'plant',
|
||||
'sheep',
|
||||
'sofa',
|
||||
'train',
|
||||
'tvmonitor']
|
||||
|
||||
CAT_LIST_TO_NAME = dict(zip(range(len(CAT_LIST)) ,CAT_LIST))
|
||||
|
||||
|
||||
def _collate(ims, y, c):
|
||||
return Datum(impath=ims, label=y, classname=c)
|
||||
|
||||
def load_img_name_list(dataset_path):
|
||||
|
||||
img_gt_name_list = open(dataset_path).readlines()
|
||||
img_name_list = [img_gt_name.strip() for img_gt_name in img_gt_name_list]
|
||||
|
||||
return img_name_list
|
||||
|
||||
def load_image_label_list_from_npy(data_root,img_name_list, label_file_path=None):
|
||||
if label_file_path is None:
|
||||
label_file_path = 'voc12/cls_labels.npy'
|
||||
cls_labels_dict = np.load(label_file_path, allow_pickle=True).item()
|
||||
label_list = []
|
||||
data_dtm = []
|
||||
|
||||
for id in img_name_list:
|
||||
if id not in cls_labels_dict.keys():
|
||||
img_name = id + '.jpg'
|
||||
else:
|
||||
img_name = id
|
||||
label = cls_labels_dict[img_name]
|
||||
label_idx = np.where(label==1)[0]
|
||||
class_name = [CAT_LIST[idx] for idx in range(len(label_idx))]
|
||||
data_dtm.append(_collate(os.path.join(data_root,img_name+'.jpg'),label,class_name))
|
||||
|
||||
return data_dtm
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class VOC12(DatasetBase):
|
||||
dataset_dir = "voc12data"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir,'VOCdevkit/VOC2012/JPEGImages')
|
||||
train_img_name_list_path = os.path.join('voc12/train_aug_id.txt')
|
||||
val_img_name_list_path = os.path.join('voc12/val_id.txt')
|
||||
|
||||
train = load_image_label_list_from_npy(self.image_dir,load_img_name_list(train_img_name_list_path))
|
||||
val = load_image_label_list_from_npy(self.image_dir,load_img_name_list(val_img_name_list_path))
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train = data["train"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
data = {"train": train}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val = self.subsample_classes(train, val, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=val)
|
||||
|
||||
@staticmethod
|
||||
def subsample_classes(*args, subsample="all"):
|
||||
"""Divide classes into two groups. The first group
|
||||
represents base classes while the second group represents
|
||||
new classes.
|
||||
|
||||
Args:
|
||||
args: a list of datasets, e.g. train, val and test.
|
||||
subsample (str): what classes to subsample.
|
||||
"""
|
||||
assert subsample in ["all", "base", "new"]
|
||||
|
||||
if subsample == "all":
|
||||
return args
|
||||
|
||||
dataset = args[0]
|
||||
labels = set()
|
||||
for item in dataset:
|
||||
label_idx = random.choices(np.where(item.label == 1)[0])[0]
|
||||
labels.add(label_idx)
|
||||
labels = list(labels)
|
||||
labels.sort()
|
||||
n = len(labels)
|
||||
# Divide classes into two halves
|
||||
m = math.ceil(n / 2)
|
||||
|
||||
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
|
||||
if subsample == "base":
|
||||
selected = labels[:m] # take the first half
|
||||
else:
|
||||
selected = labels[m:] # take the second half
|
||||
relabeler = {y: y_new for y_new, y in enumerate(selected)}
|
||||
|
||||
output = []
|
||||
for dataset in args:
|
||||
dataset_new = []
|
||||
for item in dataset:
|
||||
label_idx = random.choices(np.where(item.label == 1)[0])[0]
|
||||
if label_idx not in selected:
|
||||
continue
|
||||
|
||||
item_new = Datum(
|
||||
impath=item.impath,
|
||||
label=item.label,
|
||||
classname=item.classname
|
||||
)
|
||||
dataset_new.append(item_new)
|
||||
output.append(dataset_new)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_num_classes(data_source):
|
||||
"""Count number of classes.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
return len(CAT_LIST)
|
||||
|
||||
@staticmethod
|
||||
def get_lab2cname(data_source):
|
||||
"""Get a label-to-classname mapping (dict).
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
return CAT_LIST_TO_NAME, CAT_LIST
|
||||
|
||||
def split_dataset_by_label(self, data_source):
|
||||
"""Split a dataset, i.e. a list of Datum objects,
|
||||
into class-specific groups stored in a dictionary.
|
||||
|
||||
Args:
|
||||
data_source (list): a list of Datum objects.
|
||||
"""
|
||||
output = defaultdict(list)
|
||||
|
||||
for item in data_source:
|
||||
one_hot_label = item.label
|
||||
label_idx = random.choices(np.where(one_hot_label==1)[0])[0]
|
||||
output[label_idx].append(item)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def read_classnames(text_file):
|
||||
"""Return a dictionary containing
|
||||
key-value pairs of <folder name>: <class name>.
|
||||
"""
|
||||
classnames = OrderedDict()
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")
|
||||
folder = line[0]
|
||||
classname = " ".join(line[1:])
|
||||
classnames[folder] = classname
|
||||
return classnames
|
||||
|
||||
def read_data(self, classnames, split_dir):
|
||||
split_dir = os.path.join(self.image_dir, split_dir)
|
||||
folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
|
||||
items = []
|
||||
|
||||
for label, folder in enumerate(folders):
|
||||
imnames = listdir_nohidden(os.path.join(split_dir, folder))
|
||||
classname = classnames[folder]
|
||||
for imname in imnames:
|
||||
impath = os.path.join(split_dir, folder, imname)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
import pickle
|
||||
from scipy.io import loadmat
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
import numpy as np
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class StanfordCars(DatasetBase):
|
||||
|
||||
dataset_dir = "stanford_cars"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir)
|
||||
else:
|
||||
trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat")
|
||||
test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat")
|
||||
meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat")
|
||||
trainval = self.read_data("cars_train", trainval_file, meta_file)
|
||||
test = self.read_data("cars_test", test_file, meta_file)
|
||||
train, val = OxfordPets.split_trainval(trainval)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, image_dir, anno_file, meta_file):
|
||||
anno_file = loadmat(anno_file)["annotations"][0]
|
||||
meta_file = loadmat(meta_file)["class_names"][0]
|
||||
items = []
|
||||
|
||||
for i in range(len(anno_file)):
|
||||
imname = anno_file[i]["fname"][0]
|
||||
impath = os.path.join(self.dataset_dir, image_dir, imname)
|
||||
label = anno_file[i]["class"][0, 0]
|
||||
label = int(label) - 1 # convert to 0-based index
|
||||
classname = meta_file[label][0]
|
||||
names = classname.split(" ")
|
||||
year = names.pop(-1)
|
||||
names.insert(0, year)
|
||||
classname = " ".join(names)
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SUN397(DatasetBase):
|
||||
|
||||
dataset_dir = "sun397"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "SUN397")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
classnames = []
|
||||
with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()[1:] # remove /
|
||||
classnames.append(line)
|
||||
cname2lab = {c: i for i, c in enumerate(classnames)}
|
||||
trainval = self.read_data(cname2lab, "Training_01.txt")
|
||||
test = self.read_data(cname2lab, "Testing_01.txt")
|
||||
train, val = OxfordPets.split_trainval(trainval)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, cname2lab, text_file):
|
||||
text_file = os.path.join(self.dataset_dir, text_file)
|
||||
items = []
|
||||
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
imname = line.strip()[1:] # remove /
|
||||
classname = os.path.dirname(imname)
|
||||
label = cname2lab[classname]
|
||||
impath = os.path.join(self.image_dir, imname)
|
||||
|
||||
names = classname.split("/")[1:] # remove 1st letter
|
||||
names = names[::-1] # put words like indoor/outdoor at first
|
||||
classname = " ".join(names)
|
||||
|
||||
item = Datum(impath=impath, label=label, classname=classname)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
|
||||
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
|
||||
from dassl.utils import mkdir_if_missing
|
||||
|
||||
from .oxford_pets import OxfordPets
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class UCF101(DatasetBase):
|
||||
|
||||
dataset_dir = "ucf101"
|
||||
|
||||
def __init__(self, cfg):
|
||||
root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
|
||||
self.dataset_dir = os.path.join(root, self.dataset_dir)
|
||||
self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes")
|
||||
self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json")
|
||||
self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
|
||||
mkdir_if_missing(self.split_fewshot_dir)
|
||||
|
||||
if os.path.exists(self.split_path):
|
||||
train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
|
||||
else:
|
||||
cname2lab = {}
|
||||
filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt")
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
label, classname = line.strip().split(" ")
|
||||
label = int(label) - 1 # conver to 0-based index
|
||||
cname2lab[classname] = label
|
||||
|
||||
trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt")
|
||||
test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt")
|
||||
train, val = OxfordPets.split_trainval(trainval)
|
||||
OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
|
||||
|
||||
num_shots = cfg.DATASET.NUM_SHOTS
|
||||
if num_shots >= 1:
|
||||
seed = cfg.SEED
|
||||
preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
|
||||
|
||||
if os.path.exists(preprocessed):
|
||||
print(f"Loading preprocessed few-shot data from {preprocessed}")
|
||||
with open(preprocessed, "rb") as file:
|
||||
data = pickle.load(file)
|
||||
train, val = data["train"], data["val"]
|
||||
else:
|
||||
train = self.generate_fewshot_dataset(train, num_shots=num_shots)
|
||||
val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
|
||||
data = {"train": train, "val": val}
|
||||
print(f"Saving preprocessed few-shot data to {preprocessed}")
|
||||
with open(preprocessed, "wb") as file:
|
||||
pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
subsample = cfg.DATASET.SUBSAMPLE_CLASSES
|
||||
train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
|
||||
|
||||
super().__init__(train_x=train, val=val, test=test)
|
||||
|
||||
def read_data(self, cname2lab, text_file):
|
||||
text_file = os.path.join(self.dataset_dir, text_file)
|
||||
items = []
|
||||
|
||||
with open(text_file, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split(" ")[0] # trainlist: filename, label
|
||||
action, filename = line.split("/")
|
||||
label = cname2lab[action]
|
||||
|
||||
elements = re.findall("[A-Z][^A-Z]*", action)
|
||||
renamed_action = "_".join(elements)
|
||||
|
||||
filename = filename.replace(".avi", ".jpg")
|
||||
impath = os.path.join(self.image_dir, renamed_action, filename)
|
||||
|
||||
item = Datum(impath=impath, label=label, classname=renamed_action)
|
||||
items.append(item)
|
||||
|
||||
return items
|
||||
@@ -0,0 +1 @@
|
||||
# __init__.py
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user