|
- import mindspore.nn as nn
- import mindspore.ops as ops
- import mindspore
- from mindspore import save_checkpoint, Tensor
- from mindspore.ops import functional as F
- import numpy as np
- import time
- import os
- from mindspore import context
- from mindspore.communication.management import get_group_size
- from mindspore.context import ParallelMode
- from mindspore.ops import composite as C
- from mindspore.ops import operations as P
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
-
- class TrainOneStep(nn.TrainOneStepCell):
- def __init__(self, network, optimizer, sens=1.0, use_global_norm=True, clip_global_norm_value=5):
- super(TrainOneStep, self).__init__(network, optimizer, sens)
- self.network = network
- # self.norm = nn.Norm()
- # self.hyper_map = C.HyperMap()
- # self.get_square_sum = C.MultitypeFuncGraph("get_square_sum")
- # self._loss = loss_fn
- self.network.set_grad()
- self.weights = optimizer.parameters
- self.optimizer = optimizer
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.sens = float(sens)
- self.reducer_flag = False
- self.grad_reducer = None
- self.use_global_norm = use_global_norm
- self.clip_global_norm_value = clip_global_norm_value
- self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- self.reducer_flag = True
- if self.reducer_flag:
- mean = context.get_auto_parallel_context("gradients_mean")
- if auto_parallel_context().get_device_num_is_set():
- degree = context.get_auto_parallel_context("device_num")
- else:
- degree = get_group_size()
- self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
-
-
- def construct(self, padded_mixture, mixture_lengths, padded_source):
- loss = self.network(padded_mixture, mixture_lengths, padded_source)
- sens = P.Fill()(loss.dtype, loss.shape, self.sens)
- # print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
- grads = self.grad(self.network, self.weights)(padded_mixture, mixture_lengths, padded_source, sens)
- # print("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb")
- # grads = self.grad_reducer(grads)
-
- if self.reducer_flag:
- grads = self.grad_reducer(grads)
- # a = grads[0][0]
- # print("aaaaaaaaaaaa", a)
- # print("11111111111111", self.norm(a))
- # square_sum = self.hyper_map(self.get_square_sum, grads)
- # global_norm = F.sqrt(F.addn(square_sum))
- # print("---------------", global_norm)
- if self.use_global_norm:
- grads = C.clip_by_global_norm(grads, clip_norm=self.clip_global_norm_value)
- # b = grads[0][0]
- # print("22222222222222", self.norm(b))
- loss = ops.depend(loss, self.optimizer(grads))
- # total_loss /= j
- return loss
|