Skip to content

Commit

Permalink
Added unwrap (#2710)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectre-ns authored Jun 8, 2023
1 parent 4c71ce3 commit 0e6b11a
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 0 deletions.
82 changes: 82 additions & 0 deletions include/xtensor/xmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3088,6 +3088,88 @@ namespace xt
return f;
}

namespace detail
{
template <class E1, class E2>
auto calculate_discontinuity(E1&& discontinuity, E2&&)
{
return discontinuity;
}

template <class E2>
auto calculate_discontinuity(xt::placeholders::xtuph, E2&& period)
{
return 0.5 * period;
}

template <class E1, class E2>
auto
calculate_interval(E2&& period, typename std::enable_if<std::is_integral<E1>::value, E1>::type* = 0)
{
auto interval_high = 0.5 * period;
uint64_t remainder = static_cast<uint64_t>(period) % 2;
auto boundary_ambiguous = (remainder == 0);
return std::make_tuple(interval_high, boundary_ambiguous);
}

template <class E1, class E2>
auto
calculate_interval(E2&& period, typename std::enable_if<std::is_floating_point<E1>::value, E1>::type* = 0)
{
auto interval_high = 0.5 * period;
auto boundary_ambiguous = true;
return std::make_tuple(interval_high, boundary_ambiguous);
}
}

/**
* @ingroup basic_functions
* @brief Unwrap by taking the complement of large deltas with respect to the period
* @details https://numpy.org/doc/stable/reference/generated/numpy.unwrap.html
* @param p Input array.
* @param discontinuity
* Maximum discontinuity between values, default is `period / 2`.
* Values below `period / 2` are treated as if they were `period / 2`.
* To have an effect different from the default, use `discontinuity > period / 2`.
* @param axis Axis along which unwrap will operate, default: the last axis.
* @param period Size of the range over which the input wraps. Default: \f$ 2 \pi \f$.
*/

template <class E1, class E2 = xt::placeholders::xtuph, class E3 = double>
inline auto unwrap(
E1&& p,
E2 discontinuity = xnone(),
std::ptrdiff_t axis = -1,
E3 period = 2.0 * xt::numeric_constants<double>::PI
)
{
auto discont = detail::calculate_discontinuity(discontinuity, period);
using value_type = typename std::decay_t<E1>::value_type;
std::size_t saxis = normalize_axis(p.dimension(), axis);
auto dd = diff(p, 1, axis);
xstrided_slice_vector slice(p.dimension(), all());
slice[saxis] = range(1, xnone());
auto interval_tuple = detail::calculate_interval<value_type>(period);
auto interval_high = std::get<0>(interval_tuple);
auto boundary_ambiguous = std::get<1>(interval_tuple);
auto interval_low = -interval_high;
auto ddmod = xt::eval(xt::fmod(xt::fmod(dd - interval_low, period) + period, period) + interval_low);
if (boundary_ambiguous)
{
// for `mask = (abs(dd) == period/2)`, the above line made
//`ddmod[mask] == -period/2`. correct these such that
//`ddmod[mask] == sign(dd[mask])*period/2`.
auto boolmap = xt::equal(ddmod, interval_low) && (xt::greater(dd, 0.0));
ddmod = xt::where(boolmap, interval_high, ddmod);
}
auto ph_correct = xt::eval(ddmod - dd);
ph_correct = xt::where(xt::abs(dd) < discont, 0.0, ph_correct);
E1 up(p);
strided_view(up, slice) = strided_view(p, slice)
+ xt::cumsum(ph_correct, static_cast<std::ptrdiff_t>(saxis));
return up;
}

/**
* @ingroup basic_functions
* @brief Returns the one-dimensional piecewise linear interpolant to a function with given discrete data
Expand Down
44 changes: 44 additions & 0 deletions test/test_xmath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,4 +923,48 @@ namespace xt

EXPECT_EQ(result, expected);
}

TEST(xmath, unwrap)
{
{
// {0, pi / 4, pi / 2, -pi / 4, 0}
xt::xarray<double> expected = {0., 0.78539816, 1.57079633, -0.78539816, 0};
auto pi = xt::numeric_constants<double>::PI;
xt::xarray<double> phase = xt::linspace<double>(0, pi, 5);
xt::view(phase, xt::range(3, xt::xnone())) += pi;
auto unwrapped = xt::unwrap(phase);
EXPECT_TRUE(xt::allclose(expected, unwrapped));
}
{
xt::xarray<double> expected = {
-180.,
-140.,
-100.,
-60.,
-20.,
20.,
60.,
100.,
140.,
180.,
220.,
260.,
300.,
340.,
380.,
420.,
460.,
500.,
540.};
xt::xarray<double> phase_deg = xt::fmod(xt::linspace<double>(0, 720, 19), 360) - 180;
auto unwrapped = xt::unwrap(phase_deg, xnone(), -1, 360.0);
EXPECT_TRUE(xt::allclose(expected, unwrapped));
}
{
xt::xarray<int> expected = {2, 3, 4, 5, 6, 7, 8, 9};
xt::xarray<int> phase = {2, 3, 4, 5, 2, 3, 4, 5};
auto unwrapped = xt::unwrap(phase, xnone(), -1, 4);
EXPECT_TRUE(xt::allclose(expected, unwrapped));
}
}
}

0 comments on commit 0e6b11a

Please sign in to comment.