-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add script to evaluate perfomance on SROIE dataset #44
base: main
Are you sure you want to change the base?
Changes from all commits
b83f877
fd6c066
e94ef95
fb1b5ed
b98369d
a6c5494
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#!/usr/bin/env python | ||
|
||
import tempfile | ||
import time | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
from subprocess import run | ||
|
||
import datasets | ||
from sklearn.feature_extraction.text import CountVectorizer | ||
from sklearn.metrics import f1_score, precision_score, recall_score | ||
from tqdm import tqdm | ||
|
||
try: | ||
import pytesseract | ||
except ImportError: | ||
pytesseract = None | ||
|
||
|
||
def build_ocrs() -> None: | ||
run("cargo build --release -p ocrs-cli", shell=True, check=True, text=True) | ||
|
||
|
||
def extract_text(image_path: str) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine for now. In future I think we can introduce some optimizations for running ocrs on many images in succession. |
||
"""Extract text from an image using ocrs.""" | ||
result = run( | ||
# We run the binary directly here rather than use `cargo run` as it | ||
# is slightly faster. | ||
[f"target/release/ocrs", image_path], | ||
check=True, | ||
text=True, | ||
capture_output=True, | ||
) | ||
return result.stdout | ||
|
||
|
||
def run_global_retrieval_eval(max_samples: int) -> None: | ||
""" | ||
Evaluate OCR performance, by computing precision, recall and F1 score | ||
for the detected tokens globally on the whole document | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To clarify, is this intended to follow the evaluation protocol from "Scanned Receipt OCR" in the SROIE paper - https://arxiv.org/pdf/2103.10213.pdf? Are there any differences between the tokenization mentioned there and what scikit-learn does as far as you know? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So initially I wanted to benchmark multiple datasets, in which case matching exactly the same evaluation procedure is a bit harder. For instance, https://huggingface.co/datasets/naver-clova-ix/cord-v2 should also be easy to add to this script using the same loader. However since SROIE already takes up to 5min (orcs + tesseract) in the end I didn't. In SROIE they use whitespace tokenization, which is a bit less forgiving than scikit-learn's one. Here is an example, >>> import re
>>> line = "l'b about 22.33 10/28"
>>> re.findall(r'[^\s]+', line) # whitespace tokenization
["l'b", 'about', '22.33', '10/28']
>>> re.findall(r'(?u)\b\w\w+\b', line) # scikit-learn tokenization
['about', '22', '33', '10', '28'] so you are right's let's revert back to whitespace tokenization as SROIE does, in which case the scores are bit worse,
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... the sensitivity to spacing is unfortunate. Using the example given in the paper:
I can see imagine that in some cases it might be ambiguous to humans whether spaces should appear in the transcription. If the annotators were given specific instructions which they followed, then a model could learn these conventions from the training split. This will make models that weren't trained specifically on this dataset appear worse though. |
||
|
||
Here we use scikit-learn's tokenizer to split the text into tokens | ||
""" | ||
|
||
# Evaluate the SROIE dataset | ||
dataset = datasets.load_dataset("rth/sroie-2019-v2", split="test") | ||
true_text = ["\n".join(el["objects"]["text"]) for el in dataset] | ||
print("Evaluating on SROIE 2019 dataset...") | ||
|
||
# Build the vocabulary on the ground truth | ||
# Here we use a whitespace tokenizer as documented in the SROIE 2019 paper | ||
vectorizer = CountVectorizer(input="content", token_pattern=r'[^\s]+', binary=True) | ||
|
||
X_true = vectorizer.fit_transform(true_text[:max_samples]) | ||
|
||
# Evaluate with ocrs | ||
text_pred_ocrs = [] | ||
time_ocrs = 0 | ||
|
||
|
||
for idx, data_el in tqdm(enumerate(dataset)): | ||
if idx >= max_samples: | ||
break | ||
|
||
with tempfile.NamedTemporaryFile( | ||
suffix=".jpg", delete=False | ||
) as tmp_file: | ||
data_el["image"].save(tmp_file, format="JPEG") | ||
|
||
t0 = time.perf_counter() | ||
text_pred_ocrs.append(extract_text(tmp_file.name)) | ||
time_ocrs += time.perf_counter() - t0 | ||
|
||
X_ocrs = vectorizer.transform(text_pred_ocrs) | ||
|
||
print( | ||
" - Ocrs: {:.2f} s / image, precision {:.2f}, recall {:.2f}, F1 {:.2f}".format( | ||
time_ocrs / max_samples, | ||
precision_score(X_true, X_ocrs, average="micro"), | ||
recall_score(X_true, X_ocrs, average="micro"), | ||
f1_score(X_true, X_ocrs, average="micro"), | ||
) | ||
) | ||
if pytesseract is not None: | ||
# Optionally evaluate with pytesseract | ||
text_pred_tesseract = [] | ||
time_tesseract = 0 | ||
for idx, data_el in tqdm(enumerate(dataset)): | ||
if idx >= max_samples: | ||
break | ||
|
||
t0 = time.perf_counter() | ||
# Neural nets LSTM engine only. | ||
text_pred_tesseract.append( | ||
pytesseract.image_to_string(tmp_file.name, lang="eng", config="--oem 1") | ||
) | ||
time_tesseract += time.perf_counter() - t0 | ||
|
||
X_tesseract = vectorizer.transform(text_pred_tesseract) | ||
|
||
print( | ||
" - Tesseract: {:.2f} s / image, precision {:.2f}, recall {:.2f}, F1 {:.2f}".format( | ||
time_tesseract / max_samples, | ||
precision_score(X_true, X_tesseract, average="micro"), | ||
recall_score(X_true, X_tesseract, average="micro"), | ||
f1_score(X_true, X_tesseract, average="micro"), | ||
) | ||
) | ||
|
||
|
||
parser = ArgumentParser( | ||
description=""" | ||
Evaluate ocrs on the benchmark datasets | ||
|
||
To run this script, you need, to install dependencies: | ||
pip install scikit-learn datasets tqdm | ||
|
||
Optionally, you can install pytesseract to compare with tesseract. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that description="""
Evaluate ocrs on the benchmark datasets.
To run this script, you need, to install dependencies: pip install scikit-learn datasets tqdm.
Optionally, you can install pytesseract to compare with tesseract.
""" |
||
) | ||
parser.add_argument( | ||
"--max-samples", type=int, default=100, help="Number of samples to evaluate" | ||
) | ||
args = parser.parse_args() | ||
|
||
print("Building ocrs...") | ||
build_ocrs() | ||
run_global_retrieval_eval(max_samples=args.max_samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you sort the imports using isort conventions. You can do this with
isort
(or ruff).