Browse Source

更新 'evaluate.py'

master
unicorn 1 month ago
parent
commit
2147eea263
1 changed files with 5 additions and 31 deletions
  1. +5
    -31
      evaluate.py

+ 5
- 31
evaluate.py View File

@@ -27,8 +27,6 @@ import mindspore.dataset as ds
import mindspore.ops as ops
from mindspore import context, nn
from mindspore import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.context import ParallelMode

parser = argparse.ArgumentParser('Evaluate separation performance using DPRNN')
parser.add_argument('--train_dir', type=str, default="/home/work/user-job-dir/inputs/data_json/test",
@@ -81,7 +79,7 @@ parser.add_argument('--num_workers', default=4, type=int, #default = 8
parser.add_argument('--optimizer', default='adam', type=str,
choices=['sgd', 'adam'],
help='Optimizer (support sgd and adam now)')
parser.add_argument('--lr', default=0.00025, type=float,
parser.add_argument('--lr', default=1e-3, type=float,
help='Init learning rate')
parser.add_argument('--momentum', default=0.0, type=float,
help='Momentum for optimizer')
@@ -115,14 +113,8 @@ parser.add_argument(
default="Ascend",
choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_num', type=int, default=2,
help='Sample rate of audio file')
parser.add_argument('--device_id', type=int, default=0,
help='Sample rate of audio file')
parser.add_argument('--run_distribute', type=bool, default=True,
help='Sample rate of audio file')

parser.add_argument('--ckpt_path', type=str, default="DPRNN-45_445.ckpt",
parser.add_argument('--ckpt_path', type=str, default="DPRNN-10_3560.ckpt",
help='Path to model file created by training')

parser.add_argument('--cal_sdr', type=int, default=1,
@@ -153,24 +145,7 @@ def preprocess(args):
sample_rate=args.sample_rate)
print("preprocess done")

def evaluate(args):
if args.run_distribute:
print("distribute")
device_id = int(os.getenv("DEVICE_ID"))
device_num = args.device_num
context.set_context(device_id=device_id)
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
rank_id = get_rank() # 获取当前设备在集群中的ID
rank_size = get_group_size() # 获取集群数量

else:
device_id = args.device_id
# device_id = int(os.getenv("DEVICE_ID"))
context.set_context(device_id=device_id)
def evaluate(args):
total_SISNRi = 0
total_SDRi = 0
total_cnt = 0
@@ -203,9 +178,8 @@ def evaluate(args):
# Load data
tt_dataset = DatasetGenerator(args.train_dir, args.batch_size,
sample_rate=args.sample_rate, segment=args.segment)
tt_loader = ds.GeneratorDataset(tt_dataset, ["mixture", "lens", "sources"], num_parallel_workers=args.threads,
shuffle=False, num_shards=rank_size, shard_id=rank_id)
tt_loader = tt_loader.batch(batch_size=4)
tt_loader = ds.GeneratorDataset(tt_dataset, ["mixture", "lens", "sources"], shuffle=False)
tt_loader = tt_loader.batch(batch_size=2)

for data in tt_loader.create_dict_iterator():
padded_mixture = data["mixture"]


Loading…
Cancel
Save