-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from gitter-lab/huggingface_support
Huggingface support and action added
- Loading branch information
Showing
8 changed files
with
2,602 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
name: Compiling Huggingface Wrapper | ||
on: [push, workflow_dispatch] | ||
jobs: | ||
Combine-File: | ||
runs-on: ubuntu-latest | ||
env: | ||
HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
ref: 'main' | ||
- name: installing deps | ||
run: pip install -r huggingface/requirements.txt | ||
- name: installing torch cpu only | ||
run: pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu | ||
- name: Combining Files | ||
run: python huggingface/combine_files.py -o huggingface/huggingface_wrapper.py | ||
- name: Formatting generated code | ||
run: | | ||
python -m isort huggingface/huggingface_wrapper.py | ||
python -m black huggingface/huggingface_wrapper.py | ||
- name: Push to hub | ||
run: python huggingface/push_to_hub.py | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
This directory is to maintain the 🤗 support of METL. | ||
|
||
Herein are a few files to facilitate uploading the wrapper to 🤗. First, combine_files.py takes all of the files in the METL directory, barring files that have test or _.py (think, innit.py here) and combines them into a single file. combine_files.py also appends the huggingface wrapper code itself (stored in huggingface_code.py) onto the bottom of the script. | ||
|
||
This script then gets auto-updated to 🤗 after formatting it by running the push_to_hub.py script. Some additional small comments are included in the top of each file repeating these responsibilities. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
""" | ||
This script combines all of the files in the metl directory into one file so that it can be uploaded automatically to huggingface. | ||
Files ending with _.py and that contain test in the filename will not be included. This script automatically generates the required imports from the files as well. | ||
Regardless of changes to metl, as long as necessary files that may be added don't contain test or _.py, this should work as intended. | ||
""" | ||
|
||
import argparse | ||
import os | ||
|
||
def main(output_path: str): | ||
imports = set() | ||
code = [] | ||
metl_imports = set() | ||
for file in os.listdir('./metl'): | ||
if '.py' in file and '_.py' not in file and 'test' not in file: | ||
with open(f'./metl/{file}', 'r') as f: | ||
file_text = f.readlines() | ||
for line in file_text: | ||
line_for_compare = line.strip() | ||
if 'import ' in line_for_compare and 'metl.' not in line_for_compare: | ||
imports.add(line_for_compare) | ||
elif 'import ' in line_for_compare and 'metl.' in line_for_compare: | ||
if 'as' in line_for_compare: | ||
metl_imports.add(line_for_compare) | ||
else: | ||
code.append(line[:-1]) | ||
|
||
code = '\n'.join(code) | ||
imports = '\n'.join(imports) | ||
|
||
for line in metl_imports: | ||
import_name = line.split('as')[-1].strip() | ||
code = code.replace(f'{import_name}.', '') | ||
|
||
huggingface_import = 'from transformers import PretrainedConfig, PreTrainedModel' | ||
delimiter = '$>' | ||
|
||
with open('./huggingface/huggingface_code.py', 'r') as f: | ||
contents = f.read() | ||
delim_location = contents.find(delimiter) | ||
cut_contents = contents[delim_location+len(delimiter):] | ||
|
||
with open(output_path, 'w') as f: | ||
f.write(f'{huggingface_import}\n{imports}\n{code}\n{cut_contents}') | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Compile huggingface wrapper") | ||
parser.add_argument("-o", type=str, help="Output filepath", default='./huggingface_wrapper.py') | ||
|
||
args = parser.parse_args() | ||
|
||
args.o = os.path.abspath(args.o) | ||
return args | ||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args.o) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
This file contains the actual wrapper for METL. | ||
Above the delimiter for this file: #$> we have included imports and shell functions | ||
which prevent python (and other linters) from complaining this file has erros. | ||
""" | ||
|
||
|
||
from transformers import PretrainedConfig, PreTrainedModel | ||
|
||
def get_from_uuid(): | ||
pass | ||
|
||
def get_from_ident(): | ||
pass | ||
|
||
def get_from_checkpoint(): | ||
pass | ||
|
||
IDENT_UUID_MAP = "" | ||
UUID_URL_MAP = "" | ||
|
||
# Chop The above off. | ||
|
||
#$> | ||
# Huggingface code | ||
|
||
class METLConfig(PretrainedConfig): | ||
IDENT_UUID_MAP = IDENT_UUID_MAP | ||
UUID_URL_MAP = UUID_URL_MAP | ||
model_type = "METL" | ||
|
||
def __init__( | ||
self, | ||
id:str = None, | ||
**kwargs, | ||
): | ||
self.id = id | ||
super().__init__(**kwargs) | ||
|
||
class METLModel(PreTrainedModel): | ||
config_class = METLConfig | ||
def __init__(self, config:METLConfig): | ||
super().__init__(config) | ||
self.model = None | ||
self.encoder = None | ||
self.config = config | ||
|
||
def forward(self, X, pdb_fn=None): | ||
if pdb_fn: | ||
return self.model(X, pdb_fn=pdb_fn) | ||
return self.model(X) | ||
|
||
def load_from_uuid(self, id): | ||
if id: | ||
assert id in self.config.UUID_URL_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP" | ||
self.config.id = id | ||
|
||
self.model, self.encoder = get_from_uuid(self.config.id) | ||
|
||
def load_from_ident(self, id): | ||
if id: | ||
id = id.lower() | ||
assert id in self.config.IDENT_UUID_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP" | ||
self.config.id = id | ||
|
||
self.model, self.encoder = get_from_ident(self.config.id) | ||
|
||
def get_from_checkpoint(self, checkpoint_path): | ||
self.model, self.encoder = get_from_checkpoint(checkpoint_path) |
Oops, something went wrong.