Compare commits
2 Commits
984ce9f4bb
...
multi
| Author | SHA1 | Date | |
|---|---|---|---|
| 0b6eb7ce5e | |||
| fa3afbcae1 |
@@ -38,7 +38,7 @@ TRAINER:
|
|||||||
PROMPT_DEPTH_VISION: 9
|
PROMPT_DEPTH_VISION: 9
|
||||||
PROMPT_DEPTH_TEXT: 9
|
PROMPT_DEPTH_TEXT: 9
|
||||||
IMAGE_LOSS_WEIGHT: 8
|
IMAGE_LOSS_WEIGHT: 8
|
||||||
TEXT_LOSS_WEIGHT_STRONG: 8
|
TEXT_LOSS_WEIGHT_STRONG: 24
|
||||||
TEXT_LOSS_WEIGHT_WEAK: 24
|
TEXT_LOSS_WEIGHT_WEAK: 8
|
||||||
EWA_MEAN: 15
|
EWA_MEAN: 15
|
||||||
EWA_STD: 1
|
EWA_STD: 1
|
||||||
|
|||||||
@@ -354,14 +354,14 @@ class DZGCoOp(TrainerX):
|
|||||||
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2
|
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2
|
||||||
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3
|
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3
|
||||||
|
|
||||||
L_zlg = F.kl_div(
|
L_zpg = F.kl_div(
|
||||||
F.log_softmax(logits_final / 1, dim=1),
|
F.log_softmax(logits_final / 1, dim=1),
|
||||||
F.log_softmax(zero_shot_logits / 1, dim=1),
|
F.log_softmax(zero_shot_logits / 1, dim=1),
|
||||||
reduction='sum',
|
reduction='sum',
|
||||||
log_target=True
|
log_target=True
|
||||||
) * (1 * 1) / logits_final.numel()
|
) * (1 * 1) / logits_final.numel()
|
||||||
|
|
||||||
L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg)
|
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg)
|
||||||
loss = (loss_ce + L_zg)
|
loss = (loss_ce + L_zg)
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
Reference in New Issue
Block a user