|
- from __future__ import division
- import numpy as np
- import torch
- import os
- import logging
- from torch.utils.data import DataLoader, Dataset, Sampler
-
- logger = logging.getLogger('DeepAR.Data')
-
- class TrainDataset(Dataset):
- def __init__(self, data_path, data_name, num_class):
- self.data = np.load(os.path.join(data_path, f'train_data_{data_name}.npy')) # [T, W, F] W为滑动窗口训练特征
- self.label = np.load(os.path.join(data_path, f'train_label_{data_name}.npy')) # [T, W] 训练标签
- self.train_len = self.data.shape[0] # 训练集的长度(所有用户的用时)
- logger.info(f'train_len: {self.train_len}')
- logger.info(f'building datasets from {data_path}...')
-
- def __len__(self):
- """可以使用len方法获取对象的长度"""
- return self.train_len
-
- def __getitem__(self, index):
- """可以使用[index]的方式获取对象的元素"""
- return (self.data[index,:,:-1],int(self.data[index,0,-1]), self.label[index]) # 把最后一个特征分开,因为最后一个特征是每个用户的ID
-
- class TestDataset(Dataset):
- def __init__(self, data_path, data_name, num_class):
- self.data = np.load(os.path.join(data_path, f'test_data_{data_name}.npy')) # [T, W, F]
- self.v = np.load(os.path.join(data_path, f'test_v_{data_name}.npy')) # [T, 2]
- self.label = np.load(os.path.join(data_path, f'test_label_{data_name}.npy')) # [T, W]
- self.test_len = self.data.shape[0] # 测试集长度
- logger.info(f'test_len: {self.test_len}')
- logger.info(f'building datasets from {data_path}...')
-
- def __len__(self):
- return self.test_len
-
- def __getitem__(self, index):
- return (self.data[index,:,:-1],int(self.data[index,0,-1]),self.v[index],self.label[index])
-
- class WeightedSampler(Sampler): # 加权取样?
- def __init__(self, data_path, data_name, replacement=True):
- v = np.load(os.path.join(data_path, f'train_v_{data_name}.npy'))
- self.weights = torch.as_tensor(np.abs(v[:,0])/np.sum(np.abs(v[:,0])), dtype=torch.double)
- logger.info(f'weights: {self.weights}')
- self.num_samples = self.weights.shape[0]
- logger.info(f'num samples: {self.num_samples}')
- self.replacement = replacement
-
- def __iter__(self):
- return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
-
- def __len__(self):
- return self.num_samples
|