Browse Source

更新 'train_cload.py'

master
lwj 1 month ago
parent
commit
26125f79ca
1 changed files with 1 additions and 2 deletions
  1. +1
    -2
      train_cload.py

+ 1
- 2
train_cload.py View File

@@ -265,8 +265,7 @@ def main(args):
print("开始prepro-------------")
preprocess(args)

tr_dataset = DatasetGenerator(args.train, args.data_batch_size,
sample_rate=args.sample_rate, segment=args.segment)
tr_dataset = DatasetGenerator(args.train, args.data_batch_size, sample_rate=args.sample_rate, segment=args.segment)
# distributed_sampler = DistributedSampler(14240)
# tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=True, num_parallel_workers=8, num_shards=rank_size, shard_id=rank_id, sampler=distributed_sampler)
tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=True, num_parallel_workers=8, num_shards=rank_size, shard_id=rank_id)


Loading…
Cancel
Save