From 192461df4af04e0f61f66d945975f9ad5674d314 Mon Sep 17 00:00:00 2001 From: Harvey Devereux Date: Tue, 21 Jan 2025 15:31:55 +0000 Subject: [PATCH 1/6] Add correlation_kwargs parser for CLI --- janus_core/calculations/md.py | 2 +- janus_core/cli/md.py | 9 ++- janus_core/cli/types.py | 15 +++++ janus_core/cli/utils.py | 85 +++++++++++++++++++++++++++- janus_core/helpers/janus_types.py | 4 +- janus_core/processing/observables.py | 10 ++-- tests/test_md_cli.py | 10 ++++ 7 files changed, 125 insertions(+), 10 deletions(-) diff --git a/janus_core/calculations/md.py b/janus_core/calculations/md.py index a5bfde26..576008bf 100644 --- a/janus_core/calculations/md.py +++ b/janus_core/calculations/md.py @@ -153,7 +153,7 @@ class MolecularDynamics(BaseCalculation): files. Default is {}. post_process_kwargs : PostProcessKwargs | None Keyword arguments to control post-processing operations. - correlation_kwargs : CorrelationKwargs | None + correlation_kwargs : list[CorrelationKwargs] | None Keyword arguments to control on-the-fly correlations. seed : int | None Random seed used by numpy.random and random functions, such as in Langevin. diff --git a/janus_core/cli/md.py b/janus_core/cli/md.py index eb435ace..2a39884c 100644 --- a/janus_core/cli/md.py +++ b/janus_core/cli/md.py @@ -11,6 +11,7 @@ from janus_core.cli.types import ( Architecture, CalcKwargs, + CorrelationKwargs, Device, EnsembleKwargs, LogPath, @@ -22,7 +23,7 @@ Summary, WriteKwargs, ) -from janus_core.cli.utils import yaml_converter_callback +from janus_core.cli.utils import parse_correlation_kwargs, yaml_converter_callback app = Typer() @@ -173,6 +174,7 @@ def md( ] = None, write_kwargs: WriteKwargs = None, post_process_kwargs: PostProcessKwargs = None, + correlation_kwargs: CorrelationKwargs = None, seed: Annotated[ int | None, Option(help="Random seed for numpy.random and random functions."), @@ -289,6 +291,8 @@ def md( files. Default is {}. post_process_kwargs : Optional[PostProcessKwargs] Kwargs to pass to post-processing. + correlation_kwargs : Optional[CorrelationKwargs] + Kwrag to pass for on-the-fly correlations. seed : Optional[int] Random seed used by numpy.random and random functions, such as in Langevin. Default is None. @@ -317,7 +321,6 @@ def md( # Check options from configuration file are all valid check_config(ctx) - [ read_kwargs, calc_kwargs, @@ -335,6 +338,7 @@ def md( post_process_kwargs, ] ) + correlation_kwargs = parse_correlation_kwargs(correlation_kwargs) if ensemble not in get_args(Ensembles): raise ValueError(f"ensemble must be in {get_args(Ensembles)}") @@ -393,6 +397,7 @@ def md( "temp_time": temp_time, "write_kwargs": write_kwargs, "post_process_kwargs": post_process_kwargs, + "correlation_kwargs": correlation_kwargs, "seed": seed, } diff --git a/janus_core/cli/types.py b/janus_core/cli/types.py index c8e73842..25416b9e 100644 --- a/janus_core/cli/types.py +++ b/janus_core/cli/types.py @@ -229,6 +229,21 @@ def __str__(self) -> str: ), ] +CorrelationKwargs = Annotated[ + TyperDict | None, + Option( + parser=parse_dict_class, + help=( + """ + Keyword arguments to pass to md for on-the-fly correlations. Must be + passed as a list of dictionaries wrapped in quotes, e.g. + "[{'key' : values}]". + """ + ), + metavar="DICT", + ), +] + LogPath = Annotated[ Path | None, Option( diff --git a/janus_core/cli/utils.py b/janus_core/cli/utils.py index c92718f9..d7e632f2 100644 --- a/janus_core/cli/utils.py +++ b/janus_core/cli/utils.py @@ -15,7 +15,7 @@ from ase import Atoms from typer import Context - from janus_core.cli.types import TyperDict + from janus_core.cli.types import CorrelationKwargs, TyperDict from janus_core.helpers.janus_types import ( Architectures, ASEReadArgs, @@ -23,6 +23,8 @@ MaybeSequence, ) +from janus_core.processing.observables import Observable, Stress, Velocity + def dict_paths_to_strs(dictionary: dict) -> None: """ @@ -309,3 +311,84 @@ def check_config(ctx: Context) -> None: # Check options individually so can inform user of specific issue if option not in ctx.params: raise ValueError(f"'{option}' in configuration file is not a valid option") + + +def _select_observable(name: str, kwargs: dict) -> Observable: + """ + Select an Observable from a string. + + Parameters + ---------- + name : str + The name of an Observable to convert. + kwargs : dict + A list of kwargs of the Observables init. + + Returns + ------- + Observable + The selected observable. + """ + if name.lower() == "velocity": + return Velocity(**kwargs) + if name.lower() == "stress": + return Stress(**kwargs) + raise ValueError(f"Observable {name} is not valid") + + +def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]: + """ + Parse CLI CorrelationKwargs to md correlation_kwargs. + + Parameters + ---------- + kwargs : CorrelationKwargs + CLI correlation keyword options. + + Returns + ------- + List[dict] + The parsed correlation_kwargs for md. + """ + parsed_kwargs = [] + for name, kwarg in kwargs.value.items(): + if "a" not in kwarg and "b" not in kwarg: + raise ValueError("At least on observable must be supplied as 'a' or 'b'") + + if "b" not in kwarg: + a = kwarg["a"] + b = a + elif "a" not in kwarg: + a = kwarg["b"] + b = a + else: + a = kwarg["a"] + b = kwarg["b"] + + a_kwargs = kwarg["a_kwargs"] if "a_kwargs" in kwarg else {} + b_kwargs = kwarg["b_kwargs"] if "b_kwargs" in kwarg else {} + + if a_kwargs and b_kwargs == ".": + b_kwargs = a_kwargs + elif b_kwargs and a_kwargs == ".": + a_kwargs = b_kwargs + + blocks = kwarg["blocks"] if "blocks" in kwarg else 1 + points = kwarg["points"] if "blocks" in kwarg else 1 + averaging = kwarg["blocks"] if "blocks" in kwarg else 1 + update_frequency = ( + kwarg["update_frequency"] if "update_frequency" in kwarg else 1 + ) + + parsed_kwargs.append( + { + "name": name, + "a": _select_observable(a, a_kwargs), + "b": _select_observable(b, b_kwargs), + "blocks": blocks, + "points": points, + "averaging": averaging, + "update_frequency": update_frequency, + } + ) + return parsed_kwargs diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index f31a4b8b..98d3b8ce 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -80,7 +80,7 @@ class PostProcessKwargs(TypedDict, total=False): vaf_output_file: PathLike | None -class CorrelationKwargs(TypedDict, total=True): +class Correlation(TypedDict, total=True): """Arguments for on-the-fly correlations .""" #: observable a in , with optional args and kwargs @@ -99,6 +99,8 @@ class CorrelationKwargs(TypedDict, total=True): update_frequency: int +CorrelationKwargs = list[Correlation] + # eos_names from ase.eos EoSNames = Literal[ "sj", diff --git a/janus_core/processing/observables.py b/janus_core/processing/observables.py index 1110a9d5..348892d6 100644 --- a/janus_core/processing/observables.py +++ b/janus_core/processing/observables.py @@ -237,7 +237,7 @@ class Velocity(Observable, ComponentMixin): def __init__( self, *, - components: list[str], + components: list[str] | None = None, atoms_slice: list[int] | SliceLike | None = None, ): """ @@ -245,13 +245,13 @@ def __init__( Parameters ---------- - components : list[str] - Symbols for tensor components, x, y, and z. - atoms_slice : Union[list[int], SliceLike] + components : list[str] | None + Symbols for returned velocity components, x, y, and z (default is all). + atoms_slice : Union[list[int], SliceLike, None] List or slice of atoms to observe velocities from. """ ComponentMixin.__init__(self, components={"x": 0, "y": 1, "z": 2}) - self.components = components + self.components = components if components else ["x", "y", "z"] Observable.__init__(self, atoms_slice) diff --git a/tests/test_md_cli.py b/tests/test_md_cli.py index 970c84cc..eaae1430 100644 --- a/tests/test_md_cli.py +++ b/tests/test_md_cli.py @@ -65,6 +65,7 @@ def test_md(ensemble): traj_path = Path(f"{file_prefix[ensemble]}traj.extxyz").absolute() rdf_path = Path(f"{file_prefix[ensemble]}rdf.dat").absolute() vaf_path = Path(f"{file_prefix[ensemble]}vaf.dat").absolute() + cor_path = Path(f"{file_prefix[ensemble]}cor.dat").absolute() log_path = Path(f"{file_prefix[ensemble]}md-log.yml").absolute() summary_path = Path(f"{file_prefix[ensemble]}md-summary.yml").absolute() @@ -74,6 +75,7 @@ def test_md(ensemble): assert not traj_path.exists() assert not rdf_path.exists() assert not vaf_path.exists() + assert not cor_path.exists() assert not log_path.exists() assert not summary_path.exists() @@ -96,6 +98,12 @@ def test_md(ensemble): 2, "--post-process-kwargs", "{'rdf_compute': True, 'vaf_compute': True}", + "--correlation-kwargs", + ( + "[{'a': Velocity(), 'b': Velocity(), 'name':" + "'vaf', 'blocks': 1, 'points': 1, 'averaging'" + ": 1, 'update_frequency': 1}]" + ), ], ) @@ -107,6 +115,7 @@ def test_md(ensemble): assert traj_path.exists() assert rdf_path.exists() assert vaf_path.exists() + assert cor_path.exists() assert log_path.exists() assert summary_path.exists() @@ -144,6 +153,7 @@ def test_md(ensemble): traj_path.unlink(missing_ok=True) rdf_path.unlink(missing_ok=True) vaf_path.unlink(missing_ok=True) + cor_path.unlink(missing_ok=True) log_path.unlink(missing_ok=True) summary_path.unlink(missing_ok=True) clear_log_handlers() From d188d54d8cddd4960b45369be4dc7d5b15c73d70 Mon Sep 17 00:00:00 2001 From: Harvey Devereux Date: Tue, 21 Jan 2025 15:42:21 +0000 Subject: [PATCH 2/6] Test ditto --- janus_core/cli/utils.py | 4 ++++ tests/test_md_cli.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/janus_core/cli/utils.py b/janus_core/cli/utils.py index d7e632f2..6089a6a9 100644 --- a/janus_core/cli/utils.py +++ b/janus_core/cli/utils.py @@ -355,6 +355,7 @@ def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]: if "a" not in kwarg and "b" not in kwarg: raise ValueError("At least on observable must be supplied as 'a' or 'b'") + # Accept on Observable to be replicated. if "b" not in kwarg: a = kwarg["a"] b = a @@ -368,6 +369,9 @@ def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]: a_kwargs = kwarg["a_kwargs"] if "a_kwargs" in kwarg else {} b_kwargs = kwarg["b_kwargs"] if "b_kwargs" in kwarg else {} + # Accept "." in place of one kwargs to repeat. + if a_kwargs == "." and b_kwargs == ".": + raise ValueError("a_kwargs and b_kwargs cannot 'ditto' eachother") if a_kwargs and b_kwargs == ".": b_kwargs = a_kwargs elif b_kwargs and a_kwargs == ".": diff --git a/tests/test_md_cli.py b/tests/test_md_cli.py index eaae1430..a203cd84 100644 --- a/tests/test_md_cli.py +++ b/tests/test_md_cli.py @@ -100,9 +100,9 @@ def test_md(ensemble): "{'rdf_compute': True, 'vaf_compute': True}", "--correlation-kwargs", ( - "[{'a': Velocity(), 'b': Velocity(), 'name':" - "'vaf', 'blocks': 1, 'points': 1, 'averaging'" - ": 1, 'update_frequency': 1}]" + "{'vaf': {'a': 'Velocity'}," + " 'vaf_x': {'a': 'velocity'," + "'a_kwargs': {'components': ['x']}, 'b_kwargs': '.'}}" ), ], ) From 4c9cc2d13f46aa2f5e05207a39b47c8ccebceb6c Mon Sep 17 00:00:00 2001 From: Harvey Devereux Date: Tue, 21 Jan 2025 15:53:51 +0000 Subject: [PATCH 3/6] Fix bug --- janus_core/cli/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/janus_core/cli/utils.py b/janus_core/cli/utils.py index 6089a6a9..322bd5ed 100644 --- a/janus_core/cli/utils.py +++ b/janus_core/cli/utils.py @@ -378,8 +378,8 @@ def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]: a_kwargs = b_kwargs blocks = kwarg["blocks"] if "blocks" in kwarg else 1 - points = kwarg["points"] if "blocks" in kwarg else 1 - averaging = kwarg["blocks"] if "blocks" in kwarg else 1 + points = kwarg["points"] if "points" in kwarg else 1 + averaging = kwarg["averaging"] if "averaging" in kwarg else 1 update_frequency = ( kwarg["update_frequency"] if "update_frequency" in kwarg else 1 ) From c9fbd3bebcf69e7819e346afe49d7b9f7651ec15 Mon Sep 17 00:00:00 2001 From: Harvey Devereux Date: Wed, 22 Jan 2025 09:07:53 +0000 Subject: [PATCH 4/6] Use getattr --- janus_core/cli/utils.py | 29 +++-------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/janus_core/cli/utils.py b/janus_core/cli/utils.py index 322bd5ed..f4609017 100644 --- a/janus_core/cli/utils.py +++ b/janus_core/cli/utils.py @@ -23,7 +23,7 @@ MaybeSequence, ) -from janus_core.processing.observables import Observable, Stress, Velocity +from janus_core.processing import observables def dict_paths_to_strs(dictionary: dict) -> None: @@ -313,29 +313,6 @@ def check_config(ctx: Context) -> None: raise ValueError(f"'{option}' in configuration file is not a valid option") -def _select_observable(name: str, kwargs: dict) -> Observable: - """ - Select an Observable from a string. - - Parameters - ---------- - name : str - The name of an Observable to convert. - kwargs : dict - A list of kwargs of the Observables init. - - Returns - ------- - Observable - The selected observable. - """ - if name.lower() == "velocity": - return Velocity(**kwargs) - if name.lower() == "stress": - return Stress(**kwargs) - raise ValueError(f"Observable {name} is not valid") - - def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]: """ Parse CLI CorrelationKwargs to md correlation_kwargs. @@ -387,8 +364,8 @@ def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]: parsed_kwargs.append( { "name": name, - "a": _select_observable(a, a_kwargs), - "b": _select_observable(b, b_kwargs), + "a": getattr(observables, a)(**a_kwargs), + "b": getattr(observables, b)(**b_kwargs), "blocks": blocks, "points": points, "averaging": averaging, From a29fa35b34f079ff36669abee5c19cda2abc64ce Mon Sep 17 00:00:00 2001 From: Harvey Devereux Date: Fri, 24 Jan 2025 10:38:25 +0000 Subject: [PATCH 5/6] Set defaults, test in test_correlator --- janus_core/cli/utils.py | 47 ++++++++++++----------------- janus_core/processing/correlator.py | 24 +++++++-------- tests/test_correlator.py | 11 +++++++ tests/test_md_cli.py | 2 +- 4 files changed, 44 insertions(+), 40 deletions(-) diff --git a/janus_core/cli/utils.py b/janus_core/cli/utils.py index f4609017..dcbe9a5e 100644 --- a/janus_core/cli/utils.py +++ b/janus_core/cli/utils.py @@ -328,23 +328,23 @@ def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]: The parsed correlation_kwargs for md. """ parsed_kwargs = [] - for name, kwarg in kwargs.value.items(): - if "a" not in kwarg and "b" not in kwarg: + for name, cli_kwargs in kwargs.value.items(): + if "a" not in cli_kwargs and "b" not in cli_kwargs: raise ValueError("At least on observable must be supplied as 'a' or 'b'") # Accept on Observable to be replicated. - if "b" not in kwarg: - a = kwarg["a"] + if "b" not in cli_kwargs: + a = cli_kwargs["a"] b = a - elif "a" not in kwarg: - a = kwarg["b"] + elif "a" not in cli_kwargs: + a = cli_kwargs["b"] b = a else: - a = kwarg["a"] - b = kwarg["b"] + a = cli_kwargs["a"] + b = cli_kwargs["b"] - a_kwargs = kwarg["a_kwargs"] if "a_kwargs" in kwarg else {} - b_kwargs = kwarg["b_kwargs"] if "b_kwargs" in kwarg else {} + a_kwargs = cli_kwargs["a_kwargs"] if "a_kwargs" in cli_kwargs else {} + b_kwargs = cli_kwargs["b_kwargs"] if "b_kwargs" in cli_kwargs else {} # Accept "." in place of one kwargs to repeat. if a_kwargs == "." and b_kwargs == ".": @@ -354,22 +354,15 @@ def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]: elif b_kwargs and a_kwargs == ".": a_kwargs = b_kwargs - blocks = kwarg["blocks"] if "blocks" in kwarg else 1 - points = kwarg["points"] if "points" in kwarg else 1 - averaging = kwarg["averaging"] if "averaging" in kwarg else 1 - update_frequency = ( - kwarg["update_frequency"] if "update_frequency" in kwarg else 1 - ) + cor_kwargs = { + "name": name, + "a": getattr(observables, a)(**a_kwargs), + "b": getattr(observables, b)(**b_kwargs), + } - parsed_kwargs.append( - { - "name": name, - "a": getattr(observables, a)(**a_kwargs), - "b": getattr(observables, b)(**b_kwargs), - "blocks": blocks, - "points": points, - "averaging": averaging, - "update_frequency": update_frequency, - } - ) + for optional in ["blocks", "points", "averaging", "update_frequency"]: + if optional in cli_kwargs: + cor_kwargs[optional] = cli_kwargs[optional] + + parsed_kwargs.append(cor_kwargs) return parsed_kwargs diff --git a/janus_core/processing/correlator.py b/janus_core/processing/correlator.py index 6461c4c8..b7591bb2 100644 --- a/janus_core/processing/correlator.py +++ b/janus_core/processing/correlator.py @@ -253,13 +253,13 @@ class Correlation: Observable for b. name : str Name of correlation. - blocks : int + blocks : int, default 1 Number of correlation blocks. - points : int + points : int, default 1 Number of points per block. - averaging : int + averaging : int, default 1 Averaging window per block level. - update_frequency : int + update_frequency : int, default 1 Frequency to update the correlation, md steps. """ @@ -269,10 +269,10 @@ def __init__( a: Observable, b: Observable, name: str, - blocks: int, - points: int, - averaging: int, - update_frequency: int, + blocks: int = 1, + points: int = 1, + averaging: int = 1, + update_frequency: int = 1, ) -> None: """ Initialise a correlation. @@ -285,13 +285,13 @@ def __init__( Observable for b. name : str Name of correlation. - blocks : int + blocks : int, default 1 Number of correlation blocks. - points : int + points : int, default 1 Number of points per block. - averaging : int + averaging : int, default 1 Averaging window per block level. - update_frequency : int + update_frequency : int, default 1 Frequency to update the correlation, md steps. """ self.name = name diff --git a/tests/test_correlator.py b/tests/test_correlator.py index ded1aba0..ee61dce0 100644 --- a/tests/test_correlator.py +++ b/tests/test_correlator.py @@ -113,6 +113,11 @@ def test_vaf(tmp_path): "averaging": 1, "update_frequency": 1, }, + { + "a": Velocity(), + "b": Velocity(), + "name": "vaf_default", + }, ], write_kwargs={"invalidate_calc": False}, ) @@ -133,6 +138,12 @@ def test_vaf(tmp_path): assert vaf_na * 3 == approx(vaf_post[1][0], rel=1e-5) assert vaf_cl * 3 == approx(vaf_post[1][1], rel=1e-5) + # Default arguments are equivalent to mean square velocities. + v = np.mean([np.mean(atoms.get_velocities() ** 2) for atoms in traj]) + vaf_default = vaf["vaf_default"] + assert len(vaf_default["value"]) == 1 + assert v == approx(vaf_default["value"][0], rel=1e-5) + def test_md_correlations(tmp_path): """Test correlations as part of MD cycle.""" diff --git a/tests/test_md_cli.py b/tests/test_md_cli.py index a203cd84..683cc9e1 100644 --- a/tests/test_md_cli.py +++ b/tests/test_md_cli.py @@ -101,7 +101,7 @@ def test_md(ensemble): "--correlation-kwargs", ( "{'vaf': {'a': 'Velocity'}," - " 'vaf_x': {'a': 'velocity'," + " 'vaf_x': {'a': 'Velocity'," "'a_kwargs': {'components': ['x']}, 'b_kwargs': '.'}}" ), ], From 3703c087eb97091880edf7cfb0c56a6391046966 Mon Sep 17 00:00:00 2001 From: Harvey Devereux Date: Fri, 24 Jan 2025 11:03:32 +0000 Subject: [PATCH 6/6] Check none --- janus_core/cli/md.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/janus_core/cli/md.py b/janus_core/cli/md.py index ed3af31a..6400a10e 100644 --- a/janus_core/cli/md.py +++ b/janus_core/cli/md.py @@ -338,7 +338,9 @@ def md( post_process_kwargs, ] ) - correlation_kwargs = parse_correlation_kwargs(correlation_kwargs) + + if correlation_kwargs: + correlation_kwargs = parse_correlation_kwargs(correlation_kwargs) if ensemble not in get_args(Ensembles): raise ValueError(f"ensemble must be in {get_args(Ensembles)}")