|
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
-
- # Authors: Yossi Adi (adiyoss) and Alexandre Defossez (adefossez)
-
- import functools
- import logging
- from contextlib import contextmanager
- import inspect
- import os
- import time
- import math
- import torch
- import mindspore
- import mindspore.ops as ops
- import mindspore.numpy as np
- logger = logging.getLogger(__name__)
-
-
- def capture_init(init):
- """
- Decorate `__init__` with this, and you can then
- recover the *args and **kwargs passed to it in `self._init_args_kwargs`
- """
- @functools.wraps(init)
- def __init__(self, *args, **kwargs):
- self._init_args_kwargs = (args, kwargs)
- init(self, *args, **kwargs)
-
- return __init__
-
-
- def deserialize_model(package, strict=False):
- klass = package['class']
- if strict:
- model = klass(*package['args'], **package['kwargs'])
- else:
- sig = inspect.signature(klass)
- kw = package['kwargs']
- for key in list(kw):
- if key not in sig.parameters:
- logger.warning("Dropping inexistant parameter %s", key)
- del kw[key]
- model = klass(*package['args'], **kw)
- model.load_state_dict(package['state'])
- return model
-
-
- def copy_state(state):
- return {k: v.cpu().clone() for k, v in state.items()}
-
-
- def serialize_model(model):
- args, kwargs = model._init_args_kwargs
- state = copy_state(model.state_dict())
- return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state}
-
-
- @contextmanager
- def swap_state(model, state):
- old_state = copy_state(model.state_dict())
- model.load_state_dict(state)
- try:
- yield
- finally:
- model.load_state_dict(old_state)
-
-
- @contextmanager
- def swap_cwd(cwd):
- old_cwd = os.getcwd()
- os.chdir(cwd)
- try:
- yield
- finally:
- os.chdir(old_cwd)
-
-
- def pull_metric(history, name):
- out = []
- for metrics in history:
- if name in metrics:
- out.append(metrics[name])
- return out
-
-
- class LogProgress:
- """
- Sort of like tqdm but using log lines and not as real time.
- """
-
- def __init__(self, logger, iterable, updates=5, total=None,
- name="LogProgress", level=logging.INFO):
- self.iterable = iterable
- self.total = total or len(iterable)
- self.updates = updates
- self.name = name
- self.logger = logger
- self.level = level
-
- def update(self, **infos):
- self._infos = infos
-
- def __iter__(self):
- self._iterator = iter(self.iterable)
- self._index = -1
- self._infos = {}
- self._begin = time.time()
- return self
-
- def __next__(self):
- self._index += 1
- try:
- value = next(self._iterator)
- except StopIteration:
- raise
- else:
- return value
- finally:
- log_every = max(1, self.total // self.updates)
- # logging is delayed by 1 it, in order to have the metrics from update
- if self._index >= 1 and self._index % log_every == 0:
- self._log()
-
- def _log(self):
- self._speed = (1 + self._index) / (time.time() - self._begin)
- infos = " | ".join(f"{k.capitalize()} {v}" for k,
- v in self._infos.items())
- if self._speed < 1e-4:
- speed = "oo sec/it"
- elif self._speed < 0.1:
- speed = f"{1/self._speed:.1f} sec/it"
- else:
- speed = f"{self._speed:.1f} it/sec"
- out = f"{self.name} | {self._index}/{self.total} | {speed}"
- if infos:
- out += " | " + infos
- self.logger.log(self.level, out)
-
-
- def colorize(text, color):
- code = f"\033[{color}m"
- restore = f"\033[0m"
- return "".join([code, text, restore])
-
-
- def bold(text):
- return colorize(text, "1")
-
-
- def calculate_grad_norm(model):
- total_norm = 0.0
- is_first = True
- for p in model.parameters():
- param_norm = p.data.grad.flatten()
- if is_first:
- total_norm = param_norm
- is_first = False
- else:
- total_norm = torch.cat((total_norm.unsqueeze(
- 1), p.data.grad.flatten().unsqueeze(1)), dim=0).squeeze(1)
- return total_norm.norm(2) ** (1. / 2)
-
-
- def calculate_weight_norm(model):
- total_norm = 0.0
- is_first = True
- for p in model.parameters():
- param_norm = p.data.flatten()
- if is_first:
- total_norm = param_norm
- is_first = False
- else:
- total_norm = torch.cat((total_norm.unsqueeze(
- 1), p.data.flatten().unsqueeze(1)), dim=0).squeeze(1)
- return total_norm.norm(2) ** (1. / 2)
-
-
- def remove_pad(inputs, inputs_lengths):
- """
- Args:
- inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
- inputs_lengths: torch.Tensor, [B]
- Returns:
- results: a list containing B items, each item is [C, T], T varies
- """
- results = []
- dim = inputs.dim()
- if dim == 3:
- C = inputs.size(1)
- for input, length in zip(inputs, inputs_lengths):
- if dim == 3: # [B, C, T]
- results.append(input[:, :length].view(C, -1).cpu().numpy())
- elif dim == 2: # [B, T]
- results.append(input[:length].view(-1).cpu().numpy())
- return results
-
-
- def overlap_and_add(signal, frame_step):
- """Reconstructs a signal from a framed representation.
-
- Adds potentially overlapping frames of a signal with shape
- `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
- The resulting tensor has shape `[..., output_size]` where
-
- output_size = (frames - 1) * frame_step + frame_length
-
- Args:
- signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
- frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
-
- Returns:
- A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
- output_size = (frames - 1) * frame_step + frame_length
-
- Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
- """
- outer_dimensions = signal.shape[:-2]
- frames, frame_length = signal.shape[-2:]
-
- # gcd=Greatest Common Divisor
- subframe_length = math.gcd(frame_length, frame_step)
- subframe_step = frame_step // subframe_length
- subframes_per_frame = frame_length // subframe_length
- output_size = frame_step * (frames - 1) + frame_length
- output_subframes = output_size // subframe_length
-
- subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
- ################## unfold
- frame = mindspore.numpy.arange(0, output_subframes)
- frame = ops.Concat(-1)((
- ops.expand_dims(frame[0:-3:subframe_step], 1), ops.expand_dims(frame[1:-2:subframe_step], 1),
- ops.expand_dims(frame[2:-1:subframe_step], 1), ops.expand_dims(frame[3::subframe_step], 1)))
- # ans = mindspore.numpy.arange(0, subframes_per_frame)
- # for i in range(subframe_step, output_subframes-subframes_per_frame+1,subframe_step):
- # a = frame[i:i + subframes_per_frame]
- # if a.shape == (subframes_per_frame,):
- # ans = np.append(ans, a, axis=0)
- # frame = ans
-
- # reshape = ops.Reshape()
- # frame = reshape(ans, (-1, subframes_per_frame))
-
- # frame = signal.new_tensor(frame).clone().long() # signal may in GPU or CPU
- # frame = frame.view(-1)
-
- zeros = ops.Zeros()
- result = zeros((*outer_dimensions, output_subframes, subframe_length), mindspore.float32)
- #result.index_add_(-2, frame, subframe_signal)
- transpose = ops.Transpose()
- result = transpose(result, (2, 1, 0, 3))
-
- #frame = torch.unsqueeze(frame, 1).repeat(1, subframe_signal)
- indice = frame.asnumpy()
- li = []
- for val in indice:
- li.append(int(val))
- indices = tuple(li)
- print(indices,type(indices),len(indices))
- inplaceAdd = ops.InplaceAdd(indices)
- result = inplaceAdd(result, subframe_signal)
- result = transpose(result, (2, 1, 0, 3))
- result = result.view(*outer_dimensions, -1)
- return result
|