From c2aa671b130ef9b4dfe99244e87ba08ac8306961 Mon Sep 17 00:00:00 2001 From: Torsten Raudssus Date: Fri, 3 Jan 2025 04:32:46 +0100 Subject: [PATCH] Fixing some details, adding tests for basic auth feature on ollama --- README.md | 4 ++-- pyproject.toml | 7 +++++- tackleberry/__init__.py | 2 +- tackleberry/engine/hf.py | 15 +++++++++++++ tackleberry/engine/ollama.py | 19 ++++++++-------- tests/test_tackleberry.py | 42 ++++++++++++++++++++++++++++++------ 6 files changed, 70 insertions(+), 19 deletions(-) create mode 100644 tackleberry/engine/hf.py diff --git a/README.md b/README.md index 739fa35..5103ae7 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,3 @@ -# Tackleberry +[![Tackleberry Factory](https://raw.githubusercontent.com/Getty/tackleberry/main/tackleberry.jpg)](https://github.com/Getty/tackleberry) -![Tackleberry Factory](tackleberry.jpg) +# Tackleberry diff --git a/pyproject.toml b/pyproject.toml index 49fd9c5..5673faf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tackleberry" -version = "0.1.0.dev0" +version = "0.1.0.dev1" description = "Tackleberry (or TB) is helping you tackle the access to AI" authors = [ { name = "Torsten Raudßus", email = "torsten@raudssus.de" }, @@ -20,6 +20,11 @@ dev = [ "black>=22.0.0", "isort>=5.0.0", "mypy>=1.0.0", + "openai", + "ollama", + "transformers", + "groq", + "anthropic", ] [build-system] diff --git a/tackleberry/__init__.py b/tackleberry/__init__.py index 727b108..881ebff 100644 --- a/tackleberry/__init__.py +++ b/tackleberry/__init__.py @@ -6,4 +6,4 @@ try: __version__ = version("tackleberry") except ImportError: - __version__ = "0.0.0" + __version__ = "0.0.1.dev0" diff --git a/tackleberry/engine/hf.py b/tackleberry/engine/hf.py new file mode 100644 index 0000000..08f70c4 --- /dev/null +++ b/tackleberry/engine/hf.py @@ -0,0 +1,15 @@ +from typing import Any, Union, Dict, List, Optional +import os + +from . import TBEngine + +class TBEngineHf(TBEngine): + + def __init__(self, + hf_token: str = None, + **kwargs, + ): + self.hf_token = hf_token + + def __str__(self): + return f"TB Engine HuggingFace {hex(id(self))}" diff --git a/tackleberry/engine/ollama.py b/tackleberry/engine/ollama.py index ca891c2..d8cca7e 100644 --- a/tackleberry/engine/ollama.py +++ b/tackleberry/engine/ollama.py @@ -11,15 +11,16 @@ def __init__(self, url: str = None, **kwargs, ): - url = os.environ.get("OLLAMA_HOST") - userinfo = None - if os.environ.get("OLLAMA_PROXY_URL"): - if not url is None: - raise Exception("OLLAMA_PROXY_URL and OLLAMA_HOST set, please just use one") - else: - url = os.environ.get("OLLAMA_PROXY_URL") + if url is None: + url = os.environ.get("OLLAMA_HOST") + userinfo = None + if os.environ.get("OLLAMA_PROXY_URL"): + if not url is None: + raise Exception("OLLAMA_PROXY_URL and OLLAMA_HOST set, please just use one") + else: + url = os.environ.get("OLLAMA_PROXY_URL") if url: - parsed_url = urlparse(os.environ.get("OLLAMA_HOST")) + parsed_url = urlparse(url) if parsed_url.scheme in ["http", "https"] and parsed_url.netloc: if "@" in parsed_url.netloc: userinfo = parsed_url.netloc.split("@")[0] @@ -30,7 +31,7 @@ def __init__(self, parsed_url = parsed_url._replace(netloc=netloc) url = parsed_url.geturl() elif parsed_url.path: - url = 'http://'+parsed_url.path+'/' + url = parsed_url.scheme+'://'+parsed_url.path+'/' kwargs['host'] = url if userinfo: if not 'headers' in kwargs: diff --git a/tests/test_tackleberry.py b/tests/test_tackleberry.py index afa622a..6b59ee2 100644 --- a/tests/test_tackleberry.py +++ b/tests/test_tackleberry.py @@ -1,10 +1,13 @@ import unittest import warnings import os +from unittest.mock import patch +import requests from tackleberry import TB from tackleberry.engine import TBEngine from tackleberry.model import TBModel from tackleberry.context import TBContext, TBMessage +from tackleberry.engine.ollama import TBEngineOllama class TestTB(unittest.TestCase): @@ -15,7 +18,7 @@ def test_000_unknown(self): with self.assertRaises(KeyError): model = TB.model('xxxxx') - def test_001_openai(self): + def test_010_openai(self): """Test OpenAI""" if os.environ.get("OPENAI_API_KEY"): engine = TB.engine('openai') @@ -37,7 +40,7 @@ def test_001_openai(self): else: warnings.warn("Can't test OpenAI engine without OPENAI_API_KEY", UserWarning) - def test_002_anthropic(self): + def test_020_anthropic(self): """Test Anthropic""" if os.environ.get("ANTHROPIC_API_KEY"): engine = TB.engine('anthropic') @@ -59,7 +62,7 @@ def test_002_anthropic(self): else: warnings.warn("Can't test Anthropic engine without ANTHROPIC_API_KEY", UserWarning) - def test_003_groq(self): + def test_030_groq(self): """Test Groq""" if os.environ.get("GROQ_API_KEY"): engine = TB.engine('groq') @@ -81,7 +84,7 @@ def test_003_groq(self): else: warnings.warn("Can't test Groq engine without GROQ_API_KEY", UserWarning) - def test_004_ollama(self): + def test_040_ollama(self): """Test Ollama""" if os.environ.get("OLLAMA_HOST") or os.environ.get("OLLAMA_PROXY_URL"): engine = TB.engine('ollama') @@ -92,11 +95,38 @@ def test_004_ollama(self): else: warnings.warn("Can't test Ollama engine without explicit setting OLLAMA_HOST or OLLAMA_PROXY_URL", UserWarning) - def test_010_registry(self): + @patch('httpx.Client.send') + def test_041_ollama_userpass(self, mock_send): + """Test Ollama user pass to basic auth conversion""" + if os.environ.get("OLLAMA_HOST") or os.environ.get("OLLAMA_PROXY_URL"): + mock_response = unittest.mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"models": []} + + mock_send.return_value = mock_response + + engine = TBEngineOllama( + url = 'https://user:pass@domain.com:5000', + ) + self.assertEqual(type(engine).__name__, "TBEngineOllama") + + models = engine.get_models() + + # Assert: Verify the request details + mock_send.assert_called_once() + request, kwargs = mock_send.call_args + + self.assertEqual(request[0].method, 'GET') + self.assertEqual(request[0].url, 'https://domain.com:5000/api/tags') + self.assertEqual(request[0].headers['authorization'], 'Basic dXNlcjpwYXNz') + else: + warnings.warn("Can't test Ollama engine without explicit setting OLLAMA_HOST or OLLAMA_PROXY_URL", UserWarning) + + def test_100_registry(self): """Test registry""" self.assertEqual(TB.count, 1) - def test_020_context(self): + def test_200_context(self): """Test context""" nosys_context = TB.context() self.assertIsInstance(nosys_context, TBContext)