-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
protein bert uniref90 dataset #1
Comments
I joined it with the fasta file (on the uniprot_name field which is the RepId in fasta). import dask.dataframe as dd
from dask.distributed import Client
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
if __name__ == '__main__':
client = Client()
def read_db_save_to_parquet():
daskDF = dd.read_sql_table('protein_annotations', "sqlite:///uniref_proteins_and_annotations.db", index_col='index')
daskDF.to_parquet('the-parquet', engine='pyarrow')
def read_fasta_save_to_parquet():
fasta_df = dd.read_csv("uniref90.fasta",lineterminator=">", sep="$", header=None)
def process_fasta_row(row):
lines = row[0].split("\n")
id = lines[0].split(" ")[-1].split("=")[-1]
seq = "".join(lines[1:])
return (id, seq)
new_df = fasta_df.apply(process_fasta_row, axis=1,result_type='expand', meta={0: str, 1: str})
new_df.columns = ['uniprot_id', 'seq']
new_df.to_parquet('fasta_parquet', engine='pyarrow')
def spark_join():
# executor and driver not sure if needed
spark = SparkSession.builder \
.config("spark.executor.cores", "16") \
.config("spark.executor.memory", "16G") \
.config("spark.driver.cores", "16") \
.config("spark.driver.memory", "16G") \
.master("local[16]").appName('spark-merge').getOrCreate()
a = spark.read.parquet("the-parquet")
b = spark.read.parquet("fasta_parquet")
c = a.join(b, col("uniprot_name") == col("uniprot_id"))
c.write.mode("overwrite").parquet("uniref90_with_annotations") (best to run first the 2 functions in dask then in a second run the spark code) For reference this would be the code to merge in dask (but is too slow): def merge_without_indexing():
db_df = dd.read_parquet('the-parquet', engine='pyarrow')
fasta_index_df = dd.read_parquet('fasta_parquet', engine='pyarrow')
merged = db_df.merge(fasta_index_df, left_on=("uniprot_name"), right_on=("uniprot_id"))
merged.to_parquet('merged', engine='pyarrow') (I also tried to index and save before hand, but that's also slow (hours)) I wanted to do it all in dask but turned out dask is slow at joining big collections compared to spark. Run time of this code:
The resulting collection has 135301051 records and looks like this:
example of code to read it in dask: import dask.dataframe as dd
from dask.distributed import Client
client = Client()
merged = dd.read_parquet('uniref90_with_annotations', engine='pyarrow')
merged.head() |
That's a simple way to read parquet as a torch dataset : import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as ds
import pandas as pd
from torch.utils.data import IterableDataset
from torch.utils.data import get_worker_info
from torch.multiprocessing import Queue
class IterableManualParquetDataset(IterableDataset):
def __init__(self, path, process_func, batch_size=64):
super().__init__()
self.dataset = ds.dataset(path)
self.batch_size = batch_size
self.process_func = process_func
def __iter__(self):
worker_info = get_worker_info()
# Only divide up batches when using multiple worker processes
if worker_info != None:
batches = list(self.dataset.to_batches(batch_size=self.batch_size))
worker_load = len(batches) // worker_info.num_workers
# If more workers than batches exist, some won't be used
if worker_load == 0:
if worker_info.id < len(batches): self.batches = [batches[worker_info.id]]
else: return
else:
start = worker_load * worker_info.id
end = min(start + worker_load, len(batches))
self.batches = batches[start:end]
else: self.batches = self.dataset.to_batches(batch_size=self.batch_size)
# Process and yield each batch
for batch in self.batches:
batch = batch.to_pydict()
batch.update(self.process_func(batch))
yield batch
a = IterableManualParquetDataset("uniref90_parquet/uniref90_with_annotations", lambda x:x, batch_size=64)
u = next(iter(a)) (adapted from https://github.com/KamWithK/PyParquetLoaders/blob/master/PyTorchLoader.py ) the same could likely be done with tf data if that's better for jax. |
The dataset can be downloaded (it's 50GB) by running example on colab https://colab.research.google.com/drive/1Zcns30b1H3IcxMJ-A-wQDF6pUcyNL5ei?usp=sharing |
works perfectly! |
(discussed in discord)
after running the first step (create_uniref_db) of https://github.com/nadavbra/protein_bert I got a 24GB file "uniref_proteins_and_annotations.db" .
It seems it could be useful for generate sequences for this project, sharing the links there
There are
135301051
records in the db, in a table looking like:Sample look like this:
The text was updated successfully, but these errors were encountered: