From 07ffd70d6deb0f76e788fad78e7875fc25a99b66 Mon Sep 17 00:00:00 2001 From: joocer Date: Tue, 21 Jan 2025 22:33:31 +0000 Subject: [PATCH] #2264 --- opteryx/operators/filter_node.py | 70 ++++++++------------------------ 1 file changed, 16 insertions(+), 54 deletions(-) diff --git a/opteryx/operators/filter_node.py b/opteryx/operators/filter_node.py index e95c6c692..f4fad1a72 100644 --- a/opteryx/operators/filter_node.py +++ b/opteryx/operators/filter_node.py @@ -11,13 +11,10 @@ This node is responsible for applying filters to datasets. """ -import multiprocessing - import numpy import pyarrow from opteryx import EOS -from opteryx.config import CONCURRENT_WORKERS from opteryx.exceptions import SqlError from opteryx.managers.expression import NodeType from opteryx.managers.expression import evaluate @@ -28,32 +25,6 @@ from . import BasePlanNode -multiprocessing.set_start_method("fork", force=True) - - -def _parallel_filter(queue, morsel, function_evaluations, filters): - if function_evaluations: - morsel = evaluate_and_append(function_evaluations, morsel) - mask = evaluate(filters, morsel) - - if not isinstance(mask, pyarrow.lib.BooleanArray): - try: - mask = pyarrow.array(mask, type=pyarrow.bool_()) - except Exception as err: # nosec - raise SqlError(f"Unable to filter on expression '{format_expression(filters)} {err}'.") - - mask = numpy.nonzero(mask)[0] - # if there's no matching rows, don't return anything - if mask.size > 0 and not numpy.all(mask is None): - morsel = morsel.take(pyarrow.array(mask)) - else: - morsel = morsel.slice(0, 0) - - if queue is not None: - queue.put(morsel) - else: - return morsel - class FilterNode(BasePlanNode): def __init__(self, properties: QueryProperties, **parameters): @@ -82,30 +53,21 @@ def execute(self, morsel: pyarrow.Table, **kwargs) -> pyarrow.Table: yield morsel return - if morsel.num_rows <= 10000 or CONCURRENT_WORKERS <= 2: - yield _parallel_filter(None, morsel, self.function_evaluations, self.filter) - else: - workers = [] - queue = multiprocessing.Queue() - - for block in morsel.to_batches(((morsel.num_rows + 1) // CONCURRENT_WORKERS)): - block = pyarrow.Table.from_batches([block]) - p = multiprocessing.Process( - target=_parallel_filter, - args=(queue, block, self.function_evaluations, self.filter), - ) - p.start() - workers.append(p) - - # Collect all results from the queue - results = [] - for _ in workers: # Expecting one result per worker - results.append(queue.get()) # This will block until a result is available + if self.function_evaluations: + morsel = evaluate_and_append(self.function_evaluations, morsel) + mask = evaluate(self.filter, morsel) - # Merge all results and return them - if results: - yield pyarrow.concat_tables(results) + if not isinstance(mask, pyarrow.lib.BooleanArray): + try: + mask = pyarrow.array(mask, type=pyarrow.bool_()) + except Exception as err: # nosec + raise SqlError( + f"Unable to filter on expression '{format_expression(self.filter)} {err}'." + ) + mask = numpy.nonzero(mask)[0] - # Ensure all workers have finished before exiting - for p in workers: - p.join() + # if there's no matching rows, don't return anything + if mask.size > 0 and not numpy.all(mask is None): + yield morsel.take(pyarrow.array(mask)) + else: + yield morsel.slice(0, 0)