Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: chunk series in parallel forecast #915

Merged
merged 4 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 70 additions & 112 deletions nbs/src/core/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
"import reprlib\n",
"import time\n",
"import warnings\n",
"from collections import defaultdict\n",
"from concurrent.futures import ProcessPoolExecutor, as_completed\n",
"from pathlib import Path\n",
"from typing import Any, Dict, List, Optional, Union\n",
Expand All @@ -100,7 +99,7 @@
"import pandas as pd\n",
"import utilsforecast.processing as ufp\n",
"from fugue.execution.factory import make_execution_engine, try_get_context_execution_engine\n",
"from threadpoolctl import ThreadpoolController, threadpool_limits\n",
"from threadpoolctl import ThreadpoolController\n",
"from tqdm.auto import tqdm\n",
"from triad import conditional_dispatcher\n",
"from utilsforecast.compat import DataFrame, pl_DataFrame, pl_Series\n",
Expand All @@ -124,39 +123,7 @@
" datefmt='%Y-%m-%d %H:%M:%S',\n",
" )\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"_controller = ThreadpoolController()\n",
"\n",
"@_controller.wrap(limits=1)\n",
"def _forecast_serie(h, y, X, X_future, models, fallback_model, level, fitted):\n",
" forecast_res = {}\n",
" fitted_res = {}\n",
" times = {}\n",
" for model in models:\n",
" start = time.perf_counter()\n",
" model_kwargs = dict(h=h, y=y, X=X, X_future=X_future, fitted=fitted)\n",
" if \"level\" in inspect.signature(model.forecast).parameters and level:\n",
" model_kwargs[\"level\"] = level\n",
" try:\n",
" model_res = model.forecast(**model_kwargs)\n",
" except Exception as e:\n",
" if fallback_model is None:\n",
" raise e\n",
" model_res = fallback_model.forecast(**model_kwargs)\n",
" model_name = repr(model)\n",
" times[model_name] = time.perf_counter() - start\n",
" for k, v in model_res.items():\n",
" if k == \"mean\":\n",
" forecast_res[model_name] = v\n",
" elif k.startswith((\"lo\", \"hi\")):\n",
" col_name = f\"{model_name}-{k}\"\n",
" forecast_res[col_name] = v\n",
" elif k == \"fitted\":\n",
" fitted_res[model_name] = v\n",
" elif k.startswith((\"fitted-lo\", \"fitted-hi\")):\n",
" col_name = f'{model_name}-{k.replace(\"fitted-\", \"\")}'\n",
" fitted_res[col_name] = v\n",
" return forecast_res, fitted_res, times"
"_controller = ThreadpoolController()"
]
},
{
Expand Down Expand Up @@ -471,19 +438,19 @@
" def split_fm(self, fm, n_chunks):\n",
" return [fm[idxs] for idxs in np.array_split(range(self.n_groups), n_chunks) if idxs.size]\n",
"\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_fit(self, models, fallback_model=None):\n",
" with threadpool_limits(limits=1):\n",
" return self.fit(models=models, fallback_model=fallback_model)\n",
" return self.fit(models=models, fallback_model=fallback_model)\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_predict(self, fm, h, X=None, level=tuple()):\n",
" with threadpool_limits(limits=1):\n",
" return self.predict(fm=fm, h=h, X=X, level=level)\n",
" return self.predict(fm=fm, h=h, X=X, level=level)\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_fit_predict(self, models, h, X=None, level=tuple()):\n",
" with threadpool_limits(limits=1):\n",
" return self.fit_predict(models=models, h=h, X=X, level=level)\n",
" return self.fit_predict(models=models, h=h, X=X, level=level)\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_forecast(\n",
" self,\n",
" models,\n",
Expand All @@ -495,18 +462,18 @@
" verbose=False,\n",
" target_col='y',\n",
" ):\n",
" with threadpool_limits(limits=1):\n",
" return self.forecast(\n",
" models=models,\n",
" h=h,\n",
" fallback_model=fallback_model,\n",
" fitted=fitted,\n",
" X=X,\n",
" level=level,\n",
" verbose=verbose,\n",
" target_col=target_col,\n",
" )\n",
" \n",
" return self.forecast(\n",
" models=models,\n",
" h=h,\n",
" fallback_model=fallback_model,\n",
" fitted=fitted,\n",
" X=X,\n",
" level=level,\n",
" verbose=verbose,\n",
" target_col=target_col,\n",
" )\n",
"\n",
" @_controller.wrap(limits=1)\n",
" def _single_threaded_cross_validation(\n",
" self,\n",
" models,\n",
Expand All @@ -521,20 +488,19 @@
" verbose=False,\n",
" target_col='y',\n",
" ):\n",
" with threadpool_limits(limits=1):\n",
" return self.cross_validation(\n",
" models=models,\n",
" h=h,\n",
" test_size=test_size,\n",
" fallback_model=fallback_model,\n",
" step_size=step_size,\n",
" input_size=input_size,\n",
" fitted=fitted,\n",
" level=level,\n",
" refit=refit,\n",
" verbose=verbose,\n",
" target_col=target_col,\n",
" )"
" return self.cross_validation(\n",
" models=models,\n",
" h=h,\n",
" test_size=test_size,\n",
" fallback_model=fallback_model,\n",
" step_size=step_size,\n",
" input_size=input_size,\n",
" fitted=fitted,\n",
" level=level,\n",
" refit=refit,\n",
" verbose=verbose,\n",
" target_col=target_col,\n",
" )"
]
},
{
Expand Down Expand Up @@ -1685,12 +1651,14 @@
" fm = np.vstack([f.get() for f in futures])\n",
" return fm \n",
" \n",
" def _get_gas_Xs(self, X):\n",
" gas = self.ga.split(self.n_jobs)\n",
" def _get_gas_Xs(self, X, tasks_per_job=1):\n",
" n_chunks = min(tasks_per_job * self.n_jobs, self.ga.n_groups)\n",
" gas = self.ga.split(n_chunks)\n",
" if X is not None:\n",
" Xs = X.split(self.n_jobs)\n",
" Xs = X.split(n_chunks)\n",
" else:\n",
" from itertools import repeat\n",
"\n",
" Xs = repeat(None)\n",
" return gas, Xs\n",
" \n",
Expand Down Expand Up @@ -1735,57 +1703,47 @@
" return fm, fcsts, cols\n",
"\n",
" def _forecast_parallel(self, h, fitted, X, level, target_col):\n",
" n_series = self.ga.n_groups\n",
" forecast_res = defaultdict(lambda: np.empty(n_series * h, dtype=self.ga.data.dtype))\n",
" fitted_res = defaultdict(\n",
" lambda: np.empty(self.ga.data.shape[0], dtype=self.ga.data.dtype)\n",
" )\n",
" fitted_res[target_col] = self.ga.data[:, 0]\n",
" future2pos = {}\n",
" times = {repr(m): 0.0 for m in self.models}\n",
" gas, Xs = self._get_gas_Xs(X=X, tasks_per_job=100)\n",
" results = [None] * len(gas)\n",
" with ProcessPoolExecutor(self.n_jobs) as executor:\n",
" for i, serie in enumerate(self.ga):\n",
" y_train = serie[:, 0]\n",
" X_train = serie[:, 1:] if serie.shape[1] > 1 else None\n",
" if X is None:\n",
" X_future = None\n",
" else:\n",
" X_future = X[i]\n",
" future = executor.submit(\n",
" _forecast_serie,\n",
" future2pos = {\n",
" executor.submit(\n",
" ga._single_threaded_forecast,\n",
" h=h,\n",
" y=y_train,\n",
" X=X_train,\n",
" X_future=X_future,\n",
" models=self.models,\n",
" fallback_model=self.fallback_model,\n",
" fitted=fitted,\n",
" X=X,\n",
" level=level,\n",
" )\n",
" future2pos[future] = i\n",
" verbose=False,\n",
" target_col=target_col,\n",
" ): i\n",
" for i, (ga, X) in enumerate(zip(gas, Xs))\n",
" }\n",
" iterable = tqdm(\n",
" as_completed(future2pos), disable=not self.verbose, total=n_series, desc=\"Forecast\"\n",
" ) \n",
" as_completed(future2pos),\n",
" disable=not self.verbose,\n",
" total=len(future2pos),\n",
" desc=\"Forecast\",\n",
" bar_format=\"{l_bar}{bar}| {n_fmt}/{total_fmt} [Elapsed: {elapsed}{postfix}]\",\n",
" )\n",
" for future in iterable:\n",
" i = future2pos[future]\n",
" fcst_idxs = slice(i * h, (i + 1) * h)\n",
" serie_idxs = slice(self.ga.indptr[i], self.ga.indptr[i + 1])\n",
" serie_fcst, serie_fitted, serie_times = future.result()\n",
" for k, v in serie_fcst.items():\n",
" forecast_res[k][fcst_idxs] = v\n",
" for k, v in serie_fitted.items():\n",
" fitted_res[k][serie_idxs] = v\n",
" for model_name, model_time in serie_times.items():\n",
" times[model_name] += model_time\n",
" return {\n",
" 'cols': list(forecast_res.keys()),\n",
" 'forecasts': np.hstack([v[:, None] for v in forecast_res.values()]),\n",
" 'fitted': {\n",
" 'cols': list(fitted_res.keys()),\n",
" 'values': np.hstack([v[:, None] for v in fitted_res.values()]),\n",
" results[i] = future.result()\n",
" result = {\n",
" 'cols': results[0]['cols'],\n",
" 'forecasts': np.vstack([r['forecasts'] for r in results]),\n",
" 'times': {\n",
" m: sum(r['times'][m] for r in results)\n",
" for m in [repr(m) for m in self.models]\n",
" },\n",
" 'times': times,\n",
" } \n",
" }\n",
" if fitted:\n",
" result['fitted'] = {\n",
" 'cols': results[0]['fitted']['cols'],\n",
" 'values': np.vstack([r['fitted']['values'] for r in results]),\n",
" }\n",
" return result\n",
"\n",
" def _cross_validation_parallel(self, h, test_size, step_size, input_size, fitted, level, refit, target_col):\n",
" #create elements for each core\n",
Expand Down
2 changes: 1 addition & 1 deletion python/statsforecast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.7.7.1"
__version__ = "1.7.7.99"
__all__ = ["StatsForecast"]
from .core import StatsForecast
from .distributed import fugue # noqa
1 change: 0 additions & 1 deletion python/statsforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@
'statsforecast/core.py'),
'statsforecast.core._StatsForecast.save': ( 'src/core/core.html#_statsforecast.save',
'statsforecast/core.py'),
'statsforecast.core._forecast_serie': ('src/core/core.html#_forecast_serie', 'statsforecast/core.py'),
'statsforecast.core._get_n_jobs': ('src/core/core.html#_get_n_jobs', 'statsforecast/core.py'),
'statsforecast.core._id_as_idx': ('src/core/core.html#_id_as_idx', 'statsforecast/core.py'),
'statsforecast.core._maybe_warn_sort_df': ( 'src/core/core.html#_maybe_warn_sort_df',
Expand Down
Loading
Loading