|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/6/15
- -----------------------------------
- """
- import os
- from tqdm import tqdm
-
- import torch
- import torch.nn as nn
- from torch.utils.data.dataloader import DataLoader
- from torch.optim import AdamW
- from transformers import get_linear_schedule_with_warmup
- import torch.distributed.autograd as dist_autograd
- from torch.distributed.optim import DistributedOptimizer
- import torch.distributed.rpc as rpc
- from torch.profiler import profile, ProfilerActivity
-
- from config import args, logger
- from model import DistBertDecomposition
- from data import TextClassificationDataset
-
- devices = {
- "edge": "cpu",
- "cloud": "cuda"
- }
-
- rpc.init_rpc("edge", rank=0, world_size=2)
-
- train_data = TextClassificationDataset(args.edge_pretrain_dir)
- train_data.build_label_index()
- train_data.read_file(args.train_path, max_length=args.max_length)
- train_loader = DataLoader(train_data, batch_size=args.train_batch_size, collate_fn=train_data.collate, shuffle=True)
-
- classifier = DistBertDecomposition(args.cloud_pretrain_dir,
- args.edge_pretrain_dir,
- args.split,
- args.rank,
- train_data.label_nums,
- devices)
-
- criterion = nn.CrossEntropyLoss()
-
- opt = DistributedOptimizer(AdamW, classifier.parameter_rrefs(), lr=args.learning_rate)
-
- classifier.train()
-
- for i, batch in tqdm(enumerate(train_loader), desc="training", leave=False):
- batch_input_ids = batch["batch_input_ids"]
- batch_token_type_ids = batch["batch_token_type_ids"]
- batch_attention_mask = batch["batch_attention_mask"]
- y = batch["batch_labels"].to(devices["edge"])
-
- y_hat = classifier(input_ids=batch_input_ids, token_type_ids=batch_token_type_ids,
- attention_mask=batch_attention_mask).to(devices["edge"])
-
- with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
- with dist_autograd.context() as context_id:
- y_hat = classifier(input_ids=batch_input_ids, token_type_ids=batch_token_type_ids,
- attention_mask=batch_attention_mask).to(devices["edge"])
- loss = criterion(y_hat, y)
- dist_autograd.backward(context_id, [loss])
- opt.step(context_id)
- prof.export_chrome_trace("trace.json")
- break
-
- # torch.save(classifier.state_dict(), args.save_path)
-
- rpc.shutdown()
|