|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2023/3/24
- -----------------------------------
- """
-
- from tqdm import tqdm
-
- import torch
- import torch.nn as nn
- from torch.utils.data.dataloader import DataLoader
- from torch.optim import AdamW
- from transformers import get_linear_schedule_with_warmup
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
- import sacrebleu
-
- from config import args, logger
- from data import TranslationDataset
-
- device = torch.device(f"cuda:{args.cuda}")
- model = MBartForConditionalGeneration.from_pretrained(args.pretrain_dir)
- tokenizer = MBart50TokenizerFast.from_pretrained(args.pretrain_dir, src_lang="zh_CN", tgt_lang="en_XX")
-
- train_data = TranslationDataset(tokenizer)
- train_data.read_file(args.src_path, args.tgt_path)
- train_loader = DataLoader(train_data,
- batch_size=args.train_batch_size,
- collate_fn=train_data.data_collator,
- shuffle=True)
-
- train_steps = len(train_loader) * args.epoch
- warmup_steps = train_steps * args.warm_up_pct
-
- no_decay = ['bias', 'LayerNorm']
- optimizer_grouped_parameters = [
- {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
- 'weight_decay': 0.01},
- {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
- 'weight_decay': 0.0}]
-
- opt = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8)
- scheduler = get_linear_schedule_with_warmup(opt, warmup_steps, train_steps)
-
- model.to(device)
- for e in range(args.epoch):
- model.train()
- for i, batch in tqdm(enumerate(train_loader), desc="training", leave=False):
- input_ids = batch["input_ids"].to(device)
- attention_mask = batch["attention_mask"].to(device)
- labels = batch["labels"].to(device)
-
- loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
-
- opt.zero_grad()
- loss.backward()
- # nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0, norm_type=2)
- opt.step()
- scheduler.step()
-
- model.eval()
- src = []
- tgt = []
- hyp = []
-
- with open(args.dev_src_path, encoding="utf-8") as f:
- for line in f.readlines():
- line = line.strip()
- if not line:
- continue
- src.append(line)
-
- with open(args.dev_tgt_path, encoding="utf-8") as f:
- for line in f.readlines():
- line = line.strip()
- if not line:
- continue
- tgt.append(line)
-
- for data in tqdm(src, desc="eval", leave=False):
- encoded_zh = tokenizer(data, return_tensors="pt")
- for k, v in encoded_zh.items():
- encoded_zh[k] = v.to("cuda")
-
- # input_ids = torch.tensor([input_ids]).to("cuda")
- with torch.no_grad():
- generated_tokens = model.generate(
- **encoded_zh,
- forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"],
- max_new_tokens=1024
- )
- hyp.append(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0])
-
- bleu = sacrebleu.corpus_bleu(hyp, [tgt]).score
-
- logger.info(f"epoch:{e} bleu:{bleu}")
-
- # torch.save(model.state_dict(), args.save_path)
- model.save_pretrained(f"ft_{args.epoch}")
|