diff --git a/CHANGES.md b/CHANGES.md index 799d26dc05e..fde00b9da4c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -86,6 +86,7 @@ * Support the Process Environment for execution in Prism ([#33651](https://github.com/apache/beam/pull/33651)) * Support the AnyOf Environment for execution in Prism ([#33705](https://github.com/apache/beam/pull/33705)) * This improves support for developing Xlang pipelines, when using a compatible cross language service. +* Partitions are now configurable for the DaskRunner in the Python SDK ([#33805](https://github.com/apache/beam/pull/33805)). ## Breaking Changes diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py b/sdks/python/apache_beam/runners/dask/dask_runner.py index cc17d9919b8..8975fcf1e13 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner.py @@ -58,6 +58,18 @@ def _parse_timeout(candidate): import dask return dask.config.no_default + @staticmethod + def _extract_bag_kwargs(dask_options: t.Dict) -> t.Dict: + """Parse keyword arguments for `dask.Bag`s; used in graph translation.""" + out = {} + + if npartitions := dask_options.pop('npartitions', None): + out['npartitions'] = npartitions + if partition_size := dask_options.pop('partition_size', None): + out['partition_size'] = partition_size + + return out + @classmethod def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None: parser.add_argument( @@ -93,6 +105,21 @@ def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None: default=512, help='The number of open comms to maintain at once in the connection ' 'pool.') + partitions_parser = parser.add_mutually_exclusive_group() + partitions_parser.add_argument( + '--dask_npartitions', + dest='npartitions', + type=int, + default=None, + help='The desired number of `dask.Bag` partitions. When unspecified, ' + 'an educated guess is made.') + partitions_parser.add_argument( + '--dask_partition_size', + dest='partition_size', + type=int, + default=None, + help='The length of each `dask.Bag` partition. When unspecified, ' + 'an educated guess is made.') @dataclasses.dataclass @@ -139,9 +166,12 @@ def metrics(self): class DaskRunner(BundleBasedDirectRunner): """Executes a pipeline on a Dask distributed client.""" @staticmethod - def to_dask_bag_visitor() -> PipelineVisitor: + def to_dask_bag_visitor(bag_kwargs=None) -> PipelineVisitor: from dask import bag as db + if bag_kwargs is None: + bag_kwargs = {} + @dataclasses.dataclass class DaskBagVisitor(PipelineVisitor): bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field( @@ -149,7 +179,7 @@ class DaskBagVisitor(PipelineVisitor): def visit_transform(self, transform_node: AppliedPTransform) -> None: op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp) - op = op_class(transform_node) + op = op_class(transform_node, bag_kwargs=bag_kwargs) op_kws = {"input_bag": None, "side_inputs": None} inputs = list(transform_node.inputs) @@ -195,7 +225,7 @@ def is_fnapi_compatible(): def run_pipeline(self, pipeline, options): import dask - # TODO(alxr): Create interactive notebook support. + # TODO(alxmrs): Create interactive notebook support. if is_in_notebook(): raise NotImplementedError('interactive support will come later!') @@ -207,11 +237,12 @@ def run_pipeline(self, pipeline, options): dask_options = options.view_as(DaskOptions).get_all_options( drop_default=True) + bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options) client = ddist.Client(**dask_options) pipeline.replace_all(dask_overrides()) - dask_visitor = self.to_dask_bag_visitor() + dask_visitor = self.to_dask_bag_visitor(bag_kwargs) pipeline.visit(dask_visitor) # The dictionary in this visitor keeps a mapping of every Beam # PTransform to the equivalent Bag operation. This is highly diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py b/sdks/python/apache_beam/runners/dask/dask_runner_test.py index 66dda4a984f..afe363ba3ee 100644 --- a/sdks/python/apache_beam/runners/dask/dask_runner_test.py +++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py @@ -66,6 +66,25 @@ def test_parser_destinations__agree_with_dask_client(self): with self.subTest(f'{opt_name} in dask.distributed.Client constructor'): self.assertIn(opt_name, client_args) + def test_parser_extract_bag_kwargs__deletes_dask_kwargs(self): + options = PipelineOptions('--dask_npartitions 8'.split()) + dask_options = options.view_as(DaskOptions).get_all_options() + + self.assertIn('npartitions', dask_options) + bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options) + self.assertNotIn('npartitions', dask_options) + self.assertEqual(bag_kwargs, {'npartitions': 8}) + + def test_parser_extract_bag_kwargs__unconfigured(self): + options = PipelineOptions() + dask_options = options.view_as(DaskOptions).get_all_options() + + # It's present as a default option. + self.assertIn('npartitions', dask_options) + bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options) + self.assertNotIn('npartitions', dask_options) + self.assertEqual(bag_kwargs, {}) + class DaskRunnerRunPipelineTest(unittest.TestCase): """Test class used to introspect the dask runner via a debugger.""" diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py b/sdks/python/apache_beam/runners/dask/transform_evaluator.py index e72ebcce8b1..7cad1fe4045 100644 --- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py @@ -22,6 +22,7 @@ """ import abc import dataclasses +import logging import math import typing as t from dataclasses import field @@ -52,6 +53,8 @@ # Value types for PCollections (possibly Windowed Values). PCollVal = t.Union[WindowedValue, t.Any] +_LOGGER = logging.getLogger(__name__) + def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue: """Wraps a value (item) inside a Window.""" @@ -127,8 +130,11 @@ class DaskBagOp(abc.ABC): Attributes applied: The underlying `AppliedPTransform` which holds the code for the target operation. + bag_kwargs: (optional) Keyword arguments applied to input bags, usually + from the pipeline's `DaskOptions`. """ applied: AppliedPTransform + bag_kwargs: t.Dict = dataclasses.field(default_factory=dict) @property def transform(self): @@ -151,10 +157,28 @@ def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag: assert input_bag is None, 'Create expects no input!' original_transform = t.cast(_Create, self.transform) items = original_transform.values + + npartitions = self.bag_kwargs.get('npartitions') + partition_size = self.bag_kwargs.get('partition_size') + if npartitions and partition_size: + raise ValueError( + f'Please specify either `dask_npartitions` or ' + f'`dask_parition_size` but not both: ' + f'{npartitions=}, {partition_size=}.') + if not npartitions and not partition_size: + # partition_size is inversely related to `npartitions`. + # Ideal "chunk sizes" in dask are around 10-100 MBs. + # Let's hope ~128 items per partition is around this + # memory overhead. + default_size = 128 + partition_size = max(default_size, math.ceil(math.sqrt(len(items)) / 10)) + if partition_size == default_size: + _LOGGER.warning( + 'The new default partition size is %d, it used to be 1 ' + 'in previous DaskRunner versions.' % default_size) + return db.from_sequence( - items, - partition_size=max( - 1, math.ceil(math.sqrt(len(items)) / math.sqrt(100)))) + items, npartitions=npartitions, partition_size=partition_size) def apply_dofn_to_bundle(