Skip to content

Commit

Permalink
fix(arima): initialize residuals to zero (#959)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Dec 19, 2024
1 parent e6cacae commit 4a5cb06
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
6 changes: 3 additions & 3 deletions nbs/src/arima.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -644,12 +644,12 @@
" if len(seasonal['order']) != 3 or any(o < 0 or not isinstance(o, int) for o in seasonal['order']):\n",
" raise ValueError('order must be 3 non-negative integers')\n",
" \n",
" if seasonal['period'] is None or seasonal['period'] == 0:\n",
" if seasonal['period'] is None or seasonal['period'] <= 0:\n",
" warnings.warn('Setting seasonal period to 1')\n",
" seasonal['period'] = 1\n",
" \n",
" #fixed\n",
" #mask \n",
" #mask\n",
" arma = np.array(\n",
" [\n",
" *order[::2], \n",
Expand All @@ -661,7 +661,7 @@
" dtype=np.int32,\n",
" )\n",
" narma = arma[:4].sum().item()\n",
" \n",
"\n",
" # xtsp = init x, end x and frequency\n",
" # tsp(x) = None\n",
" Delta = np.array([1.]) \n",
Expand Down
2 changes: 1 addition & 1 deletion python/statsforecast/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def maInvert(ma):
):
raise ValueError("order must be 3 non-negative integers")

if seasonal["period"] is None or seasonal["period"] == 0:
if seasonal["period"] is None or seasonal["period"] <= 0:
warnings.warn("Setting seasonal period to 1")
seasonal["period"] = 1

Expand Down
28 changes: 16 additions & 12 deletions src/arima.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ arima_transpar(const py::array_t<double> params_inv,
assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 &&
arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0);

const int32_t mp = arma[0];
const int32_t mq = arma[1];
const int32_t msp = arma[2];
const int32_t msq = arma[3];
const int32_t ns = arma[4];
const auto mp = static_cast<size_t>(arma[0]);
const auto mq = static_cast<size_t>(arma[1]);
const auto msp = static_cast<size_t>(arma[2]);
const auto msq = static_cast<size_t>(arma[3]);
const auto ns = static_cast<size_t>(arma[4]);
const uint32_t p = mp + ns * msp;
const uint32_t q = mq + ns * msq;

Expand Down Expand Up @@ -127,22 +127,26 @@ arima_css(const py::array_t<double> yv, const py::array_t<int32_t> armav,
assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 &&
arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0);

const uint32_t ncond = arma[0] + arma[5] + arma[4] * (arma[2] + arma[6]);
const auto d = static_cast<size_t>(arma[5]);
const auto D = static_cast<size_t>(arma[6]);
const auto ncond =
static_cast<size_t>(arma[0] + arma[5] + arma[4] * (arma[2] + arma[6]));
uint32_t nu = 0;
double ssq = 0.0;

py::array_t<double> residv(n);
const auto resid = make_span(residv);
std::fill_n(resid.begin(), ncond, 0.0);
std::vector<double> w(y.begin(), y.end());

for (size_t _ = 0; _ < arma[5]; ++_) {
for (size_t _ = 0; _ < d; ++_) {
for (size_t l = n - 1; l > 0; --l) {
w[l] -= w[l - 1];
}
}

const uint32_t ns = arma[4];
for (size_t _ = 0; _ < arma[6]; ++_) {
const auto ns = static_cast<size_t>(arma[4]);
for (size_t _ = 0; _ < D; ++_) {
for (size_t l = n - 1; l >= ns; --l) {
w[l] -= w[l - ns];
}
Expand Down Expand Up @@ -622,9 +626,9 @@ py::array_t<double> arima_gradtrans(const py::array_t<double> xv,
assert(arma[0] >= 0 && arma[1] >= 0 && arma[2] >= 0 && arma[3] >= 0 &&
arma[4] >= 0 && arma[5] >= 0 && arma[6] >= 0);

const int32_t mp = arma[0];
const int32_t mq = arma[1];
const int32_t msp = arma[2];
const auto mp = static_cast<size_t>(arma[0]);
const auto mq = static_cast<size_t>(arma[1]);
const auto msp = static_cast<size_t>(arma[2]);

std::array<double, 100> w1;
std::array<double, 100> w2;
Expand Down

0 comments on commit 4a5cb06

Please sign in to comment.