Browse Source

train on ModelArts

master
root 2 months ago
parent
commit
619223c7bb
14 changed files with 66011 additions and 49878 deletions
  1. +4
    -1
      .gitignore
  2. +65443
    -49782
      analyze_fail.dat
  3. +14
    -12
      eval.py
  4. +3
    -3
      extract_subimages.py
  5. +1
    -1
      src/config/config.py
  6. +2
    -2
      src/dataset/dataset_DIV2K.py
  7. +2
    -6
      src/model/RRDB_Net.py
  8. +1
    -1
      src/model/cell.py
  9. +4
    -16
      src/model/discriminator_net.py
  10. +2
    -2
      src/model/loss.py
  11. +0
    -3
      src/utils/eval_util.py
  12. +100
    -49
      train.py
  13. +266
    -0
      train_ModelArts.py
  14. +169
    -0
      train_psnr_ModelArts.py

+ 4
- 1
.gitignore View File

@@ -110,4 +110,7 @@ checkpoints/
module_test.py
somas_meta/
analyze_fail.dat
src/model/vgg19_ImageNet.ckpt
src/model/vgg19_ImageNet.ckpt
DIV2K.zip
src/model/psnr-1_31523.ckpt
analyze_fail.dat

+ 65443
- 49782
analyze_fail.dat
File diff suppressed because it is too large
View File


+ 14
- 12
eval.py View File

@@ -8,6 +8,7 @@ import numpy as np
import cv2
import mindspore.nn as nn
from PIL import Image
from mindspore.nn import PSNR,SSIM
from mindspore import Tensor, context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
@@ -91,22 +92,23 @@ def test():
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(eval_net, param_dict)
eval_net.set_train(False)

test_data_iter = dataset.create_dict_iter()
ssim = nn.SSIM()
psnr = nn.PSNR()
test_data_iter = dataset.create_dict_iter(out_numpy)

for i, sample in enumerate(test_data):
lr = sample['inputs']
real_hr = sample['target']
gen_hr = eval_net(lr)
bic_hr = (imresize_np(lr*255).asnumpy(),4).astype(np.uint8)
real_hr = (real_hr*255).asnumpy()
gen_hr = (gen_hr*255).asnumpy()
print(str_format.format(
calculate_psnr(rgb2ycbcr(bic_img), rgb2ycbcr(real_hr)),
calculate_ssim(rgb2ycbcr(bic_img), rgb2ycbcr(real_hr)),
calculate_psnr(rgb2ycbcr(gen_hr), rgb2ycbcr(real_hr)),
calculate_ssim(rgb2ycbcr(gen_hr), rgb2ycbcr(real_hr))))
# 这里用mindspore的双三次插值采样结果
bic_hr = None
psnr_bic = psnr(gen_hr,bic_hr)
psnr_real = psnr(gen_hr,real_hr)
ssim_bic = ssim(gen_hr,bic_hr)
ssim_real = ssim(gen_hr,real_hr)
print(psnr_bic,psnr_real,ssim_bic,ssim_real)
result_img_path = os.path.join(args_opt.results_path + "DIV2K", 'Bic_SR_HR_' + str(i))
results_img = np.concatenate((bic_img, sr_img, hr_img), 1)
cv2.imwrite(result_img_path, results_img)
if i%10 == 0:
results_img = np.concatenate((bic_img.asnumpy(), sr_img.asnumpy(), hr_img.asnumpy()), 1)
cv2.imwrite(result_img_path, results_img)

+ 3
- 3
extract_subimages.py View File

@@ -25,9 +25,9 @@ def main():
if mode == 'single':
opt['input_folder'] = './data/DIV2K/DIV2K_train_HR'
opt['save_folder'] = './data/DIV2K/DIV2K800_sub'
opt['crop_sz'] = 512 # the size of each sub-image
opt['step'] = 256 # step of the sliding crop window
opt['thres_sz'] = 52 # size threshold
opt['crop_sz'] = 128 # the size of each sub-image
opt['step'] = 64 # step of the sliding crop window
opt['thres_sz'] = 12 # size threshold
extract_signle(opt)
elif mode == 'pair':
GT_folder = './data/DIV2K/DIV2K_train_HR'


+ 1
- 1
src/config/config.py View File

@@ -11,7 +11,7 @@ ESRGAN_config = {
"D_nf": 64,
# training setting
"niter": 400000,
"lr_G": [1e-4, 5e-5, 2e-5, 1e-5],
"lr_G": [2e-4, 1e-4, 5e-5, 2e-5],
"lr_D": [1e-4, 5e-5, 2e-5, 1e-5],
"lr_steps": [50000, 100000, 200000, 300000],


+ 2
- 2
src/dataset/dataset_DIV2K.py View File

@@ -43,8 +43,8 @@ def augment(img_in, img_tar, flip_h=True, rot=True):
img_in = img_in.rotate(180)
img_tar = img_tar.rotate(180)
info_aug['trans'] = True
img_in = img_in.resize((128,128),Image.BILINEAR)
img_tar = img_tar.resize((512,512),Image.BILINEAR)
img_in = img_in.resize((32,32),Image.BILINEAR)
img_tar = img_tar.resize((128,128),Image.BILINEAR)
return img_in, img_tar, info_aug


+ 2
- 6
src/model/RRDB_Net.py View File

@@ -46,9 +46,8 @@ class ResidualDenseBlock_5C(nn.Cell):
class RRDB(nn.Cell):
"""Residual in Residual Dense Block"""
def __init__(self, nf, gc=32, res_beta=0.2):
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.res_beta = res_beta
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
@@ -66,10 +65,7 @@ class RRDBNet(nn.Cell):
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(
in_nc,
nf,
3,
1,
in_nc,nf,3,1,
padding=1,
has_bias=True,
pad_mode="pad",


+ 1
- 1
src/model/cell.py View File

@@ -67,7 +67,7 @@ class GeneratorLossCell(nn.Cell):
generator_loss = (
5e-3 * adversarial_loss
+ 1.0 * perceptual_loss
+ 1e-2 * content_loss
+ 1e-1 * content_loss
)
return (fake_hr, generator_loss, content_loss, perceptual_loss, adversarial_loss)



+ 4
- 16
src/model/discriminator_net.py View File

@@ -1,11 +1,9 @@
import mindspore.nn as nn
import mindspore
import numpy as np
from mindspore import Tensor
class VGGStyleDiscriminator512(nn.Cell):
"""VGG style discriminator with input size 512 x 512.
class VGGStyleDiscriminator128(nn.Cell):
"""VGG style discriminator with input size 128 x 128.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features.
@@ -13,7 +11,7 @@ class VGGStyleDiscriminator512(nn.Cell):
"""
def __init__(self, num_in_ch, num_feat):
super(VGGStyleDiscriminator512, self).__init__()
super(VGGStyleDiscriminator128, self).__init__()
self.conv0_0 = nn.Conv2d(
num_in_ch, num_feat, 3, 1, padding=1, has_bias=True, pad_mode="pad"
@@ -50,15 +48,8 @@ class VGGStyleDiscriminator512(nn.Cell):
)
self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_0 = nn.Conv2d(
num_feat * 8, num_feat * 8, 3, 1, padding=1, has_bias=False, pad_mode="pad"
)
self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_1 = nn.Conv2d(
num_feat * 8, num_feat * 8, 4, 2, padding=1, has_bias=False, pad_mode="pad"
)
self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.linear1 = nn.Dense(num_feat * 8 * 4 * 4* 16, 100)
self.linear1 = nn.Dense(num_feat * 8 * 4 * 4 * 4, 100)
self.linear2 = nn.Dense(100, 1)
self.lrelu = nn.LeakyReLU(0.2)
self.flatten = nn.Flatten()
@@ -83,12 +74,9 @@ class VGGStyleDiscriminator512(nn.Cell):
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: (8, 8)
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: (4, 4)
feat = self.flatten(feat)
feat = self.lrelu(self.linear1(feat))
out = self.linear2(feat)
return out

+ 2
- 2
src/model/loss.py View File

@@ -2,7 +2,7 @@ import mindspore
from mindspore import nn as nn
from src.model.VGG import vgg19
import mindspore.ops.functional as F
from mindspore.train.serialization import load_checkpoint, load_param_into_net
class PerceptualLoss(nn.Cell):
# 内容损失
def __init__(self,pretrained_path):
@@ -10,7 +10,7 @@ class PerceptualLoss(nn.Cell):
vgg = vgg19()
loss_network = vgg.layers[:35]
param_dict = load_checkpoint(pretrained_path)
load_param_into_net(vgg_model,param_dict)
load_param_into_net(vgg,param_dict)
for l in loss_network:
l.requires_grad = False
self.loss_network = loss_network


+ 0
- 3
src/utils/eval_util.py View File

@@ -5,9 +5,6 @@ import numpy as np
from absl import logging
import numpy as np
from src.utils.matlab_functions import bgr2ycbcr
def reorder_image(img, input_order="HWC"):
"""Reorder images to 'HWC' order.


+ 100
- 49
train.py View File

@@ -3,9 +3,11 @@ import os
import argparse
import ast
import numpy as np
from PIL import Image
import mindspore
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.ops import operations as ops
from mindspore import Tensor, context
from mindspore.common import set_seed
@@ -19,41 +21,69 @@ from mindspore.train.callback import (
)

from src.model.RRDB_Net import RRDBNet
from src.model.discriminator_net import VGGStyleDiscriminator512
from src.model.discriminator_net import VGGStyleDiscriminator128
from src.model.cell import GeneratorLossCell, DiscriminatorLossCell, TrainOneStepCellDis, TrainOneStepCellGen
from src.config.config import ESRGAN_config
from src.dataset.dataset_DIV2K import get_dataset_DIV2K

# save image
def save_image(img, img_path):
mul = ops.Mul()
add = ops.Add()
if isinstance(img, Tensor):
img = mul(img, 0.5)
img = add(img, 0.5)
img = img.asnumpy().astype(np.uint8).transpose((0, 2, 3, 1))

elif not isinstance(img, np.ndarray):
raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img)))

IMAGE_SIZE = 64 # Image size
IMAGE_ROW = 8 # Row num
IMAGE_COLUMN = 8 # Column num
PADDING = 2 #Interval of small pictures
to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE + PADDING * (IMAGE_COLUMN + 1),
IMAGE_ROW * IMAGE_SIZE + PADDING * (IMAGE_ROW + 1))) # create a new picture
# cycle
i = 0
for y in range(1, IMAGE_ROW + 1):
for x in range(1, IMAGE_COLUMN + 1):
from_image = Image.fromarray(img[i])
to_image.paste(from_image, ((x - 1) * IMAGE_SIZE + PADDING * x, (y - 1) * IMAGE_SIZE + PADDING * y))
i = i + 1

to_image.save(img_path) #save


def parse_args():
parser = argparse.ArgumentParser("ESRGAN")
parser.add_argument('--device_target', type=str,
default="Ascend", help='Platform')
parser.add_argument('--device_id', type=int,
default=6, help='device_id')
default=3, help='device_id')
parser.add_argument(
"--aug", type=bool, default=True, help="Use augement for dataset"
)
parser.add_argument('--data_dir', type=str,
default=None, help='Dataset path')
parser.add_argument("--batch_size", type=int, default=4, help="batch_size")
parser.add_argument("--batch_size", type=int, default=16, help="batch_size")
parser.add_argument("--epoch_size", type=int,
default=20, help="epoch_size")
parser.add_argument("--rank", type=int, default=0,
parser.add_argument('--Giters', type=int, default=5, help='number of G iters per each D iter')
parser.add_argument("--rank", type=int, default=1,
help="local rank of distributed")
parser.add_argument(
"--group_size", type=int, default=1, help="world size of distributed"
"--group_size", type=int, default=0, help="world size of distributed"
)
parser.add_argument(
"--save_steps", type=int, default=3000, help="steps interval for saving"
"--keep_checkpoint_max", type=int, default=30, help="max checkpoint for saving"
)
parser.add_argument(
"--keep_checkpoint_max", type=int, default=20, help="max checkpoint for saving"
)
parser.add_argument(
"--model_save_step", type=int, default=2000, help="step num for saving"
"--model_save_step", type=int, default=3000, help="step num for saving"
)
parser.add_argument('--snapshots', type=int, default=3, help='Snapshots')
parser.add_argument('--Gpretrained_path', type=str, default="src/model/psnr-1_31523.ckpt")
parser.add_argument('--experiment', default="./images", help='Where to store samples and models')
parser.add_argument("--run_distribute", type=ast.literal_eval,
default=False, help="Run distribute, default: false.")
# Modelarts
@@ -84,15 +114,17 @@ def train():
rank = 0
device_num = 1
dataset, dataset_len = get_dataset_DIV2K(
base_dir="./data", downsample_factor=config["down_factor"], mode="train", aug=args_opt.aug, repeat=1, batch_size=args_opt.batch_size)
base_dir="./data", downsample_factor=config["down_factor"], mode="train", aug=args_opt.aug, repeat=1, batch_size=args_opt.batch_size,shard_id=args_opt.group_size,shard_num=args_opt.rank,num_readers=4)
generator = RRDBNet(
in_nc=config["ch_size"],
out_nc=config["ch_size"],
nf=config["G_nf"],
nb=config["G_nb"],
)
discriminator = VGGStyleDiscriminator512(
discriminator = VGGStyleDiscriminator128(
num_in_ch=config["ch_size"], num_feat=config["D_nf"])
param_dict = load_checkpoint(args_opt.Gpretrained_path)
load_param_into_net(generator, param_dict)
# Define network with loss
G_loss_cell = GeneratorLossCell(generator, discriminator,config["vgg_pretrain_path"])
D_loss_cell = DiscriminatorLossCell(discriminator)
@@ -118,7 +150,7 @@ def train():
print('Start Training')

ckpt_config = CheckpointConfig(
save_checkpoint_steps=args_opt.model_save_step)
save_checkpoint_steps=args_opt.model_save_step,keep_checkpoint_max=args_opt.keep_checkpoint_max)
ckpt_cb_g = ModelCheckpoint(
config=ckpt_config, directory="./checkpoints", prefix='Generator')
ckpt_cb_d = ModelCheckpoint(
@@ -139,41 +171,60 @@ def train():
ckpt_cb_g.begin(run_context_g)
ckpt_cb_d.begin(run_context_d)
start = time()

minibatch = args_opt.batch_size
ones = ops.Ones()
zeros = ops.Zeros()
real_labels = ones((minibatch, 1), mindspore.float32)
fake_labels = zeros((minibatch, 1), mindspore.float32)+Tensor(np.random.random(size=(minibatch,1)),dtype=mindspore.float32)*0.1
dis_iterations = 0
for epoch in range(args_opt.epoch_size):
G_epoch_loss = 0
D_epoch_loss = 0
G_content_epoch_loss = 0
G_perception_epoch_loss = 0
G_adversarial_epoch_loss = 0

for iteration, batch in enumerate(dataset.create_dict_iterator(), 1):
inputs = Tensor(batch["inputs"],dtype=mindspore.float32)
target = Tensor(batch["target"],dtype=mindspore.float32)
minibatch = inputs.shape[0]
ones = ops.Ones()
zeros = ops.Zeros()
real_labels = ones((minibatch, 1), mindspore.float32)
fake_labels = zeros((minibatch, 1), mindspore.float32) # torch.rand(minibatch,1)*0.3
generator_loss_all = G_trainOneStep(
inputs, target, fake_labels, real_labels)
fake_hr = generator_loss_all[0]
generator_loss = generator_loss_all[1]
data_iter = dataset.create_dict_iterator()
length = dataset_len
i = 0
while i < length:
############################
# (1) Update G network
###########################
for p in generator.trainable_params(): # reset requires_grad
p.requires_grad = True # they are set to False below in netG update

# train the discriminator Diters times
if dis_iterations < 25 or dis_iterations % 500 == 0:
Giters = 100
else:
Giters = args_opt.Giters
j = 0
while j < Giters and i < length:
j += 1

# clamp parameters to a cube
# for p in netD.trainable_params():
# p.data.clamp_(args_opt.clamp_lower, args_opt.clamp_upper)

data = data_iter.__next__()
i += 1

# train with real and fake
inputs = Tensor(data["inputs"],dtype=mindspore.float32)
target = Tensor(data["target"],dtype=mindspore.float32)
generator_loss_all = G_trainOneStep(inputs, target, fake_labels, real_labels)
fake_hr = generator_loss_all[0]
generator_loss = generator_loss_all[1]

############################
# (2) Update G network
###########################
for p in generator.trainable_params():
p.requires_grad = False # to avoid computation

discriminator_loss = D_trainOneStep(fake_hr,target)
G_epoch_loss += generator_loss
D_epoch_loss += discriminator_loss
print(epoch,iteration, dataset_len, generator_loss.asnumpy(), discriminator_loss.asnumpy())

print(
"===> Epoch: [%5d] Complete: Avg. Loss G: %.4f D: %.4f" %(
epoch, np.true_divide(G_epoch_loss.asnumpy(), dataset_len), np.true_divide(D_epoch_loss.asnumpy(), dataset_len)))
if (epoch+1) % (opt.snapshots) == 0:
print('===> Saving model')
cb_params_d.cur_step_num = epoch + 1
cb_params_g.cur_step_num = epoch + 1
ckpt_cb_g.step_end(run_context_g)
ckpt_cb_d.step_end(run_context_d)


if __name__ == '__main__':
train()
dis_iterations += 1

print('[%d/%d][%d/%d][%d] Loss_D: %10f Loss_G: %10f'
% (epoch, args_opt.epoch_size, i, length, dis_iterations,
np.sum(discriminator_loss.asnumpy()), generator_loss.asnumpy()))
if dis_iterations % 10 == 0:
save_image(target, '{0}/real_samples.png'.format(args_opt.experiment))
save_image(fake_hr, '{0}/fake_samples_{1}.png'.format(args_opt.experiment, dis_iterations))
if __name__ == "__main__":
train()

+ 266
- 0
train_ModelArts.py View File

@@ -0,0 +1,266 @@
from time import time
import os
import argparse
import ast
import numpy as np
from PIL import Image
import mindspore
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.ops import operations as ops
from mindspore import Tensor, context
from mindspore.common import set_seed
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import (
CheckpointConfig,
ModelCheckpoint,
_InternalCallbackParam,
RunContext,
)
import moxing as mox
from src.model.RRDB_Net import RRDBNet
from src.model.discriminator_net import VGGStyleDiscriminator128
from src.model.cell import GeneratorLossCell, DiscriminatorLossCell, TrainOneStepCellDis, TrainOneStepCellGen
from src.config.config import ESRGAN_config
from src.dataset.dataset_DIV2K import get_dataset_DIV2K

# save image
def save_image(img, img_path):
mul = ops.Mul()
add = ops.Add()
if isinstance(img, Tensor):
img = mul(img, 0.5)
img = add(img, 0.5)
img = img.asnumpy().astype(np.uint8).transpose((0, 2, 3, 1))

elif not isinstance(img, np.ndarray):
raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img)))

IMAGE_SIZE = 64 # Image size
IMAGE_ROW = 8 # Row num
IMAGE_COLUMN = 8 # Column num
PADDING = 2 #Interval of small pictures
to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE + PADDING * (IMAGE_COLUMN + 1),
IMAGE_ROW * IMAGE_SIZE + PADDING * (IMAGE_ROW + 1))) # create a new picture
# cycle
i = 0
for y in range(1, IMAGE_ROW + 1):
for x in range(1, IMAGE_COLUMN + 1):
from_image = Image.fromarray(img[i])
to_image.paste(from_image, ((x - 1) * IMAGE_SIZE + PADDING * x, (y - 1) * IMAGE_SIZE + PADDING * y))
i = i + 1

to_image.save(img_path) #save


def parse_args():
parser = argparse.ArgumentParser("ESRGAN")
parser.add_argument("--data_url", type=str, default=None, help="Dataset path")
parser.add_argument("--train_url", type=str, default=None, help="Train output path")
parser.add_argument("--modelArts_mode", type=bool, default=True)
#
parser.add_argument('--device_target', type=str,
default="Ascend", help='Platform')
parser.add_argument('--device_id', type=int,
default=3, help='device_id')
parser.add_argument(
"--aug", type=bool, default=True, help="Use augement for dataset"
)
parser.add_argument('--data_dir', type=str,
default=None, help='Dataset path')
parser.add_argument("--batch_size", type=int, default=16, help="batch_size")
parser.add_argument("--epoch_size", type=int,
default=20, help="epoch_size")
parser.add_argument('--Giters', type=int, default=5, help='number of G iters per each D iter')
#
parser.add_argument("--rank", type=int, default=1,
help="local rank of distributed")
parser.add_argument(
"--group_size", type=int, default=0, help="world size of distributed"
)
#
parser.add_argument(
"--keep_checkpoint_max", type=int, default=30, help="max checkpoint for saving"
)
parser.add_argument(
"--model_save_step", type=int, default=3000, help="step num for saving"
)
parser.add_argument('--snapshots', type=int, default=3, help='Snapshots')
parser.add_argument('--experiment', default="./images", help='Where to store samples and models')
#
parser.add_argument("--run_distribute", type=ast.literal_eval,
default=False, help="Run distribute, default: false.")
args, _ = parser.parse_known_args()
return args


def train():
args_opt = parse_args()
config = ESRGAN_config
device_num = int(os.getenv("RANK_SIZE"))
device_id = int(os.getenv("DEVICE_ID"))
rank_id = int(os.getenv('RANK_ID'))
local_data_url = "/cache/data"
local_train_url = "/cache/lwESRGAN"
local_zipfolder_url = "/cache/tarzip"
local_pretrain_url = "/cache/pretrain"
local_image_url = "/cache/ESRGANimage"
obs_res_path = "obs://heu-535/pretrain"
pretrain_filename = "psnr-X_XXXXX.ckpt"
vgg_filename = ""
filename = "DIV2K.zip"
mox.file.make_dirs(local_train_url)
mox.file.make_dirs(local_image_url)
context.set_context(mode=context.GRAPH_MODE,save_graphs=False,device_target="Ascend")
# init multicards training
if args_opt.modelArts_mode:
device_num = int(os.getenv("RANK_SIZE"))
device_id = int(os.getenv("DEVICE_ID"))
rank_id = int(os.getenv('RANK_ID'))
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=device_num,parallel_mode=parallel_mode, gradients_mean=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
init()

local_data_url = os.path.join(local_data_url, str(device_id))
mox.file.make_dirs(local_data_url)
local_zip_path = os.path.join(local_zipfolder_url, str(device_id), filename)
print("device:%d, local_zip_path: %s" % (device_id, local_zip_path))
obs_zip_path = os.path.join(args_opt.data_url, filename)
mox.file.copy(obs_zip_path, local_zip_path)
print(
"====================== device %d copy end =================================\n"
% (device_id)
)
unzip_command = "unzip -o %s -d %s" % (local_zip_path, local_data_url)
os.system(unzip_command)
print(
"======================= device %d unzip end =================================\n"
% (device_id)
)
# transfer dataset
local_pretrain_url = os.path.join(local_zipfolder_url,pretrain_filename)
local_pretrain_url_vgg = os.path.join(local_zipfolder_url,vgg_filename)
obs_pretrain_url = os.path.join(obs_res_path,pretrain_filename)
mox.file.copy(obs_pretrain_url, local_pretrain_url)
dataset, dataset_len = get_dataset_DIV2K(base_dir=local_data_url, downsample_factor=config["down_factor"], mode="train", aug=args_opt.aug, repeat=1, batch_size=args_opt.batch_size,shard_id=args_opt.group_size,shard_num=args_opt.rank,num_readers=4)
generator = RRDBNet(
in_nc=config["ch_size"],
out_nc=config["ch_size"],
nf=config["G_nf"],
nb=config["G_nb"],
)
discriminator = VGGStyleDiscriminator128(num_in_ch=config["ch_size"], num_feat=config["D_nf"])
param_dict = load_checkpoint(local_pretrain_url)
load_param_into_net(generator, param_dict)
# Define network with loss
G_loss_cell = GeneratorLossCell(generator, discriminator,local_pretrain_url_vgg)
D_loss_cell = DiscriminatorLossCell(discriminator)
lr_G = nn.piecewise_constant_lr(
milestone=config["lr_steps"], learning_rates=config["lr_G"]
)
lr_D = nn.piecewise_constant_lr(
milestone=config["lr_steps"], learning_rates=config["lr_D"]
)
optimizerD = nn.Adam(discriminator.trainable_params(
), learning_rate=lr_D, beta1=0.5, beta2=0.999)
optimizerG = nn.Adam(generator.trainable_params(
), learning_rate=lr_G, beta1=0.5, beta2=0.999)

# Define One step train
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, optimizerG)
D_trainOneStep = TrainOneStepCellDis(D_loss_cell, optimizerD)

# Train
G_trainOneStep.set_train()
D_trainOneStep.set_train()

print('Start Training')

ckpt_config = CheckpointConfig(
save_checkpoint_steps=args_opt.model_save_step,keep_checkpoint_max=args_opt.keep_checkpoint_max)
ckpt_cb_g = ModelCheckpoint(
config=ckpt_config, directory=local_train_url, prefix='Generator')
ckpt_cb_d = ModelCheckpoint(
config=ckpt_config, directory=local_train_url, prefix='Discriminator')

cb_params_g = _InternalCallbackParam()
cb_params_g.train_network = generator
cb_params_g.cur_step_num = 0
cb_params_g.batch_num = args_opt.batch_size
cb_params_g.cur_epoch_num = 0
cb_params_d = _InternalCallbackParam()
cb_params_d.train_network = discriminator
cb_params_d.cur_step_num = 0
cb_params_d.batch_num = args_opt.batch_size
cb_params_d.cur_epoch_num = 0
run_context_g = RunContext(cb_params_g)
run_context_d = RunContext(cb_params_d)
if device_id==0:
ckpt_cb_g.begin(run_context_g)
ckpt_cb_d.begin(run_context_d)
start = time()
minibatch = args_opt.batch_size
ones = ops.Ones()
zeros = ops.Zeros()
real_labels = ones((minibatch, 1), mindspore.float32)
fake_labels = zeros((minibatch, 1), mindspore.float32)+Tensor(np.random.random(size=(minibatch,1)),dtype=mindspore.float32)*0.1
dis_iterations = 0
for epoch in range(args_opt.epoch_size):
data_iter = dataset.create_dict_iterator()
length = dataset_len
i = 0
while i < length:
############################
# (1) Update G network
###########################
for p in generator.trainable_params(): # reset requires_grad
p.requires_grad = True # they are set to False below in netG update

# train the discriminator Diters times
if dis_iterations < 25 or dis_iterations % 500 == 0:
Giters = 100
else:
Giters = args_opt.Giters
j = 0
while j < Giters and i < length:
j += 1

# clamp parameters to a cube
# for p in netD.trainable_params():
# p.data.clamp_(args_opt.clamp_lower, args_opt.clamp_upper)

data = data_iter.__next__()
i += 1

# train with real and fake
inputs = Tensor(data["inputs"],dtype=mindspore.float32)
target = Tensor(data["target"],dtype=mindspore.float32)
generator_loss_all = G_trainOneStep(inputs, target, fake_labels, real_labels)
fake_hr = generator_loss_all[0]
generator_loss = generator_loss_all[1]

############################
# (2) Update G network
###########################
for p in generator.trainable_params():
p.requires_grad = False # to avoid computation

discriminator_loss = D_trainOneStep(fake_hr,target)
dis_iterations += 1
if device_id==0:
print('[%d/%d][%d/%d][%d] Loss_D: %10f Loss_G: %10f'
% (epoch, args_opt.epoch_size, i, length, dis_iterations,
np.sum(discriminator_loss.asnumpy()), generator_loss.asnumpy()))
if dis_iterations % 10 == 0:
save_image(target, '{0}/real_samples.png'.format(local_image_url))
save_image(fake_hr, '{0}/fake_samples_{1}.png'.format(local_image_url, dis_iterations))
if device_id == 0:
mox.file.copy_parallel(local_train_url, args_opt.train_url)
if __name__ == "__main__":
train(config)

+ 169
- 0
train_psnr_ModelArts.py View File

@@ -0,0 +1,169 @@
import mindspore
from mindspore import nn
from src.dataset.dataset_DIV2K import get_dataset_DIV2K
from src.model.RRDB_Net import RRDBNet
from src.config import config
from mindspore.parallel import set_algo_parameters
from mindspore.train.model import Model
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore import context
import argparse
from mindspore.communication.management import init


class BuildTrainNetwork(nn.Cell):
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion

def construct(self, input_data, label):
output = self.network(input_data)
net_loss = self.criterion(output, label)
return net_loss


def parse_args():
parser = argparse.ArgumentParser("Generator Pretrain")
parser.add_argument("--data_url", type=str, default=None, help="Dataset path")
parser.add_argument("--train_url", type=str, default=None, help="Train output path")
parser.add_argument("--modelArts_mode", type=bool, default=True)
parser.add_argument(
"--device_id",
type=int,
default=0,
help="device id of GPU or Ascend. (Default: None)",
)
parser.add_argument(
"--aug", type=bool, default=True, help="Use augement for dataset"
)
parser.add_argument("--batch_size", type=int, default=4, help="batch_size")
parser.add_argument("--epoch_size", type=int, default=20, help="epoch_size")
parser.add_argument("--rank", type=int, default=0, help="local rank of distributed")
parser.add_argument(
"--group_size", type=int, default=1, help="world size of distributed"
)
parser.add_argument(
"--save_steps", type=int, default=1000, help="steps interval for saving"
)
parser.add_argument(
"--keep_checkpoint_max", type=int, default=20, help="max checkpoint for saving"
)
# 分布式

parser.add_argument("--distribute", type=bool, default=False, help="run distribute")
args, _ = parser.parse_known_args()
return args


def train(config):
args_opt = parse_args()
config_psnr = config.PSNR_config
# 这里开始 ModelArts部分
device_num = int(os.getenv("RANK_SIZE"))
device_id = int(os.getenv("DEVICE_ID"))
rank_id = int(os.getenv('RANK_ID'))
local_data_url = "/cache/data"
local_train_url = "/cache/lwESRGAN"
local_zipfolder_url = "/cache/tarzip"
local_pretrain_url = "/cache/pretrain"
obs_res_path = "obs://heu-535/pretrain"
pretrain_filename = "vgg19_ImageNet.ckpt"
filename = "DIV2K.zip"
mox.file.make_dirs(local_train_url)
context.set_context(mode=context.GRAPH_MODE,save_graphs=False,device_target="Ascend")
# init multicards training
if args.modelArts_mode:
device_num = int(os.getenv("RANK_SIZE"))
device_id = int(os.getenv("DEVICE_ID"))
rank_id = int(os.getenv('RANK_ID'))
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=device_num,parallel_mode=parallel_mode, gradients_mean=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
init()

local_data_url = os.path.join(local_data_url, str(device_id))
mox.file.make_dirs(local_data_url)
local_zip_path = os.path.join(local_zipfolder_url, str(device_id), filename)
print("device:%d, local_zip_path: %s" % (device_id, local_zip_path))
obs_zip_path = os.path.join(args_opt.data_url, filename)
mox.file.copy(obs_zip_path, local_zip_path)
print(
"====================== device %d copy end =================================\n"
% (device_id)
)
unzip_command = "unzip -o %s -d %s" % (local_zip_path, local_data_url)
os.system(unzip_command)
print(
"======================= device %d unzip end =================================\n"
% (device_id)
)
# transfer dataset
local_pretrain_url = os.path.join(local_zipfolder_url,pretrain_filename)
obs_pretrain_url = os.path.join(obs_res_path,pretrain_filename)
mox.file.copy(obs_pretrain_url, local_pretrain_url)

model_psnr = RRDBNet(
in_nc=config_psnr["ch_size"],
out_nc=config_psnr["ch_size"],
nf=config_psnr["G_nf"],
nb=config_psnr["G_nb"],
)
dataset,dataset_len = get_dataset_DIV2K(
base_dir=local_data_url,
downsample_factor=config_psnr["down_factor"],
mode="train",
aug=args_opt.aug,
repeat=1,
num_readers=4,
shard_id=args_opt.rank,
shard_num=args_opt.group_size,
batch_size=args_opt.batch_size,
)

lr = nn.piecewise_constant_lr(
milestone=config_psnr["lr_steps"], learning_rates=config_psnr["lr"]
)
opt = nn.Adam(
params=model_psnr.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.99
)
loss = nn.L1Loss()
loss.add_flags_recursive(fp32=True)
# loss scale
manager_loss_scale = FixedLossScaleManager(args_opt.loss_scale, drop_overflow_update=False)
amp_level = "O2"
train_net = BuildTrainNetwork(model_psnr, loss)
iters_per_check = dataset_len
model = Model(train_net, optimizer=opt)
# callback for saving ckpts
time_cb = TimeMonitor(data_size=)
loss_cb = LossMonitor()
cbs = [time_cb, loss_cb]

config_ck = CheckpointConfig(
save_checkpoint_steps=args_opt.save_steps,
keep_checkpoint_max=args_opt.keep_checkpoint_max,
)
ckpoint_cb = ModelCheckpoint(
prefix="psnr", directory=local_train_url, config=config_ck
)
if device_id ==0:
cbs.append(ckpoint_cb)

model.train(
args_opt.epoch_size, dataset, callbacks=cbs, dataset_sink_mode=False,
)
if device_id == 0:
mox.file.copy_parallel(local_train_url, args_opt.train_url)


if __name__ == "__main__":
train(config)

Loading…
Cancel
Save