diff --git a/include/xtensor/xmath.hpp b/include/xtensor/xmath.hpp index 1fe6c9f54..b85d485b8 100644 --- a/include/xtensor/xmath.hpp +++ b/include/xtensor/xmath.hpp @@ -3088,6 +3088,88 @@ namespace xt return f; } + namespace detail + { + template + auto calculate_discontinuity(E1&& discontinuity, E2&&) + { + return discontinuity; + } + + template + auto calculate_discontinuity(xt::placeholders::xtuph, E2&& period) + { + return 0.5 * period; + } + + template + auto + calculate_interval(E2&& period, typename std::enable_if::value, E1>::type* = 0) + { + auto interval_high = 0.5 * period; + uint64_t remainder = static_cast(period) % 2; + auto boundary_ambiguous = (remainder == 0); + return std::make_tuple(interval_high, boundary_ambiguous); + } + + template + auto + calculate_interval(E2&& period, typename std::enable_if::value, E1>::type* = 0) + { + auto interval_high = 0.5 * period; + auto boundary_ambiguous = true; + return std::make_tuple(interval_high, boundary_ambiguous); + } + } + + /** + * @ingroup basic_functions + * @brief Unwrap by taking the complement of large deltas with respect to the period + * @details https://numpy.org/doc/stable/reference/generated/numpy.unwrap.html + * @param p Input array. + * @param discontinuity + * Maximum discontinuity between values, default is `period / 2`. + * Values below `period / 2` are treated as if they were `period / 2`. + * To have an effect different from the default, use `discontinuity > period / 2`. + * @param axis Axis along which unwrap will operate, default: the last axis. + * @param period Size of the range over which the input wraps. Default: \f$ 2 \pi \f$. + */ + + template + inline auto unwrap( + E1&& p, + E2 discontinuity = xnone(), + std::ptrdiff_t axis = -1, + E3 period = 2.0 * xt::numeric_constants::PI + ) + { + auto discont = detail::calculate_discontinuity(discontinuity, period); + using value_type = typename std::decay_t::value_type; + std::size_t saxis = normalize_axis(p.dimension(), axis); + auto dd = diff(p, 1, axis); + xstrided_slice_vector slice(p.dimension(), all()); + slice[saxis] = range(1, xnone()); + auto interval_tuple = detail::calculate_interval(period); + auto interval_high = std::get<0>(interval_tuple); + auto boundary_ambiguous = std::get<1>(interval_tuple); + auto interval_low = -interval_high; + auto ddmod = xt::eval(xt::fmod(xt::fmod(dd - interval_low, period) + period, period) + interval_low); + if (boundary_ambiguous) + { + // for `mask = (abs(dd) == period/2)`, the above line made + //`ddmod[mask] == -period/2`. correct these such that + //`ddmod[mask] == sign(dd[mask])*period/2`. + auto boolmap = xt::equal(ddmod, interval_low) && (xt::greater(dd, 0.0)); + ddmod = xt::where(boolmap, interval_high, ddmod); + } + auto ph_correct = xt::eval(ddmod - dd); + ph_correct = xt::where(xt::abs(dd) < discont, 0.0, ph_correct); + E1 up(p); + strided_view(up, slice) = strided_view(p, slice) + + xt::cumsum(ph_correct, static_cast(saxis)); + return up; + } + /** * @ingroup basic_functions * @brief Returns the one-dimensional piecewise linear interpolant to a function with given discrete data diff --git a/test/test_xmath.cpp b/test/test_xmath.cpp index 568c2c354..2eea8cc43 100644 --- a/test/test_xmath.cpp +++ b/test/test_xmath.cpp @@ -923,4 +923,48 @@ namespace xt EXPECT_EQ(result, expected); } + + TEST(xmath, unwrap) + { + { + // {0, pi / 4, pi / 2, -pi / 4, 0} + xt::xarray expected = {0., 0.78539816, 1.57079633, -0.78539816, 0}; + auto pi = xt::numeric_constants::PI; + xt::xarray phase = xt::linspace(0, pi, 5); + xt::view(phase, xt::range(3, xt::xnone())) += pi; + auto unwrapped = xt::unwrap(phase); + EXPECT_TRUE(xt::allclose(expected, unwrapped)); + } + { + xt::xarray expected = { + -180., + -140., + -100., + -60., + -20., + 20., + 60., + 100., + 140., + 180., + 220., + 260., + 300., + 340., + 380., + 420., + 460., + 500., + 540.}; + xt::xarray phase_deg = xt::fmod(xt::linspace(0, 720, 19), 360) - 180; + auto unwrapped = xt::unwrap(phase_deg, xnone(), -1, 360.0); + EXPECT_TRUE(xt::allclose(expected, unwrapped)); + } + { + xt::xarray expected = {2, 3, 4, 5, 6, 7, 8, 9}; + xt::xarray phase = {2, 3, 4, 5, 2, 3, 4, 5}; + auto unwrapped = xt::unwrap(phase, xnone(), -1, 4); + EXPECT_TRUE(xt::allclose(expected, unwrapped)); + } + } }