44 lines
1.0 KiB
Bash
44 lines
1.0 KiB
Bash
#!/bin/bash
|
|
|
|
# 定义种子列表
|
|
seeds=(1 2 3)
|
|
|
|
# 定义数据集列表
|
|
datasets=(
|
|
"ucf101"
|
|
"eurosat"
|
|
"oxford_pets"
|
|
"food101"
|
|
"oxford_flowers"
|
|
"dtd"
|
|
"caltech101"
|
|
"fgvc_aircraft"
|
|
"stanford_cars"
|
|
# "sun397"
|
|
# "imagenet"
|
|
)
|
|
|
|
# 对于每个种子,遍历所有数据集
|
|
for seed in "${seeds[@]}"; do
|
|
for dataset in "${datasets[@]}"; do
|
|
echo "正在运行训练: 数据集=${dataset}, 种子=${seed}"
|
|
|
|
# 运行训练命令
|
|
CUDA_VISIBLE_DEVICES=0 python train.py \
|
|
--root ~/Datasets/CoOp \
|
|
--seed "$seed" \
|
|
--trainer MaPLe \
|
|
--dataset-config-file "configs/datasets/${dataset}.yaml" \
|
|
--config-file configs/trainers/MaPLe/vit_b16_t.yaml \
|
|
--output-dir "output/DAPT_${dataset}_seed${seed}" \
|
|
--mode dapt-g \
|
|
DATASET.NUM_SHOTS ${SHOTS}
|
|
|
|
echo "完成: 数据集=${dataset}, 种子=${seed}"
|
|
echo "----------------------------------------"
|
|
done
|
|
done
|
|
|
|
echo "所有训练任务完成!"
|
|
|