From 4ec6a60a76263538efef0f1a828c1c037993d734 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 23 Nov 2023 16:21:07 +0000 Subject: [PATCH] feat: `CohereGenerator` (#6395) * added CohereGenerator with unit tests Signed-off-by: sunilkumardash9 * 1. added releasenote 2. removed commented files in test-cohere_generators 3. removed unused imports Signed-off-by: sunilkumardash9 * 1. move client creation to __init__ 2. remove dict casting of metadata in run Signed-off-by: sunilkumardash9 * few fixes Signed-off-by: sunilkumardash9 * add cohere to git workflows Signed-off-by: sunilkumardash9 * 1. CohereGenerator as top level import in generators 2. small change in doc string Signed-off-by: sunilkumardash9 * 1. corrected git workflow files for cohere import 2. changed api key env var from CO_API_KEY to COHERE_API_KEY Signed-off-by: sunilkumardash9 * added cohere in missed out workflow installs Signed-off-by: sunilkumardash9 * 1. Removed default_streaming_callback from cohere.py and added in test. 2. Added kwargs doc strings for CohereGenerator 3. removed type hints for metadata and replies 4. use COHERE_API_URL instead of hard coded URL. Signed-off-by: sunilkumardash9 * Update haystack/preview/components/generators/cohere/cohere.py Co-authored-by: Daria Fokina * Update haystack/preview/components/generators/cohere/cohere.py Co-authored-by: Daria Fokina * Update haystack/preview/components/generators/cohere/cohere.py Co-authored-by: Daria Fokina * Update haystack/preview/components/generators/cohere/cohere.py Co-authored-by: Daria Fokina * Update haystack/preview/components/generators/cohere/cohere.py Co-authored-by: Daria Fokina * move out of folder * black * fix tests * feedback * black * remove api key from tests * read api key from env var if missing * typo * black * missing import --------- Signed-off-by: sunilkumardash9 Co-authored-by: sunilkumardash9 Co-authored-by: Daria Fokina --- .github/workflows/linting_preview.yml | 4 +- .github/workflows/tests_preview.yml | 8 +- .../preview/components/generators/__init__.py | 3 +- .../preview/components/generators/cohere.py | 154 ++++++++++++++++ .../add-CohereGenerator-ca55e5b8e46df754.yaml | 4 + .../generators/test_cohere_generators.py | 168 ++++++++++++++++++ 6 files changed, 334 insertions(+), 7 deletions(-) create mode 100644 haystack/preview/components/generators/cohere.py create mode 100644 releasenotes/notes/add-CohereGenerator-ca55e5b8e46df754.yaml create mode 100644 test/preview/components/generators/test_cohere_generators.py diff --git a/.github/workflows/linting_preview.yml b/.github/workflows/linting_preview.yml index 1c29984436..1c7209d138 100644 --- a/.github/workflows/linting_preview.yml +++ b/.github/workflows/linting_preview.yml @@ -38,7 +38,7 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2' + run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - name: Mypy if: steps.files.outputs.any_changed == 'true' @@ -69,7 +69,7 @@ jobs: - name: Install Haystack run: | - pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' + pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere pip install ./haystack-linter - name: Pylint diff --git a/.github/workflows/tests_preview.yml b/.github/workflows/tests_preview.yml index a69bb369eb..f53ed289a7 100644 --- a/.github/workflows/tests_preview.yml +++ b/.github/workflows/tests_preview.yml @@ -116,7 +116,7 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' + run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - name: Run run: pytest -m "not integration" test/preview @@ -174,7 +174,7 @@ jobs: sudo apt install ffmpeg # for local Whisper tests - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' + run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - name: Run run: pytest --maxfail=5 -m "integration" test/preview @@ -230,7 +230,7 @@ jobs: colima start - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' + run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - name: Run Tika run: docker run -d -p 9998:9998 apache/tika:2.9.0.0 @@ -281,7 +281,7 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - name: Install Haystack - run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' + run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere - name: Run run: pytest --maxfail=5 -m "integration" test/preview -k 'not tika' diff --git a/haystack/preview/components/generators/__init__.py b/haystack/preview/components/generators/__init__.py index bc81975a3f..037ca7b7a5 100644 --- a/haystack/preview/components/generators/__init__.py +++ b/haystack/preview/components/generators/__init__.py @@ -1,5 +1,6 @@ +from haystack.preview.components.generators.cohere import CohereGenerator from haystack.preview.components.generators.hugging_face_local import HuggingFaceLocalGenerator from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator from haystack.preview.components.generators.openai import GPTGenerator -__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"] +__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator", "CohereGenerator"] diff --git a/haystack/preview/components/generators/cohere.py b/haystack/preview/components/generators/cohere.py new file mode 100644 index 0000000000..9c35d6047f --- /dev/null +++ b/haystack/preview/components/generators/cohere.py @@ -0,0 +1,154 @@ +import logging +import os +import sys +from typing import Any, Callable, Dict, List, Optional + +from haystack.lazy_imports import LazyImport +from haystack.preview import DeserializationError, component, default_from_dict, default_to_dict + +with LazyImport(message="Run 'pip install cohere'") as cohere_import: + from cohere import Client, COHERE_API_URL + +logger = logging.getLogger(__name__) + + +@component +class CohereGenerator: + """LLM Generator compatible with Cohere's generate endpoint. + + Queries the LLM using Cohere's API. Invocations are made using 'cohere' package. + See [Cohere API](https://docs.cohere.com/reference/generate) for more details. + + Example usage: + + ```python + from haystack.preview.generators import CohereGenerator + generator = CohereGenerator(api_key="test-api-key") + generator.run(prompt="What's the capital of France?") + ``` + """ + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "command", + streaming_callback: Optional[Callable] = None, + api_base_url: str = COHERE_API_URL, + **kwargs, + ): + """ + Instantiates a `CohereGenerator` component. + :param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var. + :param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly, command-nightly-light]. Defaults to "command". + :param streaming_callback: A callback function to be called with the streaming response. Defaults to None. + :param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai". + :param kwargs: Additional model parameters. These will be used during generation. Refer to https://docs.cohere.com/reference/generate for more details. + Some of the parameters are: + - 'max_tokens': The maximum number of tokens to be generated. Defaults to 1024. + - 'truncate': One of NONE|START|END to specify how the API will handle inputs longer than the maximum token length. Defaults to END. + - 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. + - 'preset': Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the playground. + - 'end_sequences': The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text. + - 'stop_sequences': The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text. + - 'k': Defaults to 0, min value of 0.01, max value of 0.99. + - 'p': Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`. + - 'frequency_penalty': Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, + proportional to how many times they have already appeared in the prompt or prior generation.' + - 'presence_penalty': Defaults to 0.0, min value of 0.0, max value of 1.0. Can be used to reduce repetitiveness of generated tokens. + Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies. + - 'return_likelihoods': One of GENERATION|ALL|NONE to specify how and if the token likelihoods are returned with the response. Defaults to NONE. + - 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. + The format is {token_id: bias} where bias is a float between -10 and 10. + + """ + if not api_key: + api_key = os.environ.get("COHERE_API_KEY") + if not api_key: + raise ValueError( + "CohereGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY." + ) + + self.api_key = api_key + self.model = model + self.streaming_callback = streaming_callback + self.api_base_url = api_base_url + self.model_parameters = kwargs + self.client = Client(api_key=self.api_key, api_url=self.api_base_url) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + if self.streaming_callback: + module = self.streaming_callback.__module__ + if module == "builtins": + callback_name = self.streaming_callback.__name__ + else: + callback_name = f"{module}.{self.streaming_callback.__name__}" + else: + callback_name = None + + return default_to_dict( + self, + model=self.model, + streaming_callback=callback_name, + api_base_url=self.api_base_url, + **self.model_parameters, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": + """ + Deserialize this component from a dictionary. + """ + init_params = data.get("init_parameters", {}) + streaming_callback = None + if "streaming_callback" in init_params and init_params["streaming_callback"]: + parts = init_params["streaming_callback"].split(".") + module_name = ".".join(parts[:-1]) + function_name = parts[-1] + module = sys.modules.get(module_name, None) + if not module: + raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}") + streaming_callback = getattr(module, function_name, None) + if not streaming_callback: + raise DeserializationError(f"Could not locate the streaming callback: {function_name}") + data["init_parameters"]["streaming_callback"] = streaming_callback + return default_from_dict(cls, data) + + @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) + def run(self, prompt: str): + """ + Queries the LLM with the prompts to produce replies. + :param prompt: The prompt to be sent to the generative model. + """ + response = self.client.generate( + model=self.model, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters + ) + if self.streaming_callback: + metadata_dict: Dict[str, Any] = {} + for chunk in response: + self.streaming_callback(chunk) + metadata_dict["index"] = chunk.index + replies = response.texts + metadata_dict["finish_reason"] = response.finish_reason + metadata = [metadata_dict] + self._check_truncated_answers(metadata) + return {"replies": replies, "metadata": metadata} + + metadata = [{"finish_reason": resp.finish_reason} for resp in response] + replies = [resp.text for resp in response] + self._check_truncated_answers(metadata) + return {"replies": replies, "metadata": metadata} + + def _check_truncated_answers(self, metadata: List[Dict[str, Any]]): + """ + Check the `finish_reason` returned with the Cohere response. + If the `finish_reason` is `MAX_TOKEN`, log a warning to the user. + :param metadata: The metadata returned by the Cohere API. + """ + if metadata[0]["finish_reason"] == "MAX_TOKENS": + logger.warning( + "Responses have been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions." + ) diff --git a/releasenotes/notes/add-CohereGenerator-ca55e5b8e46df754.yaml b/releasenotes/notes/add-CohereGenerator-ca55e5b8e46df754.yaml new file mode 100644 index 0000000000..bd639cceb2 --- /dev/null +++ b/releasenotes/notes/add-CohereGenerator-ca55e5b8e46df754.yaml @@ -0,0 +1,4 @@ +--- +preview: + - | + Add CohereGenerator compatible with Cohere generate endpoint diff --git a/test/preview/components/generators/test_cohere_generators.py b/test/preview/components/generators/test_cohere_generators.py new file mode 100644 index 0000000000..5202afec5b --- /dev/null +++ b/test/preview/components/generators/test_cohere_generators.py @@ -0,0 +1,168 @@ +import os + +import pytest +import cohere + +from haystack.preview.components.generators import CohereGenerator + + +def default_streaming_callback(chunk): + """ + Default callback function for streaming responses from Cohere API. + Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged. + """ + print(chunk.text, flush=True, end="") + + +class TestGPTGenerator: + def test_init_default(self): + component = CohereGenerator(api_key="test-api-key") + assert component.api_key == "test-api-key" + assert component.model == "command" + assert component.streaming_callback is None + assert component.api_base_url == cohere.COHERE_API_URL + assert component.model_parameters == {} + + def test_init_with_parameters(self): + callback = lambda x: x + component = CohereGenerator( + api_key="test-api-key", + model="command-light", + max_tokens=10, + some_test_param="test-params", + streaming_callback=callback, + api_base_url="test-base-url", + ) + assert component.api_key == "test-api-key" + assert component.model == "command-light" + assert component.streaming_callback == callback + assert component.api_base_url == "test-base-url" + assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_to_dict_default(self): + component = CohereGenerator(api_key="test-api-key") + data = component.to_dict() + assert data == { + "type": "haystack.preview.components.generators.cohere.CohereGenerator", + "init_parameters": {"model": "command", "streaming_callback": None, "api_base_url": cohere.COHERE_API_URL}, + } + + def test_to_dict_with_parameters(self): + component = CohereGenerator( + api_key="test-api-key", + model="command-light", + max_tokens=10, + some_test_param="test-params", + streaming_callback=default_streaming_callback, + api_base_url="test-base-url", + ) + data = component.to_dict() + assert data == { + "type": "haystack.preview.components.generators.cohere.CohereGenerator", + "init_parameters": { + "model": "command-light", + "max_tokens": 10, + "some_test_param": "test-params", + "api_base_url": "test-base-url", + "streaming_callback": "test_cohere_generators.default_streaming_callback", + }, + } + + def test_to_dict_with_lambda_streaming_callback(self): + component = CohereGenerator( + api_key="test-api-key", + model="command", + max_tokens=10, + some_test_param="test-params", + streaming_callback=lambda x: x, + api_base_url="test-base-url", + ) + data = component.to_dict() + assert data == { + "type": "haystack.preview.components.generators.cohere.CohereGenerator", + "init_parameters": { + "model": "command", + "streaming_callback": "test_cohere_generators.", + "api_base_url": "test-base-url", + "max_tokens": 10, + "some_test_param": "test-params", + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "test-key") + data = { + "type": "haystack.preview.components.generators.cohere.CohereGenerator", + "init_parameters": { + "model": "command", + "max_tokens": 10, + "some_test_param": "test-params", + "api_base_url": "test-base-url", + "streaming_callback": "test_cohere_generators.default_streaming_callback", + }, + } + component = CohereGenerator.from_dict(data) + assert component.api_key == "test-key" + assert component.model == "command" + assert component.streaming_callback == default_streaming_callback + assert component.api_base_url == "test-base-url" + assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_check_truncated_answers(self, caplog): + component = CohereGenerator(api_key="test-api-key") + metadata = [{"finish_reason": "MAX_TOKENS"}] + component._check_truncated_answers(metadata) + assert caplog.records[0].message == ( + "Responses have been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions." + ) + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_cohere_generator_run(self): + component = CohereGenerator(api_key=os.environ.get("COHERE_API_KEY")) + results = component.run(prompt="What's the capital of France?") + assert len(results["replies"]) == 1 + assert "Paris" in results["replies"][0] + assert len(results["metadata"]) == 1 + assert results["metadata"][0]["finish_reason"] == "COMPLETE" + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_cohere_generator_run_wrong_model_name(self): + component = CohereGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) + with pytest.raises( + cohere.CohereAPIError, + match="model not found, make sure the correct model ID was used and that you have access to the model.", + ): + component.run(prompt="What's the capital of France?") + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_cohere_generator_run_streaming(self): + class Callback: + def __init__(self): + self.responses = "" + + def __call__(self, chunk): + self.responses += chunk.text + return chunk + + callback = Callback() + component = CohereGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback) + results = component.run(prompt="What's the capital of France?") + + assert len(results["replies"]) == 1 + assert "Paris" in results["replies"][0] + assert len(results["metadata"]) == 1 + assert results["metadata"][0]["finish_reason"] == "COMPLETE" + assert callback.responses == results["replies"][0]