Browse Source

expore loss scale

master
root 1 month ago
parent
commit
3880178272
6 changed files with 86 additions and 77 deletions
  1. +2
    -3
      .gitignore
  2. +23
    -14
      eval.py
  3. +36
    -0
      export.py
  4. +4
    -4
      train.py
  5. +20
    -55
      train_ModelArts.py
  6. +1
    -1
      train_psnr_ModelArts.py

+ 2
- 3
.gitignore View File

@@ -110,8 +110,7 @@ checkpoints/
module_test.py
somas_meta/
analyze_fail.dat
src/model/vgg19_ImageNet.ckpt
DIV2K.zip
psnr.ckpt
analyze_fail.dat
*.ckpt
*.dat
images/

+ 23
- 14
eval.py View File

@@ -6,6 +6,7 @@ import datetime
import glob
import numpy as np
import cv2
from collections import OrderedDict
import mindspore.nn as nn
from mindspore.nn import PSNR,SSIM
from mindspore import Tensor, context
@@ -15,8 +16,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from src.config.config import ESRGAN_config,PSNR_config
from src.utils.eval_util import imresize_np, rgb2ycbcr, calculate_psnr, calculate_ssim

from src.model.RRDB_Net import RRDBNet

class BuildEvalNetwork(nn.Cell):
def __init__(self, network):
@@ -36,14 +36,12 @@ def parse_args(cloud_args=None):
# dataset related
parser.add_argument('--data_path', type=str,
default='', help='eval data dir')
parser.add_argument('--batch_size', default=1,
parser.add_argument('--ganckpt_path', type=str,
default='', help='gan ckpt file')
parser.add_argument('--psnrckpt_path', type=str,
default='', help='psnr ckpt file')
parser.add_argument('--batch_size', default=16,
type=int, help='batch size for per npu')
# network related
parser.add_argument('--graph_ckpt', type=int, default=1,
help='graph ckpt or feed ckpt')
parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. '
'If it is a direction, it will test all ckpt')

# logging related
parser.add_argument('--log_path', type=str,
default='outputs/', help='path to save log')
@@ -51,7 +49,8 @@ def parse_args(cloud_args=None):
help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1,
help='world size of distributed')

parser.add_argument('--alpha', type=float, default=0.4,
help='weight factor of psnr model in eval')
args_opt = parser.parse_args()
return args_opt

@@ -60,7 +59,7 @@ set_seed(1)

def test():
args_opt = parse_args()
config = PSNR_config
config_psnr = PSNR_config
print(f"test args: {args_opt}\ncfg: {config}")
context.set_context(
mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=1
@@ -72,6 +71,8 @@ def test():
nf=config_psnr["G_nf"],
nb=config_psnr["G_nb"],
)

# 需要对每个参数进行单独计算
dataset,dataset_len = get_dataset_DIV2K(
base_dir="./data",
downsample_factor=config_psnr["down_factor"],
@@ -86,8 +87,16 @@ def test():

eval_net = BuildEvalNetwork(model_psnr)

# load model
param_dict = load_checkpoint(args.ckpt_path)
# load model and Interpolating
param_dict_gan = load_checkpoint(args_opt.ganckpt_path)
param_dict_psnr = load_checkpoint(args_opt.psnrckpt_path)
param_dict = OrderedDict()
alpha = args_opt.alpha
print('Interpolating with alpha = ', alpha)

for name,cell_PSNR in net_PSNR.cells_and_names():
cell_ESRGAN = param_dict_gan[name]
net_interp[name] = (1 - alpha) * cell_PSNR + alpha * cell_ESRGAN
load_param_into_net(eval_net, param_dict)
eval_net.set_train(False)
ssim = nn.SSIM()
@@ -113,7 +122,7 @@ def test():
ssim_real_all = ssim_real
print(psnr_bic,psnr_real,ssim_bic,ssim_real)
result_img_path = os.path.join(args_opt.results_path + "DIV2K", 'Bic_SR_HR_' + str(i))
if i%10 == 0:
if i % 50 == 0:
results_img = np.concatenate((bic_img[0].asnumpy(), sr_img[0].asnumpy(), hr_img[0].asnumpy()), 1)
cv2.imwrite(result_img_path, results_img)
psnr_bic_all += psnr_bic_all/dataset_len


+ 36
- 0
export.py View File

@@ -0,0 +1,36 @@
import numpy as np

from mindspore import context, Tensor
from mindspore.train.serialization import export, load_param_into_net
from src.config.config import ESRGAN_config,PSNR_config
from src.utils import get_network, resume_model


if __name__ == '__main__':

config_psnr = PSNR_config
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)

model_psnr = RRDBNet(
in_nc=config_psnr["ch_size"],
out_nc=config_psnr["ch_size"],
nf=config_psnr["G_nf"],
nb=config_psnr["G_nb"],
)

model_psnr.set_train(True)
param_dict_gan = load_checkpoint(args_opt.ganckpt_path)
param_dict_psnr = load_checkpoint(args_opt.psnrckpt_path)
param_dict = OrderedDict()
alpha = args_opt.alpha
print('Interpolating with alpha = ', alpha)

for name,cell_PSNR in net_PSNR.cells_and_names():
cell_ESRGAN = param_dict_gan[name]
net_interp[name] = (1 - alpha) * cell_PSNR + alpha * cell_ESRGAN
load_param_into_net(model_psnr, param_dict)

input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 32, 32)).astype(np.float32))
input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128,128)).astype(np.float32))
G_file = f"ESRGAN_Generator"
export(G, input_array, file_name=G_file + '-300_11.air', file_format='AIR')

+ 4
- 4
train.py View File

@@ -117,9 +117,9 @@ def train():
milestone=config["lr_steps"], learning_rates=config["lr_D"]
)
optimizerD = nn.Adam(discriminator.trainable_params(
), learning_rate=lr_D, beta1=0.5, beta2=0.999)
), learning_rate=lr_D, beta1=0.5, beta2=0.999,loss_scale=args_opt.loss_scale)
optimizerG = nn.Adam(generator.trainable_params(
), learning_rate=lr_G, beta1=0.5, beta2=0.999)
), learning_rate=lr_G, beta1=0.5, beta2=0.999,loss_scale=args_opt.loss_scale)

# Define One step train
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, optimizerG)
@@ -168,9 +168,9 @@ def train():
generator_loss = generator_loss_all[1]
if (iterator + 1) % args_opt.Giters == 0:
discriminator_loss = D_trainOneStep(fake_hr,real_hr)
if (iterator + 1) % 100 == 0:
if (iterator + 1) % 500 == 0:
print('%d:[%d/%d]Loss_D: %10f Loss_G: %10f'
% (iterator//dataset_len,iterator,num_iters,
% ((iterator+1)//dataset_len,iterator,num_iters,
np.sum(discriminator_loss.asnumpy()), generator_loss.asnumpy()))
save_img(real_hr[0], 'real_samples_{0}.png'.format(iterator + 1),args_opt.experiment)
save_img(fake_hr[0], 'fake_samples_{0}.png'.format(iterator + 1),args_opt.experiment)


+ 20
- 55
train_ModelArts.py View File

@@ -89,7 +89,6 @@ def parse_args():
parser.add_argument(
"--model_save_step", type=int, default=3000, help="step num for saving"
)
parser.add_argument('--snapshots', type=int, default=3, help='Snapshots')
parser.add_argument('--experiment', default="./images", help='Where to store samples and models')
#
parser.add_argument("--run_distribute", type=ast.literal_eval,
@@ -203,66 +202,32 @@ def train():
cb_params_d.cur_epoch_num = 0
run_context_g = RunContext(cb_params_g)
run_context_d = RunContext(cb_params_d)
if device_id==0:
ckpt_cb_g.begin(run_context_g)
ckpt_cb_d.begin(run_context_d)
ckpt_cb_g.begin(run_context_g)
ckpt_cb_d.begin(run_context_d)
start = time()
minibatch = args_opt.batch_size
ones = ops.Ones()
zeros = ops.Zeros()
real_labels = ones((minibatch, 1), mindspore.float32)
fake_labels = zeros((minibatch, 1), mindspore.float32)+Tensor(np.random.random(size=(minibatch,1)),dtype=mindspore.float32)*0.1
dis_iterations = 0
for epoch in range(args_opt.epoch_size):
data_iter = dataset.create_dict_iterator()
length = dataset_len
i = 0
while i < length:
############################
# (1) Update G network
###########################
for p in generator.trainable_params(): # reset requires_grad
p.requires_grad = True # they are set to False below in netG update

# train the discriminator Diters times
if dis_iterations < 25 or dis_iterations % 500 == 0:
Giters = 100
else:
Giters = args_opt.Giters
j = 0
while j < Giters and i < length:
j += 1

# clamp parameters to a cube
# for p in netD.trainable_params():
# p.data.clamp_(args_opt.clamp_lower, args_opt.clamp_upper)

data = data_iter.__next__()
i += 1

# train with real and fake
inputs = Tensor(data["inputs"],dtype=mindspore.float32)
target = Tensor(data["target"],dtype=mindspore.float32)
generator_loss_all = G_trainOneStep(inputs, target, fake_labels, real_labels)
fake_hr = generator_loss_all[0]
generator_loss = generator_loss_all[1]

############################
# (2) Update G network
###########################
for p in generator.trainable_params():
p.requires_grad = False # to avoid computation

discriminator_loss = D_trainOneStep(fake_hr,target)
dis_iterations += 1
if device_id==0:
print('[%d/%d][%d/%d][%d] Loss_D: %10f Loss_G: %10f'
% (epoch, args_opt.epoch_size, i, length, dis_iterations,
np.sum(discriminator_loss.asnumpy()), generator_loss.asnumpy()))
if dis_iterations % 10 == 0:
save_image(target, '{0}/real_samples.png'.format(local_image_url))
save_image(fake_hr, '{0}/fake_samples_{1}.png'.format(local_image_url, dis_iterations))
fake_labels = zeros((minibatch, 1), mindspore.float32)+Tensor(np.random.random(size=(minibatch,1)),dtype=mindspore.float32)*0.05
num_iters = config["niter"]
for iterator in range(num_iters):
data = next(dataset_iter)
inputs = data["inputs"]
real_hr = data["target"]
generator_loss_all = G_trainOneStep(inputs, real_hr, fake_labels, real_labels)
fake_hr = generator_loss_all[0]
generator_loss = generator_loss_all[1]
if (iterator + 1) % args_opt.Giters == 0:
discriminator_loss = D_trainOneStep(fake_hr,real_hr)
if (iterator + 1) % 5000 == 0:
print('%d:[%d/%d]Loss_D: %10f Loss_G: %10f'
% ((iterator+1)//dataset_len,iterator,num_iters,
np.sum(discriminator_loss.asnumpy()), generator_loss.asnumpy()))
save_img(real_hr[0], 'real_samples_{0}.png'.format(iterator + 1),local_image_url)
save_img(fake_hr[0], 'fake_samples_{0}.png'.format(iterator + 1),local_image_url)
if device_id == 0:
mox.file.copy_parallel(local_train_url, args_opt.train_url)
mox.file.copy_parallel(local_image_url, args_opt.train_url)
if __name__ == "__main__":
train(config)

+ 1
- 1
train_psnr_ModelArts.py View File

@@ -132,7 +132,7 @@ def train(config):
milestone=config_psnr["lr_steps"], learning_rates=config_psnr["lr"]
)
opt = nn.Adam(
params=model_psnr.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.99
params=model_psnr.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.99,loss_scale=args_opt.loss_scale
)
loss = nn.L1Loss()
loss.add_flags_recursive(fp32=True)


Loading…
Cancel
Save