Skip to content

Commit

Permalink
Improve shuffle test to support 4-wide vectors.
Browse files Browse the repository at this point in the history
  • Loading branch information
mcourteaux committed Jan 16, 2025
1 parent 53f6869 commit eae9647
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions test/correctness/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@ int main(int argc, char **argv) {
f1(x, y) = x * (y + 3);
Expr vec1 = Internal::Shuffle::make_concat({f0(x, 0), f0(x, 1), f0(x, 2), f0(x, 3)});
Expr vec2 = Internal::Shuffle::make_concat({f1(x, 4), f1(x, 5), f1(x, 6), f1(x, 7)});
std::vector<int> indices0 = {3, 1, 6, 7, 2, 4, 0, 5};
std::vector<int> indices1 = {1, 0, 3, 4, 7, 0, 5, 2};
std::vector<int> indices0;
std::vector<int> indices1;
if (!target.has_gpu_feature() || target.has_feature(Target::Feature::OpenCL) || target.has_feature(Target::Feature::CUDA)) {
indices0 = {3, 1, 6, 7, 2, 4, 0, 5};
indices1 = {1, 0, 3, 4, 7, 0, 5, 2};
} else {
indices0 = {3, 1, 6, 7};
indices1 = {1, 0, 3, 4};
}
Expr shuffle1 = Internal::Shuffle::make({vec1, vec2}, indices0);
Expr shuffle2 = Internal::Shuffle::make({vec1, vec2}, indices1);
Expr result = shuffle1 * shuffle2;

// Manual logarithmic reduce.
Expr a_half1 = Halide::Internal::Shuffle::make_slice(result, 0, 1, 4);
Expr a_half2 = Halide::Internal::Shuffle::make_slice(result, 4, 1, 4);
Expr a_sumhalves = a_half1 + a_half2;
Expr b_half1 = Halide::Internal::Shuffle::make_slice(a_sumhalves, 0, 1, 2);
Expr b_half2 = Halide::Internal::Shuffle::make_slice(a_sumhalves, 2, 1, 2);
Expr b_sumhalves = b_half1 + b_half2;
g(x) = Internal::Shuffle::make_extract_element(b_sumhalves, 0) +
Internal::Shuffle::make_extract_element(b_sumhalves, 1);
while (result.type().lanes() > 1) {
int half_lanes = result.type().lanes() / 2;
Expr half1 = Halide::Internal::Shuffle::make_slice(result, 0, 1, half_lanes);
Expr half2 = Halide::Internal::Shuffle::make_slice(result, half_lanes, 1, half_lanes);
result = half1 + half2;
}
g(x) = result;

f0.compute_root();
f1.compute_root();
Expand All @@ -43,14 +49,12 @@ int main(int argc, char **argv) {
Buffer<int> im = g.realize({32}, target);
im.copy_to_host();
for (int x = 0; x < 32; x++) {
int fv0[8], fv1[8];
for (int i = 0; i < 8; ++i) {
fv0[i] = x * (indices0[i] + (indices0[i] >= 4 ? 3 : 1));
fv1[i] = x * (indices1[i] + (indices1[i] >= 4 ? 3 : 1));
}
int exp = 0;
for (int i = 0; i < 8; ++i) {
exp += fv0[i] * fv1[i];
int halfway = int(indices0.size() / 2);
for (size_t i = 0; i < indices0.size(); ++i) {
int v0 = x * (indices0[i] + (indices0[i] >= halfway ? 3 : 1));
int v1 = x * (indices1[i] + (indices1[i] >= halfway ? 3 : 1));
exp += v0 * v1;
}
if (im(x) != exp) {
printf("im[%d] = %d (expected %d)\n", x, im(x), exp);
Expand Down

0 comments on commit eae9647

Please sign in to comment.