|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import numpy as np
- from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
- _get_parallel_mode, _get_enable_parallel_optimizer)
- from mindspore.context import ParallelMode
- from mindspore._checkparam import Validator as validator
- from mindspore import ops, nn
- import os
- from mindspore.common import dtype as mstype
- from mindspore.common.parameter import Parameter,ParameterTuple
- from mindspore.ops import composite as C
- from mindspore.ops import functional as F
- from mindspore.ops import operations as P
- from mindspore.nn import DistributedGradReducer
- import mindspore
- from mindspore import nn
- from typing import Iterable, Tuple, List
- from model import positionalencoding2d
-
- np.random.seed(0)
- _GCONST_ = -0.9189385332046727 # ln(sqrt(2*pi))
-
- theta = mindspore.nn.Sigmoid()
- log_theta = mindspore.nn.LogSigmoid()
-
- class Score_Observer:
- def __init__(self, name):
- self.name = name
- self.max_epoch = 0
- self.max_score = 0.0
- self.last = 0.0
-
- def update(self, score, epoch, print_score=True):
- self.last = score
- save_weights = False
- if epoch == 0 or score > self.max_score:
- self.max_score = score
- self.max_epoch = epoch
- save_weights = True
- if print_score:
- self.print_score()
-
- return save_weights
-
- def print_score(self):
- print('(Device ) {:s}: \t last: {:.2f} \t max: {:.2f} \t epoch_max: {:d}'.format(
- self.name, self.last, self.max_score, self.max_epoch))
-
- def t2np(tensor):
- '''pytorch tensor -> numpy array'''
- return tensor.asnumpy() if tensor is not None else None
-
-
- def get_logp(C, z, logdet_J):
- logp = C * _GCONST_ - 0.5 * mindspore.numpy.sum(z ** 2, 1) + logdet_J
- return logp
-
-
- def rescale(x):
- return (x - x.min()) / (x.max() - x.min())
-
-
- class DecoderEval(nn.Cell):
- def __init__(self, net):
- super(DecoderEval, self).__init__()
- self.net = net
- def construct(self, x,c):
- z, log_jac_det = self.net(x, c)
- return z, log_jac_det
-
-
- class DecoderTrain(nn.Cell):
- def __init__(self, net,C):
- super(DecoderTrain, self).__init__()
- self.log_theta = nn.LogSigmoid()
- self.net = net
- self.C = C
- def get_logp(self, C, z, logdet_J):
- logp = C * _GCONST_ - 0.5 * mindspore.numpy.sum(z ** 2, 1) + logdet_J
- return logp
- def construct(self, x,c):
- z, log_jac_det = self.net(x, c)
- decoder_log_prob = self.get_logp(self.C, z, log_jac_det)
- log_prob = decoder_log_prob / self.C
- loss = -self.log_theta(log_prob)
- loss = loss.mean()
- return loss
-
- class DecoderOneStep(nn.Cell):
- def __init__(self, network, optimizer, Cls,cflow=True,sens=1.0):
- super(DecoderOneStep, self).__init__()
- self.network = network
- self.network.set_grad()
- self.Cls = Cls
- self.cflow = cflow
- self.train_net = DecoderTrain(self.network,self.Cls)
- self.optimizer = optimizer
- self.weights = self.optimizer.parameters
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.sens = sens
- self.reducer_flag = False
- self.grad_reducer = F.identity
- self.parallel_mode = _get_parallel_mode()
- self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
- if self.reducer_flag:
- self.mean = _get_gradients_mean()
- self.degree = _get_device_num()
- self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
-
- def construct(self, *inputs):
- loss= self.train_net(*inputs)
- sens = F.fill(loss.dtype, loss.shape, self.sens)
- grads = self.grad(self.train_net, self.weights)(*inputs, sens)
- grads = self.grad_reducer(grads)
- self.optimizer(grads)
- return loss
-
-
- class PositionalEncoding2D(mindspore.nn.Cell):
- def __init__(self,P,N):
- super(PositionalEncoding2D, self).__init__()
- self.P=P
- self.N=N
-
-
- def construct(self, e):
- """ construct """
- B, C, H, W = e.shape
- S = H * W
- E = B * S
- c_r=0
- e_r = e.reshape(B, C, S).transpose(0, 2, 1).reshape(E, C) # BHWxC
- perm = mindspore.ops.Randperm(E)(mindspore.Tensor([E],dtype=mindspore.int32))
- FIB = E // self.N # number of fiber batches
- return FIB, c_r, e_r, perm
-
- class DecoderEvalOneStep(mindspore.nn.Cell):
- def __init__(self, net:mindspore.nn.CellList,
- C=3,
- cflow=True):
- super(DecoderEvalOneStep, self).__init__()
- self.net=net
- self.net.set_train(False)
- self.net.set_grad(False)
- self.C=C
- self.cflow=cflow
-
- def construct(self, x ,c):
- """ construct """
- if self.cflow:
- z, log_jac_det = self.net(x,c)
- else:
- z, log_jac_det = self.net(x)
- decoder_log_prob = get_logp(self.C, z, log_jac_det)
- log_prob = decoder_log_prob / self.C
- return log_prob
|