Skip to content

Commit

Permalink
Fix wrong import in cohere.py and change model to model_name fo…
Browse files Browse the repository at this point in the history
…r consistency (#6405)

* Fix wrong import in `cohere.py`

* model -> model_name

* fix tests too

* black

* typo

* typo
  • Loading branch information
ZanSara authored Nov 23, 2023
1 parent fdae81e commit f3b7303
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
10 changes: 5 additions & 5 deletions haystack/preview/components/generators/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional

from haystack.lazy_imports import LazyImport
from haystack.preview.lazy_imports import LazyImport
from haystack.preview import DeserializationError, component, default_from_dict, default_to_dict

with LazyImport(message="Run 'pip install cohere'") as cohere_import:
Expand Down Expand Up @@ -31,7 +31,7 @@ class CohereGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model: str = "command",
model_name: str = "command",
streaming_callback: Optional[Callable] = None,
api_base_url: str = COHERE_API_URL,
**kwargs,
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
)

self.api_key = api_key
self.model = model
self.model_name = model_name
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.model_parameters = kwargs
Expand All @@ -90,7 +90,7 @@ def to_dict(self) -> Dict[str, Any]:

return default_to_dict(
self,
model=self.model,
model_name=self.model_name,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.model_parameters,
Expand Down Expand Up @@ -123,7 +123,7 @@ def run(self, prompt: str):
:param prompt: The prompt to be sent to the generative model.
"""
response = self.client.generate(
model=self.model, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters
model=self.model_name, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters
)
if self.streaming_callback:
metadata_dict: Dict[str, Any] = {}
Expand Down
26 changes: 15 additions & 11 deletions test/preview/components/generators/test_cohere_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestGPTGenerator:
def test_init_default(self):
component = CohereGenerator(api_key="test-api-key")
assert component.api_key == "test-api-key"
assert component.model == "command"
assert component.model_name == "command"
assert component.streaming_callback is None
assert component.api_base_url == cohere.COHERE_API_URL
assert component.model_parameters == {}
Expand All @@ -27,14 +27,14 @@ def test_init_with_parameters(self):
callback = lambda x: x
component = CohereGenerator(
api_key="test-api-key",
model="command-light",
model_name="command-light",
max_tokens=10,
some_test_param="test-params",
streaming_callback=callback,
api_base_url="test-base-url",
)
assert component.api_key == "test-api-key"
assert component.model == "command-light"
assert component.model_name == "command-light"
assert component.streaming_callback == callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
Expand All @@ -44,13 +44,17 @@ def test_to_dict_default(self):
data = component.to_dict()
assert data == {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {"model": "command", "streaming_callback": None, "api_base_url": cohere.COHERE_API_URL},
"init_parameters": {
"model_name": "command",
"streaming_callback": None,
"api_base_url": cohere.COHERE_API_URL,
},
}

def test_to_dict_with_parameters(self):
component = CohereGenerator(
api_key="test-api-key",
model="command-light",
model_name="command-light",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
Expand All @@ -60,7 +64,7 @@ def test_to_dict_with_parameters(self):
assert data == {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {
"model": "command-light",
"model_name": "command-light",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
Expand All @@ -71,7 +75,7 @@ def test_to_dict_with_parameters(self):
def test_to_dict_with_lambda_streaming_callback(self):
component = CohereGenerator(
api_key="test-api-key",
model="command",
model_name="command",
max_tokens=10,
some_test_param="test-params",
streaming_callback=lambda x: x,
Expand All @@ -81,7 +85,7 @@ def test_to_dict_with_lambda_streaming_callback(self):
assert data == {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {
"model": "command",
"model_name": "command",
"streaming_callback": "test_cohere_generators.<lambda>",
"api_base_url": "test-base-url",
"max_tokens": 10,
Expand All @@ -94,7 +98,7 @@ def test_from_dict(self, monkeypatch):
data = {
"type": "haystack.preview.components.generators.cohere.CohereGenerator",
"init_parameters": {
"model": "command",
"model_name": "command",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
Expand All @@ -103,7 +107,7 @@ def test_from_dict(self, monkeypatch):
}
component = CohereGenerator.from_dict(data)
assert component.api_key == "test-key"
assert component.model == "command"
assert component.model_name == "command"
assert component.streaming_callback == default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
Expand Down Expand Up @@ -136,7 +140,7 @@ def test_cohere_generator_run(self):
)
@pytest.mark.integration
def test_cohere_generator_run_wrong_model_name(self):
component = CohereGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY"))
component = CohereGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY"))
with pytest.raises(
cohere.CohereAPIError,
match="model not found, make sure the correct model ID was used and that you have access to the model.",
Expand Down

0 comments on commit f3b7303

Please sign in to comment.