Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Handle arith.const expr in dispatchIndexOpFoldResult func #122432

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

rutkoor
Copy link
Contributor

@rutkoor rutkoor commented Jan 10, 2025

This PR addresses the handling of arith.constant expressions in the dispatchIndexOpFoldResult helper function. Previously, the helper function dispatched an OpFoldResult into staticVec, only if it was an IntegerAttr. The changes in this PR now enable the evaluation of arith.constant expressions, extraction of the integer value, and dispatch into staticVec.

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (rutkoor)

Changes

This PR addresses the handling of arith.constant expressions in the dispatchIndexOpFoldResult helper function. Previously, the helper function dispatched an OpFoldResult into staticVec only if it was an IntegerAttr. The changes in this PR now enable the evaluation of arith.constant expressions, extraction of the integer value, and dispatch into staticVec.


Full diff: https://github.com/llvm/llvm-project/pull/122432.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+13)
  • (modified) mlir/test/Dialect/Tensor/bubble-reshapes.mlir (+20)
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5c8f6ded39ba4e..7ad4c982af2aae 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APSInt.h"
@@ -54,6 +55,18 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
     staticVec.push_back(apInt.getSExtValue());
     return;
   }
+
+  Operation *definingOp = v.getDefiningOp();
+  if (definingOp) {
+    // Check if definingOp is an arith.constant
+    if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
+      if (auto intAttr = mlir::dyn_cast<IntegerAttr>(constantOp.getValue())) {
+        staticVec.push_back(intAttr.getValue().getSExtValue());
+        return;
+      }
+    }
+  }
+
   dynamicVec.push_back(v);
   staticVec.push_back(ShapedType::kDynamic);
 }
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index cf6b12852bcd39..15bc9b0435f6e6 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -20,6 +20,26 @@ func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1:
 
 // -----
 
+func.func @bubble_parallel_reshapes2(%arg0: tensor<?x2x2x6xf32>, %s0: index, %s1: index) -> tensor<?x4x2x3xf32> {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+              output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
+  return %expand : tensor<?x4x2x3xf32>
+}
+//      CHECK: func @bubble_parallel_reshapes2
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x2x2x6xf32>
+// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME:       output_shape [%[[S0]], 2, 2, %[[C2]], %[[C3]]] : tensor<?x2x2x6xf32> into tensor<?x2x2x2x3xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x2x2x2x3xf32> into tensor<?x4x2x3xf32>
+//      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
 func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
   %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
   %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]

@rutkoor
Copy link
Contributor Author

rutkoor commented Jan 10, 2025

cc: @javedabsar , @MaheshRavishankar

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values); implements the functionality that you are looking for. You can use it with in combination with the existing dispatchIndexOpFoldResults.

@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not depend on any dialect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks a lot for the suggestion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I mean you can use dispatchIndexOpFoldResult(getAsOpFoldResult(v)) wherever you need it. I wouldn't call getAsOpFoldResult from dispatchIndexOpFoldResult because it does not fit with the name of the function. This function is just a switch that populates two vectors, it's not meant to analyze any IR.

Why do you need this functionality?

Change-Id: I15280932f88d8ff638f5d0f964a1c03ce7a7881a
%c3 = arith.constant 3 : index
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not that familiar with this op anymore but I expected output_shape [%s0, 4, 2, 3].

Copy link
Member

@matthias-springer matthias-springer Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm surprised that the verifier allows this op. @MaheshRavishankar to clarify.

@@ -54,6 +54,14 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(apInt.getSExtValue());
return;
}

OpFoldResult result = getAsOpFoldResult(v);
if (auto attr = result.dyn_cast<Attribute>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could dyn_cast the OpFoldResult directly as IntegerAttr.
saves forced cast on next line.
some thing like if (auto iattr = dyn_cast<IntegerAttr>(result)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants