|
- import mindspore
- import numpy as np
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore.ops import functional as F
-
- # 测试引入
- import config
- from discriminator_model import Discriminator
- from generator_model import Generator
-
-
- # Discriminator_With_Loss
- class WithLossCellDis(nn.Cell):
- def __init__(self,netD,netG,loss_fn1):
- super(WithLossCellDis, self).__init__(auto_prefix=True)
- self.netD=netD
- self.netG=netG
- self.loss_fn1=loss_fn1
- self.ones=ops.OnesLike()
- self.zeros=ops.ZerosLike()
-
- def construct(self, x, y):
- y_fake=self.netG(x)
- # print("y_fake:\n",y_fake)
- y_fake=F.stop_gradient(y_fake)
- D_real=self.netD(x,y)
- # print(" D_real:\n", D_real)
- D_real_loss=self.loss_fn1(D_real,self.ones(D_real))
- D_fake = self.netD(x, y_fake)
- D_fake_loss=self.loss_fn1(D_fake,self.zeros(D_fake))
- D_loss=(D_real_loss+D_fake_loss)*0.5
- # print(" D_loss:\n", D_loss)
- return D_loss
-
- @property
- def backbone_network(self):
- return self.netD
-
-
-
-
- # Generator_With_Loss
- class WithLossCellGen(nn.Cell):
- def __init__(self,netD,netG,loss_fn1,loss_fn2): #fn1为BCE_Loss,fn2为L1_Loss
- super(WithLossCellGen, self).__init__(auto_prefix=True)
- self.netD=netD
- self.netG=netG
- self.loss_fn1=loss_fn1
- self.loss_fn2=loss_fn2
- self.ones = ops.OnesLike()
-
- def construct(self, x, y):
- y_fake = self.netG(x)
- # print("y_fake:\n",y_fake)
- D_fake=self.netD(x,y_fake)
- # print(" D_real:\n", D_fake)
- G_fake_loss=self.loss_fn1(D_fake,self.ones(D_fake))
- L1_loss=self.loss_fn2(y_fake,y)*config.L1_LAMBDA
- G_loss=G_fake_loss+L1_loss
- # print(G_loss)
- return G_loss
-
- @property
- def backbone_network(self):
- return self.netG
-
-
- # def testlossG():
- # x=mindspore.Tensor(np.random.rand(1,3,256,256), mindspore.float32)
- # y=mindspore.Tensor(np.random.randn(1,3,256,256), mindspore.float32)
- # loss_model=WithLossCellGen(netD=Discriminator(),netG=Generator(),loss_fn1=nn.BCEWithLogitsLoss(),loss_fn2=nn.L1Loss())
- # # print(loss_model)
- # loss=loss_model(x,y)
- # print(loss)
- #
- #
- # if __name__ == "__main__":
- # testlossG()
-
-
-
-
-
-
-
-
-
-
-
-
-
- # class DiscriminatorLoss(nn.Cell):
- # def __init__(self,disc,gen):
- # super(DiscriminatorLoss, self).__init__()
- # self.disc=disc
- # self.gen=gen
- # self.BWL_Loss=nn.BCEWithLogitsLoss()
- #
- # def construct(self, x, y):
- # y_fake=self.gen(x)
- # D_real=self.disc(x,y)
- # D_real_loss=self.BWL_Loss(D_real,ops.OnesLike(D_real))
- # D_fake=self.disc(x,y_fake) #.detach()处理
- # D_fake_loss=self.BWL_Loss(D_fake,ops.ZerosLike(D_fake))
- # D_loss=(D_real_loss+D_fake_loss)*0.5
- # return D_loss
- #
- #
- # class GeneratorLoss(nn.Cell):
- # def __init__(self,disc,gen):
- # super(GeneratorLoss, self).__init__()
- # self.disc=disc
- # self.gen=gen
- # self.BWL_Loss=nn.BCEWithLogitsLoss()
- # self.L1_loss=nn.L1Loss()
- #
- # def construct(self, x, y):
- # y_fake = self.gen(x)
- # D_fake = self.disc(x, y_fake)
- # G_fake_loss=self.BWL_Loss(D_fake,ops.OnesLike(D_fake))
- # L1_loss=self.L1_loss(y_fake,y)*config.L1_LAMBDA
- # G_loss=G_fake_loss+L1_loss
- # return y_fake,G_loss
- #
- #
|