Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Correlator CLI and defaults #388

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class MolecularDynamics(BaseCalculation):
files. Default is {}.
post_process_kwargs
Keyword arguments to control post-processing operations.
correlation_kwargs
correlation_kwargs : list[CorrelationKwargs] | None
Keyword arguments to control on-the-fly correlations.
seed
Random seed used by numpy.random and random functions, such as in Langevin.
Expand Down
13 changes: 10 additions & 3 deletions janus_core/cli/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from janus_core.cli.types import (
Architecture,
CalcKwargs,
CorrelationKwargs,
Device,
EnsembleKwargs,
LogPath,
Expand All @@ -22,7 +23,7 @@
Summary,
WriteKwargs,
)
from janus_core.cli.utils import yaml_converter_callback
from janus_core.cli.utils import parse_correlation_kwargs, yaml_converter_callback

app = Typer()

Expand Down Expand Up @@ -173,6 +174,7 @@ def md(
] = None,
write_kwargs: WriteKwargs = None,
post_process_kwargs: PostProcessKwargs = None,
correlation_kwargs: CorrelationKwargs = None,
seed: Annotated[
int | None,
Option(help="Random seed for numpy.random and random functions."),
Expand Down Expand Up @@ -289,7 +291,9 @@ def md(
files. Default is {}.
post_process_kwargs
Kwargs to pass to post-processing.
seed
correlation_kwargs : Optional[CorrelationKwargs]
Kwrag to pass for on-the-fly correlations.
seed : Optional[int]
Random seed used by numpy.random and random functions, such as in Langevin.
Default is None.
log
Expand Down Expand Up @@ -317,7 +321,6 @@ def md(

# Check options from configuration file are all valid
check_config(ctx)

[
read_kwargs,
calc_kwargs,
Expand All @@ -336,6 +339,9 @@ def md(
]
)

if correlation_kwargs:
correlation_kwargs = parse_correlation_kwargs(correlation_kwargs)

if ensemble not in get_args(Ensembles):
raise ValueError(f"ensemble must be in {get_args(Ensembles)}")

Expand Down Expand Up @@ -393,6 +399,7 @@ def md(
"temp_time": temp_time,
"write_kwargs": write_kwargs,
"post_process_kwargs": post_process_kwargs,
"correlation_kwargs": correlation_kwargs,
"seed": seed,
}

Expand Down
15 changes: 15 additions & 0 deletions janus_core/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,21 @@ def __str__(self) -> str:
),
]

CorrelationKwargs = Annotated[
TyperDict | None,
Option(
parser=parse_dict_class,
help=(
"""
Keyword arguments to pass to md for on-the-fly correlations. Must be
passed as a list of dictionaries wrapped in quotes, e.g.
"[{'key' : values}]".
"""
),
metavar="DICT",
),
]

LogPath = Annotated[
Path | None,
Option(
Expand Down
59 changes: 58 additions & 1 deletion janus_core/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
from ase import Atoms
from typer import Context

from janus_core.cli.types import TyperDict
from janus_core.cli.types import CorrelationKwargs, TyperDict
from janus_core.helpers.janus_types import (
Architectures,
ASEReadArgs,
Devices,
MaybeSequence,
)

from janus_core.processing import observables


def dict_paths_to_strs(dictionary: dict) -> None:
"""
Expand Down Expand Up @@ -309,3 +311,58 @@ def check_config(ctx: Context) -> None:
# Check options individually so can inform user of specific issue
if option not in ctx.params:
raise ValueError(f"'{option}' in configuration file is not a valid option")


def parse_correlation_kwargs(kwargs: CorrelationKwargs) -> list[dict]:
"""
Parse CLI CorrelationKwargs to md correlation_kwargs.

Parameters
----------
kwargs : CorrelationKwargs
CLI correlation keyword options.

Returns
-------
List[dict]
The parsed correlation_kwargs for md.
"""
parsed_kwargs = []
for name, cli_kwargs in kwargs.value.items():
if "a" not in cli_kwargs and "b" not in cli_kwargs:
raise ValueError("At least on observable must be supplied as 'a' or 'b'")

# Accept on Observable to be replicated.
if "b" not in cli_kwargs:
a = cli_kwargs["a"]
b = a
elif "a" not in cli_kwargs:
a = cli_kwargs["b"]
b = a
else:
a = cli_kwargs["a"]
b = cli_kwargs["b"]

a_kwargs = cli_kwargs["a_kwargs"] if "a_kwargs" in cli_kwargs else {}
b_kwargs = cli_kwargs["b_kwargs"] if "b_kwargs" in cli_kwargs else {}

# Accept "." in place of one kwargs to repeat.
if a_kwargs == "." and b_kwargs == ".":
raise ValueError("a_kwargs and b_kwargs cannot 'ditto' eachother")
if a_kwargs and b_kwargs == ".":
b_kwargs = a_kwargs
elif b_kwargs and a_kwargs == ".":
a_kwargs = b_kwargs

cor_kwargs = {
"name": name,
"a": getattr(observables, a)(**a_kwargs),
"b": getattr(observables, b)(**b_kwargs),
}

for optional in ["blocks", "points", "averaging", "update_frequency"]:
if optional in cli_kwargs:
cor_kwargs[optional] = cli_kwargs[optional]

parsed_kwargs.append(cor_kwargs)
return parsed_kwargs
4 changes: 3 additions & 1 deletion janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class PostProcessKwargs(TypedDict, total=False):
vaf_output_file: PathLike | None


class CorrelationKwargs(TypedDict, total=True):
class Correlation(TypedDict, total=True):
"""Arguments for on-the-fly correlations <ab>."""

#: observable a in <ab>, with optional args and kwargs
Expand All @@ -99,6 +99,8 @@ class CorrelationKwargs(TypedDict, total=True):
update_frequency: int


CorrelationKwargs = list[Correlation]

# eos_names from ase.eos
EoSNames = Literal[
"sj",
Expand Down
24 changes: 12 additions & 12 deletions janus_core/processing/correlator.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,13 @@ class Correlation:
Observable for b.
name
Name of correlation.
blocks
blocks : int, default 1
Number of correlation blocks.
points
points : int, default 1
Number of points per block.
averaging
averaging : int, default 1
Averaging window per block level.
update_frequency
update_frequency : int, default 1
Frequency to update the correlation, md steps.
"""

Expand All @@ -269,10 +269,10 @@ def __init__(
a: Observable,
b: Observable,
name: str,
blocks: int,
points: int,
averaging: int,
update_frequency: int,
blocks: int = 1,
points: int = 1,
averaging: int = 1,
update_frequency: int = 1,
) -> None:
"""
Initialise a correlation.
Expand All @@ -285,13 +285,13 @@ def __init__(
Observable for b.
name
Name of correlation.
blocks
blocks : int, default 1
Number of correlation blocks.
points
points : int, default 1
Number of points per block.
averaging
averaging : int, default 1
Averaging window per block level.
update_frequency
update_frequency : int, default 1
Frequency to update the correlation, md steps.
"""
self.name = name
Expand Down
10 changes: 5 additions & 5 deletions janus_core/processing/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,21 @@ class Velocity(Observable, ComponentMixin):
def __init__(
self,
*,
components: list[str],
components: list[str] | None = None,
atoms_slice: list[int] | SliceLike | None = None,
):
"""
Initialise the observable from a symbolic str component and atom index.

Parameters
----------
components
Symbols for tensor components, x, y, and z.
atoms_slice
components : list[str] | None
Symbols for returned velocity components, x, y, and z (default is all).
atoms_slice : Union[list[int], SliceLike, None]
List or slice of atoms to observe velocities from.
"""
ComponentMixin.__init__(self, components={"x": 0, "y": 1, "z": 2})
self.components = components
self.components = components if components else ["x", "y", "z"]

Observable.__init__(self, atoms_slice)

Expand Down
11 changes: 11 additions & 0 deletions tests/test_correlator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def test_vaf(tmp_path):
"averaging": 1,
"update_frequency": 1,
},
{
"a": Velocity(),
"b": Velocity(),
"name": "vaf_default",
},
],
write_kwargs={"invalidate_calc": False},
)
Expand All @@ -133,6 +138,12 @@ def test_vaf(tmp_path):
assert vaf_na * 3 == approx(vaf_post[1][0], rel=1e-5)
assert vaf_cl * 3 == approx(vaf_post[1][1], rel=1e-5)

# Default arguments are equivalent to mean square velocities.
v = np.mean([np.mean(atoms.get_velocities() ** 2) for atoms in traj])
vaf_default = vaf["vaf_default"]
assert len(vaf_default["value"]) == 1
assert v == approx(vaf_default["value"][0], rel=1e-5)


def test_md_correlations(tmp_path):
"""Test correlations as part of MD cycle."""
Expand Down
10 changes: 10 additions & 0 deletions tests/test_md_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_md(ensemble):
traj_path = Path(f"{file_prefix[ensemble]}traj.extxyz").absolute()
rdf_path = Path(f"{file_prefix[ensemble]}rdf.dat").absolute()
vaf_path = Path(f"{file_prefix[ensemble]}vaf.dat").absolute()
cor_path = Path(f"{file_prefix[ensemble]}cor.dat").absolute()
log_path = Path(f"{file_prefix[ensemble]}md-log.yml").absolute()
summary_path = Path(f"{file_prefix[ensemble]}md-summary.yml").absolute()

Expand All @@ -64,6 +65,7 @@ def test_md(ensemble):
assert not traj_path.exists()
assert not rdf_path.exists()
assert not vaf_path.exists()
assert not cor_path.exists()
assert not log_path.exists()
assert not summary_path.exists()

Expand All @@ -86,6 +88,12 @@ def test_md(ensemble):
2,
"--post-process-kwargs",
"{'rdf_compute': True, 'vaf_compute': True}",
"--correlation-kwargs",
(
"{'vaf': {'a': 'Velocity'},"
" 'vaf_x': {'a': 'Velocity',"
"'a_kwargs': {'components': ['x']}, 'b_kwargs': '.'}}"
),
],
)

Expand All @@ -97,6 +105,7 @@ def test_md(ensemble):
assert traj_path.exists()
assert rdf_path.exists()
assert vaf_path.exists()
assert cor_path.exists()
assert log_path.exists()
assert summary_path.exists()

Expand Down Expand Up @@ -134,6 +143,7 @@ def test_md(ensemble):
traj_path.unlink(missing_ok=True)
rdf_path.unlink(missing_ok=True)
vaf_path.unlink(missing_ok=True)
cor_path.unlink(missing_ok=True)
log_path.unlink(missing_ok=True)
summary_path.unlink(missing_ok=True)
clear_log_handlers()
Expand Down
Loading