|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """ TasNet training network wrapper. """
-
- import mindspore
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore import Tensor
- import numpy as np
-
- class WithLossCell(nn.Cell):
- """
- Wrap the network with loss function to compute loss.
-
- Args:
- net (Cell): The target network to wrap.
- loss_fn (Cell): The loss function used to compute loss.
- """
- def __init__(self, net, loss_fn):
- super(WithLossCell, self).__init__(auto_prefix=False)
- self._net = net
- self._loss = loss_fn
- self.cast = ops.Cast()
- self.ones = ops.Ones()
- self.zeros = ops.Zeros()
-
- def construct(self, padded_mixture, mixture_lengths, padded_source):
- # print("---------mxiture----------") [32000 32000 32000 32000] (4,)
- # print(mixture_lengths,mixture_lengths.shape)
- # Print = ops.Print()
- # Print('*****************************************')
- # Print(padded_mixture.shape)
- # Print('*' * 100)
- # print(-7)
- # mixture_lengths = mixture_lengths.astype(mindspore.int32)
- # padded_mixture = self.cast(padded_mixture, mindspore.float32)#[2,32000]
- padded_mixture = padded_mixture.astype(mindspore.float32)
- # padded_source = self.cast(padded_source, mindspore.float32)#[2,2,32000]
- padded_source = padded_source.astype(mindspore.float32)
- # print(-8)
- # print("--------------padded_source-----------")
- # print(padded_source, padded_source.shape) #[4,2,32000]
- estimate_source = self._net(padded_mixture)#[6,B,C,32000]
- # if cross_valid:
- # estimate_source = estimate_source[-1:]
- # estimate_source = self.cast(estimate_source, mindspore.float32)
- estimate_source = estimate_source.astype(mindspore.float32)
- # estimate_source = estimate_source[-1:]#[1,2,2,32000]
-
- # for j in range(estimate_source.size(0)):
- # for j in range(6):
- # c_idx = j
- # est_src = estimate_source[j]
- # coeff = ((c_idx + 1) * (1 / cnt))
- # # loss_i = 0
- # sisnr_loss, snr, est_src, reorder_est_src = self._loss(padded_source, est_src, mixture_lengths)
- # loss += (coeff * sisnr_loss)
-
- # loss = self.zeros((1), mindspore.float32)
- loss = 0
- cnt = len(estimate_source)
- for c_idx, est_src in enumerate(estimate_source):
- # print(-1)
- # est_src = self.cast(est_src, mindspore.float32)
- # est_src = est_src.astype(mindspore.float32)
- coeff = (c_idx+1)*(1.0/cnt)
- # linshi = self.ones((1), mindspore.float32)
- # coeff = coeff * linshi
- sisnr_loss, snr, est_src, _ = self._loss(padded_source, est_src, mixture_lengths)
- # sisnr_loss, snr, perms, max_snr_idx = self._loss(padded_source, est_src, mixture_lengths)
- # sisnr_loss = self.ones((1), mindspore.float32)
- # sisnr_loss = sisnr_loss.astype(mindspore.float32)
- loss += (coeff * sisnr_loss)
- loss /= len(estimate_source)
- # print("-------------", loss)
- return loss
-
|