Skip to content

Commit

Permalink
Dequalify names when constructing RVars in rfactor
Browse files Browse the repository at this point in the history
  • Loading branch information
alexreinking committed Jan 28, 2025
1 parent 56198ad commit caca78e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
33 changes: 27 additions & 6 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,25 @@ void add_let(SubstitutionMap &subst, const string &name, const Expr &value) {
subst.emplace(name, value);
}

string dequalify(string name) {
if (const auto it = name.rfind('.'); it != string::npos) {
return name.substr(it + 1);
}
return name;
}

vector<Dim> subst_dims(const SubstitutionMap &substitution_map, const vector<Dim> &dims) {
auto new_dims = dims;
for (auto &dim : new_dims) {
if (const auto it = substitution_map.find(dim.var); it != substitution_map.end()) {
const Variable *new_var = it->second.as<Variable>();
internal_assert(new_var);
dim.var = new_var->name;
}
}
return new_dims;
}

pair<ReductionDomain, SubstitutionMap> project_rdom(const vector<Dim> &dims, const ReductionDomain &rdom, const vector<Split> &splits) {
// The bounds projections maps expressions that reference the old RDom
// bounds to expressions that reference the new RDom bounds (from dims).
Expand All @@ -694,7 +713,7 @@ pair<ReductionDomain, SubstitutionMap> project_rdom(const vector<Dim> &dims, con
for (const Dim &dim : dims) {
const Expr new_min = simplify(bounds_projection.at(dim.var + ".loop_min"));
const Expr new_extent = simplify(bounds_projection.at(dim.var + ".loop_extent"));
new_rvars.push_back(ReductionVariable{dim.var, new_min, new_extent});
new_rvars.push_back(ReductionVariable{dequalify(dim.var), new_min, new_extent});
}
ReductionDomain new_rdom{new_rvars};
new_rdom.where(rdom.predicate());
Expand Down Expand Up @@ -730,8 +749,8 @@ pair<ReductionDomain, SubstitutionMap> project_rdom(const vector<Dim> &dims, con
}
}
}
for (const auto &rv : new_rdom.domain()) {
add_let(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom));
for (size_t i = 0; i < new_rdom.domain().size(); i++) {
add_let(dim_projection, dims[i].var, RVar(new_rdom, i));
}
return {new_rdom, dim_projection};
}
Expand Down Expand Up @@ -902,7 +921,9 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
// Preserved
std::tie(preserved_rdom, preserved_map) = project_rdom(preserved_rdims, rdom, rvar_splits);
Scope<Interval> intm_rdom;
for (const auto &[var, min, extent] : intermediate_rdom.domain()) {
for (size_t i = 0; i < intermediate_rdom.domain().size(); i++) {
const auto &var = intermediate_rdims[i].var;
const auto &[_, min, extent] = intermediate_rdom.domain()[i];
intm_rdom.push(var, Interval{min, min + extent - 1});
}
preserved_rdom.set_predicate(or_condition_over_domain(substitute(preserved_map, preserved_rdom.predicate()), intm_rdom));
Expand Down Expand Up @@ -960,7 +981,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
}

intm.function().update(0).schedule() = definition.schedule().get_copy();
intm.function().update(0).schedule().dims() = std::move(intm_dims);
intm.function().update(0).schedule().dims() = subst_dims(intermediate_map, intm_dims);
intm.function().update(0).schedule().rvars() = intermediate_rdom.domain();
intm.function().update(0).schedule().splits() = var_splits;
}
Expand Down Expand Up @@ -1022,7 +1043,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
definition.args() = dim_vars_exprs;
definition.values() = substitute(preserved_map, prover_result.pattern.ops);
definition.predicate() = preserved_rdom.predicate();
definition.schedule().dims() = std::move(reducing_dims);
definition.schedule().dims() = subst_dims(preserved_map, reducing_dims);
definition.schedule().rvars() = preserved_rdom.domain();
definition.schedule().splits() = var_splits;
}
Expand Down
5 changes: 5 additions & 0 deletions src/Reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ class IRMutator;

/** A single named dimension of a reduction domain */
struct ReductionVariable {
/**
* A variable name for the reduction variable. This name must be a
* valid Var name, i.e. it must not contain a <tt>.</tt> character.
*/
std::string var;

Expr min, extent;

/** This lets you use a ReductionVariable as a key in a map of the form
Expand Down

0 comments on commit caca78e

Please sign in to comment.