Skip to content

Commit

Permalink
[mlir][Transforms][NFC] GreedyPatternRewriteDriver: Use composition…
Browse files Browse the repository at this point in the history
… instead of inheritance (llvm#92785)

This commit simplifies the design of the `GreedyPatternRewriterDriver`
class. This class used to inherit from both `PatternRewriter` and
`RewriterBase::Listener` and then attached itself as a listener.

In the new design, the class has a `PatternRewriter` field instead of
inheriting from `PatternRewriter`, which is generally perferred in
object-oriented programming.

---------

Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
  • Loading branch information
matthias-springer and zero9178 authored Jun 8, 2024
1 parent 7f5d1f1 commit 6b3e000
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ class IRRewriter : public RewriterBase {
/// place.
class PatternRewriter : public RewriterBase {
public:
explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
using RewriterBase::RewriterBase;

/// A hook used to indicate if the pattern rewriter can recover from failure
Expand Down
40 changes: 22 additions & 18 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,7 @@ class RandomizedWorklist : public Worklist {
/// This abstract class manages the worklist and contains helper methods for
/// rewriting ops on the worklist. Derived classes specify how ops are added
/// to the worklist in the beginning.
class GreedyPatternRewriteDriver : public PatternRewriter,
public RewriterBase::Listener {
class GreedyPatternRewriteDriver : public RewriterBase::Listener {
protected:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
Expand All @@ -339,7 +338,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// Notify the driver that the specified operation was inserted. Update the
/// worklist as needed: The operation is enqueued depending on scope and
/// strict mode.
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override;

/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
Expand All @@ -354,6 +354,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// reached. Return `true` if any IR was changed.
bool processWorklist();

/// The pattern rewriter that is used for making IR modifications and is
/// passed to rewrite patterns.
PatternRewriter rewriter;

/// The worklist for this transformation keeps track of the operations that
/// need to be (re)visited.
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
Expand Down Expand Up @@ -407,7 +411,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
: PatternRewriter(ctx), config(config), matcher(patterns)
: rewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, expensiveChecks(
Expand All @@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Send IR notifications to the debug handler. This handler will then forward
// all notifications to this GreedyPatternRewriteDriver.
setListener(&expensiveChecks);
rewriter.setListener(&expensiveChecks);
#else
setListener(this);
rewriter.setListener(this);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}

Expand Down Expand Up @@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {

// If the operation is trivially dead - remove it.
if (isOpTriviallyDead(op)) {
eraseOp(op);
rewriter.eraseOp(op);
changed = true;

LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
Expand Down Expand Up @@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Op results can be replaced with `foldResults`.
assert(foldResults.size() == op->getNumResults() &&
"folder produced incorrect number of results");
OpBuilder::InsertionGuard g(*this);
setInsertionPoint(op);
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
SmallVector<Value> replacements;
bool materializationSucceeded = true;
for (auto [ofr, resultType] :
Expand All @@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());
rewriter, ofr.get<Attribute>(), resultType, op->getLoc());

if (!constOp) {
// If materialization fails, cleanup any operations generated for
Expand All @@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
replacementOps.insert(replacement.getDefiningOp());
}
for (Operation *op : replacementOps) {
eraseOp(op);
rewriter.eraseOp(op);
}

materializationSucceeded = false;
Expand All @@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}

if (materializationSucceeded) {
replaceOp(op, replacements);
rewriter.replaceOp(op, replacements);
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
Expand Down Expand Up @@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);

if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
Expand Down Expand Up @@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
config.listener->notifyBlockErased(block);
}

void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
InsertPoint previous) {
void GreedyPatternRewriteDriver::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
Expand Down Expand Up @@ -822,7 +826,7 @@ class GreedyPatternRewriteIteration
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
MLIRContext *ctx = rewriter.getContext();
do {
// Check if the iteration limit was reached.
if (++iteration > config.maxIterations &&
Expand All @@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {

// `OperationFolder` CSE's constant ops (and may move them into parents
// regions to enable more aggressive CSE'ing).
OperationFolder folder(getContext(), this);
OperationFolder folder(ctx, this);
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
Expand Down Expand Up @@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification)
continueRewrites |= succeeded(simplifyRegions(*this, region));
continueRewrites |= succeeded(simplifyRegions(rewriter, region));
},
{&region}, iteration);
} while (continueRewrites);
Expand Down

0 comments on commit 6b3e000

Please sign in to comment.