Skip to content

Commit

Permalink
Avoid quadratic containsAll calls
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Jan 22, 2025
1 parent b17fa4e commit e38277f
Show file tree
Hide file tree
Showing 20 changed files with 88 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner;

import com.google.common.collect.ImmutableSet;

import java.util.Collection;
import java.util.List;

public class SymbolUtils
{
private SymbolUtils() {}

public static boolean containsAll(List<Symbol> haystack, Collection<Symbol> needles)
{
return ImmutableSet.copyOf(haystack).containsAll(needles);
}

public static boolean containsNone(Collection<Symbol> values, Collection<Symbol> testValues)
{
return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -88,6 +87,8 @@
import static io.trino.sql.ir.Comparison.Operator.LESS_THAN;
import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static io.trino.sql.planner.SymbolUtils.containsNone;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
import static io.trino.sql.planner.plan.JoinType.INNER;
import static io.trino.sql.planner.plan.JoinType.LEFT;
Expand Down Expand Up @@ -304,7 +305,7 @@ private static Result tryCreateSpatialJoin(
// ST_Distance(a, b) <= r
radius = spatialComparison.right();
Set<Symbol> radiusSymbols = extractUnique(radius);
if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) {
if (radiusSymbols.isEmpty() || containsAll(rightSymbols, radiusSymbols) && containsNone(leftSymbols, radiusSymbols)) {
newRadiusSymbol = newRadiusSymbol(context, radius);
newComparison = new Comparison(spatialComparison.operator(), spatialComparison.left(), toExpression(newRadiusSymbol, radius));
}
Expand All @@ -316,7 +317,7 @@ private static Result tryCreateSpatialJoin(
// r >= ST_Distance(a, b)
radius = spatialComparison.left();
Set<Symbol> radiusSymbols = extractUnique(radius);
if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) {
if (radiusSymbols.isEmpty() || (containsAll(rightSymbols, radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) {
newRadiusSymbol = newRadiusSymbol(context, radius);
newComparison = new Comparison(spatialComparison.operator().flip(), spatialComparison.right(), toExpression(newRadiusSymbol, radius));
}
Expand Down Expand Up @@ -529,16 +530,16 @@ private static int checkAlignment(JoinNode joinNode, Set<Symbol> maybeLeftSymbol
List<Symbol> leftSymbols = joinNode.getLeft().getOutputSymbols();
List<Symbol> rightSymbols = joinNode.getRight().getOutputSymbols();

if (leftSymbols.containsAll(maybeLeftSymbols)
if (containsAll(leftSymbols, maybeLeftSymbols)
&& containsNone(leftSymbols, maybeRightSymbols)
&& rightSymbols.containsAll(maybeRightSymbols)
&& containsAll(rightSymbols, maybeRightSymbols)
&& containsNone(rightSymbols, maybeLeftSymbols)) {
return 1;
}

if (leftSymbols.containsAll(maybeRightSymbols)
if (containsAll(leftSymbols, maybeRightSymbols)
&& containsNone(leftSymbols, maybeLeftSymbols)
&& rightSymbols.containsAll(maybeLeftSymbols)
&& containsAll(rightSymbols, maybeLeftSymbols)
&& containsNone(rightSymbols, maybeRightSymbols)) {
return -1;
}
Expand Down Expand Up @@ -606,9 +607,4 @@ private static PlanNode addPartitioningNodes(PlannerContext plannerContext, Cont
Optional.empty(),
INNER);
}

private static boolean containsNone(Collection<Symbol> values, Collection<Symbol> testValues)
{
return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.preferPartialAggregation;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL;
import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
Expand Down Expand Up @@ -123,7 +124,7 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context
.map(Partitioning.ArgumentBinding::getColumn)
.collect(Collectors.toList());

if (!aggregationNode.getGroupingKeys().containsAll(partitioningColumns)) {
if (!containsAll(aggregationNode.getGroupingKeys(), partitioningColumns)) {
return Result.empty();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Sets.intersection;
import static io.trino.SystemSessionProperties.isPushPartialAggregationThroughJoin;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static io.trino.sql.planner.iterative.rule.PushProjectionThroughJoin.pushProjectionThroughJoin;
import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs;
import static io.trino.sql.planner.plan.AggregationNode.Step.INTERMEDIATE;
Expand Down Expand Up @@ -175,7 +176,7 @@ private static boolean allAggregationsOn(Map<Symbol, Aggregation> aggregations,
.map(SymbolsExtractor::extractAll)
.flatMap(List::stream)
.collect(toImmutableSet());
return symbols.containsAll(inputs);
return containsAll(symbols, inputs);
}

private Optional<PlanNode> pushPartialToLeftChild(AggregationNode node, JoinNode child, Context context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
import static io.trino.sql.planner.plan.JoinType.INNER;

Expand Down Expand Up @@ -69,11 +70,11 @@ public static Optional<PlanNode> pushProjectionThroughJoin(
for (Map.Entry<Symbol, Expression> assignment : projectNode.getAssignments().entrySet()) {
Expression expression = assignment.getValue();
Set<Symbol> symbols = extractUnique(expression);
if (leftChild.getOutputSymbols().containsAll(symbols)) {
if (containsAll(leftChild.getOutputSymbols(), symbols)) {
// expression is satisfied with left child symbols
leftAssignmentsBuilder.put(assignment.getKey(), expression);
}
else if (rightChild.getOutputSymbols().containsAll(symbols)) {
else if (containsAll(rightChild.getOutputSymbols(), symbols)) {
// expression is satisfied with right child symbols
rightAssignmentsBuilder.put(assignment.getKey(), expression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.trino.sql.planner.EqualityInference.isInferenceCandidate;
import static io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy.AUTOMATIC;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static io.trino.sql.planner.SymbolsExtractor.extractAll;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
import static io.trino.sql.planner.iterative.rule.DetermineJoinDistributionType.canReplicate;
Expand Down Expand Up @@ -518,7 +519,7 @@ static class MultiJoinNode
this.pushedProjectionThroughJoin = pushedProjectionThroughJoin;

List<Symbol> inputSymbols = sources.stream().flatMap(source -> source.getOutputSymbols().stream()).collect(toImmutableList());
checkArgument(inputSymbols.containsAll(outputSymbols), "inputs do not contain all output symbols");
checkArgument(containsAll(inputSymbols, outputSymbols), "inputs do not contain all output symbols");
}

public Expression getFilter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
import static io.trino.SystemSessionProperties.isDistributedSortEnabled;
import static io.trino.SystemSessionProperties.isSpillEnabled;
import static io.trino.SystemSessionProperties.isTaskScaleWritersEnabled;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION;
Expand Down Expand Up @@ -609,7 +610,7 @@ private List<Symbol> pruneMarkDistinctSymbols(MarkDistinctNode node, List<LocalP
if (property instanceof ConstantProperty) {
redundantSymbolsBuilder.add(((ConstantProperty<Symbol>) property).getColumn());
}
else if (!node.getDistinctSymbols().containsAll(property.getColumns())) {
else if (!containsAll(node.getDistinctSymbols(), property.getColumns())) {
// Ran into a non-distinct symbol. There will be no more symbols that are functionally dependent on distinct symbols exclusively.
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.metadata.OperatorNameUtil.mangleOperatorName;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION;
import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren;
Expand Down Expand Up @@ -760,7 +761,7 @@ private PlanWithProperties plan(PlanNode node, HashComputationSet parentPreferen
{
PlanWithProperties result = node.accept(this, parentPreference);
checkState(
result.getNode().getOutputSymbols().containsAll(result.getHashSymbols().values()),
containsAll(result.getNode().getOutputSymbols(), result.getHashSymbols().values()),
"Node %s declares hash symbols not in the output",
result.getNode().getClass().getSimpleName());
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import static io.trino.spi.function.FunctionKind.AGGREGATE;
import static io.trino.sql.ir.Booleans.TRUE;
import static io.trino.sql.ir.IrUtils.combineConjuncts;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static io.trino.sql.planner.plan.WindowFrameType.RANGE;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
Expand Down Expand Up @@ -288,7 +289,7 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context
.transformKeys(node.getAssignments()::get)
.intersect(node.getEnforcedConstraint());

checkState(node.getOutputSymbols().containsAll(context.getLookupSymbols()));
checkState(containsAll(node.getOutputSymbols(), context.getLookupSymbols()));

Set<ColumnHandle> lookupColumns = context.getLookupSymbols().stream()
.map(node.getAssignments()::get)
Expand Down Expand Up @@ -375,8 +376,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Context> context)
}

// Lookup symbols can only be passed through if they are part of the partitioning

if (!node.getPartitionBy().containsAll(context.get().getLookupSymbols())) {
if (!containsAll(node.getPartitionBy(), context.get().getLookupSymbols())) {
return node;
}

Expand All @@ -393,7 +393,7 @@ public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext<Context> c
public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext<Context> context)
{
// Lookup symbols can only be passed through the probe side of an index join
if (!node.getProbeSource().getOutputSymbols().containsAll(context.get().getLookupSymbols())) {
if (!containsAll(node.getProbeSource().getOutputSymbols(), context.get().getLookupSymbols())) {
return node;
}

Expand All @@ -411,7 +411,7 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext<Context> conte
public PlanNode visitAggregation(AggregationNode node, RewriteContext<Context> context)
{
// Lookup symbols can only be passed through if they are part of the group by columns
if (!node.getGroupingKeys().containsAll(context.get().getLookupSymbols())) {
if (!containsAll(node.getGroupingKeys(), context.get().getLookupSymbols())) {
return node;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static java.util.Objects.requireNonNull;

public class LimitPushDown
Expand Down Expand Up @@ -132,7 +133,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext<LimitConte
node.getAggregations().isEmpty() &&
!node.getGroupingKeys().isEmpty() &&
node.getOutputSymbols().size() == node.getGroupingKeys().size() &&
node.getOutputSymbols().containsAll(node.getGroupingKeys())) {
containsAll(node.getOutputSymbols(), node.getGroupingKeys())) {
PlanNode rewrittenSource = context.rewrite(node.getSource());
return new DistinctLimitNode(idAllocator.getNextId(), rewrittenSource, limit.getCount(), false, rewrittenSource.getOutputSymbols(), Optional.empty());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.trino.sql.planner.EqualityInference.isInferenceCandidate;
import static io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols;
import static io.trino.sql.planner.SymbolUtils.containsAll;
import static io.trino.sql.planner.SymbolsExtractor.extractUnique;
import static io.trino.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression;
import static io.trino.sql.planner.iterative.rule.UnwrapCastInComparison.unwrapCasts;
Expand Down Expand Up @@ -239,8 +240,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext<Expression> context)
// function is injective, but that's a rare case. The majority of window nodes are expected to be partitioned by
// pre-projected symbols.
Predicate<Expression> isSupported = conjunct ->
isDeterministic(conjunct) &&
partitionSymbols.containsAll(extractUnique(conjunct));
isDeterministic(conjunct) && containsAll(partitionSymbols, extractUnique(conjunct));

Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported));

Expand All @@ -260,8 +260,7 @@ public PlanNode visitTopNRanking(TopNRankingNode node, RewriteContext<Expression

// TODO: This could be broader. See the comment in visitWindow().
Predicate<Expression> isSupported = conjunct ->
isDeterministic(conjunct) &&
partitionSymbols.containsAll(extractUnique(conjunct));
isDeterministic(conjunct) && containsAll(partitionSymbols, extractUnique(conjunct));

Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported));

Expand Down Expand Up @@ -500,7 +499,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext<Expression> context)
if (joinEqualityExpression(conjunct, node.getLeft().getOutputSymbols(), node.getRight().getOutputSymbols())) {
Comparison equality = (Comparison) conjunct;

boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(equality.left()));
boolean alignedComparison = containsAll(node.getLeft().getOutputSymbols(), extractUnique(equality.left()));
Expression leftExpression = alignedComparison ? equality.left() : equality.right();
Expression rightExpression = alignedComparison ? equality.right() : equality.left();

Expand Down Expand Up @@ -625,7 +624,7 @@ private DynamicFiltersResult createDynamicFilters(
Comparison comparison = expression.getComparison();
Expression leftExpression = comparison.left();
Expression rightExpression = comparison.right();
boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(leftExpression));
boolean alignedComparison = containsAll(node.getLeft().getOutputSymbols(), extractUnique(leftExpression));
return new DynamicFilterExpression(
new Comparison(
alignedComparison ? comparison.operator() : comparison.operator().flip(),
Expand Down
Loading

0 comments on commit e38277f

Please sign in to comment.