Skip to content

Commit

Permalink
feat: CohereGenerator (#6395)
Browse files Browse the repository at this point in the history
* added CohereGenerator with unit tests

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>

* 1. added releasenote
2. removed commented files in test-cohere_generators
3. removed unused imports

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>

* 1. move client creation to __init__
2. remove dict casting of metadata in run

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>

* few fixes

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>

* add cohere to git workflows

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>

* 1. CohereGenerator as top level import in generators
2. small change in doc string

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>

* 1. corrected git workflow files for cohere import
2. changed api key env var from CO_API_KEY to COHERE_API_KEY

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>

* added cohere in missed out workflow installs

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>

* 1. Removed default_streaming_callback from cohere.py and added in test.
2. Added kwargs doc strings for CohereGenerator
3. removed type hints for metadata and replies
4. use COHERE_API_URL instead of hard coded URL.

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>

* Update haystack/preview/components/generators/cohere/cohere.py

Co-authored-by: Daria Fokina <daria.f93@gmail.com>

* Update haystack/preview/components/generators/cohere/cohere.py

Co-authored-by: Daria Fokina <daria.f93@gmail.com>

* Update haystack/preview/components/generators/cohere/cohere.py

Co-authored-by: Daria Fokina <daria.f93@gmail.com>

* Update haystack/preview/components/generators/cohere/cohere.py

Co-authored-by: Daria Fokina <daria.f93@gmail.com>

* Update haystack/preview/components/generators/cohere/cohere.py

Co-authored-by: Daria Fokina <daria.f93@gmail.com>

* move out of folder

* black

* fix tests

* feedback

* black

* remove api key from tests

* read api key from env var if missing

* typo

* black

* missing import

---------

Signed-off-by: sunilkumardash9 <sunilkumardash9@gmail.com>
Co-authored-by: sunilkumardash9 <sunilkumardash9@gmail.com>
Co-authored-by: Daria Fokina <daria.f93@gmail.com>
  • Loading branch information
3 people authored Nov 23, 2023
1 parent 67780a6 commit 4ec6a60
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/linting_preview.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Mypy
if: steps.files.outputs.any_changed == 'true'
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:

- name: Install Haystack
run: |
pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere
pip install ./haystack-linter
- name: Pylint
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/tests_preview.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Run
run: pytest -m "not integration" test/preview
Expand Down Expand Up @@ -174,7 +174,7 @@ jobs:
sudo apt install ffmpeg # for local Whisper tests
- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Run
run: pytest --maxfail=5 -m "integration" test/preview
Expand Down Expand Up @@ -230,7 +230,7 @@ jobs:
colima start
- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Run Tika
run: docker run -d -p 9998:9998 apache/tika:2.9.0.0
Expand Down Expand Up @@ -281,7 +281,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Haystack
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2'
run: pip install .[dev,preview,audio] langdetect transformers[torch,sentencepiece]==4.35.2 'sentence-transformers>=2.2.0' pypdf markdown-it-py mdit_plain tika 'azure-ai-formrecognizer>=3.2.0b2' cohere

- name: Run
run: pytest --maxfail=5 -m "integration" test/preview -k 'not tika'
Expand Down
3 changes: 2 additions & 1 deletion haystack/preview/components/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from haystack.preview.components.generators.cohere import CohereGenerator
from haystack.preview.components.generators.hugging_face_local import HuggingFaceLocalGenerator
from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
from haystack.preview.components.generators.openai import GPTGenerator

__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"]
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator", "CohereGenerator"]
154 changes: 154 additions & 0 deletions haystack/preview/components/generators/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import logging
import os
import sys
from typing import Any, Callable, Dict, List, Optional

from haystack.lazy_imports import LazyImport
from haystack.preview import DeserializationError, component, default_from_dict, default_to_dict

with LazyImport(message="Run 'pip install cohere'") as cohere_import:
from cohere import Client, COHERE_API_URL

logger = logging.getLogger(__name__)


@component
class CohereGenerator:
"""LLM Generator compatible with Cohere's generate endpoint.
Queries the LLM using Cohere's API. Invocations are made using 'cohere' package.
See [Cohere API](https://docs.cohere.com/reference/generate) for more details.
Example usage:
```python
from haystack.preview.generators import CohereGenerator
generator = CohereGenerator(api_key="test-api-key")
generator.run(prompt="What's the capital of France?")
```
"""

def __init__(
self,
api_key: Optional[str] = None,
model: str = "command",
streaming_callback: Optional[Callable] = None,
api_base_url: str = COHERE_API_URL,
**kwargs,
):
"""
Instantiates a `CohereGenerator` component.
:param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var.
:param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly, command-nightly-light]. Defaults to "command".
:param streaming_callback: A callback function to be called with the streaming response. Defaults to None.
:param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai".
:param kwargs: Additional model parameters. These will be used during generation. Refer to https://docs.cohere.com/reference/generate for more details.
Some of the parameters are:
- 'max_tokens': The maximum number of tokens to be generated. Defaults to 1024.
- 'truncate': One of NONE|START|END to specify how the API will handle inputs longer than the maximum token length. Defaults to END.
- 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations.
- 'preset': Identifier of a custom preset. A preset is a combination of parameters, such as prompt, temperature etc. You can create presets in the playground.
- 'end_sequences': The generated text will be cut at the beginning of the earliest occurrence of an end sequence. The sequence will be excluded from the text.
- 'stop_sequences': The generated text will be cut at the end of the earliest occurrence of a stop sequence. The sequence will be included the text.
- 'k': Defaults to 0, min value of 0.01, max value of 0.99.
- 'p': Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
- 'frequency_penalty': Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens,
proportional to how many times they have already appeared in the prompt or prior generation.'
- 'presence_penalty': Defaults to 0.0, min value of 0.0, max value of 1.0. Can be used to reduce repetitiveness of generated tokens.
Similar to `frequency_penalty`, except that this penalty is applied equally to all tokens that have already appeared, regardless of their exact frequencies.
- 'return_likelihoods': One of GENERATION|ALL|NONE to specify how and if the token likelihoods are returned with the response. Defaults to NONE.
- 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens.
The format is {token_id: bias} where bias is a float between -10 and 10.
"""
if not api_key:
api_key = os.environ.get("COHERE_API_KEY")
if not api_key:
raise ValueError(
"CohereGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY."
)

self.api_key = api_key
self.model = model
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.model_parameters = kwargs
self.client = Client(api_key=self.api_key, api_url=self.api_base_url)

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
if self.streaming_callback:
module = self.streaming_callback.__module__
if module == "builtins":
callback_name = self.streaming_callback.__name__
else:
callback_name = f"{module}.{self.streaming_callback.__name__}"
else:
callback_name = None

return default_to_dict(
self,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.model_parameters,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
streaming_callback = None
if "streaming_callback" in init_params and init_params["streaming_callback"]:
parts = init_params["streaming_callback"].split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
data["init_parameters"]["streaming_callback"] = streaming_callback
return default_from_dict(cls, data)

@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
def run(self, prompt: str):
"""
Queries the LLM with the prompts to produce replies.
:param prompt: The prompt to be sent to the generative model.
"""
response = self.client.generate(
model=self.model, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters
)
if self.streaming_callback:
metadata_dict: Dict[str, Any] = {}
for chunk in response:
self.streaming_callback(chunk)
metadata_dict["index"] = chunk.index
replies = response.texts
metadata_dict["finish_reason"] = response.finish_reason
metadata = [metadata_dict]
self._check_truncated_answers(metadata)
return {"replies": replies, "metadata": metadata}

metadata = [{"finish_reason": resp.finish_reason} for resp in response]
replies = [resp.text for resp in response]
self._check_truncated_answers(metadata)
return {"replies": replies, "metadata": metadata}

def _check_truncated_answers(self, metadata: List[Dict[str, Any]]):
"""
Check the `finish_reason` returned with the Cohere response.
If the `finish_reason` is `MAX_TOKEN`, log a warning to the user.
:param metadata: The metadata returned by the Cohere API.
"""
if metadata[0]["finish_reason"] == "MAX_TOKENS":
logger.warning(
"Responses have been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions."
)
4 changes: 4 additions & 0 deletions releasenotes/notes/add-CohereGenerator-ca55e5b8e46df754.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Add CohereGenerator compatible with Cohere generate endpoint
168 changes: 168 additions & 0 deletions test/preview/components/generators/test_cohere_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import os

import pytest
import cohere

from haystack.preview.components.generators import CohereGenerator


def default_streaming_callback(chunk):
"""
Default callback function for streaming responses from Cohere API.
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged.
"""
print(chunk.text, flush=True, end="")


class TestGPTGenerator:
def test_init_default(self):
component = CohereGenerator(api_key="test-api-key")
assert component.api_key == "test-api-key"
assert component.model == "command"
assert component.streaming_callback is None
assert component.api_base_url == cohere.COHERE_API_URL
assert component.model_parameters == {}

def test_init_with_parameters(self):
callback = lambda x: x
component = CohereGenerator(
api_key="test-api-key",
model="command-light",
max_tokens=10,
some_test_param="test-params",
streaming_callback=callback,
api_base_url="test-base-url",
)
assert component.api_key == "test-api-key"
assert component.model == "command-light"
assert component.streaming_callback == callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}

def test_to_dict_default(self):
component = CohereGenerator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {"model": "command", "streaming_callback": None, "api_base_url": cohere.COHERE_API_URL},
}

def test_to_dict_with_parameters(self):
component = CohereGenerator(
api_key="test-api-key",
model="command-light",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {
"model": "command-light",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "test_cohere_generators.default_streaming_callback",
},
}

def test_to_dict_with_lambda_streaming_callback(self):
component = CohereGenerator(
api_key="test-api-key",
model="command",
max_tokens=10,
some_test_param="test-params",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {
"model": "command",
"streaming_callback": "test_cohere_generators.<lambda>",
"api_base_url": "test-base-url",
"max_tokens": 10,
"some_test_param": "test-params",
},
}

def test_from_dict(self, monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "test-key")
data = {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {
"model": "command",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "test_cohere_generators.default_streaming_callback",
},
}
component = CohereGenerator.from_dict(data)
assert component.api_key == "test-key"
assert component.model == "command"
assert component.streaming_callback == default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}

def test_check_truncated_answers(self, caplog):
component = CohereGenerator(api_key="test-api-key")
metadata = [{"finish_reason": "MAX_TOKENS"}]
component._check_truncated_answers(metadata)
assert caplog.records[0].message == (
"Responses have been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions."
)

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None),
reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_cohere_generator_run(self):
component = CohereGenerator(api_key=os.environ.get("COHERE_API_KEY"))
results = component.run(prompt="What's the capital of France?")
assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert results["metadata"][0]["finish_reason"] == "COMPLETE"

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None),
reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_cohere_generator_run_wrong_model_name(self):
component = CohereGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY"))
with pytest.raises(
cohere.CohereAPIError,
match="model not found, make sure the correct model ID was used and that you have access to the model.",
):
component.run(prompt="What's the capital of France?")

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None),
reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.",
)
@pytest.mark.integration
def test_cohere_generator_run_streaming(self):
class Callback:
def __init__(self):
self.responses = ""

def __call__(self, chunk):
self.responses += chunk.text
return chunk

callback = Callback()
component = CohereGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback)
results = component.run(prompt="What's the capital of France?")

assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert results["metadata"][0]["finish_reason"] == "COMPLETE"
assert callback.responses == results["replies"][0]

0 comments on commit 4ec6a60

Please sign in to comment.