Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Jan 14, 2025
1 parent 8c49b93 commit 5da28d2
Show file tree
Hide file tree
Showing 15 changed files with 277 additions and 6 deletions.
10 changes: 7 additions & 3 deletions opteryx/connectors/disk_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,23 @@ def read_dataset(
blob_name=blob_name,
statistics=self.statistics,
decoder=decoder,
just_schema=just_schema,
just_schema=False,
projection=columns,
selection=predicates,
)
self.statistics.rows_seen += num_rows
yield decoded
else:
yield read_blob(
schema = read_blob(
blob_name=blob_name,
statistics=self.statistics,
decoder=decoder,
just_schema=just_schema,
just_schema=True,
)
if schema.row_count_metric:
schema.row_count_metric *= len(blob_names)
self.statistics.estimated_row_count += schema.row_count_metric
yield schema

except UnsupportedFileTypeError:
pass # Skip unsupported file types
Expand Down
12 changes: 11 additions & 1 deletion opteryx/connectors/gcp_cloudstorage_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from opteryx.connectors.capabilities import Cacheable
from opteryx.connectors.capabilities import Partitionable
from opteryx.connectors.capabilities import PredicatePushable
from opteryx.connectors.capabilities import Statistics
from opteryx.exceptions import DatasetNotFoundError
from opteryx.exceptions import DatasetReadError
from opteryx.exceptions import MissingDependencyError
Expand Down Expand Up @@ -51,7 +52,7 @@ def get_storage_credentials():


class GcpCloudStorageConnector(
BaseConnector, Cacheable, Partitionable, PredicatePushable, Asynchronous
BaseConnector, Cacheable, Partitionable, PredicatePushable, Asynchronous, Statistics
):
__mode__ = "Blob"
__type__ = "GCS"
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, credentials=None, **kwargs):
Cacheable.__init__(self, **kwargs)
PredicatePushable.__init__(self, **kwargs)
Asynchronous.__init__(self, **kwargs)
Statistics.__init__(self, **kwargs)

self.dataset = self.dataset.replace(".", OS_SEP)
self.credentials = credentials
Expand Down Expand Up @@ -253,6 +255,8 @@ def read_dataset(
selection=predicates,
just_schema=just_schema,
)
if len(blob_names) == 1:
self.relation_statistics = decoder(blob_bytes, just_statistics=True)
except Exception as err:
raise DatasetReadError(f"Unable to read file {blob_name} ({err})") from err

Expand All @@ -268,10 +272,16 @@ def get_dataset_schema(self) -> RelationSchema:
if self.schema:
return self.schema

number_of_blobs = sum(len(b) for b in self.blob_list.values())

# Read first blob for schema inference and cache it
self.schema = next(self.read_dataset(just_schema=True), None)

if self.schema is None:
raise DatasetNotFoundError(dataset=self.dataset)

if self.schema.row_count_metric:
self.schema.row_count_metric *= number_of_blobs
self.statistics.estimated_row_count += self.schema.row_count_metric

return self.schema
6 changes: 6 additions & 0 deletions opteryx/managers/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from opteryx.exceptions import InvalidInternalStateError

from .parallel_engine import execute as parallel_execute
from .serial_engine import execute as serial_execute


Expand All @@ -11,4 +12,9 @@ def execute(plan, statistics):
# Label the join legs to ensure left/right ordering
plan.label_join_legs()

"""
If we have 1 CPU, or less than 1Gb/CPU we use the serial engine.
"""

# yield from parallel_execute(plan, statistics=statistics)
yield from serial_execute(plan, statistics=statistics)
124 changes: 124 additions & 0 deletions opteryx/managers/execution/parallel_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import multiprocessing as mp
from queue import Empty
from typing import Any
from typing import Generator
from typing import Tuple

import pyarrow

from opteryx import EOS
from opteryx.constants import ResultType
from opteryx.exceptions import InvalidInternalStateError
from opteryx.models import PhysicalPlan
from opteryx.models import QueryStatistics

WORKERS = 4
kill = object()


def execute(
plan: PhysicalPlan, statistics: QueryStatistics = None, num_workers: int = WORKERS
) -> Tuple[Generator[pyarrow.Table, Any, Any], ResultType]:
"""
Execute the physical plan with morsel-level parallelism.
Parameters:
plan: PhysicalPlan
The physical plan to execute.
statistics: QueryStatistics, optional
Object to collect query statistics, defaults to None.
num_workers: int, optional
Number of parallel workers for processing morsels, defaults to 4.
Returns:
Tuple[Generator[pyarrow.Table, Any, Any], ResultType]
A generator producing pyarrow tables and the result type.
"""
try:
mp.set_start_method("fork", force=True)

# Ensure there's a single head node
head_nodes = list(set(plan.get_exit_points()))
if len(head_nodes) != 1:
raise InvalidInternalStateError(
f"Query plan has {len(head_nodes)} heads, expected exactly 1."
)

head_node = plan[head_nodes[0]]

# Queue for incoming morsels and a queue for results
work_queue = mp.Queue()
result_queue = mp.Queue()

# Create a worker pool for processing morsels
pool = mp.Pool(num_workers, _worker_init, (plan, work_queue, result_queue))

def inner_execute(plan: PhysicalPlan) -> Generator:
# Get the pump nodes from the plan and execute them in order
pump_nodes = [
(nid, node) for nid, node in plan.depth_first_search_flat() if node.is_scan
]
for pump_nid, pump_instance in pump_nodes:
work_queue.put((pump_nid, None, None))
work_queue.put((pump_nid, EOS, None))
while True:
try:
result = result_queue.get(timeout=0.1)
print("got final result", type(result))
if result == EOS:
continue
return result
except Empty:
pass

result_generator = inner_execute(plan)

print("I'm done here")

# pool.close()
# pool.join()

return result_generator, ResultType.TABULAR

finally:
# Close and join the pool after execution
pass


def _worker_init(plan: PhysicalPlan, work_queue: mp.Queue, completion_queue: mp.Queue):
"""
Initialize the worker process for morsel-level parallelism.
Parameters:
plan: PhysicalPlan
The overall physical plan.
morsel_queue: mp.Queue
Queue from which morsels are fetched.
result_queue: mp.Queue
Queue to which processed morsels are pushed.
"""
while True:
try:
work = work_queue.get(timeout=0.1)
except Empty:
continue

nid, morsel, join_leg = work

operator = plan[nid]

results = operator(morsel, join_leg)

if results is None:
continue

print("Worker got work for", operator.name, type(morsel), "results")

for result in (result for result in results if result is not None):
children = plan.outgoing_edges(nid)
print("results", type(result), children)
if len(children) == 0:
print("done")
completion_queue.put(result)
for _, child, leg in children:
work_queue.put((child, result, leg))
1 change: 1 addition & 0 deletions opteryx/models/logical_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def copy(self):
return LogicalColumn(
node_type=self.node_type,
source_column=self.source_column,
source_connector=self.source_connector,
source=self.source,
alias=self.alias,
schema_column=None
Expand Down
4 changes: 4 additions & 0 deletions opteryx/planner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"""

import datetime
import decimal
import time
from typing import Any
from typing import Dict
Expand Down Expand Up @@ -93,6 +94,8 @@ def build_literal_node(
datetime.datetime: OrsoTypes.TIMESTAMP,
datetime.time: OrsoTypes.TIME,
datetime.date: OrsoTypes.DATE,
decimal.Decimal: OrsoTypes.DECIMAL,
list: OrsoTypes.ARRAY,
}

value_type = type(value)
Expand All @@ -107,6 +110,7 @@ def build_literal_node(
)
root.left = None
root.right = None
root.schema_column.type = root.type

# DEBUG:log (f"Unable to create literal node for {value}, of type {value_type}")
return root
Expand Down
1 change: 1 addition & 0 deletions opteryx/planner/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def visit_function_dataset(
node.schema = schema
else:
raise UnsupportedSyntaxError(f"{node.function} cannot be used in place of a table.")
node.connector = None
return node, context

def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
Expand Down
2 changes: 2 additions & 0 deletions opteryx/planner/logical_planner/logical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,8 @@ def build_expression_tree(relation, dnf_list):
and_node = None
for predicate in dnf_list:
identifier, operator, value = predicate
if operator.startswith("AnyOp"):
identifier, value = value, identifier
comparison_node = Node(
node_type=NodeType.COMPARISON_OPERATOR,
value=operator,
Expand Down
1 change: 1 addition & 0 deletions opteryx/planner/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, statistics: QueryStatistics):
ConstantFoldingStrategy(statistics),
BooleanSimplificationStrategy(statistics),
SplitConjunctivePredicatesStrategy(statistics),
CorrelatedFiltersStrategy(statistics),
PredicateRewriteStrategy(statistics),
PredicatePushdownStrategy(statistics),
ProjectionPushdownStrategy(statistics),
Expand Down
2 changes: 2 additions & 0 deletions opteryx/planner/optimizer/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .boolean_simplication import BooleanSimplificationStrategy
from .constant_folding import ConstantFoldingStrategy
from .correlated_filters import CorrelatedFiltersStrategy
from .distinct_pushdown import DistinctPushdownStrategy
from .limit_pushdown import LimitPushdownStrategy
from .operator_fusion import OperatorFusionStrategy
Expand All @@ -12,6 +13,7 @@
__all__ = [
"BooleanSimplificationStrategy",
"ConstantFoldingStrategy",
"CorrelatedFiltersStrategy",
"DistinctPushdownStrategy",
"LimitPushdownStrategy",
"OperatorFusionStrategy",
Expand Down
95 changes: 95 additions & 0 deletions opteryx/planner/optimizer/strategies/correlated_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# See the License at http://www.apache.org/licenses/LICENSE-2.0
# Distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND.

"""
Optimization Rule - Correlated Filters
Type: Heuristic
Goal: Reduce Rows
When fields are joined on, we can infer ranges of values based on statistics
or filters. This can be used to reduce the number of rows that need to be read
and processed.
"""

from orso.tools import random_string

from opteryx.managers.expression import NodeType
from opteryx.models import Node
from opteryx.planner import build_literal_node
from opteryx.planner.logical_planner import LogicalPlan
from opteryx.planner.logical_planner import LogicalPlanNode
from opteryx.planner.logical_planner import LogicalPlanStepType

from .optimization_strategy import OptimizationStrategy
from .optimization_strategy import OptimizerContext


def _write_filters(left_column, right_column):
new_filters = []
if left_column.schema_column.highest_value is not None:
a_side = right_column
b_side = build_literal_node(left_column.schema_column.highest_value)
new_filter = Node(
LogicalPlanStepType.Filter,
condition=Node(NodeType.COMPARISON_OPERATOR, value="LtEq", left=a_side, right=b_side),
columns=[right_column],
relations={right_column.source},
all_relations={right_column.source},
)
new_filters.append(new_filter)

a_side = right_column
b_side = build_literal_node(left_column.schema_column.lowest_value)
new_filter = Node(
LogicalPlanStepType.Filter,
condition=Node(NodeType.COMPARISON_OPERATOR, value="GtEq", left=a_side, right=b_side),
columns=[right_column],
relations={right_column.source},
all_relations={right_column.source},
)
new_filters.append(new_filter)
return new_filters


class CorrelatedFiltersStrategy(OptimizationStrategy):
def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerContext:
if not context.optimized_plan:
context.optimized_plan = context.pre_optimized_tree.copy() # type: ignore

if (
node.node_type == LogicalPlanStepType.Join
and node.type == "inner"
and len(node.all_relations) == 2
):
left_column = node.on.left
right_column = node.on.right
new_filters = []

# Empty connectors are FUNCTION datasets, we could push filters down and create
# statistics for them, but there are other issues this creates
if (
left_column.node_type == NodeType.IDENTIFIER
and right_column.node_type == NodeType.IDENTIFIER
and left_column.source_connector != set()
):
new_filters = _write_filters(left_column, right_column)
if (
left_column.node_type == NodeType.IDENTIFIER
and right_column.node_type == NodeType.IDENTIFIER
and right_column.source_connector != set()
):
new_filters.extend(_write_filters(right_column, left_column))
for new_filter in new_filters:
context.optimized_plan.insert_node_before(
random_string(), new_filter, context.node_id
)
self.statistics.optimization_inner_join_correlated_filter += 1

return context

def complete(self, plan: LogicalPlan, context: OptimizerContext) -> LogicalPlan:
# No finalization needed for this strategy
return plan
Loading

0 comments on commit 5da28d2

Please sign in to comment.