Skip to content

Commit

Permalink
✨ Have a killer question on optimisation strategies #2297
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Jan 28, 2025
1 parent 04c70cc commit 91ded4d
Show file tree
Hide file tree
Showing 20 changed files with 69,397 additions and 59 deletions.
5 changes: 1 addition & 4 deletions opteryx/functions/string_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,16 +357,13 @@ def match_against(arr, val):


def regex_replace(array, _pattern, _replacement):
import re

from opteryx.third_party.mrabarnett import regex as re
import pyarrow

pattern = _pattern[0]
replacement = _replacement[0]
compiled_pattern = re.compile(pattern)

return compute.replace_substring_regex(array, pattern, replacement)

# Apply the regex replacement to each element in the array
vectorized_replace = numpy.vectorize(lambda x: compiled_pattern.sub(replacement, x))
return pyarrow.array(vectorized_replace(array), type=pyarrow.string())
99 changes: 55 additions & 44 deletions opteryx/operators/simple_aggregate_and_group_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
and are collected as, a single value.
"""

import time

import numpy
import pyarrow
from orso.types import OrsoTypes
Expand All @@ -21,13 +23,39 @@
from opteryx.managers.expression import evaluate_and_append
from opteryx.managers.expression import get_all_nodes_of_type
from opteryx.models import QueryProperties
from opteryx.operators.aggregate_node import AGGREGATORS
from opteryx.operators.aggregate_node import build_aggregations
from opteryx.operators.aggregate_node import extract_evaluations
from opteryx.operators.aggregate_node import project

from . import BasePlanNode


def build_finalizer_aggregations(aggregators):
column_map = {}
aggs = []

if not isinstance(aggregators, list):
aggregators = [aggregators]

for root in aggregators:
for aggregator in get_all_nodes_of_type(root, select_nodes=(NodeType.AGGREGATOR,)):
count_options = None

field_name = aggregator.schema_column.identity
if aggregator.value == "COUNT":
function = AGGREGATORS["SUM"]
else:
function = AGGREGATORS[aggregator.value]
# if the array agg is distinct, base off that function instead
aggs.append((field_name, function, count_options))
column_map[aggregator.schema_column.identity] = f"{field_name}_{function}".replace(
"_hash_", "_"
)

return column_map, aggs


class SimpleAggregateAndGroupNode(BasePlanNode):
SIMPLE_AGGREGATES = {"SUM", "MIN", "MAX", "COUNT"}

Expand Down Expand Up @@ -63,7 +91,11 @@ def __init__(self, properties: QueryProperties, **parameters):
self.group_by_columns = list({node.schema_column.identity for node in self.groups})
self.column_map, self.aggregate_functions = build_aggregations(self.aggregates)

self.accumulator = {}
self.finalizer_map, self.finalizer_aggregations = build_finalizer_aggregations(
self.aggregates
)

self.buffer = []

@property
def config(self): # pragma: no cover
Expand All @@ -76,20 +108,23 @@ def name(self): # pragma: no cover
return "Group By Simple"

def execute(self, morsel: pyarrow.Table, **kwargs):
internal_names = list(self.column_map.values()) + self.group_by_columns
column_names = list(self.column_map.keys()) + self.group_by_columns

if morsel == EOS:
py_dict = {}
for k, v in self.accumulator.items():
for i, group in enumerate(self.group_by_columns):
if group not in py_dict:
py_dict[group] = [k[i]]
else:
py_dict[group].append(k[i])
for column_id, value in v.items():
if column_id not in py_dict:
py_dict[column_id] = [value]
else:
py_dict[column_id].append(value)
yield pyarrow.Table.from_pydict(py_dict)
start = time.monotonic_ns()

internal_names = list(self.finalizer_map.values()) + self.group_by_columns
column_names = list(self.finalizer_map.keys()) + self.group_by_columns

groups = pyarrow.concat_tables(self.buffer, promote_options="permissive")
groups = groups.group_by(self.group_by_columns)
groups = groups.aggregate(self.finalizer_aggregations)
groups = groups.select(internal_names)
groups = groups.rename_columns(column_names)

self.statistics.time_groupby_finalize += time.monotonic_ns() - start
yield groups
yield EOS
return

Expand All @@ -108,37 +143,13 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
)

# use pyarrow to do phase 1 of the group by
st = time.monotonic_ns()
groups = morsel.group_by(self.group_by_columns)
groups = groups.aggregate(self.aggregate_functions)
# project to the desired column names from the pyarrow names
groups = groups.select(list(self.column_map.values()) + self.group_by_columns)
groups = groups.rename_columns(list(self.column_map.keys()) + self.group_by_columns)

# we now merge the results into the accumulator
for row in groups.to_pylist():
for aggregate in self.aggregates:
column_id = aggregate.schema_column.identity
value = row[column_id]
groups = tuple(row[group] for group in self.group_by_columns)

if groups not in self.accumulator:
self.accumulator[groups] = {}

if self.accumulator[groups].get(column_id) is None:
self.accumulator[groups][column_id] = value
elif aggregate.value == "COUNT" or aggregate.value == "SUM":
self.accumulator[groups][column_id] += value
elif aggregate.value == "MIN":
self.accumulator[groups][column_id] = min(
self.accumulator[groups][column_id], value
)
elif aggregate.value == "MAX":
self.accumulator[groups][column_id] = max(
self.accumulator[groups][column_id], value
)
else:
raise NotImplementedError(
f"SimpleAggregateAndGroupNode does not support {aggregate.value}"
)
groups = groups.select(internal_names)
groups = groups.rename_columns(column_names)
self.statistics.time_pregrouping += time.monotonic_ns() - st

self.buffer.append(groups)

yield None
3 changes: 3 additions & 0 deletions opteryx/planner/optimizer/strategies/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def fold_constants(root: Node, statistics: QueryStatistics) -> Node:

# fold costants in function parameters - this is generally aggregations we're affecting here
if root.parameters:
if isinstance(root.parameters, tuple):
root.parameters = list(root.parameters)
for i, param in enumerate(root.parameters):
root.parameters[i] = fold_constants(param, statistics)

Expand Down Expand Up @@ -322,6 +324,7 @@ def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerCo

if node.node_type == LogicalPlanStepType.AggregateAndGroup:
node.groups = [g.centre if g.node_type == NodeType.NESTED else g for g in node.groups]
node.groups = [fold_constants(g, self.statistics) for g in node.groups]
context.optimized_plan[context.node_id] = node

return context
Expand Down
4 changes: 4 additions & 0 deletions opteryx/planner/optimizer/strategies/optimization_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, tree: LogicalPlan):
self.seen_projections: int = 0
self.seen_unions: int = 0
self.seen_distincts: int = 0
self.seen_projects_since_distinct: int = 0

self.collected_predicates: list = []
"""We collect predicates we should be able to push to reads and joins"""
Expand All @@ -44,6 +45,9 @@ def __init__(self, tree: LogicalPlan):
self.collected_limits: list = []
"""We collect limits to to to eliminate rows earlier"""

self.distincted_indentities: set = set()
"""The columns that implicitly exist in the plan because of a distinct"""


class OptimizationStrategy:
def __init__(self, statistics):
Expand Down
9 changes: 1 addition & 8 deletions opteryx/planner/optimizer/strategies/projection_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,13 @@ def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerCo
# If we're at the something other than the top project (e.g. in a subquery) in a plan we
# may be able to remove some columns (and potentially some evaluations) if the columns
# aren't referenced in the outer query.
if node.node_type == LogicalPlanStepType.Distinct:
context.seen_distincts += 1
if node.node_type == LogicalPlanStepType.Union:
context.seen_unions += 1
if node.node_type == LogicalPlanStepType.Project:
if (
context.seen_distincts == 0
and context.seen_unions == 0
and context.seen_projections > 0
):
if context.seen_unions == 0 and context.seen_projections > 0:
node.columns = [
n for n in node.columns if n.schema_column.identity in node.pre_update_columns
]
self.seen_distincts = 0
if context.seen_unions == 0:
context.seen_projections += 1

Expand Down
2 changes: 1 addition & 1 deletion opteryx/planner/physical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_physical_plan(logical_plan, query_properties) -> PhysicalPlan:
else:
node = operators.AggregateNode(query_properties, **{k:v for k,v in node_config.items() if k in ("aggregates", "all_relations")})
elif node_type == LogicalPlanStepType.AggregateAndGroup:
if False and all(agg.value in operators.SimpleAggregateAndGroupNode.SIMPLE_AGGREGATES for agg in node_config["aggregates"]):
if all(agg.value in operators.SimpleAggregateAndGroupNode.SIMPLE_AGGREGATES and agg.duplicate_treatment != "Distinct" for agg in node_config["aggregates"]):
node = operators.SimpleAggregateAndGroupNode(query_properties, **{k:v for k,v in node_config.items() if k in ("aggregates", "groups", "projection", "all_relations")})
else:
node = operators.AggregateAndGroupNode(query_properties, **{k:v for k,v in node_config.items() if k in ("aggregates", "groups", "projection", "all_relations")})
Expand Down
11 changes: 10 additions & 1 deletion opteryx/planner/sql_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,22 @@ def sql_parts(string):
)
# Match ", ', b", b', `
# We match b prefixes separately after the non-prefix versions
quoted_strings = re.compile(r'("[^"]*"|\'[^\']*\'|\b[bB]"[^"]*"|\b[bB]\'[^\']*\'|`[^`]*`)')
quoted_strings = re.compile(
r'("[^"]*"|\'[^\']*\'|\b[bB]"[^"]*"|\b[bB]\'[^\']*\'|\b[rR]"[^"]*"|\b[rR]\'[^\']*\'|`[^`]*`)'
)

parts = []
for part in quoted_strings.split(string):
if part and part[-1] in ("'", '"', "`"):
if part[0] in ("b", "B"):
parts.append(f"blob({part[1:]})")
elif part[0] in ("r", "R"):
# We take the raw string and encode it, pass it into the
# plan as the encoded string and let the engine decode it
from base64 import b85encode

encoded_part = b85encode(part[2:-1].encode()).decode()
parts.append(f"BASE85_DECODE('{encoded_part}')")
else:
parts.append(part)
else:
Expand Down
3 changes: 3 additions & 0 deletions opteryx/third_party/mrabarnett/regex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .regex import *
from . import regex
__all__ = regex.__all__
Loading

0 comments on commit 91ded4d

Please sign in to comment.