Skip to content

Commit

Permalink
fix: reduce overhead in parallel forecast (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Sep 11, 2024
1 parent ba88f08 commit 12f6654
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 72 deletions.
75 changes: 39 additions & 36 deletions nbs/src/core/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"import pandas as pd\n",
"import utilsforecast.processing as ufp\n",
"from fugue.execution.factory import make_execution_engine, try_get_context_execution_engine\n",
"from threadpoolctl import threadpool_limits\n",
"from threadpoolctl import ThreadpoolController, threadpool_limits\n",
"from tqdm.auto import tqdm\n",
"from triad import conditional_dispatcher\n",
"from utilsforecast.compat import DataFrame, pl_DataFrame, pl_Series\n",
Expand All @@ -113,7 +113,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6d8d5b82-2be9-41f5-8cd0-3903d0761e09",
"id": "46f43744-f0c9-4d6a-b897-acc9df8f59b5",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -123,7 +123,40 @@
" format='%(asctime)s %(name)s %(levelname)s: %(message)s',\n",
" datefmt='%Y-%m-%d %H:%M:%S',\n",
" )\n",
"logger = logging.getLogger(__name__)"
"logger = logging.getLogger(__name__)\n",
"\n",
"_controller = ThreadpoolController()\n",
"\n",
"@_controller.wrap(limits=1)\n",
"def _forecast_serie(h, y, X, X_future, models, fallback_model, level, fitted):\n",
" forecast_res = {}\n",
" fitted_res = {}\n",
" times = {}\n",
" for model in models:\n",
" start = time.perf_counter()\n",
" model_kwargs = dict(h=h, y=y, X=X, X_future=X_future, fitted=fitted)\n",
" if \"level\" in inspect.signature(model.forecast).parameters and level:\n",
" model_kwargs[\"level\"] = level\n",
" try:\n",
" model_res = model.forecast(**model_kwargs)\n",
" except Exception as e:\n",
" if fallback_model is None:\n",
" raise e\n",
" model_res = fallback_model.forecast(**model_kwargs)\n",
" model_name = repr(model)\n",
" times[model_name] = time.perf_counter() - start\n",
" for k, v in model_res.items():\n",
" if k == \"mean\":\n",
" forecast_res[model_name] = v\n",
" elif k.startswith((\"lo\", \"hi\")):\n",
" col_name = f\"{model_name}-{k}\"\n",
" forecast_res[col_name] = v\n",
" elif k == \"fitted\":\n",
" fitted_res[model_name] = v\n",
" elif k.startswith((\"fitted-lo\", \"fitted-hi\")):\n",
" col_name = f'{model_name}-{k.replace(\"fitted-\", \"\")}'\n",
" fitted_res[col_name] = v\n",
" return forecast_res, fitted_res, times"
]
},
{
Expand Down Expand Up @@ -1701,38 +1734,6 @@
" cols = cols[0]\n",
" return fm, fcsts, cols\n",
"\n",
" def _forecast_serie(self, level, **forecast_kwargs):\n",
" forecast_res = {}\n",
" fitted_res = {}\n",
" times = {}\n",
" with threadpool_limits(limits=1):\n",
" for model in self.models:\n",
" if 'level' in inspect.signature(model.forecast).parameters and level:\n",
" model_kwargs = {**forecast_kwargs, 'level': level}\n",
" else:\n",
" model_kwargs = forecast_kwargs\n",
" start = time.perf_counter()\n",
" try:\n",
" model_res = model.forecast(**model_kwargs)\n",
" except Exception as e:\n",
" if self.fallback_model is None:\n",
" raise e\n",
" model_res = self.fallback_model.forecast(**model_kwargs)\n",
" model_name = repr(model)\n",
" times[model_name] = time.perf_counter() - start\n",
" for k, v in model_res.items(): \n",
" if k == 'mean':\n",
" forecast_res[model_name] = v\n",
" elif k.startswith(('lo', 'hi')):\n",
" col_name = f'{model_name}-{k}'\n",
" forecast_res[col_name] = v\n",
" elif k == 'fitted':\n",
" fitted_res[model_name] = v\n",
" elif k.startswith(('fitted-lo', 'fitted-hi')):\n",
" col_name = f'{model_name}-{k.replace(\"fitted-\", \"\")}'\n",
" fitted_res[col_name] = v\n",
" return forecast_res, fitted_res, times\n",
"\n",
" def _forecast_parallel(self, h, fitted, X, level, target_col):\n",
" n_series = self.ga.n_groups\n",
" forecast_res = defaultdict(lambda: np.empty(n_series * h, dtype=np.float32))\n",
Expand All @@ -1751,11 +1752,13 @@
" else:\n",
" X_future = X[i]\n",
" future = executor.submit(\n",
" self._forecast_serie,\n",
" _forecast_serie,\n",
" h=h,\n",
" y=y_train,\n",
" X=X_train,\n",
" X_future=X_future,\n",
" models=self.models,\n",
" fallback_model=self.fallback_model,\n",
" fitted=fitted,\n",
" level=level,\n",
" )\n",
Expand Down
3 changes: 1 addition & 2 deletions python/statsforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@
'statsforecast/core.py'),
'statsforecast.core._StatsForecast._forecast_parallel': ( 'src/core/core.html#_statsforecast._forecast_parallel',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast._forecast_serie': ( 'src/core/core.html#_statsforecast._forecast_serie',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast._get_cap_size': ( 'src/core/core.html#_statsforecast._get_cap_size',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast._get_gas_Xs': ( 'src/core/core.html#_statsforecast._get_gas_xs',
Expand Down Expand Up @@ -187,6 +185,7 @@
'statsforecast/core.py'),
'statsforecast.core._StatsForecast.save': ( 'src/core/core.html#_statsforecast.save',
'statsforecast/core.py'),
'statsforecast.core._forecast_serie': ('src/core/core.html#_forecast_serie', 'statsforecast/core.py'),
'statsforecast.core._get_n_jobs': ('src/core/core.html#_get_n_jobs', 'statsforecast/core.py'),
'statsforecast.core._id_as_idx': ('src/core/core.html#_id_as_idx', 'statsforecast/core.py'),
'statsforecast.core._maybe_warn_sort_df': ( 'src/core/core.html#_maybe_warn_sort_df',
Expand Down
72 changes: 38 additions & 34 deletions python/statsforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
make_execution_engine,
try_get_context_execution_engine,
)
from threadpoolctl import threadpool_limits
from threadpoolctl import ThreadpoolController, threadpool_limits
from tqdm.auto import tqdm
from triad import conditional_dispatcher
from utilsforecast.compat import DataFrame, pl_DataFrame, pl_Series
Expand All @@ -43,6 +43,40 @@
)
logger = logging.getLogger(__name__)

_controller = ThreadpoolController()


@_controller.wrap(limits=1)
def _forecast_serie(h, y, X, X_future, models, fallback_model, level, fitted):
forecast_res = {}
fitted_res = {}
times = {}
for model in models:
start = time.perf_counter()
model_kwargs = dict(h=h, y=y, X=X, X_future=X_future, fitted=fitted)
if "level" in inspect.signature(model.forecast).parameters and level:
model_kwargs["level"] = level
try:
model_res = model.forecast(**model_kwargs)
except Exception as e:
if fallback_model is None:
raise e
model_res = fallback_model.forecast(**model_kwargs)
model_name = repr(model)
times[model_name] = time.perf_counter() - start
for k, v in model_res.items():
if k == "mean":
forecast_res[model_name] = v
elif k.startswith(("lo", "hi")):
col_name = f"{model_name}-{k}"
forecast_res[col_name] = v
elif k == "fitted":
fitted_res[model_name] = v
elif k.startswith(("fitted-lo", "fitted-hi")):
col_name = f'{model_name}-{k.replace("fitted-", "")}'
fitted_res[col_name] = v
return forecast_res, fitted_res, times

# %% ../../nbs/src/core/core.ipynb 10
class GroupedArray(BaseGroupedArray):

Expand Down Expand Up @@ -1252,38 +1286,6 @@ def _fit_predict_parallel(self, h, X, level):
cols = cols[0]
return fm, fcsts, cols

def _forecast_serie(self, level, **forecast_kwargs):
forecast_res = {}
fitted_res = {}
times = {}
with threadpool_limits(limits=1):
for model in self.models:
if "level" in inspect.signature(model.forecast).parameters and level:
model_kwargs = {**forecast_kwargs, "level": level}
else:
model_kwargs = forecast_kwargs
start = time.perf_counter()
try:
model_res = model.forecast(**model_kwargs)
except Exception as e:
if self.fallback_model is None:
raise e
model_res = self.fallback_model.forecast(**model_kwargs)
model_name = repr(model)
times[model_name] = time.perf_counter() - start
for k, v in model_res.items():
if k == "mean":
forecast_res[model_name] = v
elif k.startswith(("lo", "hi")):
col_name = f"{model_name}-{k}"
forecast_res[col_name] = v
elif k == "fitted":
fitted_res[model_name] = v
elif k.startswith(("fitted-lo", "fitted-hi")):
col_name = f'{model_name}-{k.replace("fitted-", "")}'
fitted_res[col_name] = v
return forecast_res, fitted_res, times

def _forecast_parallel(self, h, fitted, X, level, target_col):
n_series = self.ga.n_groups
forecast_res = defaultdict(lambda: np.empty(n_series * h, dtype=np.float32))
Expand All @@ -1302,11 +1304,13 @@ def _forecast_parallel(self, h, fitted, X, level, target_col):
else:
X_future = X[i]
future = executor.submit(
self._forecast_serie,
_forecast_serie,
h=h,
y=y_train,
X=X_train,
X_future=X_future,
models=self.models,
fallback_model=self.fallback_model,
fitted=fitted,
level=level,
)
Expand Down

0 comments on commit 12f6654

Please sign in to comment.