Skip to content

Commit

Permalink
Updated repo with changes from PRbefore labor day
Browse files Browse the repository at this point in the history
  • Loading branch information
John-Peters-UW committed Sep 3, 2024
1 parent ccccc69 commit ba7e603
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 9 deletions.
27 changes: 27 additions & 0 deletions metl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def _get_data_encoding(hparams):


def load_model_and_data_encoder(state_dict, hparams, strict, raw, indexing):
"""
Passing raw to this function dictates that the un-wrapped and unsafe version of the model and data encoder should be returned.
These are what is saved onto zenodo. Strict decides how certain errors are handled. See ModelEncoder.validate_pdb for more detail.
Indexing sets the expected index style of mutations passed into METL for predictions.
"""
model = models.Model[hparams["model_name"]].cls(**hparams)
model.load_state_dict(state_dict)

Expand All @@ -120,6 +125,13 @@ def load_model_and_data_encoder(state_dict, hparams, strict, raw, indexing):


def get_from_uuid(uuid, strict=True, raw=False, indexing=0):
"""
Strict is True here as models loaded from zenodo have matching PDB and wild types.
Indexing is 0, as that is the default for METL. raw is always false by default as the preferred way to
use METL (as an end-user) is with error handling enabled.
"""

if uuid in UUID_URL_MAP:
state_dict, hparams = download_checkpoint(uuid)
return load_model_and_data_encoder(state_dict, hparams, strict, raw, indexing)
Expand All @@ -128,6 +140,13 @@ def get_from_uuid(uuid, strict=True, raw=False, indexing=0):


def get_from_ident(ident, strict=True, raw=False, indexing=0):
"""
Strict is True here as models loaded from zenodo have matching PDB and wild types.
Indexing is 0, as that is the default for METL. raw is always false by default as the preferred way to
use METL (as an end-user) is with error handling enabled.
"""

ident = ident.lower()
if ident in IDENT_UUID_MAP:
state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident])
Expand All @@ -137,6 +156,14 @@ def get_from_ident(ident, strict=True, raw=False, indexing=0):


def get_from_checkpoint(ckpt_fn, strict=False, raw=False, indexing=0):
"""
Strict is False here as checkpoint models are assumed to be fine-tuned,
and custom PDBs that don't exactly match the wild type may be more common here.
Indexing is 0, as that is the default for METL. raw is always false by default as the preferred way to
use METL (as an end-user) is with error handling enabled.
"""

ckpt = torch.load(ckpt_fn, map_location="cpu")
state_dict = ckpt["state_dict"]
hyper_parameters = ckpt["hyper_parameters"]
Expand Down
15 changes: 11 additions & 4 deletions metl/model_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def check_if_pdb_needed(self, model):
return False

def validate_pdb(self, pdb_file, wt):
"""
When validating a PDB, it is possible that the PDB file and wild type (wt) passed will differ.
Strict raises an exception if this occurs, otherwise this potential error is not checked.
Strict is off by default when loading from a checkpoint file, and on when loading models from Zenodo.
"""
try:
ppdb = PandasPdb().read_pdb(pdb_file)
except Exception as e:
Expand All @@ -41,11 +46,14 @@ def validate_pdb(self, pdb_file, wt):
wildtype = ''.join(wt_seq)

if self.strict:
err_str = "Strict mode is on because a METL model that we trained was used. Wildtype and PDB sequeunces must match."
err_str += " If this is expected behavior, pass strict=False to the load function you used."
err_str = "Strict mode is on because a METL model that we trained was used. Wildtype and PDB sequences must match."
err_str += " To ignore the sequence mismatch, pass strict=False to the load function you used."
assert wildtype == wt, err_str

def validate_variants(self, variants, wt):
"""
Variants much be validated only after conversion to 0 based!
"""
wt_len = len(wt)
for index, variant in enumerate(variants):
split = variant.split(',')
Expand All @@ -60,8 +68,7 @@ def validate_variants(self, variants, wt):
error_str = f"The position for the mutation is {location} but it needs to be between 0 "
error_str += f"and {len(wt)-1} if 0-based and 1 and {len(wt)} if 1-based."
errors.append(error_str)

if wt[location] != from_amino_acid:
elif wt[location] != from_amino_acid:
errors.append(f"Wildtype at position {location} is {wt[location]} but variant had {from_amino_acid}. Check the variant input.")

if len(errors) != 0:
Expand Down
2 changes: 0 additions & 2 deletions metl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import enum
from os.path import isfile
from typing import List, Tuple, Optional
from biopandas.pdb import PandasPdb
import os

import torch
import torch.nn as nn
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
setuptools
numpy=1.23.2
numpy>=1.23.2,<2
networkx>=2.6.3
scipy>=1.9.1
biopandas>=0.2.7
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ license = MIT
packages=find:
install_requires =
torch
numpy==1.23.2
numpy>=1.23.2,<2
scipy
biopandas
networkx
biopython==1.84
biopython>=1.84

0 comments on commit ba7e603

Please sign in to comment.