diff --git a/test/correctness/shuffle.cpp b/test/correctness/shuffle.cpp index 6381ddacb929..887565093d55 100644 --- a/test/correctness/shuffle.cpp +++ b/test/correctness/shuffle.cpp @@ -17,21 +17,27 @@ int main(int argc, char **argv) { f1(x, y) = x * (y + 3); Expr vec1 = Internal::Shuffle::make_concat({f0(x, 0), f0(x, 1), f0(x, 2), f0(x, 3)}); Expr vec2 = Internal::Shuffle::make_concat({f1(x, 4), f1(x, 5), f1(x, 6), f1(x, 7)}); - std::vector indices0 = {3, 1, 6, 7, 2, 4, 0, 5}; - std::vector indices1 = {1, 0, 3, 4, 7, 0, 5, 2}; + std::vector indices0; + std::vector indices1; + if (!target.has_gpu_feature() || target.has_feature(Target::Feature::OpenCL) || target.has_feature(Target::Feature::CUDA)) { + indices0 = {3, 1, 6, 7, 2, 4, 0, 5}; + indices1 = {1, 0, 3, 4, 7, 0, 5, 2}; + } else { + indices0 = {3, 1, 6, 7}; + indices1 = {1, 0, 3, 4}; + } Expr shuffle1 = Internal::Shuffle::make({vec1, vec2}, indices0); Expr shuffle2 = Internal::Shuffle::make({vec1, vec2}, indices1); Expr result = shuffle1 * shuffle2; // Manual logarithmic reduce. - Expr a_half1 = Halide::Internal::Shuffle::make_slice(result, 0, 1, 4); - Expr a_half2 = Halide::Internal::Shuffle::make_slice(result, 4, 1, 4); - Expr a_sumhalves = a_half1 + a_half2; - Expr b_half1 = Halide::Internal::Shuffle::make_slice(a_sumhalves, 0, 1, 2); - Expr b_half2 = Halide::Internal::Shuffle::make_slice(a_sumhalves, 2, 1, 2); - Expr b_sumhalves = b_half1 + b_half2; - g(x) = Internal::Shuffle::make_extract_element(b_sumhalves, 0) + - Internal::Shuffle::make_extract_element(b_sumhalves, 1); + while (result.type().lanes() > 1) { + int half_lanes = result.type().lanes() / 2; + Expr half1 = Halide::Internal::Shuffle::make_slice(result, 0, 1, half_lanes); + Expr half2 = Halide::Internal::Shuffle::make_slice(result, half_lanes, 1, half_lanes); + result = half1 + half2; + } + g(x) = result; f0.compute_root(); f1.compute_root(); @@ -43,14 +49,12 @@ int main(int argc, char **argv) { Buffer im = g.realize({32}, target); im.copy_to_host(); for (int x = 0; x < 32; x++) { - int fv0[8], fv1[8]; - for (int i = 0; i < 8; ++i) { - fv0[i] = x * (indices0[i] + (indices0[i] >= 4 ? 3 : 1)); - fv1[i] = x * (indices1[i] + (indices1[i] >= 4 ? 3 : 1)); - } int exp = 0; - for (int i = 0; i < 8; ++i) { - exp += fv0[i] * fv1[i]; + int halfway = int(indices0.size() / 2); + for (size_t i = 0; i < indices0.size(); ++i) { + int v0 = x * (indices0[i] + (indices0[i] >= halfway ? 3 : 1)); + int v1 = x * (indices1[i] + (indices1[i] >= halfway ? 3 : 1)); + exp += v0 * v1; } if (im(x) != exp) { printf("im[%d] = %d (expected %d)\n", x, im(x), exp);