Skip to content

Commit

Permalink
The Bag Partition is now configurable. (#33805)
Browse files Browse the repository at this point in the history
* The Bag Partition is now configurable.

Configuring the number of partitions in the Dask runner is very important to tune performance. This CL gives users control over this parameter.

* Apply formatter.

* Passing lint via the `run_pylint.sh` script.

* Implementing review feedback.

* Attempting to pass lint/fmt check.

* Fixing isort issues by reading CI output.

* More indentation.

* rm blank like for isort.
  • Loading branch information
alxmrs authored Feb 1, 2025
1 parent b1d5e00 commit bfa0c59
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
* Support the Process Environment for execution in Prism ([#33651](https://github.com/apache/beam/pull/33651))
* Support the AnyOf Environment for execution in Prism ([#33705](https://github.com/apache/beam/pull/33705))
* This improves support for developing Xlang pipelines, when using a compatible cross language service.
* Partitions are now configurable for the DaskRunner in the Python SDK ([#33805](https://github.com/apache/beam/pull/33805)).

## Breaking Changes

Expand Down
39 changes: 35 additions & 4 deletions sdks/python/apache_beam/runners/dask/dask_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def _parse_timeout(candidate):
import dask
return dask.config.no_default

@staticmethod
def _extract_bag_kwargs(dask_options: t.Dict) -> t.Dict:
"""Parse keyword arguments for `dask.Bag`s; used in graph translation."""
out = {}

if npartitions := dask_options.pop('npartitions', None):
out['npartitions'] = npartitions
if partition_size := dask_options.pop('partition_size', None):
out['partition_size'] = partition_size

return out

@classmethod
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
Expand Down Expand Up @@ -93,6 +105,21 @@ def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
default=512,
help='The number of open comms to maintain at once in the connection '
'pool.')
partitions_parser = parser.add_mutually_exclusive_group()
partitions_parser.add_argument(
'--dask_npartitions',
dest='npartitions',
type=int,
default=None,
help='The desired number of `dask.Bag` partitions. When unspecified, '
'an educated guess is made.')
partitions_parser.add_argument(
'--dask_partition_size',
dest='partition_size',
type=int,
default=None,
help='The length of each `dask.Bag` partition. When unspecified, '
'an educated guess is made.')


@dataclasses.dataclass
Expand Down Expand Up @@ -139,17 +166,20 @@ def metrics(self):
class DaskRunner(BundleBasedDirectRunner):
"""Executes a pipeline on a Dask distributed client."""
@staticmethod
def to_dask_bag_visitor() -> PipelineVisitor:
def to_dask_bag_visitor(bag_kwargs=None) -> PipelineVisitor:
from dask import bag as db

if bag_kwargs is None:
bag_kwargs = {}

@dataclasses.dataclass
class DaskBagVisitor(PipelineVisitor):
bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field(
default_factory=collections.OrderedDict)

def visit_transform(self, transform_node: AppliedPTransform) -> None:
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
op = op_class(transform_node)
op = op_class(transform_node, bag_kwargs=bag_kwargs)

op_kws = {"input_bag": None, "side_inputs": None}
inputs = list(transform_node.inputs)
Expand Down Expand Up @@ -195,7 +225,7 @@ def is_fnapi_compatible():
def run_pipeline(self, pipeline, options):
import dask

# TODO(alxr): Create interactive notebook support.
# TODO(alxmrs): Create interactive notebook support.
if is_in_notebook():
raise NotImplementedError('interactive support will come later!')

Expand All @@ -207,11 +237,12 @@ def run_pipeline(self, pipeline, options):

dask_options = options.view_as(DaskOptions).get_all_options(
drop_default=True)
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
client = ddist.Client(**dask_options)

pipeline.replace_all(dask_overrides())

dask_visitor = self.to_dask_bag_visitor()
dask_visitor = self.to_dask_bag_visitor(bag_kwargs)
pipeline.visit(dask_visitor)
# The dictionary in this visitor keeps a mapping of every Beam
# PTransform to the equivalent Bag operation. This is highly
Expand Down
19 changes: 19 additions & 0 deletions sdks/python/apache_beam/runners/dask/dask_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ def test_parser_destinations__agree_with_dask_client(self):
with self.subTest(f'{opt_name} in dask.distributed.Client constructor'):
self.assertIn(opt_name, client_args)

def test_parser_extract_bag_kwargs__deletes_dask_kwargs(self):
options = PipelineOptions('--dask_npartitions 8'.split())
dask_options = options.view_as(DaskOptions).get_all_options()

self.assertIn('npartitions', dask_options)
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
self.assertNotIn('npartitions', dask_options)
self.assertEqual(bag_kwargs, {'npartitions': 8})

def test_parser_extract_bag_kwargs__unconfigured(self):
options = PipelineOptions()
dask_options = options.view_as(DaskOptions).get_all_options()

# It's present as a default option.
self.assertIn('npartitions', dask_options)
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
self.assertNotIn('npartitions', dask_options)
self.assertEqual(bag_kwargs, {})


class DaskRunnerRunPipelineTest(unittest.TestCase):
"""Test class used to introspect the dask runner via a debugger."""
Expand Down
30 changes: 27 additions & 3 deletions sdks/python/apache_beam/runners/dask/transform_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
import abc
import dataclasses
import logging
import math
import typing as t
from dataclasses import field
Expand Down Expand Up @@ -52,6 +53,8 @@
# Value types for PCollections (possibly Windowed Values).
PCollVal = t.Union[WindowedValue, t.Any]

_LOGGER = logging.getLogger(__name__)


def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue:
"""Wraps a value (item) inside a Window."""
Expand Down Expand Up @@ -127,8 +130,11 @@ class DaskBagOp(abc.ABC):
Attributes
applied: The underlying `AppliedPTransform` which holds the code for the
target operation.
bag_kwargs: (optional) Keyword arguments applied to input bags, usually
from the pipeline's `DaskOptions`.
"""
applied: AppliedPTransform
bag_kwargs: t.Dict = dataclasses.field(default_factory=dict)

@property
def transform(self):
Expand All @@ -151,10 +157,28 @@ def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
assert input_bag is None, 'Create expects no input!'
original_transform = t.cast(_Create, self.transform)
items = original_transform.values

npartitions = self.bag_kwargs.get('npartitions')
partition_size = self.bag_kwargs.get('partition_size')
if npartitions and partition_size:
raise ValueError(
f'Please specify either `dask_npartitions` or '
f'`dask_parition_size` but not both: '
f'{npartitions=}, {partition_size=}.')
if not npartitions and not partition_size:
# partition_size is inversely related to `npartitions`.
# Ideal "chunk sizes" in dask are around 10-100 MBs.
# Let's hope ~128 items per partition is around this
# memory overhead.
default_size = 128
partition_size = max(default_size, math.ceil(math.sqrt(len(items)) / 10))
if partition_size == default_size:
_LOGGER.warning(
'The new default partition size is %d, it used to be 1 '
'in previous DaskRunner versions.' % default_size)

return db.from_sequence(
items,
partition_size=max(
1, math.ceil(math.sqrt(len(items)) / math.sqrt(100))))
items, npartitions=npartitions, partition_size=partition_size)


def apply_dofn_to_bundle(
Expand Down

0 comments on commit bfa0c59

Please sign in to comment.