|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/6/15
- -----------------------------------
- """
- from tqdm import tqdm
- import time
-
- import torch
- import torch.nn as nn
- from torch.utils.data.dataloader import DataLoader
- from torch.optim import AdamW
- from transformers import get_linear_schedule_with_warmup
-
- from config import args, logger
- from model import Classification, SvdLinear, NoResViTLayer
- from data import Cifar100Dataset
-
- layer_rank = []
- rank_group = args.rank.split('-')
- for group in rank_group:
- group = group.split('*')
- if len(group) == 1:
- layer_rank.append(int(group[0]))
- else:
- layer_rank.extend([int(group[1])] * int(group[0]))
- logger.info(layer_rank)
- assert len(layer_rank) == 12
-
- device = torch.device(f"cuda:{args.cuda}")
- # device = torch.device("cpu")
-
- train_data = Cifar100Dataset(args.pretrain_dir)
- # train_data.read_file(args.train_path)
- train_data.load("train.dat")
- train_loader = DataLoader(train_data, batch_size=args.train_batch_size, collate_fn=train_data.collate, shuffle=True)
-
- dev_data = Cifar100Dataset(args.pretrain_dir)
- # dev_data.read_file(args.test_path)
- dev_data.load("test.dat")
- dev_loader = DataLoader(dev_data, batch_size=args.dev_batch_size, collate_fn=dev_data.collate)
-
- train_steps = len(train_loader) * args.epoch
- warmup_steps = train_steps * args.warm_up_pct
-
- classifier = Classification(args.pretrain_dir, train_data.label_nums)
- classifier.to(device)
- # for i, layer in enumerate(classifier.vit.encoder.layer):
- # if layer_rank[i] != 0:
- # linear = layer.output.dense
- # layer.output.dense = SvdLinear(linear, rank=layer_rank[i])
-
- for i, layer in enumerate(classifier.vit.encoder.layer):
- if layer_rank[i] != 0:
- classifier.vit.encoder.layer[i] = NoResViTLayer(layer, rank=layer_rank[i])
-
- criterion = nn.CrossEntropyLoss()
-
- no_decay = ['bias', 'LayerNorm']
- optimizer_grouped_parameters = [
- {'params': [p for n, p in classifier.named_parameters() if not any(nd in n for nd in no_decay)],
- 'weight_decay': 0.01},
- {'params': [p for n, p in classifier.named_parameters() if any(nd in n for nd in no_decay)],
- 'weight_decay': 0.0}]
-
- opt = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8)
-
- scheduler = get_linear_schedule_with_warmup(opt, warmup_steps, train_steps)
-
- for e in range(args.epoch):
- classifier.train()
- logger.info("epoch start!")
- for i, batch in tqdm(enumerate(train_loader), desc="training", leave=False):
- batch_pixel_values = batch["batch_pixel_values"].to(device)
- y = batch["batch_labels"].to(device)
- y_hat = classifier(batch_pixel_values)
- loss = criterion(y_hat, y)
-
- opt.zero_grad()
- loss.backward()
- # nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0, norm_type=2)
- opt.step()
- scheduler.step()
- logger.info("epoch end!")
-
- # classifier.eval()
- #
- # dev_loss = 0.0
- # right_num = 0
- # for batch in tqdm(train_loader, desc="eval train data", leave=False):
- # batch_input_ids = batch["batch_input_ids"].to(device)
- # batch_token_type_ids = batch["batch_token_type_ids"].to(device)
- # batch_attention_mask = batch["batch_attention_mask"].to(device)
- # y = batch["batch_labels"].to(device)
- #
- # with torch.no_grad():
- # y_hat = classifier(input_ids=batch_input_ids, token_type_ids=batch_token_type_ids,
- # attention_mask=batch_attention_mask)
- # loss = criterion(y_hat, y)
- # pred = torch.argmax(y_hat, dim=-1)
- # loss = loss.detach().cpu().item()
- # right = torch.eq(y, pred).sum().detach().item()
- # dev_loss += loss * len(batch)
- # right_num += right
- #
- # all_num = len(train_data)
- # logger.info(f"epoch: {e} train set loss: {dev_loss / all_num} acc: {right_num / all_num}")
- #
- dev_loss = 0.0
- right_num = 0
- for batch in tqdm(dev_loader, desc="eval dev data", leave=False):
- batch_pixel_values = batch["batch_pixel_values"].to(device)
- y = batch["batch_labels"].to(device)
-
- with torch.no_grad():
- y_hat = classifier(batch_pixel_values)
- loss = criterion(y_hat, y)
- pred = torch.argmax(y_hat, dim=-1)
- loss = loss.detach().cpu().item()
- right = torch.eq(y, pred).sum().detach().item()
- dev_loss += loss * len(batch)
- right_num += right
-
- all_num = len(dev_data)
- logger.info(f"epoch: {e} dev set loss: {dev_loss / all_num} acc: {right_num / all_num}")
- #
- # torch.save(classifier.state_dict(), args.save_path)
|