From 9a1d68bd6d6a4402e046e0228984351b85d4707c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 19 Dec 2024 11:20:06 -0600 Subject: [PATCH] fix(arima): initialize residuals to zero --- nbs/src/arima.ipynb | 6 +++--- python/statsforecast/arima.py | 2 +- src/arima.cpp | 28 ++++++++++++++++------------ 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/nbs/src/arima.ipynb b/nbs/src/arima.ipynb index db37c8ea..7e5c9c1c 100644 --- a/nbs/src/arima.ipynb +++ b/nbs/src/arima.ipynb @@ -644,12 +644,12 @@ " if len(seasonal['order']) != 3 or any(o < 0 or not isinstance(o, int) for o in seasonal['order']):\n", " raise ValueError('order must be 3 non-negative integers')\n", " \n", - " if seasonal['period'] is None or seasonal['period'] == 0:\n", + " if seasonal['period'] is None or seasonal['period'] <= 0:\n", " warnings.warn('Setting seasonal period to 1')\n", " seasonal['period'] = 1\n", " \n", " #fixed\n", - " #mask \n", + " #mask\n", " arma = np.array(\n", " [\n", " *order[::2], \n", @@ -661,7 +661,7 @@ " dtype=np.int32,\n", " )\n", " narma = arma[:4].sum().item()\n", - " \n", + "\n", " # xtsp = init x, end x and frequency\n", " # tsp(x) = None\n", " Delta = np.array([1.]) \n", diff --git a/python/statsforecast/arima.py b/python/statsforecast/arima.py index c149cea8..f37231a8 100644 --- a/python/statsforecast/arima.py +++ b/python/statsforecast/arima.py @@ -299,7 +299,7 @@ def maInvert(ma): ): raise ValueError("order must be 3 non-negative integers") - if seasonal["period"] is None or seasonal["period"] == 0: + if seasonal["period"] is None or seasonal["period"] <= 0: warnings.warn("Setting seasonal period to 1") seasonal["period"] = 1 diff --git a/src/arima.cpp b/src/arima.cpp index 698a8fa1..b687dd6f 100644 --- a/src/arima.cpp +++ b/src/arima.cpp @@ -53,11 +53,11 @@ arima_transpar(const py::array_t params_inv, assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 && arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0); - const int32_t mp = arma[0]; - const int32_t mq = arma[1]; - const int32_t msp = arma[2]; - const int32_t msq = arma[3]; - const int32_t ns = arma[4]; + const auto mp = static_cast(arma[0]); + const auto mq = static_cast(arma[1]); + const auto msp = static_cast(arma[2]); + const auto msq = static_cast(arma[3]); + const auto ns = static_cast(arma[4]); const uint32_t p = mp + ns * msp; const uint32_t q = mq + ns * msq; @@ -127,22 +127,26 @@ arima_css(const py::array_t yv, const py::array_t armav, assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 && arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0); - const uint32_t ncond = arma[0] + arma[5] + arma[4] * (arma[2] + arma[6]); + const auto d = static_cast(arma[5]); + const auto D = static_cast(arma[6]); + const auto ncond = + static_cast(arma[0] + arma[5] + arma[4] * (arma[2] + arma[6])); uint32_t nu = 0; double ssq = 0.0; py::array_t residv(n); const auto resid = make_span(residv); + std::fill_n(resid.begin(), ncond, 0.0); std::vector w(y.begin(), y.end()); - for (size_t _ = 0; _ < arma[5]; ++_) { + for (size_t _ = 0; _ < d; ++_) { for (size_t l = n - 1; l > 0; --l) { w[l] -= w[l - 1]; } } - const uint32_t ns = arma[4]; - for (size_t _ = 0; _ < arma[6]; ++_) { + const auto ns = static_cast(arma[4]); + for (size_t _ = 0; _ < D; ++_) { for (size_t l = n - 1; l >= ns; --l) { w[l] -= w[l - ns]; } @@ -622,9 +626,9 @@ py::array_t arima_gradtrans(const py::array_t xv, assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 && arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0); - const int32_t mp = arma[0]; - const int32_t mq = arma[1]; - const int32_t msp = arma[2]; + const auto mp = static_cast(arma[0]); + const auto mq = static_cast(arma[1]); + const auto msp = static_cast(arma[2]); std::array w1; std::array w2;