Skip to content

Commit

Permalink
Bugfix in SpacyRecognizer (#1221)
Browse files Browse the repository at this point in the history
  • Loading branch information
omri374 authored Dec 13, 2023
1 parent 43f6624 commit 4551b9b
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def analyze(self, text: str, entities, nlp_artifacts=None): # noqa D102
ner_scores = nlp_artifacts.scores

for ner_entity, ner_score in zip(ner_entities, ner_scores):
if ner_entity.label_ not in entities:
if (ner_entity.label_ not in entities
or ner_entity.label_ not in self.supported_entities):
logger.debug(
f"Skipping entity {ner_entity.label_} "
f"as it is not in the supported entities list"
Expand Down
28 changes: 28 additions & 0 deletions presidio-analyzer/presidio_analyzer/recognizer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,31 @@ def __instantiate_recognizer(
if isinstance(inst, PatternRecognizer):
inst.global_regex_flags = self.global_regex_flags
return inst

def _get_supported_languages(self) -> List[str]:
languages = []
for rec in self.recognizers:
languages.append(rec.supported_language)

return list(set(languages))

def get_supported_entities(
self, languages: Optional[List[str]] = None
) -> List[str]:
"""
Return the supported entities by the set of recognizers loaded.
:param languages: The languages to get the supported entities for.
If languages=None, returns all entities for all languages.
"""
if not languages:
languages = self._get_supported_languages()

supported_entities = []
for language in languages:
recognizers = self.get_recognizers(language=language, all_fields=True)

for recognizer in recognizers:
supported_entities.extend(recognizer.get_supported_entities())

return list(set(supported_entities))
22 changes: 22 additions & 0 deletions presidio-analyzer/tests/test_recognizer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
PatternRecognizer,
EntityRecognizer,
Pattern,
AnalyzerEngine
)
from presidio_analyzer.predefined_recognizers import SpacyRecognizer


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -213,3 +215,23 @@ def test_predefined_pattern_recognizers_have_the_right_regex_flags():
for rec in registry.recognizers:
if isinstance(rec, PatternRecognizer):
assert rec.global_regex_flags == re.DOTALL


def test_recognizer_removed_and_returned_entities_are_correct():
registry = RecognizerRegistry()
registry.load_predefined_recognizers()
registry.remove_recognizer("SpacyRecognizer")
sr = SpacyRecognizer(supported_entities=["DATE_TIME", "NRP"])
registry.add_recognizer(sr)

supported_entities = registry.get_supported_entities(languages=["en"])

assert "DATE_TIME" in supported_entities
assert "PERSON" not in supported_entities

analyzer = AnalyzerEngine(
registry=registry,
supported_languages='en'
)

analyzer.analyze("My name is David", language="en")
31 changes: 31 additions & 0 deletions presidio-analyzer/tests/test_spacy_recognizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@

import pytest
from spacy.tokens import Span, Doc
from spacy.util import get_lang_class

from presidio_analyzer.nlp_engine import NlpArtifacts, SpacyNlpEngine
from presidio_analyzer.predefined_recognizers import SpacyRecognizer
from tests import assert_result_within_score_range

Expand All @@ -14,6 +18,24 @@ def nlp_recognizer(nlp_recognizers):
return nlp_recognizers["spacy"]


@pytest.fixture(scope="module")
def mock_nlp_artifacts():
en_vocab=get_lang_class("en")().vocab
doc = Doc(en_vocab, words=["My", "name", "is", "Mitchell"])
doc.ents = [Span(doc, 2, 3, label="PERSON")]

nlp_artifacts = NlpArtifacts(
entities=doc.ents,
tokens=doc,
tokens_indices=[token.idx for token in doc],
lemmas=[token.lemma_ for token in doc],
nlp_engine=None,
language="en",
scores=[0.9 for _ in doc.ents],
)
return nlp_artifacts


def prepare_and_analyze(nlp, recognizer, text, ents):
nlp_artifacts = nlp.process_text(text, "en")
results = recognizer.analyze(text, ents, nlp_artifacts)
Expand Down Expand Up @@ -84,3 +106,12 @@ def test_analyze_no_nlp_artifacts():
spacy_recognizer = SpacyRecognizer()
res = spacy_recognizer.analyze(text="text", nlp_artifacts=None, entities=["PERSON"])
assert len(res) == 0


def test_entity_not_returned_if_not_in_supported_entities(mock_nlp_artifacts):
spacy_recognizer = SpacyRecognizer(supported_entities=["NRP"])

res = spacy_recognizer.analyze(
text="text", nlp_artifacts=mock_nlp_artifacts, entities=["DATE_TIME"]
)
assert len(res) == 0

0 comments on commit 4551b9b

Please sign in to comment.