|
- # import mindspore.nn as nn
- # import mindspore.ops as ops
- # import mindspore
- # from mindspore import save_checkpoint
- # import time
- # import os
- # from mindspore import context
- # from mindspore.communication.management import get_group_size
- # from mindspore.context import ParallelMode
- # from mindspore.ops import composite as C
- # from mindspore.ops import operations as P
- # from mindspore.ops import functional as F
- # from mindspore.parallel._auto_parallel_context import auto_parallel_context
- # from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
- # _get_parallel_mode, _get_enable_parallel_optimizer)
-
-
- # class TrainOneStepCell(nn.Cell):
- # def __init__(self, network, optimizer, sens=1.0):
- # super(TrainOneStepCell, self).__init__(auto_prefix=False)
- # self.network = network
- # self.network.set_grad()
- # self.optimizer = optimizer
- # self.weights = self.optimizer.parameters
- # self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- # self.sens = sens
- # self.reducer_flag = False
- # self.grad_reducer = F.identity
- # self.parallel_mode = _get_parallel_mode()
- # self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
- # if self.reducer_flag:
- # self.mean = _get_gradients_mean()
- # self.degree = _get_device_num()
- # self.grad_reducer = nn.DistributedGradReducer(self.weights, self.mean, self.degree)
-
- # def construct(self, *inputs):
- # loss = self.network(*inputs)
- # sens = F.fill(loss.dtype, loss.shape, self.sens)
- # grads = self.grad(self.network, self.weights)(*inputs, sens)
- # grads = self.grad_reducer(grads)
- # # grads = ops.clip_by_global_norm(grads, clip_norm=2.0)
- # loss = F.depend(loss, self.optimizer(grads))
- # return loss
-
- # from mindspore import nn
- # from mindspore import ops
- # import time
- # from mindspore.ops import composite as C
- # from mindspore.ops import functional as F
- # from mindspore.ops import operations as P
-
- # GRADIENT_CLIP_TYPE = 0
- # GRADIENT_CLIP_VALUE = 5.0
-
- # clip_grad = C.MultitypeFuncGraph("clip_grad")
-
-
- # @clip_grad.register("Number", "Number", "Tensor")
- # def _clip_grad(clip_type, clip_value, grad):
- # """
- # Clip gradients.
- # Inputs:
- # clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
- # clip_value (float): Specifies how much to clip.
- # grad (tuple[Tensor]): Gradients.
-
- # Outputs:
- # tuple[Tensor], clipped gradients.
- # """
- # if clip_type not in (0, 1):
- # return grad
- # dt = F.dtype(grad)
- # if clip_type == 0:
- # new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
- # F.cast(F.tuple_to_array((clip_value,)), dt))
- # else:
- # new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
- # return new_grad
-
-
-
-
- # class TrainOnestepGen(nn.TrainOneStepCell):
- # """TrainOnestepGen
- # Encapsulation class of DBPN network training.
- # Append an optimizer to the training network after that the construct
- # function can be called to create the backward graph.
- # Args:
- # network(Cell): Generator with loss Cell. Note that loss function should have been added
- # optimizer(Cell):Optimizer for updating the weights.
- # sens (Number): The adjust parameter. Default: 1.0.
- # Outputs:
- # Tensor
- # """
-
- # def __init__(self, network, optimizer, sens=1.0, enable_clip_grad = True):
- # super(TrainOnestepGen, self).__init__(network, optimizer, sens)
- # self.cast = P.Cast()
- # self.hyper_map = C.HyperMap()
- # self.enable_clip_grad = enable_clip_grad
-
-
-
-
- # def construct(self, *inputs):
- # """Defines the computation performed."""
- # weights = self.weights
- # loss = self.network(*inputs)
- # sens = F.fill(loss.dtype, loss.shape, self.sens)
- # grads = self.grad(self.network, self.weights)(*inputs, sens)
- # grads = self.grad_reducer(grads)
- # if self.enable_clip_grad:
- # grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
-
- # return ops.depend(loss, self.optimizer(grads))
-
-
-
-
-
- # import mindspore.nn as nn
- # import mindspore.ops as ops
- # import mindspore
- # from mindspore import save_checkpoint
- # import time
- # import os
-
- # class TrainOneStep(nn.TrainOneStepCell):
- # def __init__(self, network, optimizer, sens=1.0):
- # super(TrainOneStep, self).__init__(network, optimizer, sens)
- # self.network = network
- # # self._loss = loss_fn
-
-
- # def construct(self, padded_mixture, mixture_lengths, padded_source):
- # # def construct(self, epoch, data_loader, cross_valid=False):
-
- # # data_loader = data['tr_loader'] if not cross_valid else self.cv_loader
- # # step = data_loader.get_dataset_size()
-
- # # print("1111111111111111111111111111111111111")
-
- # loss = self.network(padded_mixture, mixture_lengths, padded_source)
-
- # # mixture_lengths = mixture_lengths.astype(mindspore.int32)
- # # padded_mixture = padded_mixture.astype(mindspore.float32)
- # # padded_source = padded_source.astype(mindspore.float32)
- # # estimate_source = self.network(padded_mixture)
- # # estimate_source = estimate_source.astype(mindspore.float32)
- # # loss = 0
- # # cnt = len(estimate_source)
- # # for c_idx, est_src in enumerate(estimate_source):
- # # coeff = (c_idx+1)*(1.0/cnt)
- # # sisnr_loss, snr, est_src, _ = self._loss(padded_source, est_src, mixture_lengths)
- # # loss += (coeff * sisnr_loss)
- # # loss /= 6
-
-
- # sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
- # # print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
- # grads = self.grad(self.network, self.weights)(padded_mixture, mixture_lengths, padded_source, sens)
- # # print("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
- # grads = self.grad_reducer(grads)
- # # print("ccccccccccccccccccccccccccccccccccccccccc")
- # loss = ops.depend(loss, self.optimizer(grads))
- # # total_loss /= j
- # return loss
-
- import mindspore.nn as nn
- import mindspore.ops as ops
- import mindspore
- from mindspore import save_checkpoint
- import time
- import os
- from mindspore import context
- from mindspore.communication.management import get_group_size
- from mindspore.context import ParallelMode
- from mindspore.ops import composite as C
- from mindspore.ops import operations as P
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
-
- class TrainOneStep(nn.TrainOneStepCell):
- def __init__(self, network, optimizer, sens=1.0, use_global_norm=True, clip_global_norm_value=5.0):
- super(TrainOneStep, self).__init__(network, optimizer, sens)
- self.network = network
- # self._loss = loss_fn
- self.network.set_grad()
- self.weights = optimizer.parameters
- self.optimizer = optimizer
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.sens = float(sens)
- self.reducer_flag = False
- self.grad_reducer = None
- self.use_global_norm = use_global_norm
- self.clip_global_norm_value = clip_global_norm_value
- self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- self.reducer_flag = True
- if self.reducer_flag:
- mean = context.get_auto_parallel_context("gradients_mean")
- if auto_parallel_context().get_device_num_is_set():
- degree = context.get_auto_parallel_context("device_num")
- else:
- degree = get_group_size()
- self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
-
-
- def construct(self, padded_mixture, mixture_lengths, padded_source):
- loss = self.network(padded_mixture, mixture_lengths, padded_source)
- sens = P.Fill()(loss.dtype, loss.shape, self.sens)
- # print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
- grads = self.grad(self.network, self.weights)(padded_mixture, mixture_lengths, padded_source, sens)
- # print("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
- # grads = self.grad_reducer(grads)
- if self.reducer_flag:
- grads = self.grad_reducer(grads)
- if self.use_global_norm:
- grads = C.clip_by_global_norm(grads, clip_norm=self.clip_global_norm_value)
- # print("ccccccccccccccccccccccccccccccccccccccccc")
- loss = ops.depend(loss, self.optimizer(grads))
- # total_loss /= j
- return loss
|