diff --git a/include/xtensor/xmath.hpp b/include/xtensor/xmath.hpp index badd624d0..d31adbdf7 100644 --- a/include/xtensor/xmath.hpp +++ b/include/xtensor/xmath.hpp @@ -3013,6 +3013,94 @@ namespace detail { return cov(eval(stack(xtuple(x, y)))); } } + + + + /* + * convolution mode placeholders for selecting the algorithm + * used in computing a 1D convolution. + * Same as NumPy's mode parameter. + */ + namespace convolve_mode + { + struct valid{}; + struct full{}; + } + + namespace detail { + template + inline auto convolve_impl(E1&& e1, E2&& e2, convolve_mode::valid) + { + using value_type = typename std::decay::type::value_type; + + size_t const na = e1.size(); + size_t const nv = e2.size(); + size_t const n = na - nv + 1; + xt::xtensor out = xt::zeros({ n }); + for (size_t i = 0; i < n; i++) + { + for (int j = 0; j < nv; j++) + { + out(i) += e1(j) * e2(j + i); + } + } + return out; + } + + template + inline auto convolve_impl(E1&& e1, E2&& e2, convolve_mode::full mode) + { + using value_type = typename std::decay::type::value_type; + + size_t const na = e1.size(); + size_t const nv = e2.size(); + size_t const n = na + nv - 1; + xt::xtensor out = xt::zeros({ n }); + for (size_t i = 0; i < n; i++) + { + size_t const jmn = (i >= nv - 1) ? i - (nv - 1) : 0; + size_t const jmx = (i < na - 1) ? i : na - 1; + for (size_t j = jmn; j <= jmx; ++j) + { + out(i) += e1(j) * e2(i - j); + } + } + return out; + } + } + + /* + * @brief computes the 1D convolution between two 1D expressions + * + * @param a 1D expression + * @param v 1D expression + * @param mode placeholder Select algorithm #convolve_mode + * + * @detail the algorithm convolves a with v and will incur a copy overhead + * should v be longer than a. + */ + template + inline auto convolve(E1&& a, E2&& v, E3 mode) + { + + if (a.dimension() != 1 || v.dimension() != 1) + { + XTENSOR_THROW(std::runtime_error, "Invalid dimentions convolution arguments must be 1D expressions"); + } + + XTENSOR_ASSERT(a.size() > 0 && v.size() > 0); + + //swap them so a is always the longest one + if (a.size() < v.size()) + { + return detail::convolve_impl(std::forward(v), std::forward(a), mode); + } + else + { + return detail::convolve_impl(std::forward(a), std::forward(v), mode); + } + } } + #endif diff --git a/test/test_xmath.cpp b/test/test_xmath.cpp index 279f0b5e7..30c189b55 100644 --- a/test/test_xmath.cpp +++ b/test/test_xmath.cpp @@ -915,4 +915,27 @@ namespace xt EXPECT_EQ(expected, xt::cov(x, y)); } + + + TEST(xmath, convolve_full) + { + xt::xarray x = { 1.0, 3.0, 1.0 }; + xt::xarray y = { 1.0, 1.0, 1.0 }; + xt::xarray expected = { 1, 4, 5, 4, 1 }; + + auto result = xt::convolve(x, y, xt::convolve_mode::full()); + + EXPECT_EQ(result, expected); + } + + TEST(xmath, convolve_valid) + { + xt::xarray x = { 3.0, 1.0, 1.0 }; + xt::xarray y = { 1.0, 1.0, 1.0 }; + xt::xarray expected = { 5 }; + + auto result = xt::convolve(x, y, xt::convolve_mode::valid()); + + EXPECT_EQ(result, expected); + } }