diff --git a/pyproject.toml b/pyproject.toml index cead7fe..e717324 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms-memory" -version = "0.0.7" +version = "0.0.8" description = "Swarms Memory - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -30,6 +30,8 @@ sentence-transformers = "*" pinecone = "*" faiss-cpu = "*" pydantic = "*" +sqlalchemy = "*" +sqlite3 = "*" diff --git a/server/api.py b/server/api.py index ddc190b..f63d7f6 100644 --- a/server/api.py +++ b/server/api.py @@ -58,9 +58,7 @@ def create_collection(request: CreateCollectionRequest): Creates a new collection with the specified name. """ try: - chroma_client.create_collection( - name=request.name - ) + chroma_client.create_collection(name=request.name) logger.info(f"Created collection with name: {request.name}") return { "message": f"Collection '{request.name}' created successfully." diff --git a/swarms_memory/__init__.py b/swarms_memory/__init__.py index a2b9c39..716b31f 100644 --- a/swarms_memory/__init__.py +++ b/swarms_memory/__init__.py @@ -1,2 +1,2 @@ -from swarms_memory.vector_dbs import * # noqa: F401, F403 -from swarms_memory.utils import * # noqa: F401, F403 \ No newline at end of file +from swarms_memory.vector_dbs import * # noqa: F401, F403 +from swarms_memory.utils import * # noqa: F401, F403 diff --git a/swarms_memory/dbs/__init__.py b/swarms_memory/dbs/__init__.py index e69de29..6d0f8e8 100644 --- a/swarms_memory/dbs/__init__.py +++ b/swarms_memory/dbs/__init__.py @@ -0,0 +1,5 @@ +from swarms_memory.dbs.pg import PostgresDB +from swarms_memory.dbs.sqlite import SQLiteDB + + +__all__ = ["PostgresDB", "SQLiteDB"] diff --git a/swarms_memory/dbs/pg.py b/swarms_memory/dbs/pg.py new file mode 100644 index 0000000..957f5a1 --- /dev/null +++ b/swarms_memory/dbs/pg.py @@ -0,0 +1,142 @@ +import uuid +from typing import Any, List, Optional + +from sqlalchemy import JSON, Column, String, create_engine +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session +from swarms_memory import BaseVectorDatabase + + +class PostgresDB(BaseVectorDatabase): + """ + A class representing a Postgres database. + + Args: + connection_string (str): The connection string for the Postgres database. + table_name (str): The name of the table in the database. + + Attributes: + engine: The SQLAlchemy engine for connecting to the database. + table_name (str): The name of the table in the database. + VectorModel: The SQLAlchemy model representing the vector table. + + """ + + def __init__( + self, connection_string: str, table_name: str, *args, **kwargs + ): + """ + Initializes a new instance of the PostgresDB class. + + Args: + connection_string (str): The connection string for the Postgres database. + table_name (str): The name of the table in the database. + + """ + self.engine = create_engine( + connection_string, *args, **kwargs + ) + self.table_name = table_name + self.VectorModel = self._create_vector_model() + + def _create_vector_model(self): + """ + Creates the SQLAlchemy model for the vector table. + + Returns: + The SQLAlchemy model representing the vector table. + + """ + Base = declarative_base() + + class VectorModel(Base): + __tablename__ = self.table_name + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + unique=True, + nullable=False, + ) + vector = Column( + String + ) # Assuming vector is stored as a string + namespace = Column(String) + meta = Column(JSON) + + return VectorModel + + def add( + self, + vector: str, + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + ) -> None: + """ + Adds or updates a vector in the database. + + Args: + vector (str): The vector to be added or updated. + vector_id (str, optional): The ID of the vector. If not provided, a new ID will be generated. + namespace (str, optional): The namespace of the vector. + meta (dict, optional): Additional metadata associated with the vector. + + """ + try: + with Session(self.engine) as session: + obj = self.VectorModel( + id=vector_id, + vector=vector, + namespace=namespace, + meta=meta, + ) + session.merge(obj) + session.commit() + except Exception as e: + print(f"Error adding or updating vector: {e}") + + def query( + self, query: Any, namespace: Optional[str] = None + ) -> List[Any]: + """ + Queries vectors from the database based on the given query and namespace. + + Args: + query (Any): The query or condition to filter the vectors. + namespace (str, optional): The namespace of the vectors to be queried. + + Returns: + List[Any]: A list of vectors that match the query and namespace. + + """ + try: + with Session(self.engine) as session: + q = session.query(self.VectorModel) + if namespace: + q = q.filter_by(namespace=namespace) + # Assuming 'query' is a condition or filter + q = q.filter(query) + return q.all() + except Exception as e: + print(f"Error querying vectors: {e}") + return [] + + def delete_vector(self, vector_id): + """ + Deletes a vector from the database based on the given vector ID. + + Args: + vector_id: The ID of the vector to be deleted. + + """ + try: + with Session(self.engine) as session: + obj = session.get(self.VectorModel, vector_id) + if obj: + session.delete(obj) + session.commit() + except Exception as e: + print(f"Error deleting vector: {e}") diff --git a/swarms_memory/dbs/pinecone.py b/swarms_memory/dbs/pinecone.py new file mode 100644 index 0000000..534c603 --- /dev/null +++ b/swarms_memory/dbs/pinecone.py @@ -0,0 +1,219 @@ +from typing import Optional + +import pinecone +from attr import define, field + +from swarms_memory import BaseVectorDatabase +from swarms.utils import str_to_hash + + +@define +class PineconeDB(BaseVectorDatabase): + """ + PineconeDB is a vector storage driver that uses Pinecone as the underlying storage engine. + + Pinecone is a vector database that allows you to store, search, and retrieve high-dimensional vectors with + blazing speed and low latency. It is a managed service that is easy to use and scales effortlessly, so you can + focus on building your applications instead of managing your infrastructure. + + Args: + api_key (str): The API key for your Pinecone account. + index_name (str): The name of the index to use. + environment (str): The environment to use. Either "us-west1-gcp" or "us-east1-gcp". + project_name (str, optional): The name of the project to use. Defaults to None. + index (pinecone.Index, optional): The Pinecone index to use. Defaults to None. + + Methods: + upsert_vector(vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs) -> str: + Upserts a vector into the index. + load_entry(vector_id: str, namespace: Optional[str] = None) -> Optional[BaseVectorStore.Entry]: + Loads a single vector from the index. + load_entries(namespace: Optional[str] = None) -> list[BaseVectorStore.Entry]: + Loads all vectors from the index. + query(query: str, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata=True, **kwargs) -> list[BaseVectorStore.QueryResult]: + Queries the index for vectors similar to the given query string. + create_index(name: str, **kwargs) -> None: + Creates a new index. + + Usage: + >>> from swarms_memory.vector_stores.pinecone import PineconeDB + >>> from swarms.utils.embeddings import USEEmbedding + >>> from swarms.utils.hash import str_to_hash + >>> from swarms.utils.dataframe import dataframe_to_hash + >>> import pandas as pd + >>> + >>> # Create a new PineconeDB instance: + >>> pv = PineconeDB( + >>> api_key="your-api-key", + >>> index_name="your-index-name", + >>> environment="us-west1-gcp", + >>> project_name="your-project-name" + >>> ) + >>> # Create a new index: + >>> pv.create_index("your-index-name") + >>> # Create a new USEEmbedding instance: + >>> use = USEEmbedding() + >>> # Create a new dataframe: + >>> df = pd.DataFrame({ + >>> "text": [ + >>> "This is a test", + >>> "This is another test", + >>> "This is a third test" + >>> ] + >>> }) + >>> # Embed the dataframe: + >>> df["embedding"] = df["text"].apply(use.embed_string) + >>> # Upsert the dataframe into the index: + >>> pv.upsert_vector( + >>> vector=df["embedding"].tolist(), + >>> vector_id=dataframe_to_hash(df), + >>> namespace="your-namespace" + >>> ) + >>> # Query the index: + >>> pv.query( + >>> query="This is a test", + >>> count=10, + >>> namespace="your-namespace" + >>> ) + >>> # Load a single entry from the index: + >>> pv.load_entry( + >>> vector_id=dataframe_to_hash(df), + >>> namespace="your-namespace" + >>> ) + >>> # Load all entries from the index: + >>> pv.load_entries( + >>> namespace="your-namespace" + >>> ) + + + """ + + api_key: str = field(kw_only=True) + index_name: str = field(kw_only=True) + environment: str = field(kw_only=True) + project_name: Optional[str] = field(default=None, kw_only=True) + index: pinecone.Index = field(init=False) + + def __attrs_post_init__(self) -> None: + """Post init""" + pinecone.init( + api_key=self.api_key, + environment=self.environment, + project_name=self.project_name, + ) + + self.index = pinecone.Index(self.index_name) + + def add( + self, + vector: list[float], + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs, + ) -> str: + """Add a vector to the index. + + Args: + vector (list[float]): _description_ + vector_id (Optional[str], optional): _description_. Defaults to None. + namespace (Optional[str], optional): _description_. Defaults to None. + meta (Optional[dict], optional): _description_. Defaults to None. + + Returns: + str: _description_ + """ + vector_id = ( + vector_id if vector_id else str_to_hash(str(vector)) + ) + + params = {"namespace": namespace} | kwargs + + self.index.upsert([(vector_id, vector, meta)], **params) + + return vector_id + + def load_entries(self, namespace: Optional[str] = None): + """Load all entries from the index. + + Args: + namespace (Optional[str], optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching + # all values from a namespace: + # https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5 + + results = self.index.query( + self.embedding_driver.embed_string(""), + top_k=10000, + include_metadata=True, + namespace=namespace, + ) + + for result in results["matches"]: + entry = { + "id": result["id"], + "vector": result["values"], + "meta": result["metadata"], + "namespace": result["namespace"], + } + return entry + + def query( + self, + query: str, + count: Optional[int] = None, + namespace: Optional[str] = None, + include_vectors: bool = False, + # PineconeDBStorageDriver-specific params: + include_metadata=True, + **kwargs, + ): + """Query the index for vectors similar to the given query string. + + Args: + query (str): _description_ + count (Optional[int], optional): _description_. Defaults to None. + namespace (Optional[str], optional): _description_. Defaults to None. + include_vectors (bool, optional): _description_. Defaults to False. + include_metadata (bool, optional): _description_. Defaults to True. + + Returns: + _type_: _description_ + """ + vector = self.embedding_driver.embed_string(query) + + params = { + "top_k": count, + "namespace": namespace, + "include_values": include_vectors, + "include_metadata": include_metadata, + } | kwargs + + results = self.index.query(vector, **params) + + for r in results["matches"]: + entry = { + "id": results["id"], + "vector": results["values"], + "score": results["scores"], + "meta": results["metadata"], + "namespace": results["namespace"], + } + return entry + + def create_index(self, name: str, **kwargs) -> None: + """Create a new index. + + Args: + name (str): _description_ + """ + params = { + "name": name, + "dimension": self.embedding_driver.dimensions, + } | kwargs + + pinecone.create_index(**params) diff --git a/swarms_memory/dbs/sqlite.py b/swarms_memory/dbs/sqlite.py new file mode 100644 index 0000000..02c1348 --- /dev/null +++ b/swarms_memory/dbs/sqlite.py @@ -0,0 +1,121 @@ +from typing import Any, List, Optional, Tuple + +from swarms_memory import BaseVectorDatabase + +try: + import sqlite3 +except ImportError: + raise ImportError( + "Please install sqlite3 to use the SQLiteDB class." + ) + + +class SQLiteDB(BaseVectorDatabase): + """ + A reusable class for SQLite database operations with methods for adding, + deleting, updating, and querying data. + + Attributes: + db_path (str): The file path to the SQLite database. + """ + + def __init__(self, db_path: str): + """ + Initializes the SQLiteDB class with the given database path. + + Args: + db_path (str): The file path to the SQLite database. + """ + self.db_path = db_path + + def execute_query( + self, query: str, params: Optional[Tuple[Any, ...]] = None + ) -> List[Tuple]: + """ + Executes a SQL query and returns fetched results. + + Args: + query (str): The SQL query to execute. + params (Tuple[Any, ...], optional): The parameters to substitute into the query. + + Returns: + List[Tuple]: The results fetched from the database. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params or ()) + return cursor.fetchall() + except Exception as error: + print(f"Error executing query: {error}") + raise error + + def add(self, query: str, params: Tuple[Any, ...]) -> None: + """ + Adds a new entry to the database. + + Args: + query (str): The SQL query for insertion. + params (Tuple[Any, ...]): The parameters to substitute into the query. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + except Exception as error: + print(f"Error adding new entry: {error}") + raise error + + def delete(self, query: str, params: Tuple[Any, ...]) -> None: + """ + Deletes an entry from the database. + + Args: + query (str): The SQL query for deletion. + params (Tuple[Any, ...]): The parameters to substitute into the query. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + except Exception as error: + print(f"Error deleting entry: {error}") + raise error + + def update(self, query: str, params: Tuple[Any, ...]) -> None: + """ + Updates an entry in the database. + + Args: + query (str): The SQL query for updating. + params (Tuple[Any, ...]): The parameters to substitute into the query. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(query, params) + conn.commit() + except Exception as error: + print(f"Error updating entry: {error}") + raise error + + def query( + self, query: str, params: Optional[Tuple[Any, ...]] = None + ) -> List[Tuple]: + """ + Fetches data from the database based on a query. + + Args: + query (str): The SQL query to execute. + params (Tuple[Any, ...], optional): The parameters to substitute into the query. + + Returns: + List[Tuple]: The results fetched from the database. + """ + try: + return self.execute_query(query, params) + except Exception as error: + print(f"Error querying database: {error}") + raise error diff --git a/swarms_memory/utils/__init__.py b/swarms_memory/utils/__init__.py index c38a3ba..a20a717 100644 --- a/swarms_memory/utils/__init__.py +++ b/swarms_memory/utils/__init__.py @@ -1,6 +1,8 @@ from swarms_memory.utils.action_subtask import ActionSubtaskEntry -from swarms_memory.utils.dict_internal_memory import DictInternalMemory +from swarms_memory.utils.dict_internal_memory import ( + DictInternalMemory, +) from swarms_memory.utils.dict_shared_memory import DictSharedMemory from swarms_memory.utils.short_term_memory import ShortTermMemory from swarms_memory.utils.visual_memory import VisualShortTermMemory diff --git a/swarms_memory/utils/dict_shared_memory.py b/swarms_memory/utils/dict_shared_memory.py index 8ac92d1..f81e2fd 100644 --- a/swarms_memory/utils/dict_shared_memory.py +++ b/swarms_memory/utils/dict_shared_memory.py @@ -44,7 +44,9 @@ def add( entry_id = str(uuid.uuid4()) data = {} epoch = datetime.datetime.utcfromtimestamp(0) - epoch = (datetime.datetime.utcnow() - epoch).total_seconds() + epoch = ( + datetime.datetime.utcnow() - epoch + ).total_seconds() data[entry_id] = { "agent": agent_id, "epoch": epoch, diff --git a/swarms_memory/utils/short_term_memory.py b/swarms_memory/utils/short_term_memory.py index 1b8056c..5768957 100644 --- a/swarms_memory/utils/short_term_memory.py +++ b/swarms_memory/utils/short_term_memory.py @@ -2,6 +2,7 @@ import logging import threading + class ShortTermMemory: """Short term memory. @@ -37,7 +38,9 @@ def __init__( self.medium_term_memory = [] self.lock = threading.Lock() - def add(self, role: str = None, message: str = None, *args, **kwargs): + def add( + self, role: str = None, message: str = None, *args, **kwargs + ): """Add a message to the short term memory. Args: @@ -155,7 +158,9 @@ def save_to_file(self, filename: str): with open(filename, "w") as f: json.dump( { - "short_term_memory": (self.short_term_memory), + "short_term_memory": ( + self.short_term_memory + ), "medium_term_memory": ( self.medium_term_memory ), @@ -177,7 +182,9 @@ def load_from_file(self, filename: str, *args, **kwargs): with self.lock: with open(filename) as f: data = json.load(f) - self.short_term_memory = data.get("short_term_memory", []) + self.short_term_memory = data.get( + "short_term_memory", [] + ) self.medium_term_memory = data.get( "medium_term_memory", [] ) diff --git a/swarms_memory/vector_dbs/__init__.py b/swarms_memory/vector_dbs/__init__.py index eed6900..ea24111 100644 --- a/swarms_memory/vector_dbs/__init__.py +++ b/swarms_memory/vector_dbs/__init__.py @@ -3,4 +3,9 @@ from swarms_memory.vector_dbs.faiss_wrapper import FAISSDB from swarms_memory.vector_dbs.base_vectordb import BaseVectorDatabase -__all__ = ["ChromaDB", "PineconeMemory", "FAISSDB", "BaseVectorDatabase"] \ No newline at end of file +__all__ = [ + "ChromaDB", + "PineconeMemory", + "FAISSDB", + "BaseVectorDatabase", +] diff --git a/swarms_memory/vector_dbs/base_vectordb.py b/swarms_memory/vector_dbs/base_vectordb.py index da16506..3576852 100644 --- a/swarms_memory/vector_dbs/base_vectordb.py +++ b/swarms_memory/vector_dbs/base_vectordb.py @@ -1,6 +1,7 @@ from abc import ABC from loguru import logger + class BaseVectorDatabase(ABC): """ Abstract base class for a database. diff --git a/swarms_memory/vector_dbs/chroma_db_wrapper.py b/swarms_memory/vector_dbs/chroma_db_wrapper.py index 6521374..b33d5dc 100644 --- a/swarms_memory/vector_dbs/chroma_db_wrapper.py +++ b/swarms_memory/vector_dbs/chroma_db_wrapper.py @@ -85,9 +85,7 @@ def __init__( # If docs if docs_folder: - logger.info( - f"Traversing directory: {docs_folder}" - ) + logger.info(f"Traversing directory: {docs_folder}") self.traverse_directory() def add( diff --git a/tests/utils/test_dictsharedmemory.py b/tests/utils/test_dictsharedmemory.py index ee945e3..83ba186 100644 --- a/tests/utils/test_dictsharedmemory.py +++ b/tests/utils/test_dictsharedmemory.py @@ -63,7 +63,9 @@ def test_parametrized_get_top_n( memory_instance, scores, agent_ids, expected_top_score ): for score, agent_id in zip(scores, agent_ids): - memory_instance.add(score, agent_id, 1, f"Entry by {agent_id}") + memory_instance.add( + score, agent_id, 1, f"Entry by {agent_id}" + ) top_1 = memory_instance.get_top_n(1) top_score = next(iter(top_1.values()))["score"] assert ( @@ -76,7 +78,9 @@ def test_parametrized_get_top_n( def test_add_entry_invalid_input(memory_instance): with pytest.raises(ValueError): - memory_instance.add("invalid_score", "agent123", 1, "Test Entry") + memory_instance.add( + "invalid_score", "agent123", 1, "Test Entry" + ) # Mocks and monkey-patching diff --git a/tests/utils/test_langchainchromavectormemory.py b/tests/utils/test_langchainchromavectormemory.py index 35c12c2..0b0d3ba 100644 --- a/tests/utils/test_langchainchromavectormemory.py +++ b/tests/utils/test_langchainchromavectormemory.py @@ -42,7 +42,9 @@ def test_initialization_default_settings(vector_memory): def test_add_entry(vector_memory, embeddings_mock): - with patch.object(vector_memory.db, "add_texts") as add_texts_mock: + with patch.object( + vector_memory.db, "add_texts" + ) as add_texts_mock: vector_memory.add("Example text") add_texts_mock.assert_called() @@ -88,5 +90,7 @@ def test_search_memory_different_params( "similarity_search_with_score", return_value=expected, ): - result = vector_memory.search_memory(query, k=k, type=type) + result = vector_memory.search_memory( + query, k=k, type=type + ) assert len(result) == (k if k > 0 else 0) diff --git a/tests/utils/test_pinecone.py b/tests/utils/test_pinecone.py index b9f2cf8..1aeb7c0 100644 --- a/tests/utils/test_pinecone.py +++ b/tests/utils/test_pinecone.py @@ -1,7 +1,7 @@ import os from unittest.mock import patch -from examples.memory.pinecone import PineconeDB +from swarms_memory.dbs.pinecone import PineconeDB api_key = os.getenv("PINECONE_API_KEY") or "" diff --git a/tests/utils/test_pq_db.py b/tests/utils/test_pq_db.py index 2fbffb6..be941fe 100644 --- a/tests/utils/test_pq_db.py +++ b/tests/utils/test_pq_db.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv -from examples.memory.pg import PostgresDB +from swarms_memory.dbs.pg import PostgresDB load_dotenv() diff --git a/tests/utils/test_short_term_memory.py b/tests/utils/test_short_term_memory.py index 39e78e5..32a5233 100644 --- a/tests/utils/test_short_term_memory.py +++ b/tests/utils/test_short_term_memory.py @@ -71,7 +71,9 @@ def test_search_memory(): memory = ShortTermMemory() memory.add("user", "Hello, world!") assert memory.search_memory("Hello") == { - "short_term": [(0, {"role": "user", "message": "Hello, world!"})], + "short_term": [ + (0, {"role": "user", "message": "Hello, world!"}) + ], "medium_term": [], } @@ -112,7 +114,9 @@ def add_messages(): for _ in range(1000): memory.add("user", "Hello, world!") - threads = [threading.Thread(target=add_messages) for _ in range(10)] + threads = [ + threading.Thread(target=add_messages) for _ in range(10) + ] for thread in threads: thread.start() for thread in threads: diff --git a/tests/utils/test_sqlite.py b/tests/utils/test_sqlite.py index 5f36ca3..ace01de 100644 --- a/tests/utils/test_sqlite.py +++ b/tests/utils/test_sqlite.py @@ -8,7 +8,9 @@ @pytest.fixture def db(): conn = sqlite3.connect(":memory:") - conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + conn.execute( + "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" + ) conn.commit() return SQLiteDB(":memory:") @@ -28,7 +30,9 @@ def test_delete(db): def test_update(db): db.add("INSERT INTO test (name) VALUES (?)", ("test",)) - db.update("UPDATE test SET name = ? WHERE name = ?", ("new", "test")) + db.update( + "UPDATE test SET name = ? WHERE name = ?", ("new", "test") + ) result = db.query("SELECT * FROM test") assert result == [(1, "new")] @@ -97,4 +101,6 @@ def test_query_with_wrong_query(db): def test_execute_query_with_wrong_query(db): with pytest.raises(sqlite3.OperationalError): - db.execute_query("SELECT * FROM wrong WHERE name = ?", ("test",)) + db.execute_query( + "SELECT * FROM wrong WHERE name = ?", ("test",) + ) diff --git a/tests/vector_dbs/test_qdrant.py b/tests/vector_dbs/test_qdrant.py index caa940e..0750b63 100644 --- a/tests/vector_dbs/test_qdrant.py +++ b/tests/vector_dbs/test_qdrant.py @@ -29,7 +29,9 @@ def test_qdrant_init(qdrant_client, mock_qdrant_client): assert qdrant_client.client is not None -def test_load_embedding_model(qdrant_client, mock_sentence_transformer): +def test_load_embedding_model( + qdrant_client, mock_sentence_transformer +): qdrant_client._load_embedding_model("model_name") mock_sentence_transformer.assert_called_once_with("model_name") diff --git a/tests/vector_dbs/test_weaviate.py b/tests/vector_dbs/test_weaviate.py index 93e9aaf..5abd818 100644 --- a/tests/vector_dbs/test_weaviate.py +++ b/tests/vector_dbs/test_weaviate.py @@ -16,7 +16,9 @@ def weaviate_client_mock(): grpc_port="mock_grpc_port", grpc_secure=False, auth_client_secret="mock_api_key", - additional_headers={"X-OpenAI-Api-Key": "mock_openai_api_key"}, + additional_headers={ + "X-OpenAI-Api-Key": "mock_openai_api_key" + }, additional_config=Mock(), ) @@ -72,7 +74,9 @@ def test_update_object(weaviate_client_mock): # Test updating an object object_id = "12345" properties = {"name": "Jane"} - weaviate_client_mock.update("test_collection", object_id, properties) + weaviate_client_mock.update( + "test_collection", object_id, properties + ) weaviate_client_mock.client.collections.get.assert_called_with( "test_collection" ) @@ -139,7 +143,9 @@ def test_create_collection_failure(weaviate_client_mock): "weaviate_client.weaviate.collections.create", side_effect=Exception("Create error"), ): - with pytest.raises(Exception, match="Error creating collection"): + with pytest.raises( + Exception, match="Error creating collection" + ): weaviate_client_mock.create_collection( "test_collection", [{"name": "property"}] )