diff --git a/pyproject.toml b/pyproject.toml index 15099db..4a87fc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,13 @@ [project] name = "tackleberry" -version = "0.1.0.dev1" +version = "0.1.0" description = "Tackleberry (or TB) is helping you tackle the access to AI" authors = [ { name = "Torsten Raudßus", email = "torsten@raudssus.de" }, ] dependencies = [ "pydantic>=2.0.0", + "instructor>=1.7.0", "typing-extensions>=4.0.0", "pyyaml>=6.0.0", ] @@ -27,7 +28,9 @@ dev = [ "ollama", "transformers", "groq", + "instructor[groq]", "anthropic", + "instructor[anthropic]", ] [build-system] diff --git a/tackleberry/__init__.py b/tackleberry/__init__.py index 121dfda..a5e7382 100644 --- a/tackleberry/__init__.py +++ b/tackleberry/__init__.py @@ -6,4 +6,4 @@ try: __version__ = version("tackleberry") except ImportError: - __version__ = "0.0.1.dev0" + __version__ = "0.0.1.dev1" diff --git a/tackleberry/chat.py b/tackleberry/chat.py new file mode 100644 index 0000000..eae97a6 --- /dev/null +++ b/tackleberry/chat.py @@ -0,0 +1,42 @@ +from typing import Union + +from .context import TBContext +from .model import TBModel + +from pydantic import BaseModel + +class TBChat: + count = 0 + + def __init__(self, + model_name_or_model: Union[str, TBModel], + context: TBContext = None, + system_prompt: str = None, + struct: BaseModel = None, + name: str = None, + **kwargs, + ): + TBChat.count += 1 + self.name = name or f'TBChat-{TBChat.count}' + if isinstance(model_name_or_model, TBModel): + self.model = model_name_or_model + else: + from . import TB + self.model = TB.model(model_name_or_model) + self.struct = struct + self.context = context if context is not None else TBContext() + if system_prompt is not None: + self.context.add_system(system_prompt) + self.model_name = self.model.name + self.runtime = self.model.runtime + + def get_messages(self): + return self.runtime.get_messages_from_context(self.context) + + def query(self, + query_or_context: Union[str, TBContext], + struct: BaseModel = None, + **kwargs, + ): + context = query_or_context if isinstance(query_or_context, TBContext) else self.context.copy_with_query(query_or_context) + return self.runtime.chat_context(self, context, struct=struct, **kwargs) diff --git a/tackleberry/context.py b/tackleberry/context.py index 5b8a077..3ab297d 100644 --- a/tackleberry/context.py +++ b/tackleberry/context.py @@ -3,7 +3,7 @@ import copy class TBMessage: - + """Base class for message""" def __init__(self, content: str, role: str): self.role = role self.content = content @@ -29,7 +29,7 @@ def __init__(self, user_message: str, role: Optional[str] = None): super().__init__(user_message, role if role is not None else "user") class TBContext: - + """Combined messages as context""" def __init__(self, system_prompt: Optional[str] = None, user_query: Optional[str] = None): self.messages = [] self.query = user_query @@ -44,8 +44,11 @@ def add_system(self, system_prompt: str): self.messages.append(TBMessageSystem(system_prompt)) return self - def add_assistant(self, assistant_context: str): - self.messages.append(TBMessageAssistant(assistant_context)) + def has_system(self): + return any(isinstance(message, TBMessageSystem) for message in self.messages) + + def add_assistant(self, assistant_message: str): + self.messages.append(TBMessageAssistant(assistant_message)) return self def add_user(self, user_message: str): @@ -56,11 +59,11 @@ def add_query(self, user_query: str): self.query = user_query return self - def spawn(self): + def copy(self): return copy.deepcopy(self) - def spawn_with(self, message: Union[TBMessage, 'TBContext']): - clone = self.spawn() + def copy_with(self, message: Union[TBMessage, 'TBContext']): + clone = self.copy() if isinstance(message, TBMessage): clone.add(message) elif isinstance(message, TBContext): @@ -68,19 +71,19 @@ def spawn_with(self, message: Union[TBMessage, 'TBContext']): clone.add(cmessage) return clone - def spawn_with_system(self, system_prompt: str): - return self.spawn_with(TBMessageSystem(system_prompt)) + def copy_with_system(self, system_prompt: str): + return self.copy_with(TBMessageSystem(system_prompt)) - def spawn_with_assistant(self, assistant_context: str): - return self.spawn_with(TBMessageAssistant(assistant_context)) + def copy_with_assistant(self, assistant_context: str): + return self.copy_with(TBMessageAssistant(assistant_context)) - def spawn_with_user(self, user_message: str): - return self.spawn_with(TBMessageUser(user_message)) + def copy_with_user(self, user_message: str): + return self.copy_with(TBMessageUser(user_message)) - def spawn_with_query(self, user_query: str): - return self.spawn().add_query(user_query) + def copy_with_query(self, user_query: str): + return self.copy().add_query(user_query) - def all_messages(self): + def get_messages(self): messages = self.messages if not self.query is None: messages.append(TBMessageUser(self.query)) @@ -88,7 +91,7 @@ def all_messages(self): def to_messages(self): message_list = [] - for message in self.all_messages(): + for message in self.get_messages(): message_list.append({ "content": message.content, "role": message.role, diff --git a/tackleberry/engine/anthropic.py b/tackleberry/engine/anthropic.py deleted file mode 100644 index 173c5f4..0000000 --- a/tackleberry/engine/anthropic.py +++ /dev/null @@ -1,31 +0,0 @@ -import os - -from .base import TBEngine - -class TBEngineAnthropic(TBEngine): - default_max_tokens = 256 - - def __init__(self, - api_key: str = None, - max_tokens: int = None, - **kwargs, - ): - self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") - if not isinstance(self.api_key, str) or len(self.api_key) < 51: - raise Exception("Anthropic needs api_key (ANTHROPIC_API_KEY)") - from anthropic import Anthropic - self.client = Anthropic( - api_key=self.api_key, - **kwargs, - ) - self.max_tokens = max_tokens or TBEngineAnthropic.default_max_tokens - - def get_models(self): - models = [] - for model in self.client.models.list().data: - models.append(model.id) - models.sort() - return models - - def __str__(self): - return f"TB Engine Anthropic {hex(id(self))}" diff --git a/tackleberry/engine/groq.py b/tackleberry/engine/groq.py deleted file mode 100644 index d8359d3..0000000 --- a/tackleberry/engine/groq.py +++ /dev/null @@ -1,28 +0,0 @@ -import os - -from .base import TBEngine - -class TBEngineGroq(TBEngine): - - def __init__(self, - api_key: str = None, - **kwargs, - ): - self.api_key = api_key or os.environ.get("GROQ_API_KEY") - if not isinstance(self.api_key, str) or len(self.api_key) < 51: - raise Exception("Groq needs api_key (GROQ_API_KEY)") - from groq import Groq - self.client = Groq( - api_key=self.api_key, - **kwargs, - ) - - def get_models(self): - models = [] - for model in self.client.models.list().data: - models.append(model.id) - models.sort() - return models - - def __str__(self): - return f"TB Engine Groq {hex(id(self))}" diff --git a/tackleberry/engine/openai.py b/tackleberry/engine/openai.py deleted file mode 100644 index 740a50c..0000000 --- a/tackleberry/engine/openai.py +++ /dev/null @@ -1,28 +0,0 @@ -import os - -from .base import TBEngine - -class TBEngineOpenai(TBEngine): - - def __init__(self, - api_key: str = None, - **kwargs, - ): - self.api_key = api_key or os.environ.get("OPENAI_API_KEY") - if not isinstance(self.api_key, str) or len(self.api_key) < 51: - raise Exception("OpenAI needs api_key (OPENAI_API_KEY)") - from openai import OpenAI - self.client = OpenAI( - api_key=self.api_key, - **kwargs, - ) - - def get_models(self): - models = [] - for model in self.client.models.list().data: - models.append(model.id) - models.sort() - return models - - def __str__(self): - return f"TB Engine OpenAI {hex(id(self))}" diff --git a/tackleberry/main.py b/tackleberry/main.py index 7eee525..b313be7 100644 --- a/tackleberry/main.py +++ b/tackleberry/main.py @@ -2,7 +2,7 @@ from importlib import import_module from .registry import TBRegistry -from .engine import TBEngine +from .runtime import TBRuntime from .context import TBContext class TBMain: @@ -16,7 +16,7 @@ def __init__(self, TBMain.count += 1 self.name = name or f'TB-{TBMain.count}' self.registry = registry if registry else TBMain.registry - self.engines = {} + self.runtimes = {} def __str__(self): return f"TBMain instance {self.name}" @@ -32,16 +32,16 @@ def model(self, ): model_parts = model.split('/') if len(model_parts) > 1: - engine_class = model_parts.pop(0) + runtime_class = model_parts.pop(0) model = '/'.join(model_parts) else: - engine_class = self.registry.get_engine_by_model(model) - if engine_class is None: - raise Exception(f"Can't find engine for model '{model}'") - engine = self.engine(engine_class, **kwargs) - if engine is None: - raise Exception(f"Can't find engine for engine class '{engine_class}'") - return engine.model(model) + runtime_class = self.registry.get_runtime_by_model(model) + if runtime_class is None: + raise Exception(f"Can't find runtime for model '{model}'") + runtime = self.runtime(runtime_class) + if runtime is None: + raise Exception(f"Can't find runtime for runtime class '{runtime_class}'") + return runtime.model(model, **kwargs) def chat(self, model: str, @@ -49,24 +49,24 @@ def chat(self, ): return self.model(model).chat(**kwargs) - def engine(self, - engine_class: str, + def runtime(self, + runtime_class: str, **kwargs, ): - if engine_class in self.engines: - return self.engines[engine_class] + if runtime_class in self.runtimes: + return self.runtimes[runtime_class] try: from importlib import import_module - from_list = [f"TBEngine{engine_class.title()}"] - mod = import_module(f".engine.{engine_class}", package=__package__) - self.engines[engine_class] = getattr(mod, from_list[0])(**kwargs) + from_list = [f"TBRuntime{runtime_class.title()}"] + mod = import_module(f".runtime.{runtime_class}", package=__package__) + self.runtimes[runtime_class] = getattr(mod, from_list[0])(**kwargs) except ImportError: - mod = import_module(f"tackleberry.engine.{engine_class}") - self.engines[engine_class] = getattr(mod, f"TBEngine{engine_class.title()}")(**kwargs) - if isinstance(self.engines[engine_class], TBEngine): - return self.engines[engine_class] + mod = import_module(f"tackleberry.runtime.{runtime_class}") + self.runtimes[runtime_class] = getattr(mod, f"TBRuntime{runtime_class.title()}")(**kwargs) + if isinstance(self.runtimes[runtime_class], TBRuntime): + return self.runtimes[runtime_class] else: - raise Exception(f"Can't find engine '{engine_class}'") + raise Exception(f"Can't find runtime '{runtime_class}'") TB = TBMain() diff --git a/tackleberry/model.py b/tackleberry/model.py new file mode 100644 index 0000000..ed70ad9 --- /dev/null +++ b/tackleberry/model.py @@ -0,0 +1,12 @@ +from .runtime import TBRuntime + +class TBModel: + + def __init__(self, runtime: TBRuntime, name: str, **kwargs): + self.runtime = runtime + self.name = name + self.options = kwargs + + def chat(self, **kwargs): + from .chat import TBChat + return TBChat(self, **kwargs) diff --git a/tackleberry/model/__init__.py b/tackleberry/model/__init__.py deleted file mode 100644 index 3586a3c..0000000 --- a/tackleberry/model/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -import importlib - -# Special case: Explicitly import TBModel from base.py -from .base import TBModel - -# Automatically detect all Python files in the current directory -current_dir = os.path.dirname(__file__) -module_files = [ - f for f in os.listdir(current_dir) - if f.endswith(".py") and f not in ("__init__.py", "base.py") -] - -# Create __all__ with the module names and the special case -__all__ = ["TBModel"] + [os.path.splitext(f)[0] for f in module_files] - -# Dynamically import the modules and add them to the global namespace -for module_name in __all__[1:]: # Skip "TBModel" as it's already imported - module = importlib.import_module(f".{module_name}", package=__name__) - globals()[module_name] = module diff --git a/tackleberry/model/base.py b/tackleberry/model/base.py deleted file mode 100644 index 719d624..0000000 --- a/tackleberry/model/base.py +++ /dev/null @@ -1,11 +0,0 @@ -from ..engine import TBEngine - -class TBModel: - - def __init__(self, engine: TBEngine, name: str): - self.engine = engine - self.name = name - - def chat(self, **kwargs): - from .chat import TBModelChat - return TBModelChat(self.engine, self, **kwargs) diff --git a/tackleberry/model/chat.py b/tackleberry/model/chat.py deleted file mode 100644 index 1c3c748..0000000 --- a/tackleberry/model/chat.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Union - -from .base import TBModel -from ..engine import TBEngine -from ..context import TBContext, TBSystemPromptError - -class TBModelChat(TBModel): - - def __init__(self, - engine: TBEngine, - model_name_or_model: Union[str, TBModel], - context: TBContext = None, - system_prompt: str = None, - **kwargs, - ): - self.engine = engine - if context is not None and system_prompt is not None: - raise TBSystemPromptError("A TBModelChat can't handle system_prompt and context at once.") - self.system_prompt = system_prompt - self.context = context - if isinstance(model_name_or_model, TBModel): - self.model = model_name_or_model - else: - self.model = self.engine.model(model_name_or_model) - self.name = model_name_or_model.name diff --git a/tackleberry/registry.py b/tackleberry/registry.py index 3df6d54..faca797 100644 --- a/tackleberry/registry.py +++ b/tackleberry/registry.py @@ -3,14 +3,14 @@ import yaml import os -from .engine import TBEngine +from .runtime import TBRuntime class TBRegistry: def __init__(self, name: Optional[str] = None): if name is None: name = str(uuid.uuid4()) - self._engines = {} + self._runtimes = {} self._update_models() def load_registry(self): @@ -22,22 +22,22 @@ def load_registry(self): def _update_models(self): self._models = {} registry = self.load_registry() - for engine_name in self._engines: - # If the engine is in registry, then we delete it from there to not collide with the specific version - if engine_name in registry: - del registry[engine_name] - hasattr(self._engines[engine_name], 'get_models') - for model in self._engines[engine_name].get_models: - self._models[model] = engine_name - for registry_engine in registry: - for model in registry[registry_engine]: - self._models[model] = registry_engine - - def get_engine_by_model(self, model: str): + for runtime_name in self._runtimes: + # If the runtime is in registry, then we delete it from there to not collide with the specific version + if runtime_name in registry: + del registry[runtime_name] + hasattr(self._runtimes[runtime_name], 'get_models') + for model in self._runtimes[runtime_name].get_models: + self._models[model] = runtime_name + for registry_runtime in registry: + for model in registry[registry_runtime]: + self._models[model] = registry_runtime + + def get_runtime_by_model(self, model: str): return self._models[model] - def add_engine(self, name: str, engine: TBEngine = None): - self._engines[name] = engine + def add_runtime(self, name: str, runtime: TBRuntime = None): + self._runtimes[name] = runtime self._update_models() return self diff --git a/tackleberry/engine/__init__.py b/tackleberry/runtime/__init__.py similarity index 67% rename from tackleberry/engine/__init__.py rename to tackleberry/runtime/__init__.py index bef2fa9..c9be60a 100644 --- a/tackleberry/engine/__init__.py +++ b/tackleberry/runtime/__init__.py @@ -1,8 +1,8 @@ import os import importlib -# Special case: Explicitly import TBEngine from base.py -from .base import TBEngine +# Special case: Explicitly import TBRuntime from base.py +from .base import TBRuntime # Automatically detect all Python files in the current directory current_dir = os.path.dirname(__file__) @@ -12,9 +12,9 @@ ] # Create __all__ with the module names and the special case -__all__ = ["TBEngine"] + [os.path.splitext(f)[0] for f in module_files] +__all__ = ["TBRuntime"] + [os.path.splitext(f)[0] for f in module_files] # Dynamically import the modules and add them to the global namespace -for module_name in __all__[1:]: # Skip "TBEngine" as it's already imported +for module_name in __all__[1:]: # Skip "TBRuntime" as it's already imported module = importlib.import_module(f".{module_name}", package=__name__) globals()[module_name] = module \ No newline at end of file diff --git a/tackleberry/runtime/anthropic.py b/tackleberry/runtime/anthropic.py new file mode 100644 index 0000000..fb6ce72 --- /dev/null +++ b/tackleberry/runtime/anthropic.py @@ -0,0 +1,54 @@ +import os + +from .base import TBRuntime +from ..context import TBContext +from ..chat import TBChat + +import instructor +from pydantic import BaseModel + +class TBRuntimeAnthropic(TBRuntime): + default_max_tokens = 512 + + def __init__(self, + api_key: str = None, + max_tokens: int = None, + **kwargs, + ): + self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") + if not isinstance(self.api_key, str) or len(self.api_key) < 51: + raise Exception(str(self)+" needs api_key (ANTHROPIC_API_KEY)") + from anthropic import Anthropic + self.client = Anthropic( + api_key=self.api_key, + **kwargs, + ) + self.max_tokens = max_tokens or TBRuntimeAnthropic.default_max_tokens + + def get_models(self): + models = [] + for model in self.client.models.list().data: + models.append(model.id) + models.sort() + return models + + def chat_context(self, chat: TBChat, context: TBContext, struct: BaseModel = None, **kwargs): + if struct is not None: + client = instructor.from_anthropic(self.client) + response = client.messages.create( + model=chat.model.name, + max_tokens=self.max_tokens, + messages=self.get_messages_from_context(context), + response_model=struct, + ) + return response + else: + response = self.client.messages.create( + model=chat.model.name, + max_tokens=self.max_tokens, + messages=self.get_messages_from_context(context), + ) + return response.content + + def __str__(self): + return f"TB Runtime Anthropic {hex(id(self))}" diff --git a/tackleberry/engine/base.py b/tackleberry/runtime/base.py similarity index 51% rename from tackleberry/engine/base.py rename to tackleberry/runtime/base.py index ba59632..5358a8a 100644 --- a/tackleberry/engine/base.py +++ b/tackleberry/runtime/base.py @@ -1,4 +1,6 @@ -class TBEngine: +from ..context import TBContext + +class TBRuntime: def __init__(self): pass @@ -14,5 +16,8 @@ def chat(self, model: str, **kwargs, ): - from ..model import TBModelChat - return TBModelChat(self, model, **kwargs) + from ..chat import TBChat + return TBChat(self.model(model), **kwargs) + + def get_messages_from_context(self, context: TBContext): + return context.to_messages() diff --git a/tackleberry/runtime/groq.py b/tackleberry/runtime/groq.py new file mode 100644 index 0000000..9dfb810 --- /dev/null +++ b/tackleberry/runtime/groq.py @@ -0,0 +1,49 @@ +import os + +from .base import TBRuntime +from ..context import TBContext +from ..chat import TBChat + +import instructor +from pydantic import BaseModel + +class TBRuntimeGroq(TBRuntime): + + def __init__(self, + api_key: str = None, + **kwargs, + ): + self.api_key = api_key or os.environ.get("GROQ_API_KEY") + if not isinstance(self.api_key, str) or len(self.api_key) < 51: + raise Exception("Groq needs api_key (GROQ_API_KEY)") + from groq import Groq + self.client = Groq( + api_key=self.api_key, + **kwargs, + ) + + def get_models(self): + models = [] + for model in self.client.models.list().data: + models.append(model.id) + models.sort() + return models + + def chat_context(self, chat: TBChat, context: TBContext, struct: BaseModel = None, **kwargs): + if struct is not None: + client = instructor.from_groq(self.client) + response = client.chat.completions.create( + model=chat.model.name, + messages=self.get_messages_from_context(context), + response_model=struct, + ) + return response + else: + response = self.client.chat.completions.create( + model=chat.model.name, + messages=self.get_messages_from_context(context), + ) + return response.content + + def __str__(self): + return f"TB Runtime Groq {hex(id(self))}" diff --git a/tackleberry/engine/ollama.py b/tackleberry/runtime/ollama.py similarity index 60% rename from tackleberry/engine/ollama.py rename to tackleberry/runtime/ollama.py index 9ff2102..41fe850 100644 --- a/tackleberry/engine/ollama.py +++ b/tackleberry/runtime/ollama.py @@ -2,12 +2,19 @@ from urllib.parse import urlparse import base64 -from .base import TBEngine +from typing import Union -class TBEngineOllama(TBEngine): +from .base import TBRuntime +from ..context import TBContext +from ..chat import TBChat + +from pydantic import BaseModel + +class TBRuntimeOllama(TBRuntime): def __init__(self, url: str = None, + keep_alive: Union[float, str] = None, **kwargs, ): if url is None: @@ -38,6 +45,8 @@ def __init__(self, auth_bytes = userinfo.encode("utf-8") auth_base64 = base64.b64encode(auth_bytes).decode("utf-8") kwargs['headers']['Authorization'] = 'Basic '+auth_base64 + if not keep_alive is None: + self.keep_alive = keep_alive from ollama import Client as Ollama self.client = Ollama( **kwargs, @@ -50,5 +59,24 @@ def get_models(self): models.sort() return models + def chat_context(self, chat: TBChat, context: TBContext, struct: BaseModel = None, **kwargs): + if struct is not None: + chat_kwargs = { + "model": chat.model.name, + "messages": self.get_messages_from_context(context), + "format": struct.model_json_schema(), + } + if hasattr(self, 'keep_alive'): + chat_kwargs["keep_alive"] = self.keep_alive + response = self.client.chat(**chat_kwargs, **kwargs) + return struct.model_validate_json(response.message.content) + else: + response = self.client.chat( + model=chat.model.name, + messages=self.get_messages_from_context(context), + **kwargs, + ) + return response.message.content + def __str__(self): - return f"TB Engine Ollama {hex(id(self))}" + return f"TB Runtime Ollama {hex(id(self))}" diff --git a/tackleberry/runtime/openai.py b/tackleberry/runtime/openai.py new file mode 100644 index 0000000..5374fb7 --- /dev/null +++ b/tackleberry/runtime/openai.py @@ -0,0 +1,48 @@ +import os + +from .base import TBRuntime +from ..context import TBContext +from ..chat import TBChat + +from pydantic import BaseModel + +class TBRuntimeOpenai(TBRuntime): + + def __init__(self, + api_key: str = None, + **kwargs, + ): + self.api_key = api_key or os.environ.get("OPENAI_API_KEY") + if not isinstance(self.api_key, str) or len(self.api_key) < 51: + raise Exception("OpenAI needs api_key (OPENAI_API_KEY)") + from openai import OpenAI + self.client = OpenAI( + api_key=self.api_key, + **kwargs, + ) + + def get_models(self): + models = [] + for model in self.client.models.list().data: + models.append(model.id) + models.sort() + return models + + def chat_context(self, chat: TBChat, context: TBContext, struct: BaseModel = None, **kwargs): + if struct is not None: + messages = self.get_messages_from_context(context) + response = self.client.beta.chat.completions.parse( + model=chat.model.name, + messages=self.get_messages_from_context(context), + response_format=struct, + ) + return response.choices[0].message.parsed + else: + response = self.client.chat.completions.create( + model=chat.model.name, + messages=self.get_messages_from_context(context), + ) + return response.content + + def __str__(self): + return f"TB Runtime OpenAI {hex(id(self))}" diff --git a/tackleberry/engine/trf.py b/tackleberry/runtime/trf.py similarity index 51% rename from tackleberry/engine/trf.py rename to tackleberry/runtime/trf.py index 74338fc..b6b0650 100644 --- a/tackleberry/engine/trf.py +++ b/tackleberry/runtime/trf.py @@ -1,6 +1,6 @@ -from .base import TBEngine +from .base import TBRuntime -class TBEngineTrf(TBEngine): +class TBRuntimeTrf(TBRuntime): def __init__(self, hf_token: str = None, @@ -9,4 +9,4 @@ def __init__(self, self.hf_token = hf_token def __str__(self): - return f"TB Engine HuggingFace transformers {hex(id(self))}" + return f"TB Runtime HuggingFace transformers {hex(id(self))}" diff --git a/tests/test_tackleberry.py b/tests/test_100_tackleberry.py similarity index 50% rename from tests/test_tackleberry.py rename to tests/test_100_tackleberry.py index aa7c37a..7a5a343 100644 --- a/tests/test_tackleberry.py +++ b/tests/test_100_tackleberry.py @@ -4,124 +4,126 @@ from unittest.mock import patch import requests from tackleberry import TB -from tackleberry.engine import TBEngine +from tackleberry.runtime import TBRuntime from tackleberry.model import TBModel +from tackleberry.chat import TBChat from tackleberry.context import TBContext, TBMessage -from tackleberry.engine.ollama import TBEngineOllama class TestTB(unittest.TestCase): def test_000_unknown(self): - """Test not existing Model and Engine""" + """Test not existing Model and Runtime""" with self.assertRaises(ModuleNotFoundError): - engine = TB.engine('xxxxx') + runtime = TB.runtime('xxxxx') with self.assertRaises(KeyError): model = TB.model('xxxxx') with self.assertRaises(KeyError): - modelchat = TB.chat('xxxxx') + chat = TB.chat('xxxxx') def test_010_openai(self): """Test OpenAI""" if os.environ.get("OPENAI_API_KEY"): - engine = TB.engine('openai') - self.assertIsInstance(engine, TBEngine) - self.assertEqual(type(engine).__name__, "TBEngineOpenai") - engine_model = engine.model('gpt-4o') - self.assertIsInstance(engine_model, TBModel) - self.assertEqual(type(engine_model).__name__, "TBModel") - engine_slash_model = TB.model('openai/gpt-4o') - self.assertIsInstance(engine_slash_model, TBModel) - self.assertEqual(type(engine_slash_model).__name__, "TBModel") + runtime = TB.runtime('openai') + self.assertIsInstance(runtime, TBRuntime) + self.assertEqual(type(runtime).__name__, "TBRuntimeOpenai") + runtime_model = runtime.model('gpt-4o') + self.assertIsInstance(runtime_model, TBModel) + self.assertEqual(type(runtime_model).__name__, "TBModel") + runtime_slash_model = TB.model('openai/gpt-4o') + self.assertIsInstance(runtime_slash_model, TBModel) + self.assertEqual(type(runtime_slash_model).__name__, "TBModel") model = TB.model('gpt-4o') self.assertIsInstance(model, TBModel) self.assertEqual(type(model).__name__, "TBModel") - self.assertIsInstance(model.engine, TBEngine) - self.assertEqual(type(model.engine).__name__, "TBEngineOpenai") - modelchat = TB.chat('gpt-4o') - self.assertIsInstance(modelchat, TBModel) - self.assertEqual(type(modelchat).__name__, "TBModelChat") - models = engine.get_models() + self.assertIsInstance(model.runtime, TBRuntime) + self.assertEqual(type(model.runtime).__name__, "TBRuntimeOpenai") + chat = TB.chat('gpt-4o') + self.assertIsInstance(chat, TBChat) + self.assertEqual(type(chat).__name__, "TBChat") + models = runtime.get_models() self.assertTrue(len(models) > 20) else: - warnings.warn("Can't test OpenAI engine without OPENAI_API_KEY", UserWarning) + warnings.warn("Can't test OpenAI runtime without OPENAI_API_KEY", UserWarning) def test_020_anthropic(self): """Test Anthropic""" if os.environ.get("ANTHROPIC_API_KEY"): - engine = TB.engine('anthropic') - self.assertIsInstance(engine, TBEngine) - self.assertEqual(type(engine).__name__, "TBEngineAnthropic") - engine_model = engine.model('claude-2.1') - self.assertIsInstance(engine_model, TBModel) - self.assertEqual(type(engine_model).__name__, "TBModel") - engine_slash_model = TB.model('anthropic/claude-2.1') - self.assertIsInstance(engine_slash_model, TBModel) - self.assertEqual(type(engine_slash_model).__name__, "TBModel") + runtime = TB.runtime('anthropic') + self.assertIsInstance(runtime, TBRuntime) + self.assertEqual(type(runtime).__name__, "TBRuntimeAnthropic") + runtime_model = runtime.model('claude-2.1') + self.assertIsInstance(runtime_model, TBModel) + self.assertEqual(type(runtime_model).__name__, "TBModel") + runtime_slash_model = TB.model('anthropic/claude-2.1') + self.assertIsInstance(runtime_slash_model, TBModel) + self.assertEqual(type(runtime_slash_model).__name__, "TBModel") model = TB.model('claude-2.1') self.assertIsInstance(model, TBModel) self.assertEqual(type(model).__name__, "TBModel") - self.assertIsInstance(model.engine, TBEngine) - self.assertEqual(type(model.engine).__name__, "TBEngineAnthropic") - modelchat = TB.chat('claude-2.1') - self.assertIsInstance(modelchat, TBModel) - self.assertEqual(type(modelchat).__name__, "TBModelChat") - models = engine.get_models() + self.assertIsInstance(model.runtime, TBRuntime) + self.assertEqual(type(model.runtime).__name__, "TBRuntimeAnthropic") + chat = TB.chat('claude-2.1') + self.assertIsInstance(chat, TBChat) + self.assertEqual(type(chat).__name__, "TBChat") + models = runtime.get_models() self.assertTrue(len(models) > 3) else: - warnings.warn("Can't test Anthropic engine without ANTHROPIC_API_KEY", UserWarning) + warnings.warn("Can't test Anthropic runtime without ANTHROPIC_API_KEY", UserWarning) def test_030_groq(self): """Test Groq""" if os.environ.get("GROQ_API_KEY"): - engine = TB.engine('groq') - self.assertIsInstance(engine, TBEngine) - self.assertEqual(type(engine).__name__, "TBEngineGroq") - engine_model = engine.model('llama3-8b-8192') - self.assertIsInstance(engine_model, TBModel) - self.assertEqual(type(engine_model).__name__, "TBModel") - engine_slash_model = TB.model('groq/llama3-8b-8192') - self.assertIsInstance(engine_slash_model, TBModel) - self.assertEqual(type(engine_slash_model).__name__, "TBModel") + runtime = TB.runtime('groq') + self.assertIsInstance(runtime, TBRuntime) + self.assertEqual(type(runtime).__name__, "TBRuntimeGroq") + runtime_model = runtime.model('llama3-8b-8192') + self.assertIsInstance(runtime_model, TBModel) + self.assertEqual(type(runtime_model).__name__, "TBModel") + runtime_slash_model = TB.model('groq/llama3-8b-8192') + self.assertIsInstance(runtime_slash_model, TBModel) + self.assertEqual(type(runtime_slash_model).__name__, "TBModel") model = TB.model('llama3-8b-8192') self.assertIsInstance(model, TBModel) self.assertEqual(type(model).__name__, "TBModel") - self.assertIsInstance(model.engine, TBEngine) - self.assertEqual(type(model.engine).__name__, "TBEngineGroq") - modelchat = TB.chat('llama3-8b-8192') - self.assertIsInstance(modelchat, TBModel) - self.assertEqual(type(modelchat).__name__, "TBModelChat") - models = engine.get_models() + self.assertIsInstance(model.runtime, TBRuntime) + self.assertEqual(type(model.runtime).__name__, "TBRuntimeGroq") + chat = TB.chat('llama3-8b-8192') + self.assertIsInstance(chat, TBChat) + self.assertEqual(type(chat).__name__, "TBChat") + models = runtime.get_models() self.assertTrue(len(models) > 10) else: - warnings.warn("Can't test Groq engine without GROQ_API_KEY", UserWarning) + warnings.warn("Can't test Groq runtime without GROQ_API_KEY", UserWarning) def test_040_ollama(self): """Test Ollama""" if os.environ.get("OLLAMA_HOST") or os.environ.get("OLLAMA_PROXY_URL"): - engine = TB.engine('ollama') - self.assertIsInstance(engine, TBEngine) - self.assertEqual(type(engine).__name__, "TBEngineOllama") - models = engine.get_models() + from tackleberry.runtime.ollama import TBRuntimeOllama + runtime = TB.runtime('ollama') + self.assertIsInstance(runtime, TBRuntime) + self.assertEqual(type(runtime).__name__, "TBRuntimeOllama") + models = runtime.get_models() self.assertTrue(len(models) > 0) else: - warnings.warn("Can't test Ollama engine without explicit setting OLLAMA_HOST or OLLAMA_PROXY_URL", UserWarning) + warnings.warn("Can't test Ollama runtime without explicit setting OLLAMA_HOST or OLLAMA_PROXY_URL", UserWarning) @patch('httpx.Client.send') def test_041_ollama_userpass(self, mock_send): """Test Ollama user pass to basic auth conversion""" if os.environ.get("OLLAMA_HOST") or os.environ.get("OLLAMA_PROXY_URL"): + from tackleberry.runtime.ollama import TBRuntimeOllama mock_response = unittest.mock.Mock() mock_response.status_code = 200 mock_response.json.return_value = {"models": []} mock_send.return_value = mock_response - engine = TBEngineOllama( + runtime = TBRuntimeOllama( url = 'https://user:pass@domain.com:5000', ) - self.assertEqual(type(engine).__name__, "TBEngineOllama") + self.assertEqual(type(runtime).__name__, "TBRuntimeOllama") - models = engine.get_models() + models = runtime.get_models() # Assert: Verify the request details mock_send.assert_called_once() @@ -131,7 +133,7 @@ def test_041_ollama_userpass(self, mock_send): self.assertEqual(request[0].url, 'https://domain.com:5000/api/tags') self.assertEqual(request[0].headers['authorization'], 'Basic dXNlcjpwYXNz') else: - warnings.warn("Can't test Ollama engine without explicit setting OLLAMA_HOST or OLLAMA_PROXY_URL", UserWarning) + warnings.warn("Can't test Ollama runtime without explicit setting OLLAMA_HOST or OLLAMA_PROXY_URL", UserWarning) def test_100_registry(self): """Test registry""" diff --git a/tests/test_200_structured_output.py b/tests/test_200_structured_output.py new file mode 100644 index 0000000..b3270d5 --- /dev/null +++ b/tests/test_200_structured_output.py @@ -0,0 +1,62 @@ +import unittest +import warnings +import os +from unittest.mock import patch +import requests +from tackleberry import TB +from pydantic import BaseModel + +import sys + +class UserInfo(BaseModel): + name: str + age: int + +class TestTB(unittest.TestCase): + + def test_010_openai(self): + """Test OpenAI""" + if os.environ.get("OPENAI_API_KEY"): + chat = TB.chat('openai/gpt-4o-mini') + user_info = chat.query("Extract the name and the age: 'John is 20 years old'", UserInfo) + self.assertIsInstance(user_info, UserInfo) + self.assertEqual(user_info.name, "John") + self.assertEqual(user_info.age, 20) + else: + warnings.warn("Can't test OpenAI runtime without OPENAI_API_KEY", UserWarning) + + def test_020_anthropic(self): + """Test Anthropic""" + if os.environ.get("ANTHROPIC_API_KEY"): + chat = TB.chat('anthropic/claude-3-5-sonnet-20241022') + user_info = chat.query("Extract the name and the age: 'John is 20 years old'", UserInfo) + self.assertIsInstance(user_info, UserInfo) + self.assertEqual(user_info.name, "John") + self.assertEqual(user_info.age, 20) + else: + warnings.warn("Can't test Anthropic runtime without ANTHROPIC_API_KEY", UserWarning) + + def test_030_groq(self): + """Test Groq""" + if os.environ.get("GROQ_API_KEY"): + chat = TB.chat('groq/llama3-8b-8192') + user_info = chat.query("Extract the name and the age: 'John is 20 years old'", UserInfo) + self.assertIsInstance(user_info, UserInfo) + self.assertEqual(user_info.name, "John") + self.assertEqual(user_info.age, 20) + else: + warnings.warn("Can't test Groq runtime without GROQ_API_KEY", UserWarning) + + def test_040_ollama(self): + """Test Ollama""" + if (os.environ.get("OLLAMA_HOST") or os.environ.get("OLLAMA_PROXY_URL")) and (os.environ.get("TACKLEBERRY_OLLAMA_TEST_MODEL") or 'gemma2:2b'): + chat = TB.chat('ollama/gemma2:2b') + user_info = chat.query("Extract the name and the age: 'John is 20 years old'", UserInfo) + self.assertIsInstance(user_info, UserInfo) + self.assertEqual(user_info.name, "John") + self.assertEqual(user_info.age, 20) + else: + warnings.warn("Can't test Ollama runtime without explicit setting OLLAMA_HOST or OLLAMA_PROXY_URL", UserWarning) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file