|
- # -*- coding: utf-8 -*-
- # Created by wangyingde on 2021/5/20
- from easy_forecast.batch.simple_data_process_strategies import SimpleBatchDataProcess
- from easy_forecast.strategies.ts_lark.forecast_validation import CrossValidation
- from easy_forecast.validation.basic_acc_metric import BasicAccMetric
- from easy_forecast.strategies.ts_cross_validation.single_frcst_model import ArimaStrategy, CrostonStrategy, HWStrategy, SMAStrategy, SAStrategy
- from easy_forecast.batch.batch_forecast import BatchForecastTs
- from easy_forecast.config import naming_specification
- from easy_forecast.ts.ts_basic_models import CommonFrcst
- from easy_forecast.ts.rawdata_process import RawDataProcess
- from easy_forecast.log.global_logger import log
- import pandas as pd
- from sklearn import linear_model
-
-
- class LassoRegression:
- def __init__(self):
- pass
-
- def train(self, train_data: pd.DataFrame):
- train_data = train_data.copy(deep=True)
- model_list = train_data['mod_n'].unique()
- model_list = sorted(model_list)
- train_data = train_data.pivot_table(index=['obj_no', 'ds', 'y'], values=['yhat'], columns=['mod_n'],
- fill_value=0).reset_index()
- train_data.columns = train_data.columns.droplevel().map(str)
- train_data = train_data.reset_index(drop=True).rename_axis(None, axis=1)
- train_data.columns.values[0:3] = ['obj_no', 'ds', 'y']
- x_data = train_data.loc[:, model_list].values
- y_data = train_data['y'].values
- model = linear_model.LassoCV(n_jobs=1, max_iter=1000)
- model.fit(x_data, y_data)
- return model
-
- def predict(self, predict_data, predict_model):
- predict_data = predict_data.copy(deep=True)
- model_list = predict_data['mod_n'].unique()
- model_list = sorted(model_list)
- predict_data = predict_data.pivot_table(index=['obj_no', 'ds'], values=['yhat'], columns=['mod_n'],
- fill_value=0).reset_index()
- predict_data.columns = predict_data.columns.droplevel().map(str)
- predict_data = predict_data.reset_index(drop=True).rename_axis(None, axis=1)
- predict_data.columns.values[0:2] = ['obj_no', 'ds']
- x_data = predict_data.loc[:, model_list].copy()
- # 修复系数
- predict_data['yhat'] = predict_model.predict(x_data)
- predict_res = predict_data.loc[:, ['obj_no', 'ds', 'yhat']].copy()
- return predict_res
-
-
-
- class Larktrategy(BatchForecastTs):
-
- def __init__(self, df, freq, cutoff_date, md_dict=None, forecast_periods=1,
- test_period_n=None, loop=False, df_x=None, interval_width=0.8,
- top_k=4):
- """
-
- :param df:
- :param freq:
- :param cutoff_date:
- :param forecast_periods:
- :param md_dict:
- :param enable_train: 是否训练模型
- :param model_selected: 训练好的模型策略, pd.DataFrame。如果为空,进行训练,如果不为空,则直接使用该模型进行预测
- :param test_period_n:
- :param top_k:
- :param loop:
- :param df_x:
- :param interval_width:
- """
- super().__init__(df=df, cutoff_date=None, freq=freq, df_x=df_x,
- forecast_periods=forecast_periods,
- interval_width=interval_width)
- self._cutoff_date = cutoff_date
- self._test_period_n = test_period_n
- self._top_k = top_k
- self._loop = loop
- self._md_dict = md_dict
- self.initialization_parameter()
-
- def initialization_parameter(self):
-
- if self._freq in [naming_specification.FREQ_DAY, naming_specification.FREQ_DAY_ABBR]:
- season_period = 7
- mv_steps = 7
- elif self._freq in [naming_specification.FREQ_WEEK, naming_specification.FREQ_WEEK_ABBR]:
- season_period = 4
- mv_steps = 4
- elif self._freq in [naming_specification.FREQ_MONTH, naming_specification.FREQ_MONTH_ABBR]:
- season_period = 12
- mv_steps = 12
-
- if self._test_period_n is None and self._freq in [naming_specification.FREQ_DAY,
- naming_specification.FREQ_DAY_ABBR]:
- self._test_period_n = min(self._forecast_periods, 7)
- elif self._test_period_n is None and self._freq in [naming_specification.FREQ_WEEK,
- naming_specification.FREQ_WEEK_ABBR]:
- self._test_period_n = min(self._forecast_periods, 8)
- elif self._test_period_n is None and self._freq in [naming_specification.FREQ_MONTH,
- naming_specification.FREQ_MONTH_ABBR]:
- self._test_period_n = min(self._forecast_periods, 6)
-
- # 调用模型
- # 给模型指定参数
- self.dp_dict = {
- 'dp_simple': {
- 'class_name': SimpleBatchDataProcess
- }
- }
- # TODO 指定模型参数,待改进
- if self._md_dict is None:
- self._md_dict = {
- 'arima': {
- 'class_name': ArimaStrategy,
- 'season_period': season_period
- },
-
- 'croston_strategy': {
- 'class_name': CrostonStrategy,
- },
-
- 'sma_strategy': {
- 'class_name': SMAStrategy,
- 'mv_steps': mv_steps
- },
-
- 'hw_strategy': {
- 'class_name': HWStrategy,
- 'season_period': season_period
- },
- }
- # 必备模型
- default_model_dict = {'sa_strategy': {'class_name': SAStrategy}}
- self._md_dict.update(default_model_dict)
-
- self.acc_metric = [BasicAccMetric.RMSE]
-
- def process(self, obj_data, obj_data_x, obj_no):
- """
-
- """
- lines = obj_data.copy()
- try:
- cv = CrossValidation(df=lines, data_process_dict=self.dp_dict, forecast_strategy_dict=self._md_dict,
- cutoff_date=self._cutoff_date,
- acc_metric=self.acc_metric,
- ms_cutoff_date=self._cutoff_date, forecast_freq=self._freq,
- ms_test_period_n=self._test_period_n)
-
- # 回测
- back_test_rs, model_selected = cv.fit(acc_metric=BasicAccMetric.RMSE, forecast_period=self._forecast_periods, \
- use_pool=False, loop=self._loop, top_k=self._top_k)
-
- cv_rs = cv.predict(acc_metric=BasicAccMetric.RMSE, forecast_period=self._forecast_periods,
- model_selection=model_selected,
- use_pool=False, top_k=self._top_k)
- # 训练
- lasso = LassoRegression()
- lasso_model = lasso.train(train_data=back_test_rs)
- rs = lasso.predict(predict_data=cv_rs, predict_model=lasso_model)
- y_mean = lines.tail(60)['y'].mean()
- y_upper = y_mean * 3
- #y_upper = lines['y'].quantile(0.9)
- y_lower = 0
- rs.loc[:, 'yhat'] = rs['yhat'].apply(lambda y: self.adjust_y(y, upper=y_upper, lower=y_lower))
- except Exception as e:
- log.warning(f"[{obj_no}] use simple average!!! error message")
- log.error(e, exc_info=True)
- lines = RawDataProcess(lines, data_freq=self._freq, cutoff_date=self._cutoff_date).get_sample_data()
- lines[naming_specification.OBJ_NO] = obj_no
- forecast = CommonFrcst(lines, forecasting_freq=self._freq, forecasting_periods=self._forecast_periods,
- interval_width=self._interval_width)
- rs = forecast.sa().tail(self._forecast_periods)
- rs[naming_specification.OBJ_NO] = obj_no
- rs = rs.loc[:, [naming_specification.OBJ_NO, naming_specification.DATETIME_STAMP, naming_specification.Y_HAT]].copy()
- return rs
-
- @staticmethod
- def adjust_y(y, upper, lower=0):
- if y >= 0 and y <= upper:
- return y
- elif y > upper:
- return upper
- else:
- return lower
-
-
- if __name__ == '__main__':
- pass
|