Compare commits
2 Commits
984ce9f4bb
...
multi
| Author | SHA1 | Date | |
|---|---|---|---|
| 0b6eb7ce5e | |||
| fa3afbcae1 |
@@ -38,7 +38,7 @@ TRAINER:
|
||||
PROMPT_DEPTH_VISION: 9
|
||||
PROMPT_DEPTH_TEXT: 9
|
||||
IMAGE_LOSS_WEIGHT: 8
|
||||
TEXT_LOSS_WEIGHT_STRONG: 8
|
||||
TEXT_LOSS_WEIGHT_WEAK: 24
|
||||
TEXT_LOSS_WEIGHT_STRONG: 24
|
||||
TEXT_LOSS_WEIGHT_WEAK: 8
|
||||
EWA_MEAN: 15
|
||||
EWA_STD: 1
|
||||
|
||||
@@ -354,14 +354,14 @@ class DZGCoOp(TrainerX):
|
||||
L_sg_strong = F.l1_loss(text_features_strong, semantic_embeddings.cuda(), reduction='mean') * lambda2
|
||||
L_sg_weak = F.l1_loss(text_features_weak, semantic_embeddings.cuda(), reduction='mean') * lambda3
|
||||
|
||||
L_zlg = F.kl_div(
|
||||
L_zpg = F.kl_div(
|
||||
F.log_softmax(logits_final / 1, dim=1),
|
||||
F.log_softmax(zero_shot_logits / 1, dim=1),
|
||||
reduction='sum',
|
||||
log_target=True
|
||||
) * (1 * 1) / logits_final.numel()
|
||||
|
||||
L_zg = (L_zlg + L_sg_strong + L_sg_weak + L_zvg)
|
||||
L_zg = (L_zpg + L_sg_strong + L_sg_weak + L_zvg)
|
||||
loss = (loss_ce + L_zg)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
Reference in New Issue
Block a user