|
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import context
- from mindspore.context import ParallelMode
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
- from mindspore.communication.management import get_group_size
- import mindspore.ops as ops
- from discriminator_model import Discriminator
- from generator_model import Generator
-
- class Pix2Pix(nn.Cell):
- def __init__(self,TrainOneStepCellDis,TrainOneStepCellGen):
- super(Pix2Pix, self).__init__(auto_prefix=True)
- self.TrainOneStepCellDis=TrainOneStepCellDis
- self.TrainOneStepCellGen=TrainOneStepCellGen
-
- def construct(self, x, y):
- output_D=self.TrainOneStepCellDis(x,y).view(-1)
- netD_loss=output_D.mean()
- output_G=self.TrainOneStepCellGen(x,y).view(-1)
- netG_loss=output_G.mean()
- return netD_loss,netG_loss
-
-
-
-
-
-
-
-
-
-
-
- # class DisWithLossCell(nn.Cell):
- # def __int__(self, dis_loss_fn):
- # super(DisWithLossCell, self).__int__(auto_prefix=False)
- # self.dis_loss_fn = dis_loss_fn
- #
- # def construct(self, x, y):
- # dis_loss = self.dis_loss_fn(x, y)
- #
- # return dis_loss
-
- #
- # class WithLossCell(nn.Cell):
- # def __int__(self, network): #network即为GeneratorLoss
- # super(WithLossCell, self).__int__(auto_prefix=False)
- # self.network = network
- #
- # def construct(self, x, y):
- # fake_image,gen_loss = self.network(x, y)
- #
- # return gen_loss
- #
- #
- # class TrainOneStepDis(nn.Cell):
- # """
- # Encapsulation class of Cycle GAN discriminator network training.
- #
- # Append an optimizer to the training network after that the construct
- # function can be called to create the backward graph.
- #
- # Args:
- # D (Cell): Discriminator with loss Cell. Note that loss function should have been added.
- # optimizer (Optimizer): Optimizer for updating the weights.
- # sens (Number): The adjust parameter. Default: 1.0.
- # """
- # def __init__(self,D,optimizer,sens=1.0): #D为DiscriminatorLoss
- # super(TrainOneStepDis, self).__init__(auto_prefix=False)
- # self.optimizer=optimizer
- # self.D=D
- # self.D.set_grad()
- # self.D.set_train()
- # self.grad=ops.GradOperation(get_by_list=True,sens_param=True)
- # self.sens=sens
- # self.weights=ms.ParameterTuple(D.trainaable_params())
- # self.reducer_flag=False
- # self.grad_reducer=None
- # self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- # if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- # self.reducer_flag = True
- # if self.reducer_flag:
- # mean = context.get_auto_parallel_context("gradients_mean")
- # if auto_parallel_context().get_device_num_is_set():
- # degree = context.get_auto_parallel_context("device_num")
- # else:
- # degree = get_group_size()
- # self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
- #
- # def construct(self, x, y):
- # weights = self.weights
- # ld = self.D(x,y) #discriminator_loss
- # sens_d = ops.Fill()(ops.DType()(ld), ops.Shape()(ld), self.sens)
- # grads_d = self.grad(self.D, weights)(x,y,sens_d)
- # if self.reducer_flag:
- # # apply grad reducer on grads
- # grads_d = self.grad_reducer(grads_d)
- # return ops.depend(ld, self.optimizer(grads_d))
- #
- #
- # class TrainOneStepGen(nn.Cell):
- # """
- # Encapsulation class of Cycle GAN generator network training.
- #
- # Append an optimizer to the training network after that the construct
- # function can be called to create the backward graph.
- #
- # Args:
- # G (Cell): Generator with loss Cell. Note that loss function should have been added.
- # generator (Cell): Generator of CycleGAN.
- # optimizer (Optimizer): Optimizer for updating the weights.
- # sens (Number): The adjust parameter. Default: 1.0.
- # """
- # def __init__(self,G,gen,optimizer,sens=1.0): #G为GeneratorLoss,gen为generator
- # super(TrainOneStepGen, self).__init__(auto_prefix=False)
- # self.optimizer = optimizer
- # self.G = G
- # self.G.set_grad()
- # self.G.set_train()
- # self.grad=ops.GradOperation(get_by_list=True,sens_param=True)
- # self.sens=sens
- # self.weights = ms.ParameterTuple(gen.trainable_params())
- # self.net = WithLossCell(G) #
- # self.reducer_flag = False
- # self.grad_reducer = None
- # self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- # if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- # self.reducer_flag = True
- # if self.reducer_flag:
- # mean = context.get_auto_parallel_context("gradients_mean")
- # if auto_parallel_context().get_device_num_is_set():
- # degree = context.get_auto_parallel_context("device_num")
- # else:
- # degree = get_group_size()
- # self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
- #
- # def construct(self, x, y):
- # weights = self.weights
- # fake_image,lg= self.G(x, y) #G为Generatorloss
- # sens = ops.Fill()(ops.DType()(lg), ops.Shape()(lg), self.sens)
- # grads_g = self.grad(self.net, weights)(x, y, sens)
- # if self.reducer_flag:
- # # apply grad reducer on grads
- # grads_g = self.grad_reducer(grads_g)
- #
- # return fake_image,ops.depend(lg, self.optimizer(grads_g))
|