Skip to content

Commit

Permalink
[luci/pass] Extend ExpandBroadcastConstPass (#14571)
Browse files Browse the repository at this point in the history
Let's extend the pass to support broadcast dim to all, not only the last dimension.

ONE-DCO-Signed-off-by: Dayoung Lee <dayoung.lee@samsung.com>
  • Loading branch information
dayo09 authored Jan 23, 2025
1 parent 18da2eb commit d712744
Showing 1 changed file with 36 additions and 31 deletions.
67 changes: 36 additions & 31 deletions compiler/luci/pass/src/ExpandBroadcastConstPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "luci/Pass/ExpandBroadcastConstPass.h"
#include "helpers/ArrayIndex.h"

#include <luci/IR/CircleNodes.h>
#include <luci/Log.h>
Expand All @@ -41,13 +42,13 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl
broadcast_dims.push_back(dim);
}

if (broadcast_dims.size() != 1 || broadcast_dims.back() != node->rank() - 1)
if (broadcast_dims.size() != 1)
{
WARN(l) << "NYI: Only depth broadcast removal is supported";
WARN(l) << "NYI: Only single dimension broadcast is supported";
return nullptr;
}

if (successor->rank() == 1 || successor->rank() > 4)
if (successor->rank() < 2 || successor->rank() > 4)
{
WARN(l) << "NYI: Only 2D/3D/4D tensor broadcast removal is supported";
return nullptr;
Expand All @@ -70,34 +71,38 @@ luci::CircleConst *create_expanded_constant(luci::CircleConst *node, luci::Circl
auto const node_data = &node->at<loco::DataType::FLOAT32>(0);
auto const constant_data = &constant->at<loco::DataType::FLOAT32>(0);

if (successor->rank() == 2)
{
auto const N = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t n = 0; n < N; ++n)
std::fill_n(constant_data + n * D, D, node_data[n]);
}
else if (successor->rank() == 3)
{
auto const H = successor->dim(successor->rank() - 3).value();
auto const W = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t h = 0; h < H; ++h)
for (uint32_t w = 0; w < W; ++w)
std::fill_n(constant_data + h * W * D + w * D, D, node_data[h * W + w]);
}
else if (successor->rank() == 4)
{
auto const N = successor->dim(successor->rank() - 4).value();
auto const H = successor->dim(successor->rank() - 3).value();
auto const W = successor->dim(successor->rank() - 2).value();
auto const D = successor->dim(successor->rank() - 1).value();
for (uint32_t n = 0; n < N; ++n)
for (uint32_t h = 0; h < H; ++h)
for (uint32_t w = 0; w < W; ++w)
std::fill_n(constant_data + n * H * W * D + h * W * D + w * D, D,
node_data[n * H * W + h * W + w]);
}
assert(successor->rank() >= 2 && successor->rank() <= 4);

// Virtually extend the constant node to 4D to support all cases
// (Only for index calculation)
// Example. (2, 4) -> (1, 1, 2, 4)
auto const D0 = (successor->rank() < 4) ? 1 : successor->dim(successor->rank() - 4).value();
auto const D1 = (successor->rank() < 3) ? 1 : successor->dim(successor->rank() - 3).value();
auto const D2 = (successor->rank() < 2) ? 1 : successor->dim(successor->rank() - 2).value();
auto const D3 = successor->dim(successor->rank() - 1).value();

auto idx = luci::Array4DIndex(D0, D1, D2, D3);

auto const D0_orig = (node->rank() < 4) ? 1 : node->dim(node->rank() - 4).value();
auto const D1_orig = (node->rank() < 3) ? 1 : node->dim(node->rank() - 3).value();
auto const D2_orig = (node->rank() < 2) ? 1 : node->dim(node->rank() - 2).value();
auto const D3_orig = node->dim(node->rank() - 1).value();

auto idx_orig = luci::Array4DIndex(D0_orig, D1_orig, D2_orig, D3_orig);

for (uint32_t d0 = 0; d0 < D0; ++d0)
for (uint32_t d1 = 0; d1 < D1; ++d1)
for (uint32_t d2 = 0; d2 < D2; ++d2)
for (uint32_t d3 = 0; d3 < D3; ++d3)
{
auto const d0_orig = (D0_orig == 1) ? 0 : d0;
auto const d1_orig = (D1_orig == 1) ? 0 : d1;
auto const d2_orig = (D2_orig == 1) ? 0 : d2;
auto const d3_orig = (D3_orig == 1) ? 0 : d3;
constant_data[idx(d0, d1, d2, d3)] =
node_data[idx_orig(d0_orig, d1_orig, d2_orig, d3_orig)];
}

return constant;
}

Expand Down

0 comments on commit d712744

Please sign in to comment.