-
Notifications
You must be signed in to change notification settings - Fork 299
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mfles model reference and experiment (#853)
- Loading branch information
Showing
8 changed files
with
920 additions
and
564 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from functools import partial | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
from utilsforecast.evaluation import evaluate | ||
from utilsforecast.losses import smape, mape, rmse, mae, mase | ||
|
||
|
||
def generate_metrics(path: Path) -> str: | ||
seasonalities = { | ||
'hourly': 24, | ||
'daily': 7, | ||
'weekly': 52, | ||
'monthly': 12, | ||
'quarterly': 4, | ||
'yearly': 1, | ||
} | ||
fmts = { | ||
'mase': '{:.2f}', | ||
'rmse': '{:,.1f}', | ||
'smape': '{:.1%}', | ||
'mape': '{:.1%}', | ||
'mae': '{:,.1f}', | ||
} | ||
season_length = seasonalities[path.name] | ||
pmase = partial(mase, seasonality=season_length) | ||
|
||
train = pd.read_parquet(path / 'train.parquet') | ||
valid = pd.read_parquet(path / 'valid.parquet') | ||
eval_res = evaluate( | ||
valid, | ||
train_df=train, | ||
metrics=[smape, mape, rmse, mae, pmase], | ||
) | ||
summary = eval_res.drop(columns='unique_id').groupby('metric').mean() | ||
formatted = {} | ||
for metric in summary.index: | ||
row = summary.loc[metric] | ||
best = row.idxmin() | ||
fmt = fmts[metric] | ||
row = row.map(fmt.format) | ||
row[best] = '**' + row[best] + '**' | ||
formatted[metric] = row | ||
out_cols = [c for c in summary.columns if c != 'AutoARIMA'] | ||
if 'AutoARIMA' in summary.columns: | ||
out_cols.append('AutoARIMA') | ||
return pd.DataFrame(formatted).T[out_cols].to_markdown() | ||
|
||
|
||
def generate_times(path: Path) -> str: | ||
df = pd.read_csv(path / 'times.csv') | ||
df['time'] /= 60 | ||
df = df.sort_values('time') | ||
df = df.rename(columns={'time': 'CPU time (min)'}) | ||
return df.to_markdown(index=False, floatfmt=',.0f') | ||
|
||
|
||
if __name__ == '__main__': | ||
with open('summary.md', 'wt') as f: | ||
for path in Path('results').iterdir(): | ||
f.write(f'# {path.name.capitalize()}') | ||
f.write('\n## Metrics\n') | ||
f.write(generate_metrics(path)) | ||
f.write('\n\n## Times\n') | ||
f.write(generate_times(path)) | ||
f.write('\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import argparse | ||
import os | ||
import random | ||
from functools import partial | ||
from pathlib import Path | ||
|
||
os.environ['NIXTLA_NUMBA_CACHE'] = '1' | ||
os.environ['NIXTLA_ID_AS_COL'] = '1' | ||
|
||
import pandas as pd | ||
from datasetsforecast.m4 import M4, M4Info | ||
|
||
from statsforecast import StatsForecast | ||
from statsforecast.models import ( | ||
AutoARIMA, | ||
AutoETS, | ||
AutoMFLES, | ||
AutoTBATS, | ||
DynamicOptimizedTheta, | ||
SeasonalNaive, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('group') | ||
args = parser.parse_args() | ||
|
||
# data | ||
seasonality_overrides = { | ||
"Daily": [7], | ||
"Hourly": [24, 24*7], | ||
"Weekly": [52], | ||
} | ||
group = args.group.capitalize() | ||
df, *_ = M4.load("data", group) | ||
df['ds'] = df['ds'].astype('int64') | ||
info = M4Info[group] | ||
h = info.horizon | ||
season_length = [info.seasonality] | ||
if group in seasonality_overrides: | ||
season_length = seasonality_overrides[group] | ||
valid = df.groupby("unique_id").tail(h) | ||
train = df.drop(valid.index) | ||
print(f'Running {group}. season_length: {season_length}') | ||
|
||
# forecast | ||
sf = StatsForecast( | ||
models=[ | ||
AutoARIMA(season_length=season_length[0]), | ||
AutoETS(season_length=season_length[0]), | ||
AutoMFLES(test_size=h, season_length=season_length, n_windows=3), | ||
AutoTBATS(season_length=season_length), | ||
DynamicOptimizedTheta(season_length=season_length[0]), | ||
SeasonalNaive(season_length=season_length[0]), | ||
], | ||
freq=1, | ||
n_jobs=-1, | ||
verbose=True, | ||
) | ||
preds = sf.forecast(df=train, h=h) | ||
res = preds.merge(valid, on=['unique_id', 'ds']) | ||
|
||
# save results | ||
results_path = Path('results') / group.lower() | ||
results_path.mkdir(exist_ok=True) | ||
res = preds.merge(valid, on=['unique_id', 'ds']) | ||
res.to_parquet(results_path / 'valid.parquet', index=False) | ||
train.to_parquet(results_path / 'train.parquet', index=False) | ||
times = pd.Series(sf.forecast_times_).reset_index() | ||
times.columns = ['model', 'time'] | ||
times.to_csv(results_path / 'times.csv', index=False) |
Oops, something went wrong.