|
- import os
- from dataclasses import dataclass, field
- from typing import Optional, Dict, Sequence
-
- import torch
- import transformers
- from torch.utils.data import Dataset
-
-
- class D1(Dataset):
- """Dataset for supervised fine-tuning."""
-
- def __init__(self):
- super().__init__()
- self.input_ids = [i for i in range(5)]
-
- def __len__(self):
- return len(self.input_ids)
-
- def __getitem__(self, i) -> Dict[str, torch.Tensor]:
- return dict(input_ids=self.input_ids[i])
-
-
- class D2(Dataset):
- """Dataset for supervised fine-tuning."""
-
- def __init__(self):
- super().__init__()
- self.input_ids = [i for i in range(6, 9)]
-
- def __len__(self):
- return len(self.input_ids)
-
- def __getitem__(self, i) -> Dict[str, torch.Tensor]:
- return dict(input_ids=self.input_ids[i])
-
-
- # @dataclass
- # class DataCollatorForSupervisedDataset(object):
- # """Collate examples for supervised fine-tuning."""
-
- # tokenizer: transformers.PreTrainedTokenizer
-
- # def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
- # input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
- # input_ids = torch.nn.utils.rnn.pad_sequence(
- # input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
- # )
- # labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
- # return dict(
- # input_ids=input_ids,
- # labels=labels,
- # # attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
- # )
-
- def run():
- print(os.environ["LOCAL_RANK"])
- print("rank", os.environ["RANK"])
-
- if int(os.environ["RANK"]) == 0:
- d = D1()
- else:
- d = D2()
-
- for i in d:
- print(i)
-
-
- if __name__ == "__main__":
- run()
|