From 8149b7f1ca7ec4e23ebd0c03ff813fce571e6d92 Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Tue, 27 Aug 2024 17:34:44 -0400 Subject: [PATCH] [CLEANUP][Tests][Non Vector DB Memory] --- LICENSE | 2 +- README.md | 2 - examples/faiss_example.py | 2 +- examples/pinecome_wrapper_example.py | 2 +- pyproject.toml | 6 +- requirements.txt | 1 + swarms_memory/__init__.py | 7 +- swarms_memory/dbs/__init__.py | 0 swarms_memory/utils/__init__.py | 14 ++ swarms_memory/utils/action_subtask.py | 15 ++ swarms_memory/utils/dict_internal_memory.py | 86 ++++++++ swarms_memory/utils/dict_shared_memory.py | 96 +++++++++ swarms_memory/utils/short_term_memory.py | 186 +++++++++++++++++ swarms_memory/utils/visual_memory.py | 118 +++++++++++ swarms_memory/vector_dbs/__init__.py | 6 + swarms_memory/vector_dbs/base_db.py | 139 +++++++++++++ swarms_memory/vector_dbs/base_vectordb.py | 148 ++++++++++++++ .../{ => vector_dbs}/chroma_db_wrapper.py | 16 +- .../{ => vector_dbs}/faiss_wrapper.py | 0 .../{ => vector_dbs}/pinecone_wrapper.py | 0 tests/utils/test_dictinternalmemory.py | 71 +++++++ tests/utils/test_dictsharedmemory.py | 88 ++++++++ .../utils/test_langchainchromavectormemory.py | 92 +++++++++ tests/utils/test_pinecone.py | 82 ++++++++ tests/utils/test_pq_db.py | 80 ++++++++ tests/utils/test_short_term_memory.py | 130 ++++++++++++ tests/utils/test_sqlite.py | 100 +++++++++ tests/{ => vector_dbs}/test_chromadb.py | 2 +- tests/{ => vector_dbs}/test_pinecone.py | 2 +- tests/vector_dbs/test_qdrant.py | 52 +++++ tests/vector_dbs/test_weaviate.py | 192 ++++++++++++++++++ 31 files changed, 1712 insertions(+), 25 deletions(-) create mode 100644 swarms_memory/dbs/__init__.py create mode 100644 swarms_memory/utils/__init__.py create mode 100644 swarms_memory/utils/action_subtask.py create mode 100644 swarms_memory/utils/dict_internal_memory.py create mode 100644 swarms_memory/utils/dict_shared_memory.py create mode 100644 swarms_memory/utils/short_term_memory.py create mode 100644 swarms_memory/utils/visual_memory.py create mode 100644 swarms_memory/vector_dbs/__init__.py create mode 100644 swarms_memory/vector_dbs/base_db.py create mode 100644 swarms_memory/vector_dbs/base_vectordb.py rename swarms_memory/{ => vector_dbs}/chroma_db_wrapper.py (92%) rename swarms_memory/{ => vector_dbs}/faiss_wrapper.py (100%) rename swarms_memory/{ => vector_dbs}/pinecone_wrapper.py (100%) create mode 100644 tests/utils/test_dictinternalmemory.py create mode 100644 tests/utils/test_dictsharedmemory.py create mode 100644 tests/utils/test_langchainchromavectormemory.py create mode 100644 tests/utils/test_pinecone.py create mode 100644 tests/utils/test_pq_db.py create mode 100644 tests/utils/test_short_term_memory.py create mode 100644 tests/utils/test_sqlite.py rename tests/{ => vector_dbs}/test_chromadb.py (96%) rename tests/{ => vector_dbs}/test_pinecone.py (96%) create mode 100644 tests/vector_dbs/test_qdrant.py create mode 100644 tests/vector_dbs/test_weaviate.py diff --git a/LICENSE b/LICENSE index ca69c7e..1df1043 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Eternal Reclaimer +Copyright (c) 2023 The Galactic Swarm Corporation Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index f632af2..bbba3da 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ -[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf) -
diff --git a/examples/faiss_example.py b/examples/faiss_example.py index 7eb4d3d..010e39f 100644 --- a/examples/faiss_example.py +++ b/examples/faiss_example.py @@ -1,5 +1,5 @@ from typing import List, Dict, Any -from swarms_memory.faiss_wrapper import FAISSDB +from swarms_memory.vector_dbs.faiss_wrapper import FAISSDB from transformers import AutoTokenizer, AutoModel diff --git a/examples/pinecome_wrapper_example.py b/examples/pinecome_wrapper_example.py index 5d135e7..a837802 100644 --- a/examples/pinecome_wrapper_example.py +++ b/examples/pinecome_wrapper_example.py @@ -1,5 +1,5 @@ from typing import List, Dict, Any -from swarms_memory.pinecone_wrapper import PineconeMemory +from swarms_memory.vector_dbs.pinecone_wrapper import PineconeMemory # Example usage diff --git a/pyproject.toml b/pyproject.toml index 77224a3..01f3c91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms-memory" -version = "0.0.5" +version = "0.0.7" description = "Swarms Memory - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -28,7 +28,9 @@ chromadb = "*" loguru = "*" sentence-transformers = "*" pinecone = "*" -faiss = "*" +faiss-cpu = "*" +pydantic = "*" + diff --git a/requirements.txt b/requirements.txt index 782e092..0be3138 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ sentence-transformers pinecone faiss-cpu torch +pydantic diff --git a/swarms_memory/__init__.py b/swarms_memory/__init__.py index 6e917b2..a2b9c39 100644 --- a/swarms_memory/__init__.py +++ b/swarms_memory/__init__.py @@ -1,5 +1,2 @@ -from swarms_memory.chroma_db_wrapper import ChromaDB -from swarms_memory.pinecone_wrapper import PineconeMemory -from swarms_memory.faiss_wrapper import FAISSDB - -__all__ = ["ChromaDB", "PineconeMemory", "FAISSDB"] +from swarms_memory.vector_dbs import * # noqa: F401, F403 +from swarms_memory.utils import * # noqa: F401, F403 \ No newline at end of file diff --git a/swarms_memory/dbs/__init__.py b/swarms_memory/dbs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/swarms_memory/utils/__init__.py b/swarms_memory/utils/__init__.py new file mode 100644 index 0000000..c38a3ba --- /dev/null +++ b/swarms_memory/utils/__init__.py @@ -0,0 +1,14 @@ +from swarms_memory.utils.action_subtask import ActionSubtaskEntry + +from swarms_memory.utils.dict_internal_memory import DictInternalMemory +from swarms_memory.utils.dict_shared_memory import DictSharedMemory +from swarms_memory.utils.short_term_memory import ShortTermMemory +from swarms_memory.utils.visual_memory import VisualShortTermMemory + +__all__ = [ + "ActionSubtaskEntry", + "DictInternalMemory", + "DictSharedMemory", + "ShortTermMemory", + "VisualShortTermMemory", +] diff --git a/swarms_memory/utils/action_subtask.py b/swarms_memory/utils/action_subtask.py new file mode 100644 index 0000000..3c1d7d9 --- /dev/null +++ b/swarms_memory/utils/action_subtask.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + + +class ActionSubtaskEntry(BaseModel): + """Used to store ActionSubtask data to preserve TaskMemory pointers and context in the form of thought and action. + + Attributes: + thought: CoT thought string from the LLM. + action: ReAct action JSON string from the LLM. + answer: tool-generated and memory-processed response from Griptape. + """ + + thought: str + action: str + answer: str diff --git a/swarms_memory/utils/dict_internal_memory.py b/swarms_memory/utils/dict_internal_memory.py new file mode 100644 index 0000000..daba0b0 --- /dev/null +++ b/swarms_memory/utils/dict_internal_memory.py @@ -0,0 +1,86 @@ +import uuid +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple + + +class InternalMemoryBase(ABC): + """Abstract base class for internal memory of agents in the swarm.""" + + def __init__(self, n_entries): + """Initialize the internal memory. In the current architecture the memory always consists of a set of soltuions or evaluations. + During the operation, the agent should retrivie best solutions from it's internal memory based on the score. + + Moreover, the project is designed around LLMs for the proof of concepts, so we treat all entry content as a string. + """ + self.n_entries = n_entries + + @abstractmethod + def add(self, score, entry): + """Add an entry to the internal memory.""" + raise NotImplementedError + + @abstractmethod + def get_top_n(self, n): + """Get the top n entries from the internal memory.""" + raise NotImplementedError + + +class DictInternalMemory(InternalMemoryBase): + def __init__(self, n_entries: int): + """ + Initialize the internal memory. In the current architecture the memory always consists of a set of solutions or evaluations. + Simple key-value store for now. + + Args: + n_entries (int): The maximum number of entries to keep in the internal memory. + """ + super().__init__(n_entries) + self.data: Dict[str, Dict[str, Any]] = {} + + def add(self, score: float, content: Any) -> None: + """ + Add an entry to the internal memory. + + Args: + score (float): The score or fitness value associated with the entry. + content (Any): The content of the entry. + + Returns: + None + """ + random_key: str = str(uuid.uuid4()) + self.data[random_key] = {"score": score, "content": content} + + # keep only the best n entries + sorted_data: List[Tuple[str, Dict[str, Any]]] = sorted( + self.data.items(), + key=lambda x: x[1]["score"], + reverse=True, + ) + self.data = dict(sorted_data[: self.n_entries]) + + def get_top_n(self, n: int) -> List[Tuple[str, Dict[str, Any]]]: + """ + Get the top n entries from the internal memory. + + Args: + n (int): The number of top entries to retrieve. + + Returns: + List[Tuple[str, Dict[str, Any]]]: A list of tuples containing the random keys and corresponding entry data. + """ + sorted_data: List[Tuple[str, Dict[str, Any]]] = sorted( + self.data.items(), + key=lambda x: x[1]["score"], + reverse=True, + ) + return sorted_data[:n] + + def len(self) -> int: + """ + Get the number of entries in the internal memory. + + Returns: + int: The number of entries in the internal memory. + """ + return len(self.data) diff --git a/swarms_memory/utils/dict_shared_memory.py b/swarms_memory/utils/dict_shared_memory.py new file mode 100644 index 0000000..8ac92d1 --- /dev/null +++ b/swarms_memory/utils/dict_shared_memory.py @@ -0,0 +1,96 @@ +import datetime +import json +import os +import threading +import uuid +from pathlib import Path +from typing import Any, Dict + + +class DictSharedMemory: + """A class representing a shared memory that stores entries as a dictionary. + + Attributes: + file_loc (Path): The file location where the memory is stored. + lock (threading.Lock): A lock used for thread synchronization. + + Methods: + __init__(self, file_loc: str = None) -> None: Initializes the shared memory. + add_entry(self, score: float, agent_id: str, agent_cycle: int, entry: Any) -> bool: Adds an entry to the internal memory. + get_top_n(self, n: int) -> None: Gets the top n entries from the internal memory. + write_to_file(self, data: Dict[str, Dict[str, Any]]) -> bool: Writes the internal memory to a file. + """ + + def __init__(self, file_loc: str = None) -> None: + """Initialize the shared memory. In the current architecture the memory always consists of a set of soltuions or evaluations. + Moreover, the project is designed around LLMs for the proof of concepts, so we treat all entry content as a string. + """ + if file_loc is not None: + self.file_loc = Path(file_loc) + if not self.file_loc.exists(): + self.file_loc.touch() + + self.lock = threading.Lock() + + def add( + self, + score: float, + agent_id: str, + agent_cycle: int, + entry: Any, + ) -> bool: + """Add an entry to the internal memory.""" + with self.lock: + entry_id = str(uuid.uuid4()) + data = {} + epoch = datetime.datetime.utcfromtimestamp(0) + epoch = (datetime.datetime.utcnow() - epoch).total_seconds() + data[entry_id] = { + "agent": agent_id, + "epoch": epoch, + "score": score, + "cycle": agent_cycle, + "content": entry, + } + status = self.write_to_file(data) + self.plot_performance() + return status + + def get_top_n(self, n: int) -> None: + """Get the top n entries from the internal memory.""" + with self.lock: + with open(self.file_loc) as f: + try: + file_data = json.load(f) + except Exception as e: + file_data = {} + raise e + + sorted_data = dict( + sorted( + file_data.items(), + key=lambda item: item[1]["score"], + reverse=True, + ) + ) + top_n = dict(list(sorted_data.items())[:n]) + return top_n + + def write_to_file(self, data: Dict[str, Dict[str, Any]]) -> bool: + """Write the internal memory to a file.""" + if self.file_loc is not None: + with open(self.file_loc) as f: + try: + file_data = json.load(f) + except Exception as e: + file_data = {} + raise e + + file_data = file_data | data + with open(self.file_loc, "w") as f: + json.dump(file_data, f, indent=4) + + f.flush() + os.fsync(f.fileno()) + + return True diff --git a/swarms_memory/utils/short_term_memory.py b/swarms_memory/utils/short_term_memory.py new file mode 100644 index 0000000..1b8056c --- /dev/null +++ b/swarms_memory/utils/short_term_memory.py @@ -0,0 +1,186 @@ +import json +import logging +import threading + +class ShortTermMemory: + """Short term memory. + + Args: + return_str (bool, optional): _description_. Defaults to True. + autosave (bool, optional): _description_. Defaults to True. + *args: _description_ + **kwargs: _description_ + + + Example: + >>> from swarms.memory.short_term_memory import ShortTermMemory + >>> stm = ShortTermMemory() + >>> stm.add(role="agent", message="Hello world!") + >>> stm.add(role="agent", message="How are you?") + >>> stm.add(role="agent", message="I am fine.") + >>> stm.add(role="agent", message="How are you?") + >>> stm.add(role="agent", message="I am fine.") + + + """ + + def __init__( + self, + return_str: bool = True, + autosave: bool = True, + *args, + **kwargs, + ): + self.return_str = return_str + self.autosave = autosave + self.short_term_memory = [] + self.medium_term_memory = [] + self.lock = threading.Lock() + + def add(self, role: str = None, message: str = None, *args, **kwargs): + """Add a message to the short term memory. + + Args: + role (str, optional): _description_. Defaults to None. + message (str, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + try: + memory = self.short_term_memory.append( + {"role": role, "message": message} + ) + + return memory + except Exception as error: + print(f"Add to short term memory failed: {error}") + raise error + + def get_short_term(self): + """Get the short term memory. + + Returns: + _type_: _description_ + """ + return self.short_term_memory + + def get_medium_term(self): + """Get the medium term memory. + + Returns: + _type_: _description_ + """ + return self.medium_term_memory + + def clear_medium_term(self): + """Clear the medium term memory.""" + self.medium_term_memory = [] + + def get_short_term_memory_str(self, *args, **kwargs): + """Get the short term memory as a string.""" + return str(self.short_term_memory) + + def update_short_term( + self, index, role: str, message: str, *args, **kwargs + ): + """Update the short term memory. + + Args: + index (_type_): _description_ + role (str): _description_ + message (str): _description_ + + """ + self.short_term_memory[index] = { + "role": role, + "message": message, + } + + def clear(self): + """Clear the short term memory.""" + self.short_term_memory = [] + + def search_memory(self, term): + """Search the memory for a term. + + Args: + term (_type_): _description_ + + Returns: + _type_: _description_ + """ + results = {"short_term": [], "medium_term": []} + for i, message in enumerate(self.short_term_memory): + if term in message["message"]: + results["short_term"].append((i, message)) + for i, message in enumerate(self.medium_term_memory): + if term in message["message"]: + results["medium_term"].append((i, message)) + return results + + def return_shortmemory_as_str(self): + """Return the memory as a string. + + Returns: + _type_: _description_ + """ + return str(self.short_term_memory) + + def move_to_medium_term(self, index): + """Move a message from the short term memory to the medium term memory. + + Args: + index (_type_): _description_ + """ + message = self.short_term_memory.pop(index) + self.medium_term_memory.append(message) + + def return_medium_memory_as_str(self): + """Return the medium term memory as a string. + + Returns: + _type_: _description_ + """ + return str(self.medium_term_memory) + + def save_to_file(self, filename: str): + """Save the memory to a file. + + Args: + filename (str): _description_ + """ + try: + with self.lock: + with open(filename, "w") as f: + json.dump( + { + "short_term_memory": (self.short_term_memory), + "medium_term_memory": ( + self.medium_term_memory + ), + }, + f, + ) + + logging.info(f"Saved memory to {filename}") + except Exception as error: + print(f"Error saving memory to {filename}: {error}") + + def load_from_file(self, filename: str, *args, **kwargs): + """Load the memory from a file. + + Args: + filename (str): _description_ + """ + try: + with self.lock: + with open(filename) as f: + data = json.load(f) + self.short_term_memory = data.get("short_term_memory", []) + self.medium_term_memory = data.get( + "medium_term_memory", [] + ) + logging.info(f"Loaded memory from {filename}") + except Exception as error: + print(f"Erorr loading memory from {filename}: {error}") diff --git a/swarms_memory/utils/visual_memory.py b/swarms_memory/utils/visual_memory.py new file mode 100644 index 0000000..1361d6e --- /dev/null +++ b/swarms_memory/utils/visual_memory.py @@ -0,0 +1,118 @@ +from datetime import datetime +from typing import List + + +class VisualShortTermMemory: + """ + A class representing visual short-term memory. + + Attributes: + memory (list): A list to store images and their descriptions. + + Examples: + example = VisualShortTermMemory() + example.add( + images=["image1.jpg", "image2.jpg"], + description=["description1", "description2"], + timestamps=[1.0, 2.0], + locations=["location1", "location2"], + ) + print(example.return_as_string()) + # print(example.get_images()) + """ + + def __init__(self): + self.memory = [] + + def add( + self, + images: List[str] = None, + description: List[str] = None, + timestamps: List[float] = None, + locations: List[str] = None, + ): + """ + Add images and their descriptions to the memory. + + Args: + images (list): A list of image paths. + description (list): A list of corresponding descriptions. + timestamps (list): A list of timestamps for each image. + locations (list): A list of locations where the images were captured. + """ + current_time = datetime.now() + + # Create a dictionary of each image and description + # and append it to the memory + for image, description, timestamp, location in zip( + images, description, timestamps, locations + ): + self.memory.append( + { + "image": image, + "description": description, + "timestamp": timestamp, + "location": location, + "added_at": current_time, + } + ) + + def get_images(self): + """ + Get a list of all images in the memory. + + Returns: + list: A list of image paths. + """ + return [item["image"] for item in self.memory] + + def get_descriptions(self): + """ + Get a list of all descriptions in the memory. + + Returns: + list: A list of descriptions. + """ + return [item["description"] for item in self.memory] + + def search_by_location(self, location: str): + """ + Search for images captured at a specific location. + + Args: + location (str): The location to search for. + + Returns: + list: A list of images captured at the specified location. + """ + return [ + item["image"] + for item in self.memory + if item["location"] == location + ] + + def search_by_timestamp(self, start_time: float, end_time: float): + """ + Search for images captured within a specific time range. + + Args: + start_time (float): The start time of the range. + end_time (float): The end time of the range. + + Returns: + list: A list of images captured within the specified time range. + """ + return [ + item["image"] + for item in self.memory + if start_time <= item["timestamp"] <= end_time + ] + + def return_as_string(self): + """ + Return the memory as a string. + + Returns: + str: A string representation of the memory. + """ + return str(self.memory) diff --git a/swarms_memory/vector_dbs/__init__.py b/swarms_memory/vector_dbs/__init__.py new file mode 100644 index 0000000..eed6900 --- /dev/null +++ b/swarms_memory/vector_dbs/__init__.py @@ -0,0 +1,6 @@ +from swarms_memory.vector_dbs.chroma_db_wrapper import ChromaDB +from swarms_memory.vector_dbs.pinecone_wrapper import PineconeMemory +from swarms_memory.vector_dbs.faiss_wrapper import FAISSDB +from swarms_memory.vector_dbs.base_vectordb import BaseVectorDatabase + +__all__ = ["ChromaDB", "PineconeMemory", "FAISSDB", "BaseVectorDatabase"] \ No newline at end of file diff --git a/swarms_memory/vector_dbs/base_db.py b/swarms_memory/vector_dbs/base_db.py new file mode 100644 index 0000000..eb3e6f0 --- /dev/null +++ b/swarms_memory/vector_dbs/base_db.py @@ -0,0 +1,139 @@ +from abc import ABC, abstractmethod + + +class AbstractDatabase(ABC): + """ + Abstract base class for a database. + + This class defines the interface for interacting with a database. + Subclasses must implement the abstract methods to provide the + specific implementation details for connecting to a database, + executing queries, and performing CRUD operations. + + """ + + @abstractmethod + def connect(self): + """ + Connect to the database. + + This method establishes a connection to the database. + + """ + + @abstractmethod + def close(self): + """ + Close the database connection. + + This method closes the connection to the database. + + """ + + @abstractmethod + def execute_query(self, query): + """ + Execute a database query. + + This method executes the given query on the database. + + Parameters: + query (str): The query to be executed. + + """ + + @abstractmethod + def fetch_all(self): + """ + Fetch all rows from the result set. + + This method retrieves all rows from the result set of a query. + + Returns: + list: A list of dictionaries representing the rows. + + """ + + @abstractmethod + def fetch_one(self): + """ + Fetch one row from the result set. + + This method retrieves one row from the result set of a query. + + Returns: + dict: A dictionary representing the row. + + """ + + @abstractmethod + def add(self, table, data): + """ + Add a new record to the database. + + This method adds a new record to the specified table in the database. + + Parameters: + table (str): The name of the table. + data (dict): A dictionary representing the data to be added. + + """ + + @abstractmethod + def query(self, table, condition): + """ + Query the database. + + This method queries the specified table in the database based on the given condition. + + Parameters: + table (str): The name of the table. + condition (str): The condition to be applied in the query. + + Returns: + list: A list of dictionaries representing the query results. + + """ + + @abstractmethod + def get(self, table, id): + """ + Get a record from the database. + + This method retrieves a record from the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be retrieved. + + Returns: + dict: A dictionary representing the retrieved record. + + """ + + @abstractmethod + def update(self, table, id, data): + """ + Update a record in the database. + + This method updates a record in the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be updated. + data (dict): A dictionary representing the updated data. + + """ + + @abstractmethod + def delete(self, table, id): + """ + Delete a record from the database. + + This method deletes a record from the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be deleted. + + """ diff --git a/swarms_memory/vector_dbs/base_vectordb.py b/swarms_memory/vector_dbs/base_vectordb.py new file mode 100644 index 0000000..da16506 --- /dev/null +++ b/swarms_memory/vector_dbs/base_vectordb.py @@ -0,0 +1,148 @@ +from abc import ABC +from loguru import logger + +class BaseVectorDatabase(ABC): + """ + Abstract base class for a database. + + This class defines the interface for interacting with a database. + Subclasses must implement the abstract methods to provide the + specific implementation details for connecting to a database, + executing queries, and performing CRUD operations. + + """ + + def connect(self): + """ + Connect to the database. + + This method establishes a connection to the database. + + """ + + def close(self): + """ + Close the database connection. + + This method closes the connection to the database. + + """ + + def query(self, query: str): + """ + Execute a database query. + + This method executes the given query on the database. + + Parameters: + query (str): The query to be executed. + + """ + + def fetch_all(self): + """ + Fetch all rows from the result set. + + This method retrieves all rows from the result set of a query. + + Returns: + list: A list of dictionaries representing the rows. + + """ + + def fetch_one(self): + """ + Fetch one row from the result set. + + This method retrieves one row from the result set of a query. + + Returns: + dict: A dictionary representing the row. + + """ + + def add(self, doc: str): + """ + Add a new record to the database. + + This method adds a new record to the specified table in the database. + + Parameters: + table (str): The name of the table. + data (dict): A dictionary representing the data to be added. + + """ + + def get(self, query: str): + """ + Get a record from the database. + + This method retrieves a record from the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be retrieved. + + Returns: + dict: A dictionary representing the retrieved record. + + """ + + def update(self, doc): + """ + Update a record in the database. + + This method updates a record in the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be updated. + data (dict): A dictionary representing the updated data. + + """ + + def delete(self, message): + """ + Delete a record from the database. + + This method deletes a record from the specified table in the database based on the given ID. + + Parameters: + table (str): The name of the table. + id (int): The ID of the record to be deleted. + + """ + + def print_all(self): + """ + Print all records in the database. + + This method prints all records in the specified table in the database. + + """ + pass + + def log_query(self, query: str = None): + """ + Log the query. + + This method logs the query that was executed on the database. + + Parameters: + query (str): The query that was executed. + + """ + logger.info(f"Query: {query}") + + def log_retrieved_data(self, data: list = None): + """ + Log the retrieved data. + + This method logs the data that was retrieved from the database. + + Parameters: + data (dict): The data that was retrieved. + + """ + for d in data: + logger.info(f"Retrieved Data: {d}") diff --git a/swarms_memory/chroma_db_wrapper.py b/swarms_memory/vector_dbs/chroma_db_wrapper.py similarity index 92% rename from swarms_memory/chroma_db_wrapper.py rename to swarms_memory/vector_dbs/chroma_db_wrapper.py index 6645730..6521374 100644 --- a/swarms_memory/chroma_db_wrapper.py +++ b/swarms_memory/vector_dbs/chroma_db_wrapper.py @@ -1,4 +1,3 @@ -import logging import os import uuid from typing import Optional @@ -8,7 +7,6 @@ from loguru import logger from swarms.memory.base_vectordb import BaseVectorDatabase from swarms.utils.data_to_text import data_to_text -from swarms.utils.markdown_message import display_markdown_message # Load environment variables load_dotenv() @@ -59,10 +57,6 @@ def __init__( self.docs_folder = docs_folder self.verbose = verbose - # Disable ChromaDB logging - if verbose: - logging.getLogger("chromadb").setLevel(logging.INFO) - # Create Chroma collection chroma_persist_dir = "chroma" chroma_client = chromadb.PersistentClient( @@ -83,7 +77,7 @@ def __init__( *args, **kwargs, ) - display_markdown_message( + logger.info( "ChromaDB collection created:" f" {self.collection.name} with metric: {self.metric} and" f" output directory: {self.output_dir}" @@ -91,7 +85,7 @@ def __init__( # If docs if docs_folder: - display_markdown_message( + logger.info( f"Traversing directory: {docs_folder}" ) self.traverse_directory() @@ -144,7 +138,7 @@ def query( dict: The retrieved documents. """ try: - logging.info(f"Querying documents for: {query_text}") + logger.info(f"Querying documents for: {query_text}") docs = self.collection.query( query_texts=[query_text], n_results=self.n_results, @@ -158,8 +152,8 @@ def query( out += f"{doc}\n" # Display the retrieved document - display_markdown_message(f"Query: {query_text}") - display_markdown_message(f"Retrieved Document: {out}") + logger.info(f"Query: {query_text}") + logger.info(f"Retrieved Document: {out}") return out except Exception as e: diff --git a/swarms_memory/faiss_wrapper.py b/swarms_memory/vector_dbs/faiss_wrapper.py similarity index 100% rename from swarms_memory/faiss_wrapper.py rename to swarms_memory/vector_dbs/faiss_wrapper.py diff --git a/swarms_memory/pinecone_wrapper.py b/swarms_memory/vector_dbs/pinecone_wrapper.py similarity index 100% rename from swarms_memory/pinecone_wrapper.py rename to swarms_memory/vector_dbs/pinecone_wrapper.py diff --git a/tests/utils/test_dictinternalmemory.py b/tests/utils/test_dictinternalmemory.py new file mode 100644 index 0000000..bbbfad2 --- /dev/null +++ b/tests/utils/test_dictinternalmemory.py @@ -0,0 +1,71 @@ +# DictInternalMemory + +from uuid import uuid4 + +import pytest + +from swarms_memory import DictInternalMemory + +# Example of an extensive suite of tests for DictInternalMemory. + + +# Fixture for repeatedly initializing the class with different numbers of entries. +@pytest.fixture(params=[1, 5, 10, 100]) +def memory(request): + return DictInternalMemory(n_entries=request.param) + + +# Basic Tests +def test_initialization(memory): + assert memory.len() == 0 + + +def test_single_add(memory): + memory.add(10, {"data": "test"}) + assert memory.len() == 1 + + +def test_memory_limit_enforced(memory): + entries_to_add = memory.n_entries + 10 + for i in range(entries_to_add): + memory.add(i, {"data": f"test{i}"}) + assert memory.len() == memory.n_entries + + +# Parameterized Tests +@pytest.mark.parametrize( + "scores, best_score", [([10, 5, 3], 10), ([1, 2, 3], 3)] +) +def test_get_top_n(scores, best_score, memory): + for score in scores: + memory.add(score, {"data": f"test{score}"}) + top_entry = memory.get_top_n(1) + assert top_entry[0][1]["score"] == best_score + + +# Exception Testing +@pytest.mark.parametrize("invalid_n", [-1, 0]) +def test_invalid_n_entries_raises_exception(invalid_n): + with pytest.raises(ValueError): + DictInternalMemory(invalid_n) + + +# Mocks and Monkeypatching +def test_add_with_mocked_uuid4(monkeypatch, memory): + # Mock the uuid4 function to return a known value + class MockUUID: + hex = "1234abcd" + + monkeypatch.setattr(uuid4, "__str__", lambda: MockUUID.hex) + memory.add(20, {"data": "mock_uuid"}) + assert MockUUID.hex in memory.data + + +# Test using Mocks to simulate I/O or external interactions here +# ... + +# More tests to hit edge cases, concurrency issues, etc. +# ... + +# Tests for concurrency issues, if relevant +# ... diff --git a/tests/utils/test_dictsharedmemory.py b/tests/utils/test_dictsharedmemory.py new file mode 100644 index 0000000..ee945e3 --- /dev/null +++ b/tests/utils/test_dictsharedmemory.py @@ -0,0 +1,88 @@ +import os +import tempfile + +import pytest + +from swarms_memory import DictSharedMemory + +# Utility functions or fixtures might come first + + +@pytest.fixture +def memory_file(): + with tempfile.NamedTemporaryFile("w+", delete=False) as tmp_file: + yield tmp_file.name + os.unlink(tmp_file.name) + + +@pytest.fixture +def memory_instance(memory_file): + return DictSharedMemory(file_loc=memory_file) + + +# Basic tests + + +def test_init(memory_file): + memory = DictSharedMemory(file_loc=memory_file) + assert os.path.exists( + memory.file_loc + ), "Memory file should be created if non-existent" + + +def test_add_entry(memory_instance): + success = memory_instance.add(9.5, "agent123", 1, "Test Entry") + assert success, "add_entry should return True on success" + + +def test_add_entry_thread_safety(memory_instance): + # We could create multiple threads to test the thread safety of the add_entry method + pass + + +def test_get_top_n(memory_instance): + memory_instance.add(9.5, "agent123", 1, "Entry A") + memory_instance.add(8.5, "agent124", 1, "Entry B") + top_1 = memory_instance.get_top_n(1) + assert ( + len(top_1) == 1 + ), "get_top_n should return the correct number of top entries" + + +# Parameterized tests + + +@pytest.mark.parametrize( + "scores, agent_ids, expected_top_score", + [ + ([1.0, 2.0, 3.0], ["agent1", "agent2", "agent3"], 3.0), + # add more test cases + ], +) +def test_parametrized_get_top_n( + memory_instance, scores, agent_ids, expected_top_score +): + for score, agent_id in zip(scores, agent_ids): + memory_instance.add(score, agent_id, 1, f"Entry by {agent_id}") + top_1 = memory_instance.get_top_n(1) + top_score = next(iter(top_1.values()))["score"] + assert ( + top_score == expected_top_score + ), "get_top_n should return the entry with top score" + + +# Exception testing + + +def test_add_entry_invalid_input(memory_instance): + with pytest.raises(ValueError): + memory_instance.add("invalid_score", "agent123", 1, "Test Entry") + + +# Mocks and monkey-patching + + +def test_write_fails_due_to_permissions(memory_instance, mocker): + mocker.patch("builtins.open", side_effect=PermissionError) + with pytest.raises(PermissionError): + memory_instance.add(9.5, "agent123", 1, "Test Entry") diff --git a/tests/utils/test_langchainchromavectormemory.py b/tests/utils/test_langchainchromavectormemory.py new file mode 100644 index 0000000..35c12c2 --- /dev/null +++ b/tests/utils/test_langchainchromavectormemory.py @@ -0,0 +1,92 @@ +# LangchainChromaVectorMemory + +from unittest.mock import MagicMock, patch + +import pytest + +from swarms_memory import LangchainChromaVectorMemory + + +# Fixtures for setting up the memory and mocks +@pytest.fixture() +def vector_memory(tmp_path): + loc = tmp_path / "vector_memory" + return LangchainChromaVectorMemory(loc=loc) + + +@pytest.fixture() +def embeddings_mock(): + with patch("swarms_memory.OpenAIEmbeddings") as mock: + yield mock + + +@pytest.fixture() +def chroma_mock(): + with patch("swarms_memory.Chroma") as mock: + yield mock + + +@pytest.fixture() +def qa_mock(): + with patch("swarms_memory.RetrievalQA") as mock: + yield mock + + +# Example test cases +def test_initialization_default_settings(vector_memory): + assert vector_memory.chunk_size == 1000 + assert ( + vector_memory.chunk_overlap == 100 + ) # assuming default overlap of 0.1 + assert vector_memory.loc.exists() + + +def test_add_entry(vector_memory, embeddings_mock): + with patch.object(vector_memory.db, "add_texts") as add_texts_mock: + vector_memory.add("Example text") + add_texts_mock.assert_called() + + +def test_search_memory_returns_list(vector_memory): + result = vector_memory.search_memory("example query", k=5) + assert isinstance(result, list) + + +def test_ask_question_returns_string(vector_memory, qa_mock): + result = vector_memory.query("What is the color of the sky?") + assert isinstance(result, str) + + +@pytest.mark.parametrize( + "query,k,type,expected", + [ + ("example query", 5, "mmr", [MagicMock()]), + ( + "example query", + 0, + "mmr", + None, + ), # Expected none when k is 0 or negative + ( + "example query", + 3, + "cos", + [MagicMock()], + ), # Mocked object as a placeholder + ], +) +def test_search_memory_different_params( + vector_memory, query, k, type, expected +): + with patch.object( + vector_memory.db, + "max_marginal_relevance_search", + return_value=expected, + ): + with patch.object( + vector_memory.db, + "similarity_search_with_score", + return_value=expected, + ): + result = vector_memory.search_memory(query, k=k, type=type) + assert len(result) == (k if k > 0 else 0) diff --git a/tests/utils/test_pinecone.py b/tests/utils/test_pinecone.py new file mode 100644 index 0000000..b9f2cf8 --- /dev/null +++ b/tests/utils/test_pinecone.py @@ -0,0 +1,82 @@ +import os +from unittest.mock import patch + +from examples.memory.pinecone import PineconeDB + +api_key = os.getenv("PINECONE_API_KEY") or "" + + +def test_init(): + with patch("pinecone.init") as MockInit, patch( + "pinecone.Index" + ) as MockIndex: + store = PineconeDB( + api_key=api_key, + index_name="test_index", + environment="test_env", + ) + MockInit.assert_called_once() + MockIndex.assert_called_once() + assert store.index == MockIndex.return_value + + +def test_upsert_vector(): + with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: + store = PineconeDB( + api_key=api_key, + index_name="test_index", + environment="test_env", + ) + store.upsert_vector( + [1.0, 2.0, 3.0], + "test_id", + "test_namespace", + {"meta": "data"}, + ) + MockIndex.return_value.upsert.assert_called() + + +def test_load_entry(): + with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: + store = PineconeDB( + api_key=api_key, + index_name="test_index", + environment="test_env", + ) + store.load_entry("test_id", "test_namespace") + MockIndex.return_value.fetch.assert_called() + + +def test_load_entries(): + with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: + store = PineconeDB( + api_key=api_key, + index_name="test_index", + environment="test_env", + ) + store.load_entries("test_namespace") + MockIndex.return_value.query.assert_called() + + +def test_query(): + with patch("pinecone.init"), patch("pinecone.Index") as MockIndex: + store = PineconeDB( + api_key=api_key, + index_name="test_index", + environment="test_env", + ) + store.query("test_query", 10, "test_namespace") + MockIndex.return_value.query.assert_called() + + +def test_create_index(): + with patch("pinecone.init"), patch("pinecone.Index"), patch( + "pinecone.create_index" + ) as MockCreateIndex: + store = PineconeDB( + api_key=api_key, + index_name="test_index", + environment="test_env", + ) + store.create_index("test_index") + MockCreateIndex.assert_called() diff --git a/tests/utils/test_pq_db.py b/tests/utils/test_pq_db.py new file mode 100644 index 0000000..2fbffb6 --- /dev/null +++ b/tests/utils/test_pq_db.py @@ -0,0 +1,80 @@ +import os +from unittest.mock import patch + +from dotenv import load_dotenv + +from examples.memory.pg import PostgresDB + +load_dotenv() + +PSG_CONNECTION_STRING = os.getenv("PSG_CONNECTION_STRING") + + +def test_init(): + with patch("sqlalchemy.create_engine") as MockEngine: + db = PostgresDB( + connection_string=PSG_CONNECTION_STRING, + table_name="test", + ) + MockEngine.assert_called_once() + assert db.engine == MockEngine.return_value + + +def test_create_vector_model(): + with patch("sqlalchemy.create_engine"): + db = PostgresDB( + connection_string=PSG_CONNECTION_STRING, + table_name="test", + ) + model = db._create_vector_model() + assert model.__tablename__ == "test" + + +def test_add_or_update_vector(): + with patch("sqlalchemy.create_engine"), patch( + "sqlalchemy.orm.Session" + ) as MockSession: + db = PostgresDB( + connection_string=PSG_CONNECTION_STRING, + table_name="test", + ) + db.add_or_update_vector( + "test_vector", + "test_id", + "test_namespace", + {"meta": "data"}, + ) + MockSession.assert_called() + MockSession.return_value.merge.assert_called() + MockSession.return_value.commit.assert_called() + + +def test_query_vectors(): + with patch("sqlalchemy.create_engine"), patch( + "sqlalchemy.orm.Session" + ) as MockSession: + db = PostgresDB( + connection_string=PSG_CONNECTION_STRING, + table_name="test", + ) + db.query_vectors("test_query", "test_namespace") + MockSession.assert_called() + MockSession.return_value.query.assert_called() + MockSession.return_value.query.return_value.filter_by.assert_called() + MockSession.return_value.query.return_value.filter.assert_called() + MockSession.return_value.query.return_value.all.assert_called() + + +def test_delete_vector(): + with patch("sqlalchemy.create_engine"), patch( + "sqlalchemy.orm.Session" + ) as MockSession: + db = PostgresDB( + connection_string=PSG_CONNECTION_STRING, + table_name="test", + ) + db.delete_vector("test_id") + MockSession.assert_called() + MockSession.return_value.get.assert_called() + MockSession.return_value.delete.assert_called() + MockSession.return_value.commit.assert_called() diff --git a/tests/utils/test_short_term_memory.py b/tests/utils/test_short_term_memory.py new file mode 100644 index 0000000..39e78e5 --- /dev/null +++ b/tests/utils/test_short_term_memory.py @@ -0,0 +1,130 @@ +import threading + +from swarms_memory import ShortTermMemory + + +def test_init(): + memory = ShortTermMemory() + assert memory.short_term_memory == [] + assert memory.medium_term_memory == [] + + +def test_add(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + assert memory.short_term_memory == [ + {"role": "user", "message": "Hello, world!"} + ] + + +def test_get_short_term(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + assert memory.get_short_term() == [ + {"role": "user", "message": "Hello, world!"} + ] + + +def test_get_medium_term(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + memory.move_to_medium_term(0) + assert memory.get_medium_term() == [ + {"role": "user", "message": "Hello, world!"} + ] + + +def test_clear_medium_term(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + memory.move_to_medium_term(0) + memory.clear_medium_term() + assert memory.get_medium_term() == [] + + +def test_get_short_term_memory_str(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + assert ( + memory.get_short_term_memory_str() + == "[{'role': 'user', 'message': 'Hello, world!'}]" + ) + + +def test_update_short_term(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + memory.update_short_term(0, "user", "Goodbye, world!") + assert memory.get_short_term() == [ + {"role": "user", "message": "Goodbye, world!"} + ] + + +def test_clear(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + memory.clear() + assert memory.get_short_term() == [] + + +def test_search_memory(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + assert memory.search_memory("Hello") == { + "short_term": [(0, {"role": "user", "message": "Hello, world!"})], + "medium_term": [], + } + + +def test_return_shortmemory_as_str(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + assert ( + memory.return_shortmemory_as_str() + == "[{'role': 'user', 'message': 'Hello, world!'}]" + ) + + +def test_move_to_medium_term(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + memory.move_to_medium_term(0) + assert memory.get_medium_term() == [ + {"role": "user", "message": "Hello, world!"} + ] + assert memory.get_short_term() == [] + + +def test_return_medium_memory_as_str(): + memory = ShortTermMemory() + memory.add("user", "Hello, world!") + memory.move_to_medium_term(0) + assert ( + memory.return_medium_memory_as_str() + == "[{'role': 'user', 'message': 'Hello, world!'}]" + ) + + +def test_thread_safety(): + memory = ShortTermMemory() + + def add_messages(): + for _ in range(1000): + memory.add("user", "Hello, world!") + + threads = [threading.Thread(target=add_messages) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + assert len(memory.get_short_term()) == 10000 + + +def test_save_and_load(): + memory1 = ShortTermMemory() + memory1.add("user", "Hello, world!") + memory1.save_to_file("memory.json") + memory2 = ShortTermMemory() + memory2.load_from_file("memory.json") + assert memory1.get_short_term() == memory2.get_short_term() + assert memory1.get_medium_term() == memory2.get_medium_term() diff --git a/tests/utils/test_sqlite.py b/tests/utils/test_sqlite.py new file mode 100644 index 0000000..5f36ca3 --- /dev/null +++ b/tests/utils/test_sqlite.py @@ -0,0 +1,100 @@ +import sqlite3 + +import pytest + +from examples import SQLiteDB + + +@pytest.fixture +def db(): + conn = sqlite3.connect(":memory:") + conn.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + conn.commit() + return SQLiteDB(":memory:") + + +def test_add(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + result = db.query("SELECT * FROM test") + assert result == [(1, "test")] + + +def test_delete(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + db.delete("DELETE FROM test WHERE name = ?", ("test",)) + result = db.query("SELECT * FROM test") + assert result == [] + + +def test_update(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + db.update("UPDATE test SET name = ? WHERE name = ?", ("new", "test")) + result = db.query("SELECT * FROM test") + assert result == [(1, "new")] + + +def test_query(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + result = db.query("SELECT * FROM test WHERE name = ?", ("test",)) + assert result == [(1, "test")] + + +def test_execute_query(db): + db.add("INSERT INTO test (name) VALUES (?)", ("test",)) + result = db.execute_query( + "SELECT * FROM test WHERE name = ?", ("test",) + ) + assert result == [(1, "test")] + + +def test_add_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.add("INSERT INTO test (name) VALUES (?)") + + +def test_delete_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.delete("DELETE FROM test WHERE name = ?") + + +def test_update_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.update("UPDATE test SET name = ? WHERE name = ?") + + +def test_query_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.query("SELECT * FROM test WHERE name = ?") + + +def test_execute_query_without_params(db): + with pytest.raises(sqlite3.ProgrammingError): + db.execute_query("SELECT * FROM test WHERE name = ?") + + +def test_add_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.add("INSERT INTO wrong (name) VALUES (?)", ("test",)) + + +def test_delete_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.delete("DELETE FROM wrong WHERE name = ?", ("test",)) + + +def test_update_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.update( + "UPDATE wrong SET name = ? WHERE name = ?", + ("new", "test"), + ) + + +def test_query_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.query("SELECT * FROM wrong WHERE name = ?", ("test",)) + + +def test_execute_query_with_wrong_query(db): + with pytest.raises(sqlite3.OperationalError): + db.execute_query("SELECT * FROM wrong WHERE name = ?", ("test",)) diff --git a/tests/test_chromadb.py b/tests/vector_dbs/test_chromadb.py similarity index 96% rename from tests/test_chromadb.py rename to tests/vector_dbs/test_chromadb.py index 7504305..1cf6f5e 100644 --- a/tests/test_chromadb.py +++ b/tests/vector_dbs/test_chromadb.py @@ -1,5 +1,5 @@ from unittest.mock import patch, MagicMock -from swarms_memory.chroma_db_wrapper import ChromaDB +from swarms_memory.vector_dbs.chroma_db_wrapper import ChromaDB @patch("chromadb.PersistentClient") diff --git a/tests/test_pinecone.py b/tests/vector_dbs/test_pinecone.py similarity index 96% rename from tests/test_pinecone.py rename to tests/vector_dbs/test_pinecone.py index dd80da8..1617078 100644 --- a/tests/test_pinecone.py +++ b/tests/vector_dbs/test_pinecone.py @@ -1,5 +1,5 @@ from unittest.mock import patch -from swarms_memory.pinecone_wrapper import PineconeMemory +from swarms_memory.vector_dbs.pinecone_wrapper import PineconeMemory @patch("pinecone.init") diff --git a/tests/vector_dbs/test_qdrant.py b/tests/vector_dbs/test_qdrant.py new file mode 100644 index 0000000..caa940e --- /dev/null +++ b/tests/vector_dbs/test_qdrant.py @@ -0,0 +1,52 @@ +from unittest.mock import Mock, patch + +import pytest + +from swarms_memory import Qdrant + + +@pytest.fixture +def mock_qdrant_client(): + with patch("swarms_memory.Qdrant") as MockQdrantClient: + yield MockQdrantClient() + + +@pytest.fixture +def mock_sentence_transformer(): + with patch( + "sentence_transformers.SentenceTransformer" + ) as MockSentenceTransformer: + yield MockSentenceTransformer() + + +@pytest.fixture +def qdrant_client(mock_qdrant_client, mock_sentence_transformer): + client = Qdrant(api_key="your_api_key", host="your_host") + yield client + + +def test_qdrant_init(qdrant_client, mock_qdrant_client): + assert qdrant_client.client is not None + + +def test_load_embedding_model(qdrant_client, mock_sentence_transformer): + qdrant_client._load_embedding_model("model_name") + mock_sentence_transformer.assert_called_once_with("model_name") + + +def test_setup_collection(qdrant_client, mock_qdrant_client): + qdrant_client._setup_collection() + mock_qdrant_client.get_collection.assert_called_once_with( + qdrant_client.collection_name + ) + + +def test_add_vectors(qdrant_client, mock_qdrant_client): + mock_doc = Mock(page_content="Sample text") + qdrant_client.add_vectors([mock_doc]) + mock_qdrant_client.upsert.assert_called_once() + + +def test_search_vectors(qdrant_client, mock_qdrant_client): + qdrant_client.search_vectors("test query") + mock_qdrant_client.search.assert_called_once() diff --git a/tests/vector_dbs/test_weaviate.py b/tests/vector_dbs/test_weaviate.py new file mode 100644 index 0000000..93e9aaf --- /dev/null +++ b/tests/vector_dbs/test_weaviate.py @@ -0,0 +1,192 @@ +from unittest.mock import Mock, patch + +import pytest + +from swarms_memory import WeaviateDB + + +# Define fixture for a WeaviateDB instance with mocked methods +@pytest.fixture +def weaviate_client_mock(): + client = WeaviateDB( + http_host="mock_host", + http_port="mock_port", + http_secure=False, + grpc_host="mock_grpc_host", + grpc_port="mock_grpc_port", + grpc_secure=False, + auth_client_secret="mock_api_key", + additional_headers={"X-OpenAI-Api-Key": "mock_openai_api_key"}, + additional_config=Mock(), + ) + + # Mock the methods + client.client.collections.create = Mock() + client.client.collections.get = Mock() + client.client.collections.query = Mock() + client.client.collections.data.insert = Mock() + client.client.collections.data.update = Mock() + client.client.collections.data.delete_by_id = Mock() + + return client + + +# Define tests for the WeaviateDB class +def test_create_collection(weaviate_client_mock): + # Test creating a collection + weaviate_client_mock.create_collection( + "test_collection", [{"name": "property"}] + ) + weaviate_client_mock.client.collections.create.assert_called_with( + name="test_collection", + vectorizer_config=None, + properties=[{"name": "property"}], + ) + + +def test_add_object(weaviate_client_mock): + # Test adding an object + properties = {"name": "John"} + weaviate_client_mock.add("test_collection", properties) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.data.insert.assert_called_with( + properties + ) + + +def test_query_objects(weaviate_client_mock): + # Test querying objects + query = "name:John" + weaviate_client_mock.query("test_collection", query) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.query.bm25.assert_called_with( + query=query, limit=10 + ) + + +def test_update_object(weaviate_client_mock): + # Test updating an object + object_id = "12345" + properties = {"name": "Jane"} + weaviate_client_mock.update("test_collection", object_id, properties) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.data.update.assert_called_with( + object_id, properties + ) + + +def test_delete_object(weaviate_client_mock): + # Test deleting an object + object_id = "12345" + weaviate_client_mock.delete("test_collection", object_id) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.data.delete_by_id.assert_called_with( + object_id + ) + + +def test_create_collection_with_vectorizer_config( + weaviate_client_mock, +): + # Test creating a collection with vectorizer configuration + vectorizer_config = {"config_key": "config_value"} + weaviate_client_mock.create_collection( + "test_collection", [{"name": "property"}], vectorizer_config + ) + weaviate_client_mock.client.collections.create.assert_called_with( + name="test_collection", + vectorizer_config=vectorizer_config, + properties=[{"name": "property"}], + ) + + +def test_query_objects_with_limit(weaviate_client_mock): + # Test querying objects with a specified limit + query = "name:John" + limit = 20 + weaviate_client_mock.query("test_collection", query, limit) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.query.bm25.assert_called_with( + query=query, limit=limit + ) + + +def test_query_objects_without_limit(weaviate_client_mock): + # Test querying objects without specifying a limit + query = "name:John" + weaviate_client_mock.query("test_collection", query) + weaviate_client_mock.client.collections.get.assert_called_with( + "test_collection" + ) + weaviate_client_mock.client.collections.query.bm25.assert_called_with( + query=query, limit=10 + ) + + +def test_create_collection_failure(weaviate_client_mock): + # Test failure when creating a collection + with patch( + "weaviate_client.weaviate.collections.create", + side_effect=Exception("Create error"), + ): + with pytest.raises(Exception, match="Error creating collection"): + weaviate_client_mock.create_collection( + "test_collection", [{"name": "property"}] + ) + + +def test_add_object_failure(weaviate_client_mock): + # Test failure when adding an object + properties = {"name": "John"} + with patch( + "weaviate_client.weaviate.collections.data.insert", + side_effect=Exception("Insert error"), + ): + with pytest.raises(Exception, match="Error adding object"): + weaviate_client_mock.add("test_collection", properties) + + +def test_query_objects_failure(weaviate_client_mock): + # Test failure when querying objects + query = "name:John" + with patch( + "weaviate_client.weaviate.collections.query.bm25", + side_effect=Exception("Query error"), + ): + with pytest.raises(Exception, match="Error querying objects"): + weaviate_client_mock.query("test_collection", query) + + +def test_update_object_failure(weaviate_client_mock): + # Test failure when updating an object + object_id = "12345" + properties = {"name": "Jane"} + with patch( + "weaviate_client.weaviate.collections.data.update", + side_effect=Exception("Update error"), + ): + with pytest.raises(Exception, match="Error updating object"): + weaviate_client_mock.update( + "test_collection", object_id, properties + ) + + +def test_delete_object_failure(weaviate_client_mock): + # Test failure when deleting an object + object_id = "12345" + with patch( + "weaviate_client.weaviate.collections.data.delete_by_id", + side_effect=Exception("Delete error"), + ): + with pytest.raises(Exception, match="Error deleting object"): + weaviate_client_mock.delete("test_collection", object_id)