|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/6/15
- -----------------------------------
- """
- import os
- from tqdm import tqdm
-
- import torch
- import torch.nn as nn
- from torch.utils.data.dataloader import DataLoader
- from torch.optim import AdamW
- from transformers import get_linear_schedule_with_warmup
- import numpy as np
-
- from config import args, logger
- from model_sparse import Classification
- from data import TextClassificationDataset
-
- # infix = f"{int(args.fw_threshold * 10)}_{int(args.bw_threshold * 100000)}"
- # postfix = "compensate"
- #
- # logger.info(f"{args.fw_threshold}_{args.bw_threshold}_{postfix}")
- # if not os.path.exists(f"output_{infix}_{postfix}"):
- # os.mkdir(f"output_{infix}_{postfix}")
- # if not os.path.exists(f"ratio_{infix}_{postfix}"):
- # os.mkdir(f"ratio_{infix}_{postfix}")
-
- device = torch.device("cuda")
-
- train_data = TextClassificationDataset(args.pretrain_dir)
- train_data.build_label_index()
- train_data.read_file(args.train_path, max_length=args.max_length)
- train_loader = DataLoader(train_data, batch_size=args.train_batch_size, collate_fn=train_data.collate, shuffle=True)
-
- dev_data = TextClassificationDataset(args.pretrain_dir)
- dev_data.build_label_index()
- dev_data.read_file(args.dev_path, max_length=args.max_length)
- dev_loader = DataLoader(dev_data, batch_size=args.dev_batch_size, collate_fn=dev_data.collate)
-
- train_steps = len(train_loader) * args.epoch
- warmup_steps = train_steps * args.warm_up_pct
-
- classifier = Classification(args.pretrain_dir, train_data.label_nums, len(train_data))
- "-------------------------- svd -------------------------------------"
- # layer = classifier.bert.encoder.layer[5].output.dense
- # classifier.bert.encoder.layer[5].output.dense = SvdLinear(layer, rank=args.rank)
- "--------------------------------------------------------------------------------"
- classifier.to(device)
-
- # hook = Hook()
- # for i, layer in enumerate(classifier.bert.encoder.layer):
- # layer.output.LayerNorm.register_forward_hook(hook.fw_compensation(name=f"block{i}"))
- # layer.output.LayerNorm.register_full_backward_hook(hook.sparse_bw(name=f"block{i}"))
- # layer.output.LayerNorm.register_full_backward_hook(hook.save_bw(name=f"block{i}"))
- # layer.attention.register_forward_pre_hook(hook.mod_pre(name=f"block{i}"))
-
- # for n,p in classifier.named_parameters():
- # print(n)
-
- criterion = nn.CrossEntropyLoss()
-
- no_decay = ['bias', 'LayerNorm']
- optimizer_grouped_parameters = [
- {'params': [p for n, p in classifier.named_parameters() if not any(nd in n for nd in no_decay)],
- 'weight_decay': 0.01},
- {'params': [p for n, p in classifier.named_parameters() if any(nd in n for nd in no_decay)],
- 'weight_decay': 0.0}]
-
- opt = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8)
-
- scheduler = get_linear_schedule_with_warmup(opt, warmup_steps, train_steps)
-
- for e in range(args.epoch):
- classifier.train()
- for i, batch in tqdm(enumerate(train_loader), desc="training", leave=False):
- batch_input_ids = batch["batch_input_ids"].to(device)
- batch_line_ids = batch["batch_line_ids"]
- batch_attention_mask = batch["batch_attention_mask"].to(device)
- y = batch["batch_labels"].to(device)
-
- y_hat = classifier(input_ids=batch_input_ids,
- # line_ids=batch_line_ids,
- attention_mask=batch_attention_mask)
-
- loss = criterion(y_hat, y)
-
- opt.zero_grad()
- loss.backward()
- # nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0, norm_type=2)
- opt.step()
- scheduler.step()
-
- # if i % 400 == 0 and i > 0:
- # for k, v in hook.fw.items():
- # np.save(f"output_{infix}_{postfix}/e{e}_{k}_s{i}_fw", v.cpu().numpy())
- # for k, v in hook.bw.items():
- # np.save(f"output_{infix}_{postfix}/e{e}_{k}_s{i}_bw", v.cpu().numpy())
- # with open(f"ratio_{infix}_{postfix}/e{e}_s{i}.txt", 'w', encoding="utf-8") as f:
- # for k, v in hook.ratio.items():
- # f.write(f"{k}\t{v}\n")
-
- classifier.eval()
- dev_loss = 0.0
- right_num = 0
- for batch in tqdm(dev_loader, desc="eval dev data", leave=False):
- batch_input_ids = batch["batch_input_ids"].to(device)
- # batch_token_type_ids = batch["batch_token_type_ids"].to(device)
- batch_attention_mask = batch["batch_attention_mask"].to(device)
- y = batch["batch_labels"].to(device)
-
- with torch.no_grad():
- y_hat = classifier(input_ids=batch_input_ids,
- attention_mask=batch_attention_mask,
- )
- loss = criterion(y_hat, y)
- pred = torch.argmax(y_hat, dim=-1)
- loss = loss.detach().cpu().item()
- right = torch.eq(y, pred).sum().detach().item()
- dev_loss += loss * len(batch)
- right_num += right
-
- all_num = len(dev_data)
- logger.info(f"epoch: {e} dev loss: {dev_loss / all_num} dev acc: {right_num / all_num}")
-
- # torch.save(classifier.state_dict(), args.save_path)
|