|
- from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred, Dataset_2003_NEPOOL_hour,Dataset_GEFCOM2014_minute,Dataset_Energy_minute,Dataset_Exchange_day,Dataset_Stock_day,Dataset_donggua_day
- from torch.utils.data import DataLoader
-
- data_dict = {
- 'ETTh1': Dataset_ETT_hour,
- 'ETTh2': Dataset_ETT_hour,
- 'ETTm1': Dataset_ETT_minute,
- 'ETTm2': Dataset_ETT_minute,
- 'custom': Dataset_Custom,
- "2003_NEPOOL": Dataset_2003_NEPOOL_hour,
- "GEFCOM2014": Dataset_GEFCOM2014_minute,
- "energy": Dataset_Energy_minute,
- "exchange": Dataset_Exchange_day,
- "stock": Dataset_Stock_day,
- "donggua": Dataset_donggua_day
- }
-
-
- def data_provider(args, flag):
- Data = data_dict[args.data]
- timeenc = 0 if args.embed != 'timeF' else 1
-
- if flag == 'test':
- shuffle_flag = False
- drop_last = True
- batch_size = args.batch_size
- freq = args.freq
- elif flag == 'pred':
- shuffle_flag = False
- drop_last = False
- batch_size = 1
- freq = args.freq
- Data = Dataset_Pred
- else:
- shuffle_flag = True
- drop_last = True
- batch_size = args.batch_size
- freq = args.freq
-
- data_set = Data(
- root_path=args.root_path,
- data_path=args.data_path,
- flag=flag,
- size=[args.seq_len, args.label_len, args.pred_len],
- features=args.features,
- target=args.target,
- timeenc=timeenc,
- freq=freq
- )
- print(flag, len(data_set))
- data_loader = DataLoader(
- data_set,
- batch_size=batch_size,
- shuffle=shuffle_flag,
- num_workers=0,
- drop_last=drop_last)
- return data_set, data_loader
|