|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- model train
- """
- import os
- import sys
- import time
- from src.options.train_options import TrainOptions
- from src.data.ade20k_dataset import Ade20kDataset, get_transform
- from src.models.netG import SPADEGenerator
- from src.utils.lr_schedule import dynamic_lr
- from src.models.netD import MultiscaleDiscriminator
- from src.models.loss import GANLoss, VGGLoss
- from src.utils.adam import Adam
- from src.models.cells import GenTrainOneStepCell, DisTrainOneStepCell, GenWithLossCell, DisWithLossCell
- import mindspore.dataset as ds
- from mindspore import context
- from mindspore import Tensor
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore.train.serialization import save_checkpoint
- from mindspore.communication.management import init
- from mindspore.context import ParallelMode
- import cv2
-
-
- def unzip(zip_file, save_dir):
- import zipfile
- s_time = time.time()
- zip_isexist = zipfile.is_zipfile(zip_file)
- if zip_isexist:
- fz = zipfile.ZipFile(zip_file, 'r')
- data_num = len(fz.namelist())
- print("Extract Start...")
- print("unzip file num: {}".format(data_num))
- data_print = int(data_num / 100) if data_num > 100 else 1
- index = 0
- for file in fz.namelist():
- if index % data_print == 0:
- print("unzip percent: {}%".format(int(index * 100 / data_num)), flush=True)
- index += 1
- fz.extract(file, save_dir)
- print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
- int(int(time.time() - s_time) % 60)))
- print("Extract Done.")
- else:
- print("This is not zip.")
-
-
-
- opt = TrainOptions().parse()
- print(' '.join(sys.argv))
- local_data_path = '/cache/data'
- if opt.distribute:
- device_id = int(os.getenv('DEVICE_ID'))
- device_num = int(os.getenv('RANK_SIZE'))
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
- context.set_context(device_id=device_id)
- context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True, parameter_broadcast=True)
- init()
- if opt.run_modelarts:
- import moxing as mox
-
- local_data_path = os.path.join(local_data_path, str(device_id))
- opt.vgg_ckpt_path = os.path.join(local_data_path, opt.vgg_ckpt_path)
- opt.dataroot = os.path.join(local_data_path, opt.dataroot)
- unzip(os.path.join(opt.dataroot+'.zip'), local_data_path)
- print('local_data_path:', local_data_path)
- print('vgg_ckpt_path:', opt.vgg_ckpt_path)
- print('dataroot:', opt.dataroot)
- print('Download data.')
- mox.file.copy_parallel(src_url=opt.data_url, dst_url=local_data_path)
- else:
- if opt.run_modelarts:
- context.set_context(
- mode=context.GRAPH_MODE,
- # mode=context.PYNATIVE_MODE,
- device_target="Ascend", device_id=int(os.getenv("DEVICE_ID")))
- import moxing as mox
-
- local_data_path = os.path.join(local_data_path, str(opt.id))
- opt.vgg_ckpt_path = os.path.join(local_data_path, opt.vgg_ckpt_path)
- opt.dataroot = os.path.join(local_data_path, opt.dataroot)
- print('local_data_path:', local_data_path)
- print('vgg_ckpt_path:', opt.vgg_ckpt_path)
- print('dataroot:', opt.dataroot)
- print('Download data.')
- mox.file.copy_parallel(src_url=opt.data_url, dst_url=local_data_path)
- else:
- context.set_context(device_target="Ascend", mode=context.GRAPH_MODE, device_id=opt.id)
-
- if opt.run_modelarts:
- import moxing as mox
-
- local_data_path = os.path.join(local_data_path, str(opt.id))
- opt.vgg_ckpt_path = os.path.join(local_data_path, opt.vgg_ckpt_path)
- opt.dataroot = os.path.join(local_data_path, opt.dataroot)
- print('local_data_path:', local_data_path)
- print('vgg_ckpt_path:', opt.vgg_ckpt_path)
- print('dataroot:', opt.dataroot)
- print('Download data.')
- mox.file.copy_parallel(src_url=opt.data_url, dst_url=local_data_path)
-
- data = Ade20kDataset(opt)
- print("dataset [%s] of size %d was created" %
- (type(data).__name__, len(data)))
- if opt.distribute:
- dataset = ds.GeneratorDataset(data,
- ['label', 'image', 'flip_label', 'flip_image', 'crop_pos_label', 'crop_pos_image'],
- shuffle=not opt.serial_batches, num_parallel_workers=device_num,
- num_shards=device_num, shard_id=device_id)
- else:
- dataset = ds.GeneratorDataset(data,
- ['label', 'image', 'flip_label', 'flip_image', 'crop_pos_label', 'crop_pos_image'],
- shuffle=not opt.serial_batches)
- #shuffle=False)
- transform_label = get_transform(opt, method=cv2.INTER_NEAREST, normalize=False, onehot=True)
- transform_image = get_transform(opt)
- dataset = dataset.map(operations=transform_label, input_columns=["label", "flip_label", 'crop_pos_label'],
- output_columns=["label", "seg_label"],
- column_order=["label", "seg_label", 'image', 'flip_image', 'crop_pos_image'])
- dataset = dataset.map(operations=transform_image, input_columns=["image", "flip_image", 'crop_pos_image'],
- output_columns=["image"],
- column_order=["label", "image", "seg_label"])
- dataset = dataset.batch(opt.batchSize, drop_remainder=opt.isTrain)
- batch_dataset_size = dataset.get_dataset_size()
- netG = SPADEGenerator(opt)
- netD = MultiscaleDiscriminator(opt)
- G_lr = Tensor(dynamic_lr(opt, batch_dataset_size, opt.G_lr),
- ms.float32)
- D_lr = Tensor(dynamic_lr(opt, batch_dataset_size, opt.D_lr),
- ms.float32)
- netG_params = list(filter(lambda x: 'param_free_norm' not in x.name, netG.trainable_params()))
- netD_params = list(filter(lambda x: 'bn' not in x.name, netD.trainable_params()))
- optimizer_G = Adam(netG_params, learning_rate=G_lr, beta1=0.0, beta2=0.999)
- optimizer_D = Adam(netD_params, learning_rate=D_lr, beta1=0.0, beta2=0.999)
- GLoss = GANLoss(opt.gan_mode, opt=opt)
- FLoss = nn.L1Loss()
- VggLoss = VGGLoss(opt)
- netG_with_criterion = GenWithLossCell(opt, netG, netD, GLoss, FLoss, VggLoss)
- netD_with_criterion = DisWithLossCell(netG, netD, GLoss)
- netG_train = GenTrainOneStepCell(netG_with_criterion, optimizer_G)
- netD_train = DisTrainOneStepCell(netD_with_criterion, optimizer_D)
- netG_train.set_train()
- netD_train.set_train()
- start_time = time.time()
- for epoch in range(opt.total_epoch):
- start_epoch_time = time.time()
- for i, data in enumerate(dataset):
- input_semantics, real_image = data[0], data[1]
- loss_G = netG_train(input_semantics, real_image)
- loss_D = netD_train(input_semantics, real_image)
- print('[%d/%d][%d/%d]: Loss_D: %f Loss_G: %f'
- % (epoch, opt.total_epoch, i + 1, batch_dataset_size,
- loss_D.asnumpy(), loss_G.asnumpy()))
- speed = (time.time() - start_epoch_time) * 1000 / dataset.get_dataset_size()
- print('average speed: {}ms/step'.format(speed))
-
- if opt.run_modelarts:
- os.mkdir('{0}/ckpt'.format(local_data_path))
- save_checkpoint(netG, '{0}/ckpt/netG_epoch_{1}.ckpt'.format(local_data_path, opt.total_epoch))
- save_checkpoint(netD, '{0}/ckpt/netD_epoch_{1}.ckpt'.format(local_data_path, opt.total_epoch))
- mox.file.copy_parallel(os.path.join(local_data_path, 'ckpt'), opt.train_url)
- else:
- save_checkpoint(netG, './checkpoints/netG_epoch_{0}.ckpt'.format(opt.total_epoch))
- save_checkpoint(netD, './checkpoints/netD_epoch_{0}.ckpt'.format(opt.total_epoch))
|