|
- from __future__ import division
- import os
- import sys
- import time
- import argparse
- from tqdm import tqdm
- import random
- import numpy as np
-
- GPU = [0, 1, 2, 3]
- gpus = ','.join([str(i) for i in GPU])
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpus
-
- import paddle
- import paddle.nn as nn
- import paddle.distributed as dist
- from visualdl import LogWriter
-
- from config import config
- from city_dataloader import get_train_loader
- from network import Network
- from city_dataloader import CityScape
-
- from utils.init_func import init_weight, group_weight
- from utils.dir_utils import mkdir
- import utils.initializer as init
- from engine.lr_policy import WarmUpPolyLR
- from engine.engine import Engine
- from seg_opr.loss_opr import ProbOhemCrossEntropy2D
-
-
-
- if os.getenv('debug') is not None:
- is_debug = os.environ['debug']
- else:
- is_debug = False
-
- '''
- For CutMix
- '''
- import mask_gen
- from custom_collate import SegCollate
- mask_generator = mask_gen.BoxMaskGenerator(prop_range=config.cutmix_mask_prop_range, n_boxes=config.cutmix_boxmask_n_boxes,
- random_aspect_ratio=not config.cutmix_boxmask_fixed_aspect_ratio,
- prop_by_area=not config.cutmix_boxmask_by_size, within_bounds=not config.cutmix_boxmask_outside_bounds,
- invert=not config.cutmix_boxmask_no_invert)
-
- add_mask_params_to_batch = mask_gen.AddMaskParamsToBatch(
- mask_generator
- )
- collate_fn = SegCollate()
- mask_collate_fn = SegCollate(batch_aug_fn=add_mask_params_to_batch)
-
- parser = argparse.ArgumentParser()
-
- engine = Engine(custom_parser=parser)
-
- args = parser.parse_args()
-
- seed = config.seed
- if engine.distributed:
- seed = engine.local_rank
-
- random.seed(seed)
- np.random.seed(seed)
- paddle.seed(seed)
-
-
- # data loader + unsupervised data loader
- train_loader, train_sampler = get_train_loader(engine, CityScape, train_source=config.train_source, \
- unsupervised=False, collate_fn=collate_fn)
- unsupervised_train_loader_0, unsupervised_train_sampler_0 = get_train_loader(engine, CityScape, \
- train_source=config.unsup_source, unsupervised=True, collate_fn=mask_collate_fn)
- unsupervised_train_loader_1, unsupervised_train_sampler_1 = get_train_loader(engine, CityScape, \
- train_source=config.unsup_source, unsupervised=True, collate_fn=collate_fn)
-
- if engine.local_rank == 0:
- mkdir(config.log_dir)
- mkdir(config.log_dir_link)
- mkdir(config.tb_dir)
- mkdir(config.snapshot_dir)
-
- # config network and criterion
- pixel_num = 50000 * config.batch_size // engine.world_size
- criterion = ProbOhemCrossEntropy2D(ignore_index=255, thresh=0.7, min_kept=pixel_num)
- criterion_cps = nn.CrossEntropyLoss(reduction='mean', ignore_index=255, axis=1)
-
-
- if engine.distributed:
- BatchNorm2D = nn.SyncBatchNorm
- else:
- BatchNorm2D = nn.BatchNorm2D
-
- model = Network(config.num_classes, pretrained_model=config.pretrained_model)
- init_weight(model.branch1.head, init.kaiming_normal_,
- BatchNorm2D, config.bn_eps, config.bn_momentum,
- mode='fan_in', nonlinearity='relu')
- init_weight(model.branch2.head, init.kaiming_normal_,
- BatchNorm2D, config.bn_eps, config.bn_momentum,
- mode='fan_in', nonlinearity='relu')
-
- base_lr = config.lr
- if engine.distributed:
- base_lr = config.lr
-
- params_list_l = []
- params_list_l = group_weight(params_list_l, model.branch1.backbone,
- BatchNorm2D, base_lr)
- params_list_l = group_weight(params_list_l, model.branch1.head, BatchNorm2D,
- base_lr) # head lr * 10
-
- optimizer_l = paddle.optimizer.Momentum(parameters=params_list_l,
- learning_rate=base_lr,
- momentum=config.momentum,
- weight_decay=config.weight_decay)
-
- params_list_r = []
- params_list_r = group_weight(params_list_r, model.branch2.backbone,
- BatchNorm2D, base_lr)
- params_list_r = group_weight(params_list_r, model.branch2.head, BatchNorm2D,
- base_lr) # head lr * 10
-
- optimizer_r = paddle.optimizer.Momentum(parameters=params_list_r,
- learning_rate=base_lr,
- momentum=config.momentum,
- weight_decay=config.weight_decay)
-
- # config lr policy
- total_iteration = config.nepochs * config.niters_per_epoch
- lr_policy = WarmUpPolyLR(base_lr, config.lr_power, total_iteration, config.niters_per_epoch * config.warm_up_epoch)
-
- engine.register_state(dataloader=train_loader, model=model,
- optimizer_l=optimizer_l, optimizer_r=optimizer_r)
- if engine.continue_state_object:
- engine.restore_checkpoint() # it will change the state dict of optimizer also
-
- model.train()
-
- if engine.distributed:
- print('distributed !!')
- paddle.distributed.fleet.init(is_collective=True)
- optimizer_l = paddle.distributed.fleet.distributed_optimizer(
- optimizer_l) # The return is Fleet object
- optimizer_r = paddle.distributed.fleet.distributed_optimizer(
- optimizer_r) # The return is Fleet object
- model = paddle.distributed.fleet.distributed_model(model)
-
- print('begin train')
-
- for epoch in range(engine.state.epoch, config.nepochs):
- if engine.distributed:
- train_sampler.set_epoch(epoch)
- bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]'
-
- if is_debug:
- pbar = tqdm(range(10), file=sys.stdout, bar_format=bar_format)
- else:
- pbar = tqdm(range(config.niters_per_epoch), file=sys.stdout, bar_format=bar_format)
-
- dataloader = iter(train_loader)
- unsupervised_dataloader_0 = iter(unsupervised_train_loader_0)
- unsupervised_dataloader_1 = iter(unsupervised_train_loader_1)
-
- sum_loss_sup = 0
- sum_loss_sup_r = 0
- sum_cps = 0
-
- ''' supervised part '''
- for idx in pbar:
- optimizer_l.clear_grad()
- optimizer_r.clear_grad()
- engine.update_iteration(epoch, idx)
- start_time = time.time()
-
- minibatch = dataloader.next()
- unsup_minibatch_0 = unsupervised_dataloader_0.next()
- unsup_minibatch_1 = unsupervised_dataloader_1.next()
-
- imgs = minibatch['data']
- gts = minibatch['label']
- unsup_imgs_0 = unsup_minibatch_0['data']
- unsup_imgs_1 = unsup_minibatch_1['data']
- mask_params = unsup_minibatch_0['mask_params']
-
-
- # unsupervised loss on model/branch#1
- batch_mix_masks = mask_params
- unsup_imgs_mixed = unsup_imgs_0 * (1 - batch_mix_masks) + unsup_imgs_1 * batch_mix_masks
- with paddle.no_grad():
- # Estimate the pseudo-label with branch#1 & supervise branch#2
- logits_u0_tea_1, _ = model(unsup_imgs_0)
- logits_u1_tea_1, _ = model(unsup_imgs_1)
- logits_u0_tea_1 = logits_u0_tea_1.detach()
- logits_u1_tea_1 = logits_u1_tea_1.detach()
- # Estimate the pseudo-label with branch#2 & supervise branch#1
- _, logits_u0_tea_2 = model(unsup_imgs_0)
- _, logits_u1_tea_2 = model(unsup_imgs_1)
- logits_u0_tea_2 = logits_u0_tea_2.detach()
- logits_u1_tea_2 = logits_u1_tea_2.detach()
-
- # Mix teacher predictions using same mask
- # It makes no difference whether we do this with logits or probabilities as
- # the mask pixels are either 1 or 0
- logits_cons_tea_1 = logits_u0_tea_1 * (1 - batch_mix_masks) + logits_u1_tea_1 * batch_mix_masks
- ps_label_1 = paddle.argmax(logits_cons_tea_1, axis=1)
- logits_cons_tea_2 = logits_u0_tea_2 * (1 - batch_mix_masks) + logits_u1_tea_2 * batch_mix_masks
- ps_label_2 = paddle.argmax(logits_cons_tea_2, axis=1)
-
- # Get student#1 and #2 prediction for mixed image
- logits_cons_stu_1, logits_cons_stu_2 = model(unsup_imgs_mixed)
-
- cps_loss = criterion_cps(logits_cons_stu_1, ps_label_2) + criterion_cps(logits_cons_stu_2, ps_label_1)
- dist.all_reduce(cps_loss, dist.ReduceOp.SUM)
- cps_loss = cps_loss / engine.world_size
- cps_loss = cps_loss * config.cps_weight
-
- # supervised loss on both models
- sup_pred_l, sup_pred_r = model(imgs)
-
- loss_sup = criterion(sup_pred_l, gts)
- dist.all_reduce(loss_sup, dist.ReduceOp.SUM)
- loss_sup = loss_sup / engine.world_size
-
- loss_sup_r = criterion(sup_pred_r, gts)
- dist.all_reduce(loss_sup_r, dist.ReduceOp.SUM)
- loss_sup_r = loss_sup_r / engine.world_size
-
- current_idx = epoch * config.niters_per_epoch + idx
- lr = lr_policy.get_lr(current_idx)
-
- optimizer_l.set_lr(lr)
- optimizer_r.set_lr(lr)
-
- loss = loss_sup + loss_sup_r + cps_loss
- loss.backward()
-
- optimizer_l.step()
- optimizer_r.step()
-
- print_str = 'Epoch{}/{}'.format(epoch, config.nepochs) \
- + ' Iter{}/{}:'.format(idx + 1, config.niters_per_epoch) \
- + ' lr=%.2e' % optimizer_l.get_lr() \
- + ' loss_sup=%.2f' % loss_sup.item() \
- + ' loss_sup_r=%.2f' % loss_sup_r.item() \
- + ' loss_cps=%.4f' % cps_loss.item()
-
- sum_loss_sup += loss_sup.item()
- sum_loss_sup_r += loss_sup_r.item()
- sum_cps += cps_loss.item()
- pbar.set_description(print_str, refresh=False)
-
- end_time = time.time()
-
- # if engine.local_rank == 0:
- # logger.add_scalar('train_loss_sup', sum_loss_sup / len(pbar), epoch)
- # logger.add_scalar('train_loss_sup_r', sum_loss_sup_r / len(pbar), epoch)
- # logger.add_scalar('train_loss_cps', sum_cps / len(pbar), epoch)
-
- if (epoch > config.nepochs // 3) and (engine.local_rank == 0) and (epoch % config.snapshot_iter == 0) or (epoch == config.nepochs - 1):
- engine.save_and_link_checkpoint(config.snapshot_dir,
- config.log_dir,
- config.log_dir_link)
|