Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add calculator lists #346

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions janus_core/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ class BaseCalculation(FileNameMixin):
struct_path : Optional[PathLike]
Path of structure to simulate. Required if `struct` is None.
Default is None.
arch : Architectures
arch : MaybeSequence[Architectures]
MLIP architecture to use for calculations. Default is "mace_mp".
device : Devices
device : MaybeSequence[Devices]
Device to run model on. Default is "cpu".
model_path : Optional[PathLike]
model_path : Optional[MaybeSequence[PathLike]]
Path to MLIP model. Default is `None`.
read_kwargs : ASEReadArgs
Keyword arguments to pass to ase.io.read. Default is {}.
sequence_allowed : bool
Whether a sequence of Atoms objects is allowed. Default is True.
calc_kwargs : Optional[dict[str, Any]]
calc_kwargs : Optional[MaybeSequence[dict[str, Any]]]
Keyword arguments to pass to the selected calculator. Default is {}.
set_calc : Optional[bool]
Whether to set (new) calculators for structures. Default is None.
Expand Down Expand Up @@ -73,12 +73,12 @@ def __init__(
calc_name: str = "base",
struct: Optional[MaybeSequence[Atoms]] = None,
struct_path: Optional[PathLike] = None,
arch: Architectures = "mace_mp",
device: Devices = "cpu",
model_path: Optional[PathLike] = None,
arch: MaybeSequence[Architectures] = "mace_mp",
device: MaybeSequence[Devices] = "cpu",
model_path: Optional[MaybeSequence[PathLike]] = None,
read_kwargs: Optional[ASEReadArgs] = None,
sequence_allowed: bool = True,
calc_kwargs: Optional[dict[str, Any]] = None,
calc_kwargs: Optional[MaybeSequence[dict[str, Any]]] = None,
set_calc: Optional[bool] = None,
attach_logger: bool = False,
log_kwargs: Optional[dict[str, Any]] = None,
Expand All @@ -101,17 +101,17 @@ def __init__(
struct_path : Optional[PathLike]
Path of structure to simulate. Required if `struct` is None. Default is
None.
arch : Architectures
arch : MaybeSequence[Architectures]
MLIP architecture to use for calculations. Default is "mace_mp".
device : Devices
device : MaybeSequence[Devices]
Device to run MLIP model on. Default is "cpu".
model_path : Optional[PathLike]
model_path : Optional[MaybeSequence[PathLike]]
Path to MLIP model. Default is `None`.
read_kwargs : Optional[ASEReadArgs]
Keyword arguments to pass to ase.io.read. Default is {}.
sequence_allowed : bool
Whether a sequence of Atoms objects is allowed. Default is True.
calc_kwargs : Optional[dict[str, Any]]
calc_kwargs : Optional[MaybeSequence[dict[str, Any]]]
Keyword arguments to pass to the selected calculator. Default is {}.
set_calc : Optional[bool]
Whether to set (new) calculators for structures. Default is None.
Expand Down
18 changes: 9 additions & 9 deletions janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ class SinglePoint(BaseCalculation):
struct_path : Optional[PathLike]
Path of structure to simulate. Required if `struct` is None.
Default is None.
arch : Architectures
arch : MaybeSequence[Architectures]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
device : Devices
device : MaybeSequence[Devices]
Device to run model on. Default is "cpu".
model_path : Optional[PathLike]
model_path : Optional[MaybeSequence[PathLike]]
Path to MLIP model. Default is `None`.
read_kwargs : ASEReadArgs
Keyword arguments to pass to ase.io.read. By default,
Expand Down Expand Up @@ -82,9 +82,9 @@ def __init__(
*,
struct: Optional[MaybeSequence[Atoms]] = None,
struct_path: Optional[PathLike] = None,
arch: Architectures = "mace_mp",
device: Devices = "cpu",
model_path: Optional[PathLike] = None,
arch: MaybeSequence[Architectures] = "mace_mp",
device: MaybeSequence[Devices] = "cpu",
model_path: Optional[MaybeSequence[PathLike]] = None,
read_kwargs: Optional[ASEReadArgs] = None,
calc_kwargs: Optional[dict[str, Any]] = None,
set_calc: Optional[bool] = None,
Expand All @@ -107,12 +107,12 @@ def __init__(
struct_path : Optional[PathLike]
Path of structure to simulate. Required if `struct` is None.
Default is None.
arch : Architectures
arch : MaybeSequence[Architectures]
MLIP architecture to use for single point calculations.
Default is "mace_mp".
device : Devices
device : MaybeSequence[Devices]
Device to run MLIP model on. Default is "cpu".
model_path : Optional[PathLike]
model_path : Optional[MaybeSequence[PathLike]]
Path to MLIP model. Default is `None`.
read_kwargs : Optional[ASEReadArgs]
Keyword arguments to pass to ase.io.read. By default,
Expand Down
16 changes: 8 additions & 8 deletions janus_core/cli/singlepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from typer_config import use_config

from janus_core.cli.types import (
Architecture,
CalcKwargs,
Device,
ArchitectureList,
CalcKwargsList,
DeviceList,
LogPath,
ModelPath,
ModelPathList,
ReadKwargsAll,
StructPath,
Summary,
Expand All @@ -28,9 +28,9 @@ def singlepoint(
# numpydoc ignore=PR02
ctx: Context,
struct: StructPath,
arch: Architecture = "mace_mp",
device: Device = "cpu",
model_path: ModelPath = None,
arch: ArchitectureList = ("mace_mp",),
device: DeviceList = ("cpu",),
model_path: ModelPathList = None,
properties: Annotated[
Optional[list[str]],
Option(
Expand All @@ -50,7 +50,7 @@ def singlepoint(
),
] = None,
read_kwargs: ReadKwargsAll = None,
calc_kwargs: CalcKwargs = None,
calc_kwargs: CalcKwargsList = None,
write_kwargs: WriteKwargs = None,
log: LogPath = None,
tracker: Annotated[
Expand Down
23 changes: 23 additions & 0 deletions janus_core/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def __str__(self) -> str:
Device = Annotated[Optional[str], Option(help="Device to run calculations on.")]
ModelPath = Annotated[Optional[str], Option(help="Path to MLIP model.")]

ArchitectureList = Annotated[
Optional[list[str]], Option(help="MLIP architecture to use for calculations.")
]
DeviceList = Annotated[
Optional[list[str]], Option(help="Device to run calculations on.")
]
ModelPathList = Annotated[Optional[list[str]], Option(help="Path to MLIP model.")]

ReadKwargsAll = Annotated[
Optional[TyperDict],
Option(
Expand Down Expand Up @@ -117,6 +125,21 @@ def __str__(self) -> str:
),
]

CalcKwargsList = Annotated[
Optional[list[TyperDict]],
Option(
parser=parse_dict_class,
help=(
"""
Keyword arguments to pass to selected calculator. Must be passed as a
dictionary wrapped in quotes, e.g. "{'key' : value}". For the default
architecture ('mace_mp'), "{'model':'small'}" is set unless overwritten.
"""
),
metavar="DICT",
),
]

WriteKwargs = Annotated[
Optional[TyperDict],
Option(
Expand Down
45 changes: 45 additions & 0 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,47 @@ def _set_model_path(
return model_path


def _set_torch(calculate: callable, dtype: torch.dtype):
"""
Wrap calculate function to set torch default dtype before calculations.

Parameters
----------
calculate : callable
Function to wrap.
dtype : torch.dtype
Default dtype to set.

Returns
-------
callable
Wrapped function.
"""

def wrapper(*args, **kwargs) -> callable:
"""
Wrap function to set torch default dtype.

Parameters
----------
*args
Arguments passed to calculate.
**kwargs
Additional keyword arguments passed to calculate.

Returns
-------
callable
Wrapped function.
"""
import torch

torch.set_default_dtype(dtype)
return calculate(*args, **kwargs)

return wrapper


def choose_calculator(
arch: Architectures = "mace",
device: Devices = "cpu",
Expand Down Expand Up @@ -123,12 +164,14 @@ def choose_calculator(
elif arch == "mace_mp":
from mace import __version__
from mace.calculators import mace_mp
import torch

# Default to "small" model and float64 precision
model = model_path if model_path else "small"
kwargs.setdefault("default_dtype", "float64")

calculator = mace_mp(model=model, device=device, **kwargs)
calculator.calculate = _set_torch(calculator.calculate, torch.float64)

elif arch == "mace_off":
from mace import __version__
Expand Down Expand Up @@ -164,6 +207,7 @@ def choose_calculator(
potential = load_model("M3GNet-MP-2021.2.8-DIRECT-PES")

calculator = M3GNetCalculator(potential=potential, **kwargs)
calculator.calculate = _set_torch(calculator.calculate, torch.float32)

elif arch == "chgnet":
from chgnet import __version__
Expand All @@ -186,6 +230,7 @@ def choose_calculator(
model = None

calculator = CHGNetCalculator(model=model, use_device=device, **kwargs)
calculator.calculate = _set_torch(calculator.calculate, torch.float32)

elif arch == "alignn":
from alignn import __version__
Expand Down
113 changes: 113 additions & 0 deletions janus_core/helpers/multi_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Define MultiCalc ASE Calculator."""

from collections.abc import Sequence
from typing import Any

from ase.calculators.calculator import (
BaseCalculator,
Calculator,
CalculatorSetupError,
PropertyNotImplementedError,
)


class MultiCalc(BaseCalculator):
"""
ASE MultiCalc class.

Parameters
----------
calcs : Sequence[Calculator]
Calculators to use.
"""

def __init__(self, calcs: Sequence[Calculator]):
"""
Initialise class.

Parameters
----------
calcs : Sequence[Calculator]
Calculators to use.
"""
super().__init__()

if len(calcs) == 0:
raise CalculatorSetupError("Please provide a list of Calculators")

common_properties = set.intersection(
*(set(calc.implemented_properties) for calc in calcs)
)

self.implemented_properties = list(common_properties)
if not self.implemented_properties:
raise PropertyNotImplementedError(
"The provided Calculators have" " no properties in common!"
)

self.calcs = calcs

def __str__(self) -> str:
"""
Return string representation of the calculator.

Returns
-------
str
String representation.
"""
calcs = ", ".join(calc.__class__.__name__ for calc in self.calcs)
return f"{self.__class__.__name__}({calcs})"

def get_properties(self, properties, atoms) -> dict[str, Any]:
"""
Get properties from each listed calculator.

Parameters
----------
properties : list[str]
List of properties to be calculated.
atoms : Atoms
Atoms object to calculate properties for.

Returns
-------
dict
Dictionary of results.
"""
results = {}

def get_property(prop: str) -> None:
"""
Get property from each listed calculator.

Parameters
----------
prop : str
Property to get.
"""
contribs = [calc.get_property(prop, atoms) for calc in self.calcs]
results[prop] = contribs

for prop in properties: # get requested properties
get_property(prop)
for prop in self.implemented_properties: # cache all available props
if all(prop in calc.results for calc in self.calcs):
get_property(prop)
return results

def calculate(self, atoms, properties, system_changes) -> None:
"""
Calculate properties for each calculator and return values as list.

Parameters
----------
atoms : Atoms
Atoms object to calculate properties for.
properties : list[str]
List of properties to be calculated.
system_changes : list[str]
List of what has changed since last calculation.
"""
self.atoms = atoms.copy() # for caching of results
self.results = self.get_properties(properties, atoms)
Loading
Loading