Skip to content

Commit

Permalink
Merge pull request #2507 from spectre-ns/convolve
Browse files Browse the repository at this point in the history
Added xt::convolve
  • Loading branch information
JohanMabille authored Apr 15, 2022
2 parents 0548fa8 + 935c649 commit fd7a610
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
88 changes: 88 additions & 0 deletions include/xtensor/xmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3013,6 +3013,94 @@ namespace detail {
return cov(eval(stack(xtuple(x, y))));
}
}



/*
* convolution mode placeholders for selecting the algorithm
* used in computing a 1D convolution.
* Same as NumPy's mode parameter.
*/
namespace convolve_mode
{
struct valid{};
struct full{};
}

namespace detail {
template <class E1, class E2>
inline auto convolve_impl(E1&& e1, E2&& e2, convolve_mode::valid)
{
using value_type = typename std::decay<E1>::type::value_type;

size_t const na = e1.size();
size_t const nv = e2.size();
size_t const n = na - nv + 1;
xt::xtensor<value_type, 1> out = xt::zeros<value_type>({ n });
for (size_t i = 0; i < n; i++)
{
for (int j = 0; j < nv; j++)
{
out(i) += e1(j) * e2(j + i);
}
}
return out;
}

template <class E1, class E2>
inline auto convolve_impl(E1&& e1, E2&& e2, convolve_mode::full mode)
{
using value_type = typename std::decay<E1>::type::value_type;

size_t const na = e1.size();
size_t const nv = e2.size();
size_t const n = na + nv - 1;
xt::xtensor<value_type, 1> out = xt::zeros<value_type>({ n });
for (size_t i = 0; i < n; i++)
{
size_t const jmn = (i >= nv - 1) ? i - (nv - 1) : 0;
size_t const jmx = (i < na - 1) ? i : na - 1;
for (size_t j = jmn; j <= jmx; ++j)
{
out(i) += e1(j) * e2(i - j);
}
}
return out;
}
}

/*
* @brief computes the 1D convolution between two 1D expressions
*
* @param a 1D expression
* @param v 1D expression
* @param mode placeholder Select algorithm #convolve_mode
*
* @detail the algorithm convolves a with v and will incur a copy overhead
* should v be longer than a.
*/
template <class E1, class E2, class E3>
inline auto convolve(E1&& a, E2&& v, E3 mode)
{

if (a.dimension() != 1 || v.dimension() != 1)
{
XTENSOR_THROW(std::runtime_error, "Invalid dimentions convolution arguments must be 1D expressions");
}

XTENSOR_ASSERT(a.size() > 0 && v.size() > 0);

//swap them so a is always the longest one
if (a.size() < v.size())
{
return detail::convolve_impl(std::forward<E2>(v), std::forward<E1>(a), mode);
}
else
{
return detail::convolve_impl(std::forward<E1>(a), std::forward<E2>(v), mode);
}
}
}


#endif
23 changes: 23 additions & 0 deletions test/test_xmath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,4 +915,27 @@ namespace xt

EXPECT_EQ(expected, xt::cov(x, y));
}


TEST(xmath, convolve_full)
{
xt::xarray<double> x = { 1.0, 3.0, 1.0 };
xt::xarray<double> y = { 1.0, 1.0, 1.0 };
xt::xarray<double> expected = { 1, 4, 5, 4, 1 };

auto result = xt::convolve(x, y, xt::convolve_mode::full());

EXPECT_EQ(result, expected);
}

TEST(xmath, convolve_valid)
{
xt::xarray<double> x = { 3.0, 1.0, 1.0 };
xt::xarray<double> y = { 1.0, 1.0, 1.0 };
xt::xarray<double> expected = { 5 };

auto result = xt::convolve(x, y, xt::convolve_mode::valid());

EXPECT_EQ(result, expected);
}
}

0 comments on commit fd7a610

Please sign in to comment.