|
- import mindspore.common.dtype as mstype
- import mindspore.dataset as ds
- import mindspore.dataset.transforms.c_transforms as C
-
-
- def create_squad_train_dataset(data_file=None, do_shuffle=True, device_num=1, rank=0, batch_size=1, num=None):
- dataset = ds.MindDataset(data_file,
- columns_list=["input_mask", "src_ids", "pos_ids", "sent_ids", "wn_concept_ids",
- "nell_concept_ids", "start_positions", "end_positions"],
- shuffle=do_shuffle, num_shards=device_num, shard_id=rank, num_samples=num)
-
- type_int32 = C.TypeCast(mstype.int32)
- type_float32 = C.TypeCast(mstype.float32)
- dataset = dataset.map(operations=type_int32, input_columns="src_ids")
- dataset = dataset.map(operations=type_int32, input_columns="pos_ids")
- dataset = dataset.map(operations=type_int32, input_columns="sent_ids")
- dataset = dataset.map(operations=type_int32, input_columns="wn_concept_ids")
- dataset = dataset.map(operations=type_int32, input_columns="nell_concept_ids")
- dataset = dataset.map(operations=type_float32, input_columns="input_mask")
- dataset = dataset.map(operations=type_int32, input_columns="start_positions")
- dataset = dataset.map(operations=type_int32, input_columns="end_positions")
-
- dataset = dataset.batch(batch_size, drop_remainder=True)
-
- return dataset
-
-
- def create_squad_dev_dataset(data_file=None, do_shuffle=True, batch_size=1, repeat_count=1):
- dataset = ds.MindDataset(data_file,
- columns_list=["input_mask", "src_ids", "pos_ids", "sent_ids",
- "wn_concept_ids", "nell_concept_ids", "unique_id"],
- shuffle=do_shuffle)
-
- type_int32 = C.TypeCast(mstype.int32)
- type_float32 = C.TypeCast(mstype.float32)
- dataset = dataset.map(operations=type_int32, input_columns="src_ids")
- dataset = dataset.map(operations=type_int32, input_columns="pos_ids")
- dataset = dataset.map(operations=type_int32, input_columns="sent_ids")
- dataset = dataset.map(operations=type_int32, input_columns="wn_concept_ids")
- dataset = dataset.map(operations=type_int32, input_columns="nell_concept_ids")
- dataset = dataset.map(operations=type_float32, input_columns="input_mask")
- dataset = dataset.map(operations=type_int32, input_columns="unique_id")
-
- dataset = dataset.repeat(repeat_count)
- dataset = dataset.batch(batch_size, drop_remainder=True)
-
- return dataset
|