Browse Source

debug 0425

master
root 1 month ago
parent
commit
9837297e59
5 changed files with 32149 additions and 86080 deletions
  1. +1
    -1
      .gitignore
  2. +32116
    -86011
      analyze_fail.dat
  3. +4
    -4
      src/dataset/dataset_DIV2K.py
  4. +26
    -59
      train.py
  5. +2
    -5
      train_psnr_ModelArts.py

+ 1
- 1
.gitignore View File

@@ -112,6 +112,6 @@ somas_meta/
analyze_fail.dat
src/model/vgg19_ImageNet.ckpt
DIV2K.zip
src/model/psnr-1_31523.ckpt
psnr.ckpt
analyze_fail.dat
images/

+ 32116
- 86011
analyze_fail.dat
File diff suppressed because it is too large
View File


+ 4
- 4
src/dataset/dataset_DIV2K.py View File

@@ -80,8 +80,8 @@ class Dataset_DIV2K():
if self.data_augmentation:
inputs, target, _ = augment(inputs, target)
inputs = np.array(inputs)
target = np.array(target)
inputs = np.array(inputs).astype(np.float32)
target = np.array(target).astype(np.float32)
inputs = np.transpose(inputs / 255.0, (2, 0, 1))
target = np.transpose(target / 255.0, (2, 0, 1))
return inputs, target
@@ -118,8 +118,8 @@ class Dataset_Flickr():
if self.data_augmentation:
inputs, target, _ = augment(inputs, target)
inputs = np.array(inputs)
target = np.array(target)
inputs = np.array(inputs).astype(np.float32)
target = np.array(target).astype(np.float32)
inputs = np.transpose(inputs / 255.0, (2, 0, 1))
target = np.transpose(target / 255.0, (2, 0, 1))
return inputs, target


+ 26
- 59
train.py View File

@@ -33,7 +33,7 @@ def parse_args():
parser.add_argument('--device_target', type=str,
default="Ascend", help='Platform')
parser.add_argument('--device_id', type=int,
default=3, help='device_id')
default=7, help='device_id')
parser.add_argument(
"--aug", type=bool, default=True, help="Use augement for dataset"
)
@@ -44,20 +44,19 @@ def parse_args():
parser.add_argument("--batch_size", type=int, default=16, help="batch_size")
parser.add_argument("--epoch_size", type=int,
default=20, help="epoch_size")
parser.add_argument('--Giters', type=int, default=5, help='number of G iters per each D iter')
parser.add_argument('--Giters', type=int, default=2, help='number of G iters per each D iter')
parser.add_argument("--rank", type=int, default=1,
help="local rank of distributed")
parser.add_argument(
"--group_size", type=int, default=0, help="world size of distributed"
)
parser.add_argument(
"--keep_checkpoint_max", type=int, default=30, help="max checkpoint for saving"
"--keep_checkpoint_max", type=int, default=40, help="max checkpoint for saving"
)
parser.add_argument(
"--model_save_step", type=int, default=3000, help="step num for saving"
"--model_save_step", type=int, default=5000, help="step num for saving"
)
parser.add_argument('--snapshots', type=int, default=3, help='Snapshots')
parser.add_argument('--Gpretrained_path', type=str, default="src/model/psnr-1_31523.ckpt")
parser.add_argument('--Gpretrained_path', type=str, default="psnr.ckpt")
parser.add_argument('--experiment', default="./images", help='Where to store samples and models')
parser.add_argument("--run_distribute", type=ast.literal_eval,
default=False, help="Run distribute, default: false.")
@@ -96,7 +95,8 @@ def train():
rank = 0
device_num = 1
dataset, dataset_len = get_dataset_DIV2K(
base_dir="./data", downsample_factor=config["down_factor"], mode="train", aug=args_opt.aug, repeat=1, batch_size=args_opt.batch_size,shard_id=args_opt.group_size,shard_num=args_opt.rank,num_readers=4)
base_dir="./data/lw", downsample_factor=config["down_factor"], mode="train", aug=args_opt.aug, repeat=10, batch_size=args_opt.batch_size,shard_id=args_opt.group_size,shard_num=args_opt.rank,num_readers=4)
dataset_iter = dataset.create_dict_iterator()
generator = RRDBNet(
in_nc=config["ch_size"],
out_nc=config["ch_size"],
@@ -117,9 +117,9 @@ def train():
milestone=config["lr_steps"], learning_rates=config["lr_D"]
)
optimizerD = nn.Adam(discriminator.trainable_params(
), learning_rate=lr_D, beta1=0.5, beta2=0.999,loss_scale=args_opt.loss_scale)
), learning_rate=lr_D, beta1=0.5, beta2=0.999)
optimizerG = nn.Adam(generator.trainable_params(
), learning_rate=lr_G, beta1=0.5, beta2=0.999,loss_scale=args_opt.loss_scale)
), learning_rate=lr_G, beta1=0.5, beta2=0.999)

# Define One step train
G_trainOneStep = TrainOneStepCellGen(G_loss_cell, optimizerG)
@@ -157,56 +157,23 @@ def train():
ones = ops.Ones()
zeros = ops.Zeros()
real_labels = ones((minibatch, 1), mindspore.float32)
fake_labels = zeros((minibatch, 1), mindspore.float32)+Tensor(np.random.random(size=(minibatch,1)),dtype=mindspore.float32)*0.1
dis_iterations = 0
for epoch in range(args_opt.epoch_size):
data_iter = dataset.create_dict_iterator()
length = dataset_len
i = 0
while i < length:
############################
# (1) Update G network
###########################
for p in generator.trainable_params(): # reset requires_grad
p.requires_grad = True # they are set to False below in netG update

# train the discriminator Diters times
if dis_iterations < 25 or dis_iterations % 500 == 0:
Giters = 100
else:
Giters = args_opt.Giters
j = 0
while j < Giters and i < length:
j += 1

# clamp parameters to a cube
# for p in netD.trainable_params():
# p.data.clamp_(args_opt.clamp_lower, args_opt.clamp_upper)

data = data_iter.__next__()
i += 1

# train with real and fake
inputs = Tensor(data["inputs"],dtype=mindspore.float32)
target = Tensor(data["target"],dtype=mindspore.float32)
generator_loss_all = G_trainOneStep(inputs, target, fake_labels, real_labels)
fake_hr = generator_loss_all[0]
generator_loss = generator_loss_all[1]

############################
# (2) Update G network
###########################
for p in generator.trainable_params():
p.requires_grad = False # to avoid computation

discriminator_loss = D_trainOneStep(fake_hr,target)
dis_iterations += 1

print('[%d/%d][%d/%d][%d] Loss_D: %10f Loss_G: %10f'
% (epoch, args_opt.epoch_size, i, length//args_opt.batch_size, dis_iterations,
fake_labels = zeros((minibatch, 1), mindspore.float32)+Tensor(np.random.random(size=(minibatch,1)),dtype=mindspore.float32)*0.05
num_iters = config["niter"]
for iterator in range(num_iters):
data = next(dataset_iter)
inputs = data["inputs"]
real_hr = data["target"]
generator_loss_all = G_trainOneStep(inputs, real_hr, fake_labels, real_labels)
fake_hr = generator_loss_all[0]
generator_loss = generator_loss_all[1]
if (iterator + 1) % args_opt.Giters == 0:
discriminator_loss = D_trainOneStep(fake_hr,real_hr)
if (iterator + 1) % 100 == 0:
print('%d:[%d/%d]Loss_D: %10f Loss_G: %10f'
% (iterator//dataset_len,iterator,num_iters,
np.sum(discriminator_loss.asnumpy()), generator_loss.asnumpy()))
if dis_iterations % 5 == 0:
save_img(target[0], 'real_samples_{0}.png'.format(dis_iterations),args_opt.experiment)
save_img(fake_hr[0], 'fake_samples_{0}.png'.format(dis_iterations),args_opt.experiment)
save_img(real_hr[0], 'real_samples_{0}.png'.format(iterator + 1),args_opt.experiment)
save_img(fake_hr[0], 'fake_samples_{0}.png'.format(iterator + 1),args_opt.experiment)
if __name__ == "__main__":
train()

+ 2
- 5
train_psnr_ModelArts.py View File

@@ -132,19 +132,16 @@ def train(config):
milestone=config_psnr["lr_steps"], learning_rates=config_psnr["lr"]
)
opt = nn.Adam(
params=model_psnr.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.99,loss_scale=args_opt.loss_scale
params=model_psnr.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.99
)
loss = nn.L1Loss()
loss.add_flags_recursive(fp32=True)
# loss scale
manager_loss_scale = FixedLossScaleManager(args_opt.loss_scale, drop_overflow_update=False)
amp_level = "O2"
train_net = BuildTrainNetwork(model_psnr, loss)
iters_per_check = dataset_len
model = Model(train_net, optimizer=opt)
# callback for saving ckpts
time_cb = TimeMonitor(data_size=)
time_cb = TimeMonitor()
loss_cb = LossMonitor()
cbs = [time_cb, loss_cb]



Loading…
Cancel
Save