-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Handle arith.const expr in dispatchIndexOpFoldResult func #122432
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (rutkoor) ChangesThis PR addresses the handling of arith.constant expressions in the dispatchIndexOpFoldResult helper function. Previously, the helper function dispatched an OpFoldResult into staticVec only if it was an IntegerAttr. The changes in this PR now enable the evaluation of arith.constant expressions, extraction of the integer value, and dispatch into staticVec. Full diff: https://github.com/llvm/llvm-project/pull/122432.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5c8f6ded39ba4e..7ad4c982af2aae 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
@@ -54,6 +55,18 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(apInt.getSExtValue());
return;
}
+
+ Operation *definingOp = v.getDefiningOp();
+ if (definingOp) {
+ // Check if definingOp is an arith.constant
+ if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
+ if (auto intAttr = mlir::dyn_cast<IntegerAttr>(constantOp.getValue())) {
+ staticVec.push_back(intAttr.getValue().getSExtValue());
+ return;
+ }
+ }
+ }
+
dynamicVec.push_back(v);
staticVec.push_back(ShapedType::kDynamic);
}
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index cf6b12852bcd39..15bc9b0435f6e6 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -20,6 +20,26 @@ func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1:
// -----
+func.func @bubble_parallel_reshapes2(%arg0: tensor<?x2x2x6xf32>, %s0: index, %s1: index) -> tensor<?x4x2x3xf32> {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+ output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
+ return %expand : tensor<?x4x2x3xf32>
+}
+// CHECK: func @bubble_parallel_reshapes2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x2x2x6xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME: output_shape [%[[S0]], 2, 2, %[[C2]], %[[C3]]] : tensor<?x2x2x6xf32> into tensor<?x2x2x2x3xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x2x2x2x3xf32> into tensor<?x4x2x3xf32>
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
implements the functionality that you are looking for. You can use it with in combination with the existing dispatchIndexOpFoldResults
.
@@ -7,6 +7,7 @@ | |||
//===----------------------------------------------------------------------===// | |||
|
|||
#include "mlir/Dialect/Utils/StaticValueUtils.h" | |||
#include "mlir/Dialect/Arith/IR/Arith.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should not depend on any dialect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks a lot for the suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I mean you can use dispatchIndexOpFoldResult(getAsOpFoldResult(v))
wherever you need it. I wouldn't call getAsOpFoldResult
from dispatchIndexOpFoldResult
because it does not fit with the name of the function. This function is just a switch that populates two vectors, it's not meant to analyze any IR.
Why do you need this functionality?
Change-Id: I15280932f88d8ff638f5d0f964a1c03ce7a7881a
%c3 = arith.constant 3 : index | ||
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32> | ||
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]] | ||
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not that familiar with this op anymore but I expected output_shape [%s0, 4, 2, 3]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I'm surprised that the verifier allows this op. @MaheshRavishankar to clarify.
@@ -54,6 +54,14 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr, | |||
staticVec.push_back(apInt.getSExtValue()); | |||
return; | |||
} | |||
|
|||
OpFoldResult result = getAsOpFoldResult(v); | |||
if (auto attr = result.dyn_cast<Attribute>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could dyn_cast the OpFoldResult directly as IntegerAttr.
saves forced cast on next line.
some thing like if (auto iattr = dyn_cast<IntegerAttr>(result)
This PR addresses the handling of arith.constant expressions in the dispatchIndexOpFoldResult helper function. Previously, the helper function dispatched an OpFoldResult into staticVec, only if it was an IntegerAttr. The changes in this PR now enable the evaluation of arith.constant expressions, extraction of the integer value, and dispatch into staticVec.