From 3ee6c3517d3465775dc22f06a034cfcf8d06eba7 Mon Sep 17 00:00:00 2001 From: Matt Welsh Date: Mon, 29 Jan 2024 16:51:57 -0800 Subject: [PATCH] Add support for checking hash of downloaded files before use. (#230) We are using tiktoken in various production scenarios and sometimes have the problem that the download of `.tiktoken` files (e.g., `cl100k_base.tiktoken`) will get interrupted or fail, causing the cached file to be corrupted in some way. In those cases, the results returned from the encoder will be incorrect and could be damaging to our production instances. More often, when this happens, `Encoder.encode()` will throw an exception such as ``` pyo3_runtime.PanicException: no entry found for key ``` which turns out to be quite hard to track down. In an effort to make tiktoken more robust for production use, this PR adds the `sha256` hash of each of the downloaded files to `openai_public.py` and augments `read_file` to check for the hash, if provided, when the file is accessed from the cache or downloaded directly. This causes errors to be flagged at file load time, rather than when the files are used, and provides a more meaningful error message indicating what might have gone wrong. This also protects users of tiktoken from scenarios where a network issue or MITM attack could have corrupted these files in transit. --- tiktoken/load.py | 31 ++++++++++++++++++++++++------- tiktoken_ext/openai_public.py | 14 ++++++++++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/tiktoken/load.py b/tiktoken/load.py index 3f02bbf7..45729b11 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -6,6 +6,7 @@ import os import tempfile import uuid +from typing import Optional import requests @@ -26,7 +27,12 @@ def read_file(blobpath: str) -> bytes: return resp.content -def read_file_cached(blobpath: str) -> bytes: +def check_hash(data: bytes, hash: str) -> bool: + data_hash = hashlib.sha256(data).hexdigest() + return data_hash == hash + + +def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes: user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -45,9 +51,20 @@ def read_file_cached(blobpath: str) -> bytes: cache_path = os.path.join(cache_dir, cache_key) if os.path.exists(cache_path): with open(cache_path, "rb") as f: - return f.read() + data = f.read() + if expected_hash and not check_hash(data, expected_hash): + raise ValueError( + f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). " + f"Please delete the cache file at {cache_path} and try again." + ) + return data contents = read_file(blobpath) + if expected_hash and not check_hash(contents, expected_hash): + raise ValueError( + f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). " + f"This may indicate a corrupted download. Please try again." + ) try: os.makedirs(cache_dir, exist_ok=True) @@ -64,7 +81,7 @@ def read_file_cached(blobpath: str) -> bytes: def data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file: str, encoder_json_file: str + vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None ) -> dict[bytes, int]: # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -79,7 +96,7 @@ def data_gym_to_mergeable_bpe_ranks( assert len(rank_to_intbyte) == 2**8 # vocab_bpe contains the merges along with associated ranks - vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode() + vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode() bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]] def decode_data_gym(value: str) -> bytes: @@ -96,7 +113,7 @@ def decode_data_gym(value: str) -> bytes: # check that the encoder file matches the merges file # this sanity check is important since tiktoken assumes that ranks are ordered the same # as merge priority - encoder_json = json.loads(read_file_cached(encoder_json_file)) + encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash)) encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()} # drop these two special tokens if present, since they're not mergeable bpe tokens encoder_json_loaded.pop(b"<|endoftext|>", None) @@ -118,9 +135,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n") -def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]: +def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]: # NB: do not add caching to this function - contents = read_file_cached(tiktoken_bpe_file) + contents = read_file_cached(tiktoken_bpe_file, expected_hash) return { base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py index ba25cbbe..4e847b0e 100644 --- a/tiktoken_ext/openai_public.py +++ b/tiktoken_ext/openai_public.py @@ -11,6 +11,8 @@ def gpt2(): mergeable_ranks = data_gym_to_mergeable_bpe_ranks( vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json", + vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5", + encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783", ) return { "name": "gpt2", @@ -23,7 +25,8 @@ def gpt2(): def r50k_base(): mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken" + "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken", + expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930", ) return { "name": "r50k_base", @@ -36,7 +39,8 @@ def r50k_base(): def p50k_base(): mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" + "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", ) return { "name": "p50k_base", @@ -49,7 +53,8 @@ def p50k_base(): def p50k_edit(): mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" + "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069", ) special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} return { @@ -62,7 +67,8 @@ def p50k_edit(): def cl100k_base(): mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" + "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", + expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7", ) special_tokens = { ENDOFTEXT: 100257,