Skip to content

Commit

Permalink
✨ Dihedral angles are now handled explicitly via 'dihedral' keyword/h…
Browse files Browse the repository at this point in the history
…andler.
  • Loading branch information
Baharis committed Oct 21, 2024
1 parent 85e04e2 commit 9364084
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 31 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ The following instructions are currently supported by picometer:
- **Evaluation instructions**
- measure `distance` between 2 selected objects; if the selection includes
groups of atoms, measure closes distance to the group of atoms.
- measure `angle` between 2–6 selected objects; if the selection includes
(ordered) atoms, calculate direct or dihedral angle between presumed bonds.
- measure `angle` between 2–3 selected objects: planes, lines, or (ordered) atoms.
- measure `dihedral` andle between 4 individually-selected ordered centroids/atoms.


## Contributing
Expand Down
24 changes: 14 additions & 10 deletions picometer/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,8 @@ def _angle(self, *others: 'Shape') -> float:
assert all(o.kind is o.Kind.spatial for o in [self, *others])
combined = sum(others, self)
xyz = combined.cart_xyz.T
if len(combined) == 3: # interior angle
return degrees_between(xyz[0] - xyz[1], xyz[2] - xyz[1])
elif 4 <= len(combined) <= 6: # dihedral angle
plane1_dir = np.cross(xyz[0] - xyz[1], xyz[2] - xyz[1])
plane2_dir = np.cross(xyz[-3] - xyz[-2], xyz[-1] - xyz[-2])
twist_dir = np.cross(plane1_dir, plane2_dir)
sign = +1 if are_synparallel(twist_dir, xyz[2] - xyz[1]) else -1
return sign * degrees_between(plane1_dir, plane2_dir, normalize=False)
else:
return 'Input AtomSet must contain between 3 and 6 atoms'
assert len(combined) == 3, 'Input AtomSet must contain exactly 3 atoms'
return degrees_between(xyz[0] - xyz[1], xyz[2] - xyz[1])

def _distance(self, other: 'Shape') -> float:
if other.kind is self.Kind.spatial:
Expand All @@ -209,3 +201,15 @@ def _distance(self, other: 'Shape') -> float:
norms = norm(deltas, axis=1)
along = np.abs(np.dot(deltas, other.direction))
return min(norms ** 2 - along ** 2)

def dihedral(self, *others: 'AtomSet'):
assert all(o.kind is o.Kind.spatial for o in [self, *others])
combined = sum(others, self)
xyz = combined.cart_xyz.T
assert len(combined) == 4, 'Input AtomSet must contain exactly 4 atoms'
plane1_dir = np.cross(xyz[0] - xyz[1], xyz[2] - xyz[1])
plane2_dir = np.cross(xyz[-3] - xyz[-2], xyz[-1] - xyz[-2])
twist_dir = np.cross(plane1_dir, plane2_dir)
sign = +1 if are_synparallel(twist_dir, xyz[2] - xyz[1]) else -1
return sign * degrees_between(plane1_dir, plane2_dir, normalize=False)

38 changes: 25 additions & 13 deletions picometer/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from picometer.atom import group_registry, AtomSet, Locator
from picometer.models import ModelState, ModelStates

from picometer.shapes import ExplicitShape

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -186,6 +186,15 @@ def handle(self, instruction: Instruction) -> None:
def handle_one(self, instruction: Instruction, ms_key: str, ms: ModelState) -> None:
"""Abstract function to handle a process a single model state"""

def _collect_shapes(self, ms: ModelState) -> list[ExplicitShape]:
shapes = []
for locator in self.processor.selection:
if (shape_label := locator.label) in ms.shapes:
shapes.append(ms.shapes[shape_label])
else:
shapes.append(ms.nodes.locate([locator]))
return shapes


# ~~~~~~~~~~~~~~~~~~~~ CONCRETE INSTRUCTIONS DECLARATIONS ~~~~~~~~~~~~~~~~~~~~ #

Expand Down Expand Up @@ -296,12 +305,7 @@ class DistanceInstructionHandler(SerialInstructionHandler):

def handle_one(self, instruction: Instruction, ms_key: str, ms: ModelState) -> None:
label = instruction.kwargs['label']
shapes = []
for locator in self.processor.selection:
if (shape_label := locator.label) in ms.shapes:
shapes.append(ms.shapes[shape_label])
else:
shapes.append(ms.nodes.locate([locator]))
shapes = self._collect_shapes(ms)
assert len(shapes) == 2
distance = shapes[0].distance(shapes[1])
self.processor.evaluation_table.loc[ms_key, label] = distance
Expand All @@ -314,18 +318,26 @@ class AngleInstructionHandler(SerialInstructionHandler):

def handle_one(self, instruction: Instruction, ms_key: str, ms: ModelState) -> None:
label = instruction.kwargs['label']
shapes = []
for locator in self.processor.selection:
if (shape_label := locator.label) in ms.shapes:
shapes.append(ms.shapes[shape_label])
else:
shapes.append(ms.nodes.locate([locator]))
shapes = self._collect_shapes(ms)
assert len(shapes)
angle = shapes[0].angle(*shapes[1:])
self.processor.evaluation_table.loc[ms_key, label] = angle
logger.info(f'Evaluated angle {label}: {angle} for model state {ms_key}')


class DihedralInstructionHandler(SerialInstructionHandler):
name = 'dihedral'
kwargs = dict(label=str)

def handle_one(self, instruction: Instruction, ms_key: str, ms: ModelState) -> None:
label = instruction.kwargs['label']
shapes = self._collect_shapes(ms)
assert len(shapes) == 4 and all(s.kind is s.Kind.spatial for s in shapes)
dihedral = shapes[0].dihedral(*shapes[1:]) # noqa: shapes: list[AtomSet]
self.processor.evaluation_table.loc[ms_key, label] = dihedral
logger.info(f'Evaluated dihedral {label}: {dihedral} for model state {ms_key}')


class WriteInstructionHandler(BaseInstructionHandler):
name = 'write'
kwargs = dict(path=Path)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ferrocene.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ instructions:
- select: C(12)
- select: C(13)
- select: C(14)
- angle: C(11)-C(12)-C(13)-C(14)
- dihedral: C(11)-C(12)-C(13)-C(14)
- write: $ferrocene_results # tests will substitute it with ./ferrocene_results.csv
27 changes: 22 additions & 5 deletions tests/test_instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def test_angle_line_line(self):
43.63394452, 43.31590821, 43.52811746])
self.assertTrue(np.allclose(results, correct))

def test_angle_interior_nodes(self):
def test_angle_nodes(self):
self.routine_text += ' - select: C(11)\n'
self.routine_text += ' - select: C(12)\n'
self.routine_text += ' - select: C(13)\n'
Expand All @@ -354,30 +354,47 @@ def test_angle_interior_nodes(self):
108.17779184, 108.12639300, 107.63799568])
self.assertTrue(np.allclose(results, correct))

def test_angle_positive_dihedral_nodes(self):
def test_angle_fails_on_4_atoms(self):
self.routine_text += ' - select: C(11)\n'
self.routine_text += ' - select: C(12)\n'
self.routine_text += ' - select: C(13)\n'
self.routine_text += ' - select: C(14)\n'
self.routine_text += ' - angle: C(11)-C(12)-C(13)-C(14)'
with self.assertRaises(AssertionError):
_ = process(Routine.from_string(self.routine_text))

def test_dihedral_positive(self):
self.routine_text += ' - select: H(11)\n'
self.routine_text += ' - select: C(11)\n'
self.routine_text += ' - select: C(15)\n'
self.routine_text += ' - select: Fe\n'
self.routine_text += ' - angle: H(11)-C(11)-C(15)-Fe'
self.routine_text += ' - dihedral: H(11)-C(11)-C(15)-Fe'
p = process(Routine.from_string(self.routine_text))
results = p.evaluation_table['H(11)-C(11)-C(15)-Fe'].to_numpy()
correct = np.array([117.48054368, 118.56063847, 118.81095746,
118.03459677, 122.13488005, 120.58628219])
self.assertTrue(np.allclose(results, correct))

def test_angle_mixed_dihedral_nodes(self):
def test_dihedral_mixed(self):
self.routine_text += ' - select: C(11)\n'
self.routine_text += ' - select: C(12)\n'
self.routine_text += ' - select: C(13)\n'
self.routine_text += ' - select: C(14)\n'
self.routine_text += ' - angle: C(11)-C(12)-C(13)-C(14)'
self.routine_text += ' - dihedral: C(11)-C(12)-C(13)-C(14)'
p = process(Routine.from_string(self.routine_text))
results = p.evaluation_table['C(11)-C(12)-C(13)-C(14)'].to_numpy()
correct = np.array([+0.03373221, -0.00041385, +0.02161362,
+0.11565318, -0.03754215, -0.37636209])
self.assertTrue(np.allclose(results, correct))

def test_dihedral_fails_on_3_atoms(self):
self.routine_text += ' - select: C(11)\n'
self.routine_text += ' - select: C(12)\n'
self.routine_text += ' - select: C(13)\n'
self.routine_text += ' - dihedral: C(11)-C(12)-C(13)'
with self.assertRaises(AssertionError):
_ = process(Routine.from_string(self.routine_text))

def test_write(self):
routine_text = get_yaml('test_ferrocene.yaml')
_ = process(Routine.from_string(routine_text))
Expand Down

0 comments on commit 9364084

Please sign in to comment.