Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Jan 21, 2025
1 parent bd0bc7b commit 07ffd70
Showing 1 changed file with 16 additions and 54 deletions.
70 changes: 16 additions & 54 deletions opteryx/operators/filter_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
This node is responsible for applying filters to datasets.
"""

import multiprocessing

import numpy
import pyarrow

from opteryx import EOS
from opteryx.config import CONCURRENT_WORKERS
from opteryx.exceptions import SqlError
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import evaluate
Expand All @@ -28,32 +25,6 @@

from . import BasePlanNode

multiprocessing.set_start_method("fork", force=True)


def _parallel_filter(queue, morsel, function_evaluations, filters):
if function_evaluations:
morsel = evaluate_and_append(function_evaluations, morsel)
mask = evaluate(filters, morsel)

if not isinstance(mask, pyarrow.lib.BooleanArray):
try:
mask = pyarrow.array(mask, type=pyarrow.bool_())
except Exception as err: # nosec
raise SqlError(f"Unable to filter on expression '{format_expression(filters)} {err}'.")

mask = numpy.nonzero(mask)[0]
# if there's no matching rows, don't return anything
if mask.size > 0 and not numpy.all(mask is None):
morsel = morsel.take(pyarrow.array(mask))
else:
morsel = morsel.slice(0, 0)

if queue is not None:
queue.put(morsel)
else:
return morsel


class FilterNode(BasePlanNode):
def __init__(self, properties: QueryProperties, **parameters):
Expand Down Expand Up @@ -82,30 +53,21 @@ def execute(self, morsel: pyarrow.Table, **kwargs) -> pyarrow.Table:
yield morsel
return

if morsel.num_rows <= 10000 or CONCURRENT_WORKERS <= 2:
yield _parallel_filter(None, morsel, self.function_evaluations, self.filter)
else:
workers = []
queue = multiprocessing.Queue()

for block in morsel.to_batches(((morsel.num_rows + 1) // CONCURRENT_WORKERS)):
block = pyarrow.Table.from_batches([block])
p = multiprocessing.Process(
target=_parallel_filter,
args=(queue, block, self.function_evaluations, self.filter),
)
p.start()
workers.append(p)

# Collect all results from the queue
results = []
for _ in workers: # Expecting one result per worker
results.append(queue.get()) # This will block until a result is available
if self.function_evaluations:
morsel = evaluate_and_append(self.function_evaluations, morsel)
mask = evaluate(self.filter, morsel)

# Merge all results and return them
if results:
yield pyarrow.concat_tables(results)
if not isinstance(mask, pyarrow.lib.BooleanArray):
try:
mask = pyarrow.array(mask, type=pyarrow.bool_())
except Exception as err: # nosec
raise SqlError(
f"Unable to filter on expression '{format_expression(self.filter)} {err}'."
)
mask = numpy.nonzero(mask)[0]

# Ensure all workers have finished before exiting
for p in workers:
p.join()
# if there's no matching rows, don't return anything
if mask.size > 0 and not numpy.all(mask is None):
yield morsel.take(pyarrow.array(mask))
else:
yield morsel.slice(0, 0)

0 comments on commit 07ffd70

Please sign in to comment.