|
- import argparse
- import matplotlib.pyplot as plt
- import os
- import datetime
- from mindspore import context
- import mindspore.nn as nn
- from mindspore import Tensor
- import config
- from generator_model import Generator
- from discriminator_model import Discriminator
- from mindspore.train.callback import CheckpointConfig,ModelCheckpoint,_InternalCallbackParam
- from losses import WithLossCellGen,WithLossCellDis
- from pix2pix import Pix2Pix
- from dataset.pix2pix_dataset import MapDataset,create_dataset
-
- def save_losses(G_losses, D_losses, idx):
- plt.figure(figsize=(10, 5))
- plt.title("Generator and Discriminator Loss During Training")
- plt.plot(G_losses, label="G")
- plt.plot(D_losses, label="D")
- plt.xlabel("iterations")
- plt.ylabel("Loss")
- plt.legend()
- plt.savefig("./loss_show/{}.png".format(idx))
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='MindSpore dcgan training')
- parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'),
- help='device where the code will be implemented (default: Ascend)')
- # parser.add_argument('--device_id', type=int, default=4, help='device id of GPU or Ascend. (Default: 0)')
- args = parser.parse_args()
-
- context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)
- # context.set_context(device_id=args.device_id)
-
- # 数据
- dataset = MapDataset(root_dir="./maps/train/")
- ds=create_dataset(dataset)
- print("ds:", ds.get_dataset_size())
- print("ds:", ds.get_col_names())
- print("ds.shape:", ds.output_shapes())
- # print("ds.dtype:", ds)
- steps_per_epoch = ds.get_dataset_size()
-
- # 定义网络结构
- netD=Discriminator()
- netG=Generator()
-
- loss_fn1=nn.BCEWithLogitsLoss()
- loss_fn2=nn.L1Loss()
-
- netD_With_Loss=WithLossCellDis(netD,netG,loss_fn1)
- netG_With_Loss=WithLossCellGen(netD,netG,loss_fn1,loss_fn2)
-
- optimizerD = nn.Adam(netD.trainable_params(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
- optimizerG = nn.Adam(netG.trainable_params(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
-
- TrainOneStepDis=nn.TrainOneStepCell(netD_With_Loss,optimizerD,sens=1.0)
- TrainOneStepGen=nn.TrainOneStepCell(netG_With_Loss,optimizerG,sens=1.0)
-
- pix2pix_model=Pix2Pix(TrainOneStepDis,TrainOneStepGen)
- pix2pix_model.set_train()
-
- # save checkpoints
- ckpt_config=CheckpointConfig(save_checkpoint_steps=steps_per_epoch,keep_checkpoint_max=10)
- ckpt_cb=ModelCheckpoint(config=ckpt_config,directory='./ckpt',prefix='pix2pix_model')
-
- cb_params = _InternalCallbackParam()
- cb_params.train_network = pix2pix_model
- cb_params.batch_num = steps_per_epoch
- cb_params.epoch_num = config.NUM_EPOCHS #100
- # For each epoch
- cb_params.cur_epoch_num = 0
- cb_params.cur_step_num = 0
-
- # Training Loop
- G_losses = []
- D_losses = []
-
- data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=config.NUM_EPOCHS)
- print("Starting Training Loop...")
- for epoch in range(config.NUM_EPOCHS):
- # for each batch in the data_loader
- for i,data in enumerate(data_loader):
- input_image=Tensor(data["input_images"])
- target_image=Tensor(data["target_images"])
- netD_loss, netG_loss = pix2pix_model(input_image, target_image)
- if i % 20 == 0:
- print("================start===================")
- print("Date time: ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
- print("epoch: ", epoch, "/", config.NUM_EPOCHS)
- print("step: ", i, "/", steps_per_epoch)
- print("Dloss: ", netD_loss)
- print("Gloss: ", netG_loss)
- print("=================end====================")
- D_losses.append(netD_loss.asnumpy())
- G_losses.append(netG_loss.asnumpy())
- cb_params.cur_step_num = cb_params.cur_step_num + 1
-
- cb_params.cur_epoch_num = cb_params.cur_epoch_num + 1
- ckpt_cb._save_ckpt(cb_params, True)
- print("epoch-",epoch+1," saved")
- save_losses(G_losses,D_losses,epoch+1)
- print("epoch-",epoch+1," D&G_Losses saved")
-
-
-
-
-
-
-
|