|
- # import torch
- # import config
- # from torchvision.utils import save_image
-
- import mindspore
- import config
- from mindspore.ops import functional as F
- import matplotlib.image as mpimg
- import matplotlib.pyplot as plt
-
-
- def save_some_examples(gen, val_loader, epoch, folder):
- x, y = next(iter(val_loader))
- # x, y = x.to(config.DEVICE), y.to(config.DEVICE)
- gen.eval()
- # with torch.no_grad():
- y_fake = gen(x)
- y_fake = F.stop_gradient(y_fake) #取消了torch.no_grad后加入stop_gradient()
- y_fake = y_fake * 0.5 + 0.5 # remove normalization#
-
- # save_image(y_fake, folder + f"/y_gen_{epoch}.png")
- # save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
- # if epoch == 1:
- # save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
- mpimg.imsave(folder + f"/y_gen_{epoch}.png",y_fake)
- mpimg.imsave(folder + f"/input_{epoch}.png",x * 0.5 + 0.5)
- if epoch==1:
- mpimg.imsave(folder + f"/label_{epoch}.png", y * 0.5 + 0.5)
- gen.train()
-
-
- def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
- print("=> Saving checkpoint")
- checkpoint = [
- {"state_dict": model.state_dict(),
- "optimizer": optimizer.state_dict(), }
- ]
- mindspore.save_checkpoint(checkpoint, filename)
-
-
- def load_checkpoint(checkpoint_file, model, optimizer, lr):
- print("=> Loading checkpoint")
- checkpoint = mindspore.load_checkpoint(checkpoint_file)
- model.load_param_into_net(checkpoint["state_dict"])
- optimizer.load_param_into_net(checkpoint["optimizer"])
-
- # If we don't do this then it will just have learning rate of old checkpoint
- # and it will lead to many hours of debugging \:
- for param_group in optimizer.param_groups:
- param_group["lr"] = lr
-
|