Skip to content

Commit

Permalink
Named entity extractor private models (#8658)
Browse files Browse the repository at this point in the history
* add 'token' support to NamedEntityExtractor to enable using private models on HF backend

* fix existing error message format

* add release note

* add HF_API_TOKEN to e2e workflow

* add informative comment

* Updated to_dict / from_dict to handle 'token' correctly ; Added tests

* Fix lint

* Revert unwanted change
  • Loading branch information
mpangrazzi authored Dec 20, 2024
1 parent 286061f commit c192488
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env:
PYTHON_VERSION: "3.9"
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HATCH_VERSION: "1.13.0"
HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}

jobs:
run:
Expand Down
13 changes: 13 additions & 0 deletions e2e/pipelines/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
import pytest

from haystack import Document, Pipeline
Expand Down Expand Up @@ -65,6 +66,18 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)


@pytest.mark.parametrize("batch_size", [1, 3])
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
extractor.warm_up()

_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)


@pytest.mark.parametrize("batch_size", [1, 3])
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf")
Expand Down
37 changes: 32 additions & 5 deletions haystack/components/extractors/named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from haystack.utils.device import ComponentDevice
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs

with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
Expand Down Expand Up @@ -110,6 +112,7 @@ def __init__(
model: str,
pipeline_kwargs: Optional[Dict[str, Any]] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
) -> None:
"""
Create a Named Entity extractor component.
Expand All @@ -128,16 +131,28 @@ def __init__(
device/device map is specified in `pipeline_kwargs`,
it overrides this parameter (only applicable to the
HuggingFace backend).
:param token:
The API token to download private models from Hugging Face.
"""

if isinstance(backend, str):
backend = NamedEntityExtractorBackend.from_str(backend)

self._backend: _NerBackend
self._warmed_up: bool = False
self.token = token
device = ComponentDevice.resolve_device(device)

if backend == NamedEntityExtractorBackend.HUGGING_FACE:
pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=pipeline_kwargs or {},
model=model,
task="ner",
supported_tasks=["ner"],
device=device,
token=token,
)

self._backend = _HfBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
elif backend == NamedEntityExtractorBackend.SPACY:
self._backend = _SpacyBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
Expand All @@ -159,7 +174,7 @@ def warm_up(self):
self._warmed_up = True
except Exception as e:
raise ComponentError(
f"Named entity extractor with backend '{self._backend.type} failed to initialize."
f"Named entity extractor with backend '{self._backend.type}' failed to initialize."
) from e

@component.output_types(documents=List[Document])
Expand Down Expand Up @@ -201,14 +216,21 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
serialization_dict = default_to_dict(
self,
backend=self._backend.type.name,
model=self._backend.model_name,
device=self._backend.device.to_dict(),
pipeline_kwargs=self._backend._pipeline_kwargs,
token=self.token.to_dict() if self.token else None,
)

hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"]
hf_pipeline_kwargs.pop("token", None)

serialize_hf_model_kwargs(hf_pipeline_kwargs)
return serialization_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
"""
Expand All @@ -220,10 +242,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
Deserialized component.
"""
try:
init_params = data["init_parameters"]
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {})
if init_params.get("device") is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]

hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {})
deserialize_hf_model_kwargs(hf_pipeline_kwargs)
return default_from_dict(cls, data)
except Exception as e:
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
Expand Down Expand Up @@ -352,8 +378,9 @@ def __init__(
self.pipeline: Optional[HfPipeline] = None

def initialize(self):
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path)
self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path)
token = self._pipeline_kwargs.get("token", None)
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path, token=token)
self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path, token=token)

pipeline_params = {
"task": "ner",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Add `token` argument to `NamedEntityExtractor` to allow usage of private Hugging Face models.
57 changes: 56 additions & 1 deletion test/components/extractors/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from haystack.utils.auth import Secret
import pytest

from haystack import ComponentError, DeserializationError, Pipeline
Expand All @@ -11,6 +12,9 @@
def test_named_entity_extractor_backend():
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")

# private model
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")

_ = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER")

_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_sm")
Expand Down Expand Up @@ -40,7 +44,58 @@ def test_named_entity_extractor_serde():
_ = NamedEntityExtractor.from_dict(serde_data)


def test_named_entity_extractor_from_dict_no_default_parameters_hf():
def test_to_dict_default(monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)

component = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE,
model="dslim/bert-base-NER",
device=ComponentDevice.from_str("mps"),
)
data = component.to_dict()

assert data == {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {
"backend": "HUGGING_FACE",
"model": "dslim/bert-base-NER",
"device": {"type": "single", "device": "mps"},
"pipeline_kwargs": {"model": "dslim/bert-base-NER", "device": "mps", "task": "ner"},
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
},
}


def test_to_dict_with_parameters():
component = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE,
model="dslim/bert-base-NER",
device=ComponentDevice.from_str("mps"),
pipeline_kwargs={"model_kwargs": {"load_in_4bit": True}},
token=Secret.from_env_var("ENV_VAR", strict=False),
)
data = component.to_dict()

assert data == {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {
"backend": "HUGGING_FACE",
"model": "dslim/bert-base-NER",
"device": {"type": "single", "device": "mps"},
"pipeline_kwargs": {
"model": "dslim/bert-base-NER",
"device": "mps",
"task": "ner",
"model_kwargs": {"load_in_4bit": True},
},
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
},
}


def test_named_entity_extractor_from_dict_no_default_parameters_hf(monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)

data = {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {"backend": "HUGGING_FACE", "model": "dslim/bert-base-NER"},
Expand Down

0 comments on commit c192488

Please sign in to comment.