init
This commit is contained in:
419
test_imagenet.py
Normal file
419
test_imagenet.py
Normal file
@@ -0,0 +1,419 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import os
|
||||
import random
|
||||
import clip
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.optim
|
||||
from data.prepare_data import generate_dataloader # Prepare the data and dataloader
|
||||
from opts import opts # The options for the project
|
||||
from engine import partial_model
|
||||
from clip.model import ModifiedResNet, VisionTransformer
|
||||
from datasets import build_dataset
|
||||
from datasets.utils import build_data_loader
|
||||
import torchvision.transforms as transforms
|
||||
from datasets.imagenet import ImageNet
|
||||
from datasets.imagenetv2 import ImageNetV2
|
||||
import os
|
||||
import time
|
||||
from clip import clip
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch.optim
|
||||
from opts import opts # The options for the project
|
||||
# from trainer import validate # For the validate (test) process
|
||||
from models.DomainClassifierTarget import DClassifierForTarget
|
||||
from models.DomainClassifierSource import DClassifierForSource
|
||||
from utils.loss_utils import TargetDiscrimLoss, ConcatenatedCELoss
|
||||
from utils.utils import prepare_directories, set_seed, get_dataset_loader, configure_clip_encoders, save_model, \
|
||||
set_adapter_weights, get_text_feature, AverageMeter, accuracy, calculate_zeroshot_weights, gpt_clip_classifier, \
|
||||
calculate_zeroshot_weights_GPT,calculate_zero,all_classifier_GPT
|
||||
from Adapter import Weight_Adapter
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import glob
|
||||
|
||||
imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
|
||||
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
|
||||
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
|
||||
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
|
||||
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
|
||||
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
|
||||
"box turtle", "banded gecko", "green iguana", "Carolina anole",
|
||||
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
|
||||
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
|
||||
"American alligator", "triceratops", "worm snake", "ring-necked snake",
|
||||
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
|
||||
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
|
||||
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
|
||||
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
|
||||
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
|
||||
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
|
||||
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
|
||||
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
|
||||
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
|
||||
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
|
||||
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
|
||||
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
|
||||
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
|
||||
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
|
||||
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
|
||||
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
|
||||
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
|
||||
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
|
||||
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
|
||||
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
|
||||
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
|
||||
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
|
||||
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
|
||||
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
|
||||
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
|
||||
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
|
||||
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
|
||||
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
|
||||
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
|
||||
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
|
||||
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
|
||||
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
|
||||
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
|
||||
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
|
||||
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
|
||||
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
|
||||
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
|
||||
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
|
||||
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
|
||||
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
|
||||
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
|
||||
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
|
||||
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
|
||||
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
|
||||
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
|
||||
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
|
||||
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
|
||||
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
|
||||
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
|
||||
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
|
||||
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
|
||||
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
|
||||
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
|
||||
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
|
||||
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
|
||||
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
|
||||
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
|
||||
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
|
||||
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
|
||||
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
|
||||
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
|
||||
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
|
||||
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
|
||||
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
|
||||
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
|
||||
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
|
||||
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
|
||||
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
|
||||
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
|
||||
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
|
||||
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
|
||||
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
|
||||
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
|
||||
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
|
||||
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
|
||||
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
|
||||
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
|
||||
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
|
||||
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
|
||||
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
|
||||
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
|
||||
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
|
||||
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
|
||||
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
|
||||
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
|
||||
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
|
||||
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
|
||||
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
|
||||
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
|
||||
"freight car", "French horn", "frying pan", "fur coat", "garbage truck",
|
||||
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
|
||||
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
|
||||
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
|
||||
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
|
||||
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
|
||||
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
|
||||
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
|
||||
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
|
||||
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
|
||||
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
|
||||
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
|
||||
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
|
||||
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
|
||||
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
|
||||
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
|
||||
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
|
||||
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
|
||||
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
|
||||
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
|
||||
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
|
||||
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
|
||||
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
|
||||
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
|
||||
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
|
||||
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
|
||||
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
|
||||
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
|
||||
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
|
||||
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
|
||||
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
|
||||
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
|
||||
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
|
||||
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
|
||||
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
|
||||
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
|
||||
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
|
||||
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
|
||||
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
|
||||
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
|
||||
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
|
||||
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
|
||||
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
|
||||
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
|
||||
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
|
||||
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
|
||||
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
|
||||
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
|
||||
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
|
||||
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
|
||||
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
|
||||
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
|
||||
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
|
||||
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
|
||||
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
|
||||
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
|
||||
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
|
||||
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
|
||||
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
|
||||
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
|
||||
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
|
||||
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
|
||||
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
|
||||
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
|
||||
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
|
||||
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
|
||||
|
||||
imagenet_templates = ["itap of a {}.",
|
||||
"a bad photo of the {}.",
|
||||
"a origami {}.",
|
||||
"a photo of the large {}.",
|
||||
"a {} in a video game.",
|
||||
"art of the {}.",
|
||||
"a photo of the small {}."]
|
||||
|
||||
def zeroshot_classifier(classname, templates, CLIP_Text):
|
||||
with torch.no_grad():
|
||||
classname = classname.replace('_', ' ')
|
||||
str_prompts = [template.format(classname) for template in templates]
|
||||
prompts = torch.cat([clip.tokenize(p) for p in str_prompts]).cuda()
|
||||
features, eot_indices = CLIP_Text(prompts)
|
||||
return features, eot_indices
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
_2, pred2 = output.topk(1, 1, True, True)
|
||||
a = target.view(1, -1)
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
# print(correct)
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
class Feature_Extractor(nn.Module):
|
||||
def __init__(self, n_input, n_output):
|
||||
super().__init__(),
|
||||
self.linear1 = nn.Linear(n_input, n_output)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x.float())
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Weight_Adapter(nn.Module):
|
||||
def __init__(self, n_input, n_output):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(n_input, n_output)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x.float())
|
||||
return x
|
||||
|
||||
|
||||
def all_classifier(classnames, templates, model):
|
||||
with torch.no_grad():
|
||||
zeroshot_weights = []
|
||||
for classname in classnames:
|
||||
classname = classname.replace('_', ' ')
|
||||
texts = [template.format(classname) for template in templates] # format with class
|
||||
texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字
|
||||
class_embeddings = model.encode_text(texts) # embed with text encoder
|
||||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding /= class_embedding.norm()
|
||||
zeroshot_weights.append(class_embedding)
|
||||
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
||||
return zeroshot_weights
|
||||
|
||||
|
||||
def validate(classnames, templates, val_loader, model, args, zero_shots, criterion,
|
||||
alpha, beta, gama):
|
||||
global best_target_acc
|
||||
Compu1_acc = AverageMeter()
|
||||
Compu2_acc = AverageMeter()
|
||||
|
||||
losses = AverageMeter()
|
||||
|
||||
logit_scale = 4.60517
|
||||
logit_scale = math.exp(logit_scale)
|
||||
# switch to evaluate mode
|
||||
|
||||
for i, (image, label) in enumerate(val_loader):
|
||||
image = image.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
input_target = image.cuda()
|
||||
target_target = label.cuda()
|
||||
target_source = label.cuda()
|
||||
|
||||
input_target_clip = model.encode_image(input_target)
|
||||
|
||||
# 2
|
||||
logits2 = 100. * input_target_clip.float() @ zero_shots.float()
|
||||
|
||||
|
||||
|
||||
compu1 = logits2
|
||||
|
||||
compu2 =logits2
|
||||
|
||||
compu1_acc = accuracy(compu1, target_target, topk=(1, 5))
|
||||
compu2_acc = accuracy(compu2, target_target, topk=(1, 5))
|
||||
|
||||
loss = criterion(compu1, target_target)
|
||||
|
||||
Compu1_acc.update(compu1_acc[0].item(), image.size(0))
|
||||
Compu2_acc.update(compu2_acc[0].item(), image.size(0))
|
||||
|
||||
losses.update(loss.item(), image.size(0))
|
||||
print('loss:', loss.item())
|
||||
print(i, '/', len(val_loader))
|
||||
print('Compu1_acc:', Compu1_acc.val,'Compu2_acc:', Compu2_acc.val, 'alpha:', alpha, 'beta:', beta, 'gama:', gama)
|
||||
loss.backward()
|
||||
print('Compu1_acc:', Compu1_acc.avg,'Compu2_acc:', Compu2_acc.avg, 'alpha:', alpha, 'beta:', beta, 'gama:', gama,
|
||||
'losses.avg', losses.avg)
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
args = opts()
|
||||
seed = 2023
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
global best_prec1
|
||||
model, preprocess = clip.load(args.name)
|
||||
model = model.cuda()
|
||||
|
||||
# imagenet = ImageNet(args.dataset_dir, args.shot, preprocess)
|
||||
# loader = torch.utils.data.DataLoader(imagenet.test, batch_size=64, num_workers=8, shuffle=False)
|
||||
#
|
||||
classnames, templates, loader, train_loader = get_dataset_loader(args, preprocess)
|
||||
|
||||
|
||||
|
||||
# classnames = imagenet_classes
|
||||
# templates = imagenet_templates
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
# 拆分CLIP图像编码器
|
||||
# if args.name == "ViT-B/16":
|
||||
# CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
|
||||
# assert type(model.visual) == VisionTransformer
|
||||
# CLIP_Image, Image_Encoder = partial_model.get_image_vit(model.visual, image_layer_idx=1)
|
||||
# elif args.name == "ViT-B/32":
|
||||
# CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
|
||||
# assert type(model.visual) == VisionTransformer
|
||||
# CLIP_Image, Image_Encoder = partial_model.get_image_vit(model.visual, image_layer_idx=1)
|
||||
# elif args.name == "RN50":
|
||||
# CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
|
||||
# assert type(model.visual) == ModifiedResNet
|
||||
# CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=1)
|
||||
# elif args.name == "RN101":
|
||||
# CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
|
||||
# assert type(model.visual) == ModifiedResNet
|
||||
# CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=1)
|
||||
# elif args.name == "RN50x16":
|
||||
# CLIP_Text, Text_Encoder = partial_model.get_text(model, text_layer_idx=0)
|
||||
# assert type(model.visual) == ModifiedResNet
|
||||
# CLIP_Image, Image_Encoder = partial_model.get_image_resnet(model.visual, image_layer_idx=1)
|
||||
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
# 自监督温度z
|
||||
# dir = '/root/autodl-tmp/epx/imagenet_epx/16shot/epoch_86_61.36800003051758/'
|
||||
#
|
||||
# CLIP_Text = torch.load(dir + 'CLIP_Text.pth')
|
||||
# CLIP_Image = torch.load(dir + 'CLIP_Image.pth')
|
||||
# Image_Encoder = torch.load(dir + 'Image_Encoder.pth')
|
||||
# Text_Encoder = torch.load(dir + 'Text_Encoder.pth')
|
||||
# adapter = torch.load(dir + '_adapter_extractor.pth')
|
||||
|
||||
# alpha = nn.Parameter(torch.ones([]), requires_grad=True)
|
||||
# beta = nn.Parameter(torch.ones([]), requires_grad=True) # 91.35902633202728
|
||||
# gama = nn.Parameter(torch.ones([]), requires_grad=True)
|
||||
alpha = 1
|
||||
beta = 0
|
||||
gama = 1
|
||||
|
||||
zero_weights = all_classifier(classnames, templates, model)
|
||||
validate(classnames, templates,loader, model, args, zero_weights,
|
||||
criterion, alpha, beta, gama)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user