Skip to content

Commit

Permalink
Merge pull request #7 from gitter-lab/huggingface_support
Browse files Browse the repository at this point in the history
Huggingface support and action added
  • Loading branch information
samgelman authored Aug 29, 2024
2 parents 850833a + 6fc00e5 commit f6ceac5
Show file tree
Hide file tree
Showing 8 changed files with 2,602 additions and 0 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/compile_huggingface.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Compiling Huggingface Wrapper
on: [push, workflow_dispatch]
jobs:
Combine-File:
runs-on: ubuntu-latest
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
steps:
- uses: actions/checkout@v4
with:
ref: 'main'
- name: installing deps
run: pip install -r huggingface/requirements.txt
- name: installing torch cpu only
run: pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
- name: Combining Files
run: python huggingface/combine_files.py -o huggingface/huggingface_wrapper.py
- name: Formatting generated code
run: |
python -m isort huggingface/huggingface_wrapper.py
python -m black huggingface/huggingface_wrapper.py
- name: Push to hub
run: python huggingface/push_to_hub.py

5 changes: 5 additions & 0 deletions huggingface/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
This directory is to maintain the 🤗 support of METL.

Herein are a few files to facilitate uploading the wrapper to 🤗. First, combine_files.py takes all of the files in the METL directory, barring files that have test or _.py (think, innit.py here) and combines them into a single file. combine_files.py also appends the huggingface wrapper code itself (stored in huggingface_code.py) onto the bottom of the script.

This script then gets auto-updated to 🤗 after formatting it by running the push_to_hub.py script. Some additional small comments are included in the top of each file repeating these responsibilities.
59 changes: 59 additions & 0 deletions huggingface/combine_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
This script combines all of the files in the metl directory into one file so that it can be uploaded automatically to huggingface.
Files ending with _.py and that contain test in the filename will not be included. This script automatically generates the required imports from the files as well.
Regardless of changes to metl, as long as necessary files that may be added don't contain test or _.py, this should work as intended.
"""

import argparse
import os

def main(output_path: str):
imports = set()
code = []
metl_imports = set()
for file in os.listdir('./metl'):
if '.py' in file and '_.py' not in file and 'test' not in file:
with open(f'./metl/{file}', 'r') as f:
file_text = f.readlines()
for line in file_text:
line_for_compare = line.strip()
if 'import ' in line_for_compare and 'metl.' not in line_for_compare:
imports.add(line_for_compare)
elif 'import ' in line_for_compare and 'metl.' in line_for_compare:
if 'as' in line_for_compare:
metl_imports.add(line_for_compare)
else:
code.append(line[:-1])

code = '\n'.join(code)
imports = '\n'.join(imports)

for line in metl_imports:
import_name = line.split('as')[-1].strip()
code = code.replace(f'{import_name}.', '')

huggingface_import = 'from transformers import PretrainedConfig, PreTrainedModel'
delimiter = '$>'

with open('./huggingface/huggingface_code.py', 'r') as f:
contents = f.read()
delim_location = contents.find(delimiter)
cut_contents = contents[delim_location+len(delimiter):]

with open(output_path, 'w') as f:
f.write(f'{huggingface_import}\n{imports}\n{code}\n{cut_contents}')

def parse_args():
parser = argparse.ArgumentParser(description="Compile huggingface wrapper")
parser.add_argument("-o", type=str, help="Output filepath", default='./huggingface_wrapper.py')

args = parser.parse_args()

args.o = os.path.abspath(args.o)
return args

if __name__ == "__main__":
args = parse_args()
main(args.o)
69 changes: 69 additions & 0 deletions huggingface/huggingface_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
This file contains the actual wrapper for METL.
Above the delimiter for this file: #$> we have included imports and shell functions
which prevent python (and other linters) from complaining this file has erros.
"""


from transformers import PretrainedConfig, PreTrainedModel

def get_from_uuid():
pass

def get_from_ident():
pass

def get_from_checkpoint():
pass

IDENT_UUID_MAP = ""
UUID_URL_MAP = ""

# Chop The above off.

#$>
# Huggingface code

class METLConfig(PretrainedConfig):
IDENT_UUID_MAP = IDENT_UUID_MAP
UUID_URL_MAP = UUID_URL_MAP
model_type = "METL"

def __init__(
self,
id:str = None,
**kwargs,
):
self.id = id
super().__init__(**kwargs)

class METLModel(PreTrainedModel):
config_class = METLConfig
def __init__(self, config:METLConfig):
super().__init__(config)
self.model = None
self.encoder = None
self.config = config

def forward(self, X, pdb_fn=None):
if pdb_fn:
return self.model(X, pdb_fn=pdb_fn)
return self.model(X)

def load_from_uuid(self, id):
if id:
assert id in self.config.UUID_URL_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
self.config.id = id

self.model, self.encoder = get_from_uuid(self.config.id)

def load_from_ident(self, id):
if id:
id = id.lower()
assert id in self.config.IDENT_UUID_MAP, "ID given does not reference a valid METL model in the IDENT_UUID_MAP"
self.config.id = id

self.model, self.encoder = get_from_ident(self.config.id)

def get_from_checkpoint(self, checkpoint_path):
self.model, self.encoder = get_from_checkpoint(checkpoint_path)
Loading

0 comments on commit f6ceac5

Please sign in to comment.