import os import numpy as np import torch from clip import clip from tqdm.notebook import tqdm import torchvision import torchvision.transforms as transforms import torch.nn.functional as F import torch.nn as nn from collections import defaultdict import random from tqdm import tqdm import argparse print("Torch version:", torch.__version__) # assert torch.__version__.split(".") >= ["1", "7", "1"], "PyTorch 1.7.1 or later is required" imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] # Prompt Ensembling imagenet_templates = [ "itap of a {}.", "a bad photo of the {}.", "a origami {}.", "a photo of the large {}.", "a {} in a video game.", "art of the {}.", "a photo of the small {}.", ] template = ['a photo of a {}.'] # Single Prompt # imagenet_templates = ['a photo of a {}.',] #计算准确率 def accuracy(output, target, topk=(1,)): pred = output.topk(max(topk), 1, True, True)[1].t() correct = pred.eq(target.view(1, -1).expand_as(pred)) return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] #制作提示模版 def zeroshot_classifier(classnames, templates, model): with torch.no_grad(): zeroshot_weights = [] for classname in classnames: texts = [template.format(classname) for template in templates] # format with class texts = clip.tokenize(texts).cuda() # tokenizeclip.tokenize向量化文字 class_embeddings = model.encode_text(texts) # embed with text encoder class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) class_embedding = class_embeddings.mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() return zeroshot_weights #Adapter线性层 class Weight_Adapter(nn.Module): def __init__(self, clip_model, train_features_path, cls_num, shots): super().__init__() self.linear1 = nn.Linear(512, cls_num * shots, bias=False).to(clip_model.dtype) self.linear1.weight = nn.Parameter(torch.load(train_features_path).t())#所以训练网络的时候,可以使用nn.Parameter()来转换一个固定的权重数值,使的其可以跟着网络训练一直调优下去,学习到一个最适合的权重值。 print("111") def main(): print(imagenet_classes[475]) # Path for ImageNet data_path = "E:/imagenet" train_features_path = "./imagenet_f_train.pt" train_targets_path = "./imagenet_t_train.pt" test_features_path = "./imagenet_f_test.pt" test_targets_path = "./imagenet_t_test.pt" # load_train = False # load_test = False load_train = False load_test = False search = False # ~~~~~~~~~~~~~~~~~~ k_shot = 16 # ~~~~~~~~~~~~~~~~~~ parser = argparse.ArgumentParser() parser.add_argument('--lr', type=float, default=0.001, help='lr') parser.add_argument('--alpha', type=float, default=1) parser.add_argument('--beta', type=float, default=1.17) parser.add_argument('--train_epoch', type=int, default=20) parser.add_argument('--augment_epoch', type=int, default=10) args = parser.parse_args() print(args) clip.available_models() name = 'RN50x16' model, preprocess = clip.load(name) zeroshot_classifier(imagenet_classes, imagenet_templates, model) model.eval()#测试模式 batchNorm层,dropout层等用于优化训练而添加的网络层会被关闭,从而使得评估时不会发生偏移。 input_resolution = model.visual.input_resolution #图像224*224 context_length = model.context_length #77 #文本长度 vocab_size = model.vocab_size #包含训练集和测试集的所有词。 print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}") print("Input resolution:", input_resolution) print("Context length:", context_length) print("Vocab size:", vocab_size) #随机种子 random.seed(1) torch.manual_seed(1)#为CPU中设置种子,生成随机数 print(f"{len(imagenet_classes)} classes, {len(imagenet_templates)} templates") # adapter = Weight_Adapter(model, train_features_path, len(imagenet_classes), k_shot).cuda() images = torchvision.datasets.ImageNet(data_path, split='val', transform=preprocess) #torchvision.datasets这个包中包含MNIST、FakeData、COCO、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR等一些常用的数据集,并且提供了数据集设置的一些重要参数设置,可以通过简单数据集设置来进行数据集的调用。从 #制作存储了图片的路径和标签信息的txt 将这些信息转化为list,该list每一个元素对应一个样本 通过getitem函数,读取数据和标签,并返回数据和标签 loader = torch.utils.data.DataLoader(images, batch_size=64, num_workers=8, shuffle=False) #shuffle=True可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序 #num_workers=2可以并行加载数据(利用多核处理器加快载入数据的效率) #batch :可以分批次读取:batch-size train_tranform = transforms.Compose([ transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), #随机裁剪 但是保留纵横比 transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) train_images = torchvision.datasets.ImageNet(data_path, split='train', transform=train_tranform) split_by_label_dict = defaultdict(list) print('Load data finished.') for i in range(len(train_images.imgs)): split_by_label_dict[train_images.targets[i]].append(train_images.imgs[i]) #train_images.imgs[i] 0000000=('/private/rocky/dataset/train/n01440764/n01440764_10026.JPEG', 0) #train_images.targets[i] 0,0,0....0,1,1....1 #invalid syntax (, line 1) imgs = [] targets = [] for label, items in split_by_label_dict.items(): imgs = imgs + random.sample(items, k_shot) #random.sample的 用法,多用于截取列表的指定长度的随机数,但是不会改变列表本身的排序 targets = targets + [label for i in range(k_shot)] train_images.imgs = imgs #[('/private/rocky/dataset/train/n01440764/n01440764_13161.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_8600.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_11547.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_2271.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_12659.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_7324.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_6395.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_6870.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_4681.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_1703.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_12182.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_7173.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_10548.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_5003.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_600.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_9543.JPEG', 0), ('/private/rocky/dataset/train/n01443537/n01443537_10035.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_2906.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_1895.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_17646.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_7947.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_13218.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_20735.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_11063.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_10763.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_1087.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_5697.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_10242.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_22942.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_17307.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_24539.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_1103.JPEG', 1), ('/private/rocky/dataset/train/n01484850/n01484850_663.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_20733.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_3405.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_5240.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_7415.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_21143.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_26157.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_21044.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_20650.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_4104.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_23515.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_11466.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_32242.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_7543.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_15960.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_18872.JPEG', 2), ('/private/rocky/dataset/train/n01491361/n01491361_984.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_462.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_14786.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_5187.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_7360.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_6401.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_7459.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_2911.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_4716.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_4424.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_8919.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_733.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_7416.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_590.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_895.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_1105.JPEG', 3), ('/private/rocky/dataset/train/n01494475/n01494475_6324.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_21813.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_4713.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_4925.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_17164.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_3936.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_7557.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_4044.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_1397.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_5406.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_6875.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_14918.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_16903.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_7066.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_4401.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_3964.JPEG', 4), ('/private/rocky/dataset/train/n01496331/n01496331_4080.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_11164.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_3534.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_11534.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_26582.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_9048.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_8049.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_7517.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_30562.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_17713.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_17548.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_4612.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_2189.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_10444.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_19596.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_6064.JPEG', 5), ('/private/rocky/dataset/train/n01498041/n01498041_8215.JPEG', 6), ('/private/rocky/dataset/train/n01498041/n01498041_20637.JPEG', 6), ('/private/rocky/dataset/train/n01498041/n01498041_5338.JPEG', 6), ('/private/rocky/dataset/train/n01498041/n01498041_7627.JPEG', 6... train_images.targets = targets #[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6... train_images.samples = imgs #[('/private/rocky/dataset/train/n01440764/n01440764_13161.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_8600.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_11547.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_2271.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_12659.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_7324.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_6395.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_6870.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_4681.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_1703.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_12182.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_7173.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_10548.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_5003.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_600.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_9543.JPEG', 0), ('/private/rocky/dataset/train/n01443537/n01443537_10035.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_2906.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_1895.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_17646.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_7947.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_13218.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_20735.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_11063.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_10763.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_1087.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_5697.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_10242.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_22942.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_17307.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_24539.JPEG', 1), ('/private/rocky/dataset/train/n01443537/n01443537_1103.JPEG', 1), ('/private/rocky/dataset/train/n01484850/n01484850_663.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_20733.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_3405.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_5240.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_7415.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_21143.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_26157.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_21044.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_20650.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_4104.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_23515.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_11466.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_32242.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_7543.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_15960.JPEG', 2), ('/private/rocky/dataset/train/n01484850/n01484850_18872.JPEG', 2), ('/private/rocky/dataset/train/n01491361/n01491361_984.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_462.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_14786.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_5187.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_7360.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_6401.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_7459.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_2911.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_4716.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_4424.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_8919.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_733.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_7416.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_590.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_895.JPEG', 3), ('/private/rocky/dataset/train/n01491361/n01491361_1105.JPEG', 3), ('/private/rocky/dataset/train/n01494475/n01494475_6324.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_21813.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_4713.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_4925.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_17164.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_3936.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_7557.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_4044.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_1397.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_5406.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_6875.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_14918.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_16903.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_7066.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_4401.JPEG', 4), ('/private/rocky/dataset/train/n01494475/n01494475_3964.JPEG', 4), ('/private/rocky/dataset/train/n01496331/n01496331_4080.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_11164.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_3534.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_11534.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_26582.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_9048.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_8049.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_7517.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_30562.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_17713.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_17548.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_4612.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_2189.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_10444.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_19596.JPEG', 5), ('/private/rocky/dataset/train/n01496331/n01496331_6064.JPEG', 5), ('/private/rocky/dataset/train/n01498041/n01498041_8215.JPEG', 6), ('/private/rocky/dataset/train/n01498041/n01498041_20637.JPEG', 6), ('/private/rocky/dataset/train/n01498041/n01498041_5338.JPEG', 6), ('/private/rocky/dataset/train/n01498041/n01498041_7627.JPEG', 6... train_loader = torch.utils.data.DataLoader(train_images, batch_size=64, num_workers=8, shuffle=False) #16个样本 #class_to_idx = {dict: 1842} {'tench': 0, 'Tinca tinca': 0, 'goldfish': 1, 'Carassius auratus': 1, 'great white shark': 2, 'white shark': 2, 'man-eater': 2, 'man-eating shark': 2, 'Carcharodon carcharias': 2, 'tiger shark': 3, 'Galeocerdo cuvieri': 3, 'hammerhead': 4, 'hammerhead shark': 4, 'electric ray': 5, 'crampfish': 5, 'numbfish': 5, 'torpedo': 5, 'stingray': 6, 'cock': 7, 'hen': 8, 'ostrich': 9, 'Struthio camelus': 9, 'brambling': 10, 'Fringilla montifringilla': 10, 'goldfinch': 11, 'Carduelis carduelis': 11, 'house finch': 12, 'linnet': 12, 'Carpodacus mexicanus': 12, 'junco': 13, 'snowbird': 13, 'indigo bunting': 14, 'indigo finch': 14, 'indigo bird': 14, 'Passerina cyanea': 14, 'robin': 15, 'American robin': 15, 'Turdus migratorius': 15, 'bulbul': 16, 'jay': 17, 'magpie': 18, 'chickadee': 19, 'water ouzel': 20, 'dipper': 20, 'kite': 21, 'bald eagle': 22, 'American eagle': 22, 'Haliaeetus leucocephalus': 22, 'vulture': 23, 'great grey owl': 24, 'great gray owl': 24, 'Strix nebulosa': 24, 'European fire sa... #classes = {list: 1000} [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'), ('great white shark', 'white shark', 'man-eater', 'man-eating shark', 'Carcharodon carcharias'), ('tiger shark', 'Galeocerdo cuvieri'), ('hammerhead', 'hammerhead shark'), ('electric ray', 'crampfish', 'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',), ('ostrich', 'Struthio camelus'), ('brambling', 'Fringilla montifringilla'), ('goldfinch', 'Carduelis carduelis'), ('house finch', 'linnet', 'Carpodacus mexicanus'), ('junco', 'snowbird'), ('indigo bunting', 'indigo finch', 'indigo bird', 'Passerina cyanea'), ('robin', 'American robin', 'Turdus migratorius'), ('bulbul',), ('jay',), ('magpie',), ('chickadee',), ('water ouzel', 'dipper'), ('kite',), ('bald eagle', 'American eagle', 'Haliaeetus leucocephalus'), ('vulture',), ('great grey owl', 'great gray owl', 'Strix nebulosa'), ('European fire salamander', 'Salamandra salamandra'), ('common newt', 'Triturus vulgaris'), ('eft',), ('spotted salamander', 'Ambystoma maculat... #imgs = {list: 16000} [('/private/rocky/dataset/train/n01440764/n01440764_13161.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_8600.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_11547.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_2271.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_12659.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_7324.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_6395.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_6870.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_4681.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_1703.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_12182.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_7173.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_10548.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_5003.JPEG', 0), ('/private/rocky/dataset/train/n01440764/n01440764_600.J... train_loader_shuffle = torch.utils.data.DataLoader(train_images, batch_size=64, num_workers=8, shuffle=True) # ------------------------------------------getting text feature------------------------------------------ print('start getting text features.') zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates, model) print('finish getting text features. start getting image features') # ------------------------------------------saving training features------------------------------------------ print('start saving training image features') if not load_train: train_images_targets = [] train_images_features_agg = [] with torch.no_grad(): for augment_idx in range(args.augment_epoch): train_images_features = [] print('Augment time: {:} / {:}'.format(augment_idx, args.augment_epoch)) for i, (images, target) in enumerate(tqdm(train_loader)):# tqdm 是一个快速,可扩展的Python进度条 images = images.cuda() # 放在GPU上 image_features = model.encode_image(images) #解码器 encode image train_images_features.append(image_features) if augment_idx == 0: target = target.cuda() train_images_targets.append(target) images_features_cat = torch.cat(train_images_features, dim=0).unsqueeze(0) #torch.cat()是为了把多个tensor进行拼接而存在的 train_images_features_agg.append(images_features_cat) train_images_features_agg = torch.cat(train_images_features_agg, dim=0).mean(dim=0) train_images_features_agg /= train_images_features_agg.norm(dim=-1, keepdim=True)#L2范数通常会被用来做优化目标函数的正则化项,防止模型为了迎合训练集而过于复杂造成过拟合的情况,从而提高模型的泛化能力。 #frobenius范数,即矩阵元素绝对值的平方和再开平方,matlab调用函数norm(A, ’fro‘ #x.norm(p=2,dim=1,keepdim=True) 求指定维度上的范数 :返回输入张量给定维dim 上每行的p范数 默认是L2范数 Euclid范数(欧几里得范数,常用计算向量长度),即向量元素绝对值的平方和再开方,pytorch调用函数norm(x, 2)。 #normalize:布尔值,可不填,默认为False。表示当设置为False时,将忽略此参数,如果为True,则特征矩阵在进入回归前将会被减去均值(中心化)并除以L2范式(缩放) # 当dim = 0 # 时, 是对每一维度相同位置的数值进行softmax运算,和为1 # 当dim = 1 # 时, 是对某一维度的列进行softmax运算,和为1 # 当dim = 2 # 时, 是对某一维度的行进行softmax运算,和为1 train_images_features_agg = train_images_features_agg.permute(1, 0) # 可以写成permute(0, 1)这里不做任何变化,维数与之前相同 # 如果写成permute(1, 0)得到的就是矩阵的转置 # 可以写成permute(0, 1)这里不做任何变化,维数与之前相同 # 如果写成permute(1, 0)得到的就是矩阵的转置 # 如果三维是permute(0, 1, 2) train_images_targets = F.one_hot(torch.cat(train_images_targets, dim=0)).half() # 将此存储强制转换为半类型 torch.save(train_images_features_agg, train_features_path) torch.save(train_images_targets, train_targets_path) else: train_images_features_agg = torch.load(train_features_path) train_images_targets = torch.load(train_targets_path) # ------------------------------------------saving testing features------------------------------------------ print('start saving testing image features') if not load_test: test_features = [] test_labels = [] with torch.no_grad(): for i, (images, target) in enumerate(tqdm(loader)): images = images.cuda() target = target.cuda() image_features = model.encode_image(images) image_features /= image_features.norm(dim=-1, keepdim=True) test_features.append(image_features) test_labels.append(target) test_features = torch.cat(test_features) test_labels = torch.cat(test_labels) torch.save(test_features, test_features_path) torch.save(test_labels, test_targets_path) else: test_features = torch.load(test_features_path) test_labels = torch.load(test_targets_path) # CLIP Zero-shot top1, top5, n = 0., 0., 0. logits = 100. * test_features @ zeroshot_weights #类别 # @ 两个矩阵相乘 测试集与 样本模版做相似度计算 acc1, acc5 = accuracy(logits, test_labels, topk=(1, 5)) top1 += acc1 top5 += acc5 n += test_features.size(0) top1 = (top1 / n) * 100 top5 = (top5 / n) * 100 print() print(f"CLIP Top-1 accuracy: {top1:.2f}, with zero-shot learning") # # Tip-Adapter # alpha = args.alpha # beta = args.beta # top1, top5, n = 0., 0., 0. # new_knowledge = test_features @ train_images_features_agg # #测试集图像向量(经过解码器) 与 训练集的图像向量 做乘法 计算相似度? f test *F T train # new_logits = ((-1) * (alpha - alpha * new_knowledge)).exp() @ train_images_targets # #1- f test * F T train 论文说是欧式距离 其实余弦距离 A=EXP(-β(1-ftest * FT train) ALtrain 表示query-key匹配度A和缓存值A相乘得到 从cache模型中检索到的值(计算标签是啥) # logits = 100. * test_features @ zeroshot_weights # logits = logits + new_logits * beta # acc1, acc5 = accuracy(logits, test_labels, topk=(1, 5)) # top1 += acc1 # top5 += acc5 # n += test_features.size(0) # top1 = (top1 / n) * 100 # top5 = (top5 / n) * 100 # print() # print(f"AdapterV2 Top-1 accuracy: {top1:.2f}, without training") # print() # # if search: # # alpha_list = [i * (6.0 - 1.0) / 20 + 1 for i in range(20)] # beta_list = [i * (7 - 0.1) / 200 + 0.1 for i in range(200)] # best_top1 = 0 # best_alpha = 0 # best_beta = 0 # # for alpha in alpha_list: # for beta in beta_list: # top1, top5, n = 0., 0., 0. # batch_idx = 0 # # predict # with torch.no_grad(): # test_features = torch.load(test_features_path) # test_labels = torch.load(test_targets_path) # test_features_new = test_features # new_knowledge = test_features @ train_images_features_agg # new_logits = ((-1) * (alpha - alpha * new_knowledge)).exp() @ (train_images_targets) # logits = 100. * test_features_new @ zeroshot_weights # logits = logits + new_logits * beta # # acc1, acc5 = accuracy(logits, test_labels, topk=(1, 5)) # batch_idx += 1 # top1 += acc1 # top5 += acc5 # n += test_features_new.size(0) # top1 = (top1 / n) * 100 # top5 = (top5 / n) * 100 # # if top1 > best_top1: # text = 'New best setting, alpha: {:.2f}, beta: {:.2f}; Top-1 acc: {:.2f}'.format(alpha, beta, top1) # print(text) # best_top1 = top1 # best_alpha = alpha # best_beta = beta # # print(f"After searching, {name}, {k_shot} shot. Best Top-1 {best_top1:.2f}") # print() # # ------------------------------------------ Tip-Adapter-F ------------------------------------------ adapter = Weight_Adapter(model, train_features_path, len(imagenet_classes), k_shot).cuda() optimizer = torch.optim.AdamW(adapter.parameters(), lr=args.lr, eps=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch * len(train_loader_shuffle)) #调整学习率 1有序调整:等间隔调整(Step),按需调整学习率(MultiStep),指数衰减调整(Exponential)和 余弦退火CosineAnnealing。2自适应调整:自适应调整学习率 ReduceLROnPlateau。3自定义调整:自定义调整学习率 LambdaLR。 best_top1 = 0 best_epoch = 0 for train_idx in range(args.train_epoch): adapter.train() correct_all = 0 n = 0 loss_list = [] print('Train time: {:} / {:}'.format(train_idx, args.train_epoch)) alpha = args.alpha beta = args.beta for i, (images, target) in enumerate(tqdm(train_loader_shuffle)): images = images.cuda() target = target.cuda() with torch.no_grad(): image_features = model.encode_image(images) image_features /= image_features.norm(dim=-1, keepdim=True) new_knowledge = adapter.linear1(image_features) new_logits = ((-1) * (alpha - alpha * new_knowledge)).exp() @ (train_images_targets) logits = 100. * image_features @ zeroshot_weights logits = logits + new_logits * beta loss = F.cross_entropy(logits, target) loss_value = loss.item() correct = accuracy(logits, target) correct_all += correct[0] n += len(logits) loss_list.append(loss_value) optimizer.zero_grad() loss.backward() optimizer.step() #optimizer.step()通常⽤在每个mini-batch之中,⽽scheduler.step()通常⽤在epoch⾥⾯,但是不绝对,可以根据具体的需求来做。只有⽤了optimizer.step(),模型才会更新,⽽scheduler.step()是对lr进⾏调整。 scheduler.step() current_lr = scheduler.get_last_lr()[0] text = 'LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_all / n, correct_all, n, sum(loss_list)/len(loss_list)) print(text) # eval adapter.eval() top1, top5, n = 0., 0., 0. with torch.no_grad(): test_features = torch.load(test_features_path) test_labels = torch.load(test_targets_path) test_features_new = test_features new_knowledge = adapter.linear1(test_features_new) new_logits = ((-1) * (alpha - alpha * new_knowledge)).exp() @ (train_images_targets) logits = 100. * test_features_new @ zeroshot_weights logits = logits + new_logits * beta acc1, acc5 = accuracy(logits, test_labels, topk=(1, 5)) top1 += acc1 top5 += acc5 n += test_features.size(0) top1 = (top1 / n) * 100 top5 = (top5 / n) * 100 text = f"Testing Top-1 Accuracy: {top1:.2f}" print(text) print() if top1 > best_top1: best_top1 = top1 best_epoch = train_idx print(f"Best Testing Top-1 Accuracy: {best_top1:.2f}, at Epoch: {best_epoch}") print() print("Begin to search") alpha_list = [i * (6.0 - 1.0) / 20 + 1 for i in range(20)] beta_list = [i * (7 - 0.1) / 200 + 0.1 for i in range(200)] best_top1 = 0 # ------------------------------------------ Search ------------------------------------------ adapter.eval() for alpha in alpha_list: for beta in beta_list: top1, top5, n = 0., 0., 0. batch_idx = 0 # predict with torch.no_grad(): test_features = torch.load(test_features_path) test_labels = torch.load(test_targets_path) test_features_new = test_features new_knowledge = adapter.linear1(test_features_new) new_logits = ((-1) * (alpha - alpha * new_knowledge)).exp() @ (train_images_targets) logits = 100. * test_features_new @ zeroshot_weights logits = logits + new_logits * beta # measure accuracy acc1, acc5 = accuracy(logits, test_labels, topk=(1, 5)) batch_idx += 1 top1 += acc1 top5 += acc5 n += test_features_new.size(0) top1 = (top1 / n) * 100 top5 = (top5 / n) * 100 if top1 > best_top1: text = 'New best setting, alpha: {:.2f}, beta: {:.2f}; Top-1 acc: {:.2f}'.format(alpha, beta, top1) print(text) best_top1 = top1 print(f"{name}, {k_shot} shot. Best Top-1 {best_top1:.2f}") if __name__ == '__main__': main()