diff --git a/nbs/src/core/core.ipynb b/nbs/src/core/core.ipynb index 8bc47c870..11746bb25 100644 --- a/nbs/src/core/core.ipynb +++ b/nbs/src/core/core.ipynb @@ -100,7 +100,7 @@ "import pandas as pd\n", "import utilsforecast.processing as ufp\n", "from fugue.execution.factory import make_execution_engine, try_get_context_execution_engine\n", - "from threadpoolctl import threadpool_limits\n", + "from threadpoolctl import ThreadpoolController, threadpool_limits\n", "from tqdm.auto import tqdm\n", "from triad import conditional_dispatcher\n", "from utilsforecast.compat import DataFrame, pl_DataFrame, pl_Series\n", @@ -113,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6d8d5b82-2be9-41f5-8cd0-3903d0761e09", + "id": "46f43744-f0c9-4d6a-b897-acc9df8f59b5", "metadata": {}, "outputs": [], "source": [ @@ -123,7 +123,40 @@ " format='%(asctime)s %(name)s %(levelname)s: %(message)s',\n", " datefmt='%Y-%m-%d %H:%M:%S',\n", " )\n", - "logger = logging.getLogger(__name__)" + "logger = logging.getLogger(__name__)\n", + "\n", + "_controller = ThreadpoolController()\n", + "\n", + "@_controller.wrap(limits=1)\n", + "def _forecast_serie(h, y, X, X_future, models, fallback_model, level, fitted):\n", + " forecast_res = {}\n", + " fitted_res = {}\n", + " times = {}\n", + " for model in models:\n", + " start = time.perf_counter()\n", + " model_kwargs = dict(h=h, y=y, X=X, X_future=X_future, fitted=fitted)\n", + " if \"level\" in inspect.signature(model.forecast).parameters and level:\n", + " model_kwargs[\"level\"] = level\n", + " try:\n", + " model_res = model.forecast(**model_kwargs)\n", + " except Exception as e:\n", + " if fallback_model is None:\n", + " raise e\n", + " model_res = fallback_model.forecast(**model_kwargs)\n", + " model_name = repr(model)\n", + " times[model_name] = time.perf_counter() - start\n", + " for k, v in model_res.items():\n", + " if k == \"mean\":\n", + " forecast_res[model_name] = v\n", + " elif k.startswith((\"lo\", \"hi\")):\n", + " col_name = f\"{model_name}-{k}\"\n", + " forecast_res[col_name] = v\n", + " elif k == \"fitted\":\n", + " fitted_res[model_name] = v\n", + " elif k.startswith((\"fitted-lo\", \"fitted-hi\")):\n", + " col_name = f'{model_name}-{k.replace(\"fitted-\", \"\")}'\n", + " fitted_res[col_name] = v\n", + " return forecast_res, fitted_res, times" ] }, { @@ -1701,38 +1734,6 @@ " cols = cols[0]\n", " return fm, fcsts, cols\n", "\n", - " def _forecast_serie(self, level, **forecast_kwargs):\n", - " forecast_res = {}\n", - " fitted_res = {}\n", - " times = {}\n", - " with threadpool_limits(limits=1):\n", - " for model in self.models:\n", - " if 'level' in inspect.signature(model.forecast).parameters and level:\n", - " model_kwargs = {**forecast_kwargs, 'level': level}\n", - " else:\n", - " model_kwargs = forecast_kwargs\n", - " start = time.perf_counter()\n", - " try:\n", - " model_res = model.forecast(**model_kwargs)\n", - " except Exception as e:\n", - " if self.fallback_model is None:\n", - " raise e\n", - " model_res = self.fallback_model.forecast(**model_kwargs)\n", - " model_name = repr(model)\n", - " times[model_name] = time.perf_counter() - start\n", - " for k, v in model_res.items(): \n", - " if k == 'mean':\n", - " forecast_res[model_name] = v\n", - " elif k.startswith(('lo', 'hi')):\n", - " col_name = f'{model_name}-{k}'\n", - " forecast_res[col_name] = v\n", - " elif k == 'fitted':\n", - " fitted_res[model_name] = v\n", - " elif k.startswith(('fitted-lo', 'fitted-hi')):\n", - " col_name = f'{model_name}-{k.replace(\"fitted-\", \"\")}'\n", - " fitted_res[col_name] = v\n", - " return forecast_res, fitted_res, times\n", - "\n", " def _forecast_parallel(self, h, fitted, X, level, target_col):\n", " n_series = self.ga.n_groups\n", " forecast_res = defaultdict(lambda: np.empty(n_series * h, dtype=np.float32))\n", @@ -1751,11 +1752,13 @@ " else:\n", " X_future = X[i]\n", " future = executor.submit(\n", - " self._forecast_serie,\n", + " _forecast_serie,\n", " h=h,\n", " y=y_train,\n", " X=X_train,\n", " X_future=X_future,\n", + " models=self.models,\n", + " fallback_model=self.fallback_model,\n", " fitted=fitted,\n", " level=level,\n", " )\n", diff --git a/python/statsforecast/_modidx.py b/python/statsforecast/_modidx.py index 2a6176659..cfcac72aa 100644 --- a/python/statsforecast/_modidx.py +++ b/python/statsforecast/_modidx.py @@ -143,8 +143,6 @@ 'statsforecast/core.py'), 'statsforecast.core._StatsForecast._forecast_parallel': ( 'src/core/core.html#_statsforecast._forecast_parallel', 'statsforecast/core.py'), - 'statsforecast.core._StatsForecast._forecast_serie': ( 'src/core/core.html#_statsforecast._forecast_serie', - 'statsforecast/core.py'), 'statsforecast.core._StatsForecast._get_cap_size': ( 'src/core/core.html#_statsforecast._get_cap_size', 'statsforecast/core.py'), 'statsforecast.core._StatsForecast._get_gas_Xs': ( 'src/core/core.html#_statsforecast._get_gas_xs', @@ -187,6 +185,7 @@ 'statsforecast/core.py'), 'statsforecast.core._StatsForecast.save': ( 'src/core/core.html#_statsforecast.save', 'statsforecast/core.py'), + 'statsforecast.core._forecast_serie': ('src/core/core.html#_forecast_serie', 'statsforecast/core.py'), 'statsforecast.core._get_n_jobs': ('src/core/core.html#_get_n_jobs', 'statsforecast/core.py'), 'statsforecast.core._id_as_idx': ('src/core/core.html#_id_as_idx', 'statsforecast/core.py'), 'statsforecast.core._maybe_warn_sort_df': ( 'src/core/core.html#_maybe_warn_sort_df', diff --git a/python/statsforecast/core.py b/python/statsforecast/core.py index a02355e2a..a8eee6358 100644 --- a/python/statsforecast/core.py +++ b/python/statsforecast/core.py @@ -26,7 +26,7 @@ make_execution_engine, try_get_context_execution_engine, ) -from threadpoolctl import threadpool_limits +from threadpoolctl import ThreadpoolController, threadpool_limits from tqdm.auto import tqdm from triad import conditional_dispatcher from utilsforecast.compat import DataFrame, pl_DataFrame, pl_Series @@ -43,6 +43,40 @@ ) logger = logging.getLogger(__name__) +_controller = ThreadpoolController() + + +@_controller.wrap(limits=1) +def _forecast_serie(h, y, X, X_future, models, fallback_model, level, fitted): + forecast_res = {} + fitted_res = {} + times = {} + for model in models: + start = time.perf_counter() + model_kwargs = dict(h=h, y=y, X=X, X_future=X_future, fitted=fitted) + if "level" in inspect.signature(model.forecast).parameters and level: + model_kwargs["level"] = level + try: + model_res = model.forecast(**model_kwargs) + except Exception as e: + if fallback_model is None: + raise e + model_res = fallback_model.forecast(**model_kwargs) + model_name = repr(model) + times[model_name] = time.perf_counter() - start + for k, v in model_res.items(): + if k == "mean": + forecast_res[model_name] = v + elif k.startswith(("lo", "hi")): + col_name = f"{model_name}-{k}" + forecast_res[col_name] = v + elif k == "fitted": + fitted_res[model_name] = v + elif k.startswith(("fitted-lo", "fitted-hi")): + col_name = f'{model_name}-{k.replace("fitted-", "")}' + fitted_res[col_name] = v + return forecast_res, fitted_res, times + # %% ../../nbs/src/core/core.ipynb 10 class GroupedArray(BaseGroupedArray): @@ -1252,38 +1286,6 @@ def _fit_predict_parallel(self, h, X, level): cols = cols[0] return fm, fcsts, cols - def _forecast_serie(self, level, **forecast_kwargs): - forecast_res = {} - fitted_res = {} - times = {} - with threadpool_limits(limits=1): - for model in self.models: - if "level" in inspect.signature(model.forecast).parameters and level: - model_kwargs = {**forecast_kwargs, "level": level} - else: - model_kwargs = forecast_kwargs - start = time.perf_counter() - try: - model_res = model.forecast(**model_kwargs) - except Exception as e: - if self.fallback_model is None: - raise e - model_res = self.fallback_model.forecast(**model_kwargs) - model_name = repr(model) - times[model_name] = time.perf_counter() - start - for k, v in model_res.items(): - if k == "mean": - forecast_res[model_name] = v - elif k.startswith(("lo", "hi")): - col_name = f"{model_name}-{k}" - forecast_res[col_name] = v - elif k == "fitted": - fitted_res[model_name] = v - elif k.startswith(("fitted-lo", "fitted-hi")): - col_name = f'{model_name}-{k.replace("fitted-", "")}' - fitted_res[col_name] = v - return forecast_res, fitted_res, times - def _forecast_parallel(self, h, fitted, X, level, target_col): n_series = self.ga.n_groups forecast_res = defaultdict(lambda: np.empty(n_series * h, dtype=np.float32)) @@ -1302,11 +1304,13 @@ def _forecast_parallel(self, h, fitted, X, level, target_col): else: X_future = X[i] future = executor.submit( - self._forecast_serie, + _forecast_serie, h=h, y=y_train, X=X_train, X_future=X_future, + models=self.models, + fallback_model=self.fallback_model, fitted=fitted, level=level, )