From caca78e4393b6bf7b246b630a5964949d7784651 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Mon, 27 Jan 2025 17:31:36 -0500 Subject: [PATCH] Dequalify names when constructing RVars in rfactor --- src/Func.cpp | 33 +++++++++++++++++++++++++++------ src/Reduction.h | 5 +++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index e87b14c89c9a..9d9db3da9cab 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -671,6 +671,25 @@ void add_let(SubstitutionMap &subst, const string &name, const Expr &value) { subst.emplace(name, value); } +string dequalify(string name) { + if (const auto it = name.rfind('.'); it != string::npos) { + return name.substr(it + 1); + } + return name; +} + +vector subst_dims(const SubstitutionMap &substitution_map, const vector &dims) { + auto new_dims = dims; + for (auto &dim : new_dims) { + if (const auto it = substitution_map.find(dim.var); it != substitution_map.end()) { + const Variable *new_var = it->second.as(); + internal_assert(new_var); + dim.var = new_var->name; + } + } + return new_dims; +} + pair project_rdom(const vector &dims, const ReductionDomain &rdom, const vector &splits) { // The bounds projections maps expressions that reference the old RDom // bounds to expressions that reference the new RDom bounds (from dims). @@ -694,7 +713,7 @@ pair project_rdom(const vector &dims, con for (const Dim &dim : dims) { const Expr new_min = simplify(bounds_projection.at(dim.var + ".loop_min")); const Expr new_extent = simplify(bounds_projection.at(dim.var + ".loop_extent")); - new_rvars.push_back(ReductionVariable{dim.var, new_min, new_extent}); + new_rvars.push_back(ReductionVariable{dequalify(dim.var), new_min, new_extent}); } ReductionDomain new_rdom{new_rvars}; new_rdom.where(rdom.predicate()); @@ -730,8 +749,8 @@ pair project_rdom(const vector &dims, con } } } - for (const auto &rv : new_rdom.domain()) { - add_let(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom)); + for (size_t i = 0; i < new_rdom.domain().size(); i++) { + add_let(dim_projection, dims[i].var, RVar(new_rdom, i)); } return {new_rdom, dim_projection}; } @@ -902,7 +921,9 @@ Func Stage::rfactor(const vector> &preserved) { // Preserved std::tie(preserved_rdom, preserved_map) = project_rdom(preserved_rdims, rdom, rvar_splits); Scope intm_rdom; - for (const auto &[var, min, extent] : intermediate_rdom.domain()) { + for (size_t i = 0; i < intermediate_rdom.domain().size(); i++) { + const auto &var = intermediate_rdims[i].var; + const auto &[_, min, extent] = intermediate_rdom.domain()[i]; intm_rdom.push(var, Interval{min, min + extent - 1}); } preserved_rdom.set_predicate(or_condition_over_domain(substitute(preserved_map, preserved_rdom.predicate()), intm_rdom)); @@ -960,7 +981,7 @@ Func Stage::rfactor(const vector> &preserved) { } intm.function().update(0).schedule() = definition.schedule().get_copy(); - intm.function().update(0).schedule().dims() = std::move(intm_dims); + intm.function().update(0).schedule().dims() = subst_dims(intermediate_map, intm_dims); intm.function().update(0).schedule().rvars() = intermediate_rdom.domain(); intm.function().update(0).schedule().splits() = var_splits; } @@ -1022,7 +1043,7 @@ Func Stage::rfactor(const vector> &preserved) { definition.args() = dim_vars_exprs; definition.values() = substitute(preserved_map, prover_result.pattern.ops); definition.predicate() = preserved_rdom.predicate(); - definition.schedule().dims() = std::move(reducing_dims); + definition.schedule().dims() = subst_dims(preserved_map, reducing_dims); definition.schedule().rvars() = preserved_rdom.domain(); definition.schedule().splits() = var_splits; } diff --git a/src/Reduction.h b/src/Reduction.h index d36e803b9f5a..d93bf741cd09 100644 --- a/src/Reduction.h +++ b/src/Reduction.h @@ -14,7 +14,12 @@ class IRMutator; /** A single named dimension of a reduction domain */ struct ReductionVariable { + /** + * A variable name for the reduction variable. This name must be a + * valid Var name, i.e. it must not contain a . character. + */ std::string var; + Expr min, extent; /** This lets you use a ReductionVariable as a key in a map of the form