|
- """
- -------------------------------------------------
- File Name: TrafficPredictionExample
- Description :
- Author : zhangweifeng
- date: 2020/11/20
- -------------------------------------------------
- Change Activity:
- 2020/11/20:
- -------------------------------------------------
- """
- import train
- import argparse
- import test
- import yaml
- import numpy as np
-
- class TrafficPredictionExample():
-
- def __init__(self,conf_file):
- '''
- 初始化模型预测,提取模型配置参数
- Arg:
- ---
- conf_file: 参数配置文件的路径
- '''
- f=open(conf_file,encoding='utf8')
- args=yaml.load(f.read())
- print(args)
- self.data = args['data']
- self.adjdata = args['adjdata']
- self.seq_length = 12
- self.nhid = 32
- self.in_dim = 2
- self.num_nodes = args['num_nodes']
- self.batch_size = 32
- self.learning_rate = 0.001
- self.dropout = 0.3
- self.weight_decay = 0.0001
- self.epochs = args['epochs']
- self.print_every = 50
- self.save = args['savedir']
- self.SE_file = args['SE_file']
- self.checkpoint=args['checkpoint']
- self.test_output_file=args['test_output_file']
- self.data_mean=args['data_mean']
- self.data_std=args['data_std']
-
- def __call__(self, test_data_tensor):
- '''
- 调用模型进行预测
- Arg
- ---
- test_data_tensor: shape:[L,12,N,2].
- L is the length of test data. 12 is the number of time steps.
- N is the number of road segments. 2 is the channel, the 0-index is the speed, the 1-index is time of a day.
-
- Return
- ---
- result: shape:[L,N,12] the speed in future 12 time steps
- '''
- result=test.test(data=self.data, adjdata=self.adjdata, seq_length=self.seq_length, nhid=self.nhid, in_dim=self.in_dim,
- num_nodes=self.num_nodes, batch_size=self.batch_size, learning_rate=self.learning_rate,
- dropout=self.dropout, weight_decay=self.weight_decay, savedcheckpoint=self.checkpoint,
- SE_file=self.SE_file, test_output_file=self.test_output_file,test_data_tensor=test_data_tensor,data_mean=self.data_mean,data_std=self.data_std)
-
- return result
-
- def train(self):
- train.main(data=self.data, adjdata=self.adjdata, seq_length=self.seq_length,nhid=self.nhid, in_dim=self.in_dim,
- num_nodes=self.num_nodes, batch_size=self.batch_size, learning_rate=self.learning_rate,
- dropout=self.dropout, weight_decay=self.weight_decay, epochs=self.epochs,
- print_every=self.print_every, save=self.save,SE_file=self.SE_file,test_output_file=self.test_output_file)
-
- def test(self):
- test.main(data=self.data, adjdata=self.adjdata, seq_length=self.seq_length ,nhid=self.nhid, in_dim=self.in_dim,
- num_nodes=self.num_nodes, batch_size=self.batch_size, learning_rate=self.learning_rate,
- dropout=self.dropout, weight_decay=self.weight_decay, savedcheckpoint=self.checkpoint, SE_file=self.SE_file,test_output_file=self.test_output_file)
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--train', action='store_true', help='whether train')
- parser.add_argument('--test', action='store_true', help='whether test')
- args = parser.parse_args()
- tp = TrafficPredictionExample('config.yml')
-
- # if args.train:
- # tp.train()
- # if args.test:
- # tp.test()
-
- cat_data = np.load('data/METR-LA/test.npz')
- testdata = cat_data['x']
- result = tp(testdata)
- print(result.shape)
|