Skip to content

Commit

Permalink
Adding free function xt::strides
Browse files Browse the repository at this point in the history
  • Loading branch information
tdegeus committed Mar 16, 2022
1 parent 5ef5a19 commit bcf9bb8
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 0 deletions.
16 changes: 16 additions & 0 deletions docs/source/api/xstrides.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.. Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht
Distributed under the terms of the BSD 3-Clause License.
The full license is in the file LICENSE, distributed with this software.
xshape
======

Defined in ``xtensor/xstride.hpp``

.. doxygenfunction:: auto strides(const E& e, xt::stride_type type)
:project: xtensor

.. doxygenfunction:: auto strides(const E& e, S axis, xt::stride_type type)
:project: xtensor
14 changes: 14 additions & 0 deletions docs/source/numpy-differences.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ Strides
Strided containers of xtensor and numpy having the same exact memory layout may have different strides when accessing them through the ``strides`` attribute.
The reason is an optimization in xtensor, which is to set the strides to ``0`` in dimensions of length ``1``, which simplifies the implementation of broadcasting of universal functions.

.. tip::

Use the free function ``xt::strides`` to switch between representations.

.. code-block:: cpp
xt::strides(a); // strides of ``a`` corresponding to storage
xt::strides(a, xt::stride_type::normal); // same
xt::strides(a, xt::stride_type::internal); // ``== a.strides()``
xt::strides(a, xt::stride_type::bytes) // strides in bytes, as in numpy
Array indices
-------------

Expand Down
121 changes: 121 additions & 0 deletions include/xtensor/xstrides.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,127 @@ namespace xt
return begin;
}


/***********
* strides *
***********/

namespace detail
{
template <class return_type, class S, class T, class D>
inline return_type compute_stride_impl(layout_type layout, const S& shape, T axis, D default_stride)
{
if (layout == layout_type::row_major)
{
return std::accumulate(
shape.cbegin() + axis + 1,
shape.cend(),
static_cast<return_type>(1),
std::multiplies<return_type>()
);
}
if (layout == layout_type::column_major)
{
return std::accumulate(
shape.cbegin(),
shape.cbegin() + axis,
static_cast<return_type>(1),
std::multiplies<return_type>()
);
}
return default_stride;
}
}

/**
* @ingroup strides
* @brief strides_type
*
* Choose stride type
*/
enum class stride_type
{
internal = 0, ///< As used internally (with `stride(axis) == 0` if `shape(axis) == 1`)
normal = 1, ///< Normal stride corresponding to storage.
bytes = 2, ///< Normal stride in bytes.
};

/**
* @ingroup strides
* @brief strides
*
* Get strides of an object.
* @param a an array
* @return array
*/
template <class E>
inline auto strides(const E& e, stride_type type = stride_type::normal) noexcept
{
using strides_type = typename E::strides_type;
using return_type = typename strides_type::value_type;
strides_type ret = e.strides();
auto shape = e.shape();

if (type == stride_type::internal)
{
return ret;
}

for (std::size_t i = 0; i < ret.size(); ++i)
{
if (shape[i] == 1)
{
ret[i] = detail::compute_stride_impl<return_type>(e.layout(), shape, i, ret[i]);
}
}

if (type == stride_type::bytes)
{
return_type f = static_cast<return_type>(sizeof(typename E::value_type));
std::for_each(ret.begin(), ret.end(), [f](auto& c){ c *= f; });
}

return ret;
}

/**
* @ingroup strides
* @brief strides
*
* Get stride of an object along an axis.
* @param a an array
* @return integer
*/
template <class E>
inline auto strides(const E& e, std::size_t axis, stride_type type = stride_type::normal) noexcept
{
using strides_type = typename E::strides_type;
using return_type = typename strides_type::value_type;

return_type ret = e.strides()[axis];

if (type == stride_type::internal)
{
return ret;
}

if (ret == 0)
{
if (e.shape(axis) == 1)
{
ret = detail::compute_stride_impl<return_type>(e.layout(), e.shape(), axis, ret);
}
}

if (type == stride_type::bytes)
{
return_type f = static_cast<return_type>(sizeof(typename E::value_type));
ret *= f;
}

return ret;
}

/******************
* Implementation *
******************/
Expand Down
86 changes: 86 additions & 0 deletions test/test_xstrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,92 @@ namespace xt
EXPECT_TRUE(t5);
}

TEST(xstrides, free_function_2d_row_major)
{
xt::xarray<int, xt::layout_type::row_major> a = xt::ones<int>({1, 3});
using stype = std::vector<std::ptrdiff_t>;
std::ptrdiff_t sof = sizeof(int);

EXPECT_EQ(xt::strides(a), stype({3, 1}));
EXPECT_EQ(xt::strides(a, xt::stride_type::normal), stype({3, 1}));
EXPECT_EQ(xt::strides(a, xt::stride_type::internal), stype({0, 1}));
EXPECT_EQ(xt::strides(a, xt::stride_type::bytes), stype({3 * sof, sof}));

EXPECT_TRUE(xt::strides(a, 0) == 3);
EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::normal) == 3);
EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::internal) == 0);
EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::bytes) == 3 * sof);

EXPECT_TRUE(xt::strides(a, 1) == 1);
EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::normal) == 1);
EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::internal) == 1);
EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::bytes) == sof);
}

TEST(xstrides, free_function_4d_row_major)
{
xt::xarray<int, xt::layout_type::row_major> a = xt::ones<int>({5, 4, 1, 4});
using stype = std::vector<std::ptrdiff_t>;
std::ptrdiff_t sof = sizeof(int);

EXPECT_EQ(xt::strides(a), stype({16, 4, 4, 1}));
EXPECT_EQ(xt::strides(a, xt::stride_type::normal), stype({16, 4, 4, 1}));
EXPECT_EQ(xt::strides(a, xt::stride_type::internal), stype({16, 4, 0, 1}));
EXPECT_EQ(xt::strides(a, xt::stride_type::bytes), stype({16 * sof, 4 * sof, 4 * sof, 1 * sof}));

EXPECT_TRUE(xt::strides(a, 0) == 16);
EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::normal) == 16);
EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::internal) == 16);
EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::bytes) == 16 * sof);

EXPECT_TRUE(xt::strides(a, 1) == 4);
EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::normal) == 4);
EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::internal) == 4);
EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::bytes) == 4 * sof);

EXPECT_TRUE(xt::strides(a, 2) == 4);
EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::normal) == 4);
EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::internal) == 0);
EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::bytes) == 4 * sof);

EXPECT_TRUE(xt::strides(a, 3) == 1);
EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::normal) == 1);
EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::internal) == 1);
EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::bytes) == sof);
}

TEST(xstrides, free_function_4d_column_major)
{
xt::xarray<int, xt::layout_type::column_major> a = xt::ones<int>({5, 4, 1, 4});
using stype = std::vector<std::ptrdiff_t>;
std::ptrdiff_t sof = sizeof(int);

EXPECT_EQ(xt::strides(a), stype({1, 5, 20, 20}));
EXPECT_EQ(xt::strides(a, xt::stride_type::normal), stype({1, 5, 20, 20}));
EXPECT_EQ(xt::strides(a, xt::stride_type::internal), stype({1, 5, 0, 20}));
EXPECT_EQ(xt::strides(a, xt::stride_type::bytes), stype({sof, 5 * sof, 20 * sof, 20 * sof}));

EXPECT_TRUE(xt::strides(a, 0) == 1);
EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::normal) == 1);
EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::internal) == 1);
EXPECT_TRUE(xt::strides(a, 0, xt::stride_type::bytes) == sof);

EXPECT_TRUE(xt::strides(a, 1) == 5);
EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::normal) == 5);
EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::internal) == 5);
EXPECT_TRUE(xt::strides(a, 1, xt::stride_type::bytes) == 5 * sof);

EXPECT_TRUE(xt::strides(a, 2) == 20);
EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::normal) == 20);
EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::internal) == 0);
EXPECT_TRUE(xt::strides(a, 2, xt::stride_type::bytes) == 20 * sof);

EXPECT_TRUE(xt::strides(a, 3) == 20);
EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::normal) == 20);
EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::internal) == 20);
EXPECT_TRUE(xt::strides(a, 3, xt::stride_type::bytes) == 20 * sof);
}

TEST(xstrides, unravel_from_strides)
{
SUBCASE("row_major strides")
Expand Down

0 comments on commit bcf9bb8

Please sign in to comment.