-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
Showing
9 changed files
with
328 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# See the License at http://www.apache.org/licenses/LICENSE-2.0 | ||
# Distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND. | ||
|
||
""" | ||
Simple Aggregation Node | ||
This is a SQL Query Execution Plan Node. | ||
This node performs aggregates without performing groupings, this is a specialized version | ||
which focuses on building aggregates which don't require seeing the entire dataset at a time. | ||
We avoid doing some work by not creating entire columns of data where possible. | ||
""" | ||
|
||
import pyarrow | ||
|
||
from opteryx import EOS | ||
from opteryx.managers.expression import NodeType | ||
from opteryx.managers.expression import evaluate_and_append | ||
from opteryx.models import QueryProperties | ||
from opteryx.operators.aggregate_node import extract_evaluations | ||
|
||
from . import BasePlanNode | ||
|
||
|
||
class SimpleAggregateCollector: | ||
def __init__(self, aggregate_type, column_id, count_nulls=False): | ||
self.aggregate_type = aggregate_type | ||
self.current_value = None | ||
self.count_nulls = count_nulls if aggregate_type == "COUNT" else False | ||
self.counter = 0 | ||
self.column_id = column_id | ||
|
||
def collect(self, values): | ||
if self.count_nulls: | ||
self.counter += pyarrow.compute.count(values).as_py() | ||
else: | ||
self.counter += pyarrow.compute.count(values, mode="only_valid").as_py() | ||
|
||
if self.current_value is None: | ||
if self.aggregate_type in ("SUM", "AVG"): | ||
self.current_value = pyarrow.compute.sum(values).as_py() | ||
elif self.aggregate_type == "MIN": | ||
self.current_value = pyarrow.compute.min(values).as_py() | ||
elif self.aggregate_type == "MAX": | ||
self.current_value = pyarrow.compute.max(values).as_py() | ||
elif self.aggregate_type != "COUNT": | ||
raise ValueError(f"Unsupported aggregate type: {self.aggregate_type}") | ||
else: | ||
if self.aggregate_type in ("SUM", "AVG"): | ||
self.current_value += pyarrow.compute.sum(values).as_py() | ||
elif self.aggregate_type == "MIN": | ||
self.current_value = min(self.current_value, pyarrow.compute.min(values).as_py()) | ||
elif self.aggregate_type == "MAX": | ||
self.current_value = max(self.current_value, pyarrow.compute.max(values).as_py()) | ||
elif self.aggregate_type != "COUNT": | ||
raise ValueError(f"Unsupported aggregate type: {self.aggregate_type}") | ||
|
||
def collect_literal(self, literal, count): | ||
self.counter += count | ||
|
||
if self.current_value is None: | ||
if self.aggregate_type in ("SUM", "AVG"): | ||
self.current_value = literal * count | ||
elif self.aggregate_type == "MIN" or self.aggregate_type == "MAX": | ||
self.current_value = literal | ||
elif self.aggregate_type != "COUNT": | ||
raise ValueError(f"Unsupported aggregate type: {self.aggregate_type}") | ||
else: | ||
if self.aggregate_type in ("SUM", "AVG"): | ||
self.current_value += literal * count | ||
elif self.aggregate_type == "MIN": | ||
self.current_value = min(self.current_value, literal) | ||
elif self.aggregate_type == "MAX": | ||
self.current_value = max(self.current_value, literal) | ||
elif self.aggregate_type != "COUNT": | ||
raise ValueError(f"Unsupported aggregate type: {self.aggregate_type}") | ||
|
||
def get_result(self): | ||
if self.aggregate_type == "AVG": | ||
if self.counter == 0 or self.current_value is None: | ||
return None | ||
return self.current_value / self.counter | ||
if self.aggregate_type == "COUNT": | ||
return self.counter | ||
return self.current_value | ||
|
||
|
||
class SimpleAggregateNode(BasePlanNode): | ||
SIMPLE_AGGREGATES = {"SUM", "MIN", "MAX", "AVG", "COUNT"} | ||
|
||
def __init__(self, properties: QueryProperties, **parameters): | ||
BasePlanNode.__init__(self, properties=properties, **parameters) | ||
|
||
self.aggregates = parameters.get("aggregates", []) | ||
|
||
# Get any functions we need to execute before aggregating | ||
self.evaluatable_nodes = extract_evaluations(self.aggregates) | ||
|
||
# Create collectors for each aggregate | ||
self.accumulator = {} | ||
for aggregate in self.aggregates: | ||
aggregate_type = aggregate.value | ||
final_column_id = aggregate.schema_column.identity | ||
|
||
self.accumulator[final_column_id] = SimpleAggregateCollector( | ||
aggregate_type, final_column_id | ||
) | ||
|
||
@property | ||
def config(self): # pragma: no cover | ||
return str(self.aggregates) | ||
|
||
@property | ||
def name(self): # pragma: no cover | ||
return "Aggregation (Simple)" | ||
|
||
def execute(self, morsel: pyarrow.Table, **kwargs) -> pyarrow.Table: | ||
if morsel == EOS: | ||
names = [] | ||
values = [] | ||
for k, v in self.accumulator.items(): | ||
names.append(k) | ||
values.append([v.get_result()]) | ||
yield pyarrow.Table.from_arrays(values, names=names) | ||
yield EOS | ||
return | ||
|
||
# Allow grouping by functions by evaluating them first | ||
if self.evaluatable_nodes: | ||
morsel = evaluate_and_append(self.evaluatable_nodes, morsel) | ||
|
||
for aggregate in self.aggregates: | ||
if aggregate.node_type in (NodeType.AGGREGATOR,): | ||
column_node = aggregate.parameters[0] | ||
|
||
if column_node.node_type == NodeType.LITERAL: | ||
self.accumulator[aggregate.schema_column.identity].collect_literal( | ||
column_node.value, morsel.num_rows | ||
) | ||
elif column_node.node_type == NodeType.WILDCARD: | ||
if "$COUNT(*)" in morsel.column_names: | ||
self.accumulator[aggregate.schema_column.identity].collect_literal( | ||
1, morsel["$COUNT(*)"][0].as_py() | ||
) | ||
else: | ||
self.accumulator[aggregate.schema_column.identity].collect_literal( | ||
1, morsel.num_rows | ||
) | ||
else: | ||
raw_column_values = morsel[column_node.schema_column.identity] | ||
self.accumulator[aggregate.schema_column.identity].collect(raw_column_values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.insert(1, os.path.join(sys.path[0], "../..")) | ||
|
||
import opteryx | ||
|
||
def test_avg_parquet(): | ||
cur = opteryx.query("SELECT AVG(followers) FROM testdata.flat.formats.parquet") | ||
stats = cur.stats | ||
assert stats["columns_read"] == 1, stats["columns_read"] | ||
assert stats["rows_read"] == 100000, stats["rows_read"] | ||
assert stats["rows_seen"] == 100000, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 3081.258, first | ||
|
||
def test_avg_non_parquet(): | ||
cur = opteryx.query("SELECT AVG(followers) FROM testdata.flat.ten_files;") | ||
stats = cur.stats | ||
assert stats["columns_read"] == 1, stats["columns_read"] | ||
assert stats["rows_read"] == 250, stats["rows_read"] | ||
assert stats["rows_seen"] == 250, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 7500362.668, first | ||
|
||
def test_avg_group_by(): | ||
""" we're reading data from the file, even though it starts SELECT COUNT(*) FROM """ | ||
cur = opteryx.query( | ||
"SELECT AVG(followers) FROM testdata.flat.formats.parquet GROUP BY tweet_id ORDER BY tweet_id;" | ||
) | ||
stats = cur.stats | ||
assert stats["columns_read"] == 2, stats["columns_read"] | ||
assert stats["rows_read"] == 100000, stats["rows_read"] | ||
assert stats["rows_seen"] == 100000, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 6.0, first | ||
|
||
if __name__ == "__main__": # pragma: no cover | ||
from tests.tools import run_tests | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.insert(1, os.path.join(sys.path[0], "../..")) | ||
|
||
import opteryx | ||
|
||
def test_max_parquet(): | ||
cur = opteryx.query("SELECT MAX(followers) FROM testdata.flat.formats.parquet") | ||
stats = cur.stats | ||
assert stats["columns_read"] == 1, stats["columns_read"] | ||
assert stats["rows_read"] == 100000, stats["rows_read"] | ||
assert stats["rows_seen"] == 100000, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 8266250, first | ||
|
||
def test_max_non_parquet(): | ||
cur = opteryx.query("SELECT MAX(followers) FROM testdata.flat.ten_files;") | ||
stats = cur.stats | ||
assert stats["columns_read"] == 1, stats["columns_read"] | ||
assert stats["rows_read"] == 250, stats["rows_read"] | ||
assert stats["rows_seen"] == 250, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 10415949, first | ||
|
||
def test_max_group_by(): | ||
""" we're reading data from the file, even though it starts SELECT COUNT(*) FROM """ | ||
cur = opteryx.query( | ||
"SELECT MAX(followers) FROM testdata.flat.formats.parquet GROUP BY tweet_id ORDER BY tweet_id;" | ||
) | ||
stats = cur.stats | ||
assert stats["columns_read"] == 2, stats["columns_read"] | ||
assert stats["rows_read"] == 100000, stats["rows_read"] | ||
assert stats["rows_seen"] == 100000, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 6.0, first | ||
|
||
if __name__ == "__main__": # pragma: no cover | ||
from tests.tools import run_tests | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.insert(1, os.path.join(sys.path[0], "../..")) | ||
|
||
import opteryx | ||
|
||
def test_min_parquet(): | ||
cur = opteryx.query("SELECT MIN(followers) FROM testdata.flat.formats.parquet") | ||
stats = cur.stats | ||
assert stats["columns_read"] == 1, stats["columns_read"] | ||
assert stats["rows_read"] == 100000, stats["rows_read"] | ||
assert stats["rows_seen"] == 100000, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 0, first | ||
|
||
def test_min_non_parquet(): | ||
cur = opteryx.query("SELECT MIN(followers) FROM testdata.flat.ten_files;") | ||
stats = cur.stats | ||
assert stats["columns_read"] == 1, stats["columns_read"] | ||
assert stats["rows_read"] == 250, stats["rows_read"] | ||
assert stats["rows_seen"] == 250, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 100, first | ||
|
||
def test_min_group_by(): | ||
""" we're reading data from the file, even though it starts SELECT COUNT(*) FROM """ | ||
cur = opteryx.query( | ||
"SELECT MIN(followers) FROM testdata.flat.formats.parquet GROUP BY tweet_id ORDER BY tweet_id;" | ||
) | ||
stats = cur.stats | ||
assert stats["columns_read"] == 2, stats["columns_read"] | ||
assert stats["rows_read"] == 100000, stats["rows_read"] | ||
assert stats["rows_seen"] == 100000, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 6.0, first | ||
|
||
if __name__ == "__main__": # pragma: no cover | ||
from tests.tools import run_tests | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.insert(1, os.path.join(sys.path[0], "../..")) | ||
|
||
import opteryx | ||
|
||
def test_sum_parquet(): | ||
cur = opteryx.query("SELECT SUM(followers) FROM testdata.flat.formats.parquet") | ||
stats = cur.stats | ||
assert stats["columns_read"] == 1, stats["columns_read"] | ||
assert stats["rows_read"] == 100000, stats["rows_read"] | ||
assert stats["rows_seen"] == 100000, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 308125800, first | ||
|
||
def test_sum_non_parquet(): | ||
cur = opteryx.query("SELECT SUM(followers) FROM testdata.flat.ten_files;") | ||
stats = cur.stats | ||
assert stats["columns_read"] == 1, stats["columns_read"] | ||
assert stats["rows_read"] == 250, stats["rows_read"] | ||
assert stats["rows_seen"] == 250, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 1875090667, first | ||
|
||
def test_sum_group_by(): | ||
""" we're reading data from the file, even though it starts SELECT COUNT(*) FROM """ | ||
cur = opteryx.query( | ||
"SELECT SUM(followers) FROM testdata.flat.formats.parquet GROUP BY tweet_id ORDER BY tweet_id;" | ||
) | ||
stats = cur.stats | ||
assert stats["columns_read"] == 2, stats["columns_read"] | ||
assert stats["rows_read"] == 100000, stats["rows_read"] | ||
assert stats["rows_seen"] == 100000, stats["rows_seen"] | ||
first = cur.fetchone()[0] | ||
assert first == 6.0, first | ||
|
||
if __name__ == "__main__": # pragma: no cover | ||
from tests.tools import run_tests | ||
|
||
run_tests() |