diff --git a/tiktoken/load.py b/tiktoken/load.py index 3f02bbf7..45729b11 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -6,6 +6,7 @@ import os import tempfile import uuid +from typing import Optional import requests @@ -26,7 +27,12 @@ def read_file(blobpath: str) -> bytes: return resp.content -def read_file_cached(blobpath: str) -> bytes: +def check_hash(data: bytes, hash: str) -> bool: + data_hash = hashlib.sha256(data).hexdigest() + return data_hash == hash + + +def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes: user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -45,9 +51,20 @@ def read_file_cached(blobpath: str) -> bytes: cache_path = os.path.join(cache_dir, cache_key) if os.path.exists(cache_path): with open(cache_path, "rb") as f: - return f.read() + data = f.read() + if expected_hash and not check_hash(data, expected_hash): + raise ValueError( + f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). " + f"Please delete the cache file at {cache_path} and try again." + ) + return data contents = read_file(blobpath) + if expected_hash and not check_hash(contents, expected_hash): + raise ValueError( + f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). " + f"This may indicate a corrupted download. Please try again." + ) try: os.makedirs(cache_dir, exist_ok=True) @@ -64,7 +81,7 @@ def read_file_cached(blobpath: str) -> bytes: def data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file: str, encoder_json_file: str + vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None ) -> dict[bytes, int]: # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -79,7 +96,7 @@ def data_gym_to_mergeable_bpe_ranks( assert len(rank_to_intbyte) == 2**8 # vocab_bpe contains the merges along with associated ranks - vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode() + vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode() bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]] def decode_data_gym(value: str) -> bytes: @@ -96,7 +113,7 @@ def decode_data_gym(value: str) -> bytes: # check that the encoder file matches the merges file # this sanity check is important since tiktoken assumes that ranks are ordered the same # as merge priority - encoder_json = json.loads(read_file_cached(encoder_json_file)) + encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash)) encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()} # drop these two special tokens if present, since they're not mergeable bpe tokens encoder_json_loaded.pop(b"<|endoftext|>", None) @@ -118,9 +135,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n") -def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: +def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]: # NB: do not add caching to this function - contents = read_file_cached(tiktoken_bpe_file) + contents = read_file_cached(tiktoken_bpe_file, expected_hash) return { base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index ba25cbbe..4e847b0e 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -11,6 +11,8 @@ def gpt2(): mergeable_ranks = data_gym_to_mergeable_bpe_ranks( vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json", + vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5", + encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783", ) return { "name": "gpt2", @@ -23,7 +25,8 @@ def gpt2(): def r50k_base(): mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken" + "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken", + expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930", ) return { "name": "r50k_base", @@ -36,7 +39,8 @@ def r50k_base(): def p50k_base(): mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" + "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", ) return { "name": "p50k_base", @@ -49,7 +53,8 @@ def p50k_base(): def p50k_edit(): mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" + "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", ) special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} return { @@ -62,7 +67,8 @@ def p50k_edit(): def cl100k_base(): mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" + "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", + expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7", ) special_tokens = { ENDOFTEXT: 100257,