Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Dec 3, 2023
1 parent 0727f28 commit e4c57ea
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 46 deletions.
2 changes: 1 addition & 1 deletion opteryx/components/heuristic_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self):

self.strategies = [
SplitConjunctivePredicatesStrategy(),
# PredicatePushdownStrategy(),
PredicatePushdownStrategy(),
ProjectionPushdownStrategy(),
]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from orso.tools import random_string

from opteryx.components.binder.binder_visitor import extract_join_fields
Expand Down Expand Up @@ -29,7 +41,7 @@ def visit(
if not context.optimized_plan:
context.optimized_plan = context.pre_optimized_tree.copy()

if node.node_type in (
if False and node.node_type in (
LogicalPlanStepType.Scan,
LogicalPlanStepType.FunctionDataset,
LogicalPlanStepType.Subquery,
Expand All @@ -44,17 +56,20 @@ def visit(
elif node.node_type == LogicalPlanStepType.Filter:
# collect predicates we can probably push
if node.simple and len(node.relations) > 0:
# record where the node was, so we can put it back
node.nid = context.node_id
node.plan_path = context.optimized_plan.trace_to_root(context.node_id)
context.collected_predicates.append(node)
context.optimized_plan.remove_node(context.node_id, heal=True)

elif node.node_type == LogicalPlanStepType.Join:
elif False and node.node_type == LogicalPlanStepType.Join and context.collected_predicates:
# push predicates which reference multiple relations here

if node.type == "cross join" and node.unnest_column:
# if it's a CROSS JOIN UNNEST - don't try to push any further
# IMPROVE: we should push everything that doesn't reference the unnested column
context = self._handle_predicates(node, context)
elif node.type in ("cross join", "inner"):
elif node.type in ("cross join",): # , "inner"):
# we may be able to rewrite as an inner join
remaining_predicates = []
for predicate in context.collected_predicates:
Expand All @@ -66,16 +81,13 @@ def visit(
else:
remaining_predicates.append(predicate)

new_left_columns, new_right_columns = extract_join_fields(
node.on, node.left_relation_names, node.right_relation_names
)

if not node.right_columns:
node.right_columns = []
node.right_columns.extend(new_right_columns)
if not node.left_columns:
node.left_columns = []
node.left_columns.extend(new_left_columns)
print("LEFT", node.left_columns, node.left_relation_names)
print("RIGHT", node.right_columns, node.right_relation_names)
node.left_columns, node.right_columns = extract_join_fields(
node.on, node.left_relation_names, node.right_relation_names
)
print("LEFT", node.left_columns, node.left_relation_names)
print("RIGHT", node.right_columns, node.right_relation_names)

mismatches = get_mismatched_condition_column_types(node.on)
if mismatches:
Expand All @@ -85,7 +97,7 @@ def visit(
node.columns = get_all_nodes_of_type(node.on, (NodeType.IDENTIFIER,))
context.collected_predicates = remaining_predicates

else:
elif context.collected_predicates:
# IMPROVE, allow pushing past OUTER, SEMI, ANTI joins on one leg
for predicate in context.collected_predicates:
context.optimized_plan.insert_node_after(
Expand All @@ -104,17 +116,18 @@ def visit(
remaining_predicates.append(predicate)
context.collected_predicates = remaining_predicates

context.optimized_plan.add_node(context.node_id, node)

# DEBUG: log (context.optimized_plan.draw())
context.optimized_plan.add_node(context.node_id, node)
return context

def complete(self, plan: LogicalPlan, context: HeuristicOptimizerContext) -> LogicalPlan:
# anything we couldn't push, we need to put back
if context.collected_predicates:
context.collected_predicates.reverse()
exit_node_id = context.optimized_plan.get_exit_points()[0]
for predicate in context.collected_predicates:
context.optimized_plan.insert_node_before(random_string(), predicate, exit_node_id)
for predicate in context.collected_predicates:
for nid in predicate.plan_path:
if nid in context.optimized_plan:
context.optimized_plan.insert_node_before(predicate.nid, predicate, nid)
break
return context.optimized_plan

def _handle_predicates(
Expand All @@ -125,9 +138,7 @@ def _handle_predicates(
if len(predicate.relations) == 1 and predicate.relations.intersection(
(node.relation, node.alias)
):
context.optimized_plan.insert_node_after(
random_string(), predicate, context.node_id
)
context.optimized_plan.insert_node_after(predicate.nid, predicate, context.node_id)
continue
remaining_predicates.append(predicate)
context.collected_predicates = remaining_predicates
Expand Down
3 changes: 3 additions & 0 deletions opteryx/components/logical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class LogicalPlan(Graph):


class LogicalPlanNode(Node):
def copy(self) -> "Node":
return LogicalPlanNode(**super().copy().properties)

def __str__(self):
try:
# fmt:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import itertools


def calculate_predicate_costs(predicate_selectivity, predicate_execution_time):
# Generate all possible predicate arrangements
predicate_arrangements = list(itertools.permutations(range(len(predicate_selectivity))))
Expand All @@ -29,15 +30,19 @@ def calculate_predicate_costs(predicate_selectivity, predicate_execution_time):
execution_time = 0.0
for i in arrangement:
execution_time += predicate_execution_time[i] * cumulative_data_size
cumulative_data_size *= predicate_selectivity[i] # Reduce data size based on selectivity
cumulative_data_size *= predicate_selectivity[
i
] # Reduce data size based on selectivity
arrangement_costs[arrangement] = execution_time

return arrangement_costs


def print_arrangement_costs(arrangement_costs):
for arrangement, cost in arrangement_costs.items():
print(f"Arrangement {arrangement}: Execution Time = {cost}")


# Sample predicate selectivity and execution time estimates
predicate_selectivity = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
predicate_execution_time = [0.8, 0.2, 0.5, 0.1, 0.1, 0.4, 0.7]
Expand All @@ -47,5 +52,3 @@ def print_arrangement_costs(arrangement_costs):

# Print the costs
print_arrangement_costs(arrangement_costs)


2 changes: 1 addition & 1 deletion opteryx/models/connection_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ def __post_init__(self):
Initializes additional attributes after the object has been created.
"""
# The initializer is a function rather than an empty constructor so we init here
object.__setattr__(self, "variables", SystemVariables.copy(VariableOwner.USER))
object.__setattr__(self, "variables", SystemVariables.snapshot(VariableOwner.USER))
20 changes: 10 additions & 10 deletions opteryx/models/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,17 @@ def _inner_copy(obj: Any) -> Any:
Returns:
Any: The new, independent deep copy.
"""
if isinstance(obj, list):
return [_inner_copy(item) for item in obj]
if isinstance(obj, tuple):
return tuple(_inner_copy(item) for item in obj)
if isinstance(obj, set):
return {_inner_copy(item) for item in obj}
if isinstance(obj, dict):
return {key: _inner_copy(value) for key, value in obj.items()}
if hasattr(obj, "copy"):
return obj.copy()
try:
if isinstance(obj, list):
return [_inner_copy(item) for item in obj]
if isinstance(obj, tuple):
return tuple(_inner_copy(item) for item in obj)
if isinstance(obj, set):
return {_inner_copy(item) for item in obj}
if isinstance(obj, dict):
return {key: _inner_copy(value) for key, value in obj.items()}
if hasattr(obj, "copy"):
return obj.copy()
return copy.deepcopy(obj)
except:
return obj
Expand Down
7 changes: 5 additions & 2 deletions opteryx/operators/join_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def __init__(self, properties: QueryProperties, **config):
self._using = config.get("using")

self._left_columns = config.get("left_columns")
self._left_relation = config.get("left_relation")
self._left_relation = config.get("left_relation_names")

self._right_columns = config.get("right_columns")
self._right_relation = config.get("right_relation")
self._right_relation = config.get("right_relation_names")

@property
def name(self): # pragma: no cover
Expand All @@ -69,6 +69,9 @@ def execute(self) -> Iterable:

right_table = pyarrow.concat_tables(right_node.execute(), mode="default")

print("O_LEFT", self._left_columns, self._left_relation)
print("O_RIGHT", self._right_columns, self._right_relation)

for morsel in left_node.execute():
# do the join
new_morsel = morsel.join(
Expand Down
2 changes: 1 addition & 1 deletion opteryx/shared/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __iter__(self):
def __len__(self):
return len(self._variables)

def copy(self, owner: VariableOwner = VariableOwner.USER) -> "SystemVariablesContainer":
def snapshot(self, owner: VariableOwner = VariableOwner.USER) -> "SystemVariablesContainer":
return SystemVariablesContainer(owner)

def as_column(self, key: str):
Expand Down
45 changes: 44 additions & 1 deletion opteryx/third_party/travers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ def __add__(self, other):
self._nodes.update(other._nodes)
return self

def __contains__(self, nid: str) -> bool:
return nid in self._nodes

def draw(self):
tree = self.depth_first_search()
return "".join(print_tree_inner(tree))
Expand Down Expand Up @@ -526,6 +529,46 @@ def _inner_copy(obj: Any) -> Any:

graph = Graph()
graph._nodes = _inner_copy(self._nodes)
graph._edges = _inner_copy(self._edges)
graph._edges = self.copy_edges()

return graph

def copy_edges(self):
"""
Creates an independent copy of the edges in the graph.
Returns:
A new dictionary representing the edges in the graph.
"""
new_edges = {}
for source, target, relationship in self.edges():
if source not in new_edges:
new_edges[source] = []
new_edges[source].append((target, relationship))
return new_edges

def trace_to_root(self, nid: str) -> list:
"""
Traces the path from this node to the root of the tree, recording each node along the way.
Args:
tree: The tree structure containing the nodes and edges.
"""
route = []
current_node_id = nid
while True:
# Get the node before the current node
outgoing_edges = self.outgoing_edges(current_node_id)
if not outgoing_edges:
break # Reached the root

# Assuming the first element of the first tuple in ingoing_edges is the previous node's ID
previous_node_id = outgoing_edges[0][1]

# Record this node in the chain
route.append(previous_node_id)

# Move to the previous node
current_node_id = previous_node_id

return route
6 changes: 3 additions & 3 deletions tests/misc/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def test_connection_variables():
# Create a clone of the system variables object
connection_vars = SystemVariables.copy()
connection_vars = SystemVariables.snapshot()

# Verify that the clone has the same values as the original
assert connection_vars["max_cache_evictions"] == 32
Expand All @@ -26,7 +26,7 @@ def test_connection_variables():

def test_variables_permissions():
# Create a clone of the system variables object
connection_vars = SystemVariables.copy()
connection_vars = SystemVariables.snapshot()

# we shouldn't be able to change the licence
with pytest.raises(PermissionsError):
Expand All @@ -44,7 +44,7 @@ def test_variables_permissions():

def test_variable_types():
# Create a clone of the system variables object
connection_vars = SystemVariables.copy()
connection_vars = SystemVariables.snapshot()

# max_cache_evictions is a numeric field, so should fail if we try
# to set to a string
Expand Down

0 comments on commit e4c57ea

Please sign in to comment.