Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Jul 8, 2024
1 parent e9dd002 commit 515b519
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 45 deletions.
80 changes: 80 additions & 0 deletions pinecome_wrapper_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import List, Dict, Any
from swarms_memory.pinecone_wrapper import PineconeMemory


# Example usage
if __name__ == "__main__":
from transformers import AutoTokenizer, AutoModel
import torch

# Custom embedding function using a HuggingFace model
def custom_embedding_function(text: str) -> List[float]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
)
with torch.no_grad():
outputs = model(**inputs)
embeddings = (
outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
)
return embeddings

# Custom preprocessing function
def custom_preprocess(text: str) -> str:
return text.lower().strip()

# Custom postprocessing function
def custom_postprocess(
results: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
for result in results:
result["custom_score"] = (
result["score"] * 2
) # Example modification
return results

# Initialize the wrapper with custom functions
wrapper = PineconeMemory(
api_key="your-api-key",
environment="your-environment",
index_name="your-index-name",
embedding_function=custom_embedding_function,
preprocess_function=custom_preprocess,
postprocess_function=custom_postprocess,
logger_config={
"handlers": [
{
"sink": "custom_rag_wrapper.log",
"rotation": "1 GB",
},
{
"sink": lambda msg: print(
f"Custom log: {msg}", end=""
)
},
],
},
)

# Adding documents
wrapper.add(
"This is a sample document about artificial intelligence.",
{"category": "AI"},
)
wrapper.add(
"Python is a popular programming language for data science.",
{"category": "Programming"},
)

# Querying
results = wrapper.query("What is AI?", filter={"category": "AI"})
for result in results:
print(
f"Score: {result['score']}, Custom Score: {result['custom_score']}, Text: {result['metadata']['text']}"
)
172 changes: 127 additions & 45 deletions swarms_memory/pinecone_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Callable, Optional
import pinecone
from loguru import logger
from sentence_transformers import SentenceTransformer
from swarms.memory.base_vectordb import BaseVectorDatabase

class PineconeMemory(BaseVectorDatabase):

class PineconeMemory:
"""
A wrapper class for Pinecone-based Retrieval-Augmented Generation (RAG) system.
A highly customizable wrapper class for Pinecone-based Retrieval-Augmented Generation (RAG) system.
This class provides methods to add documents to the Pinecone index and query the index
for similar documents.
for similar documents. It allows for custom embedding models, preprocessing functions,
and other customizations.
"""

def __init__(self, api_key: str, environment: str, index_name: str, dimension: int = 768):
def __init__(
self,
api_key: str,
environment: str,
index_name: str,
dimension: int = 768,
embedding_model: Optional[Any] = None,
embedding_function: Optional[
Callable[[str], List[float]]
] = None,
preprocess_function: Optional[Callable[[str], str]] = None,
postprocess_function: Optional[
Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]]
] = None,
metric: str = "cosine",
pod_type: str = "p1",
namespace: str = "",
logger_config: Optional[Dict[str, Any]] = None,
):
"""
Initialize the PineconeMemory.
Expand All @@ -21,75 +40,138 @@ def __init__(self, api_key: str, environment: str, index_name: str, dimension: i
environment (str): Pinecone environment.
index_name (str): Name of the Pinecone index to use.
dimension (int): Dimension of the document embeddings. Defaults to 768.
embedding_model (Optional[Any]): Custom embedding model. Defaults to None.
embedding_function (Optional[Callable]): Custom embedding function. Defaults to None.
preprocess_function (Optional[Callable]): Custom preprocessing function. Defaults to None.
postprocess_function (Optional[Callable]): Custom postprocessing function. Defaults to None.
metric (str): Distance metric for Pinecone index. Defaults to 'cosine'.
pod_type (str): Pinecone pod type. Defaults to 'p1'.
namespace (str): Pinecone namespace. Defaults to ''.
logger_config (Optional[Dict]): Configuration for the logger. Defaults to None.
"""
self._setup_logger(logger_config)
logger.info("Initializing PineconeMemory")

pinecone.init(api_key=api_key, environment=environment)

if index_name not in pinecone.list_indexes():
logger.info(f"Creating new Pinecone index: {index_name}")
pinecone.create_index(index_name, dimension=dimension)

pinecone.create_index(
index_name,
dimension=dimension,
metric=metric,
pod_type=pod_type,
)

self.index = pinecone.Index(index_name)
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.namespace = namespace

self.embedding_model = embedding_model or SentenceTransformer(
"all-MiniLM-L6-v2"
)
self.embedding_function = (
embedding_function or self._default_embedding_function
)
self.preprocess_function = (
preprocess_function or self._default_preprocess_function
)
self.postprocess_function = (
postprocess_function or self._default_postprocess_function
)

logger.info("PineconeMemory initialized successfully")

def add(self, doc: str) -> None:
def _setup_logger(self, config: Optional[Dict[str, Any]] = None):
"""Set up the logger with the given configuration."""
default_config = {
"handlers": [
{"sink": "rag_wrapper.log", "rotation": "500 MB"},
{"sink": lambda msg: print(msg, end="")},
],
}
logger.configure(**(config or default_config))

def _default_embedding_function(self, text: str) -> List[float]:
"""Default embedding function using the SentenceTransformer model."""
return self.embedding_model.encode(text).tolist()

def _default_preprocess_function(self, text: str) -> str:
"""Default preprocessing function."""
return text.strip()

def _default_postprocess_function(
self, results: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Default postprocessing function."""
return results

def add(
self, doc: str, metadata: Optional[Dict[str, Any]] = None
) -> None:
"""
Add a document to the Pinecone index.
Args:
doc (str): The document to be added.
metadata (Optional[Dict[str, Any]]): Additional metadata for the document.
Returns:
None
"""
logger.info(f"Adding document: {doc[:50]}...")
embedding = self.model.encode(doc).tolist()
processed_doc = self.preprocess_function(doc)
embedding = self.embedding_function(processed_doc)
id = str(abs(hash(doc)))
self.index.upsert([(id, embedding, {"text": doc})])
metadata = metadata or {}
metadata["text"] = processed_doc
self.index.upsert(
vectors=[(id, embedding, metadata)],
namespace=self.namespace,
)
logger.success(f"Document added successfully with ID: {id}")

def query(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
def query(
self,
query: str,
top_k: int = 5,
filter: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Query the Pinecone index for similar documents.
Args:
query (str): The query string.
top_k (int): The number of top results to return. Defaults to 5.
filter (Optional[Dict[str, Any]]): Metadata filter for the query.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the top_k most similar documents.
"""
logger.info(f"Querying with: {query}")
query_embedding = self.model.encode(query).tolist()
results = self.index.query(query_embedding, top_k=top_k, include_metadata=True)

processed_query = self.preprocess_function(query)
query_embedding = self.embedding_function(processed_query)
results = self.index.query(
vector=query_embedding,
top_k=top_k,
include_metadata=True,
namespace=self.namespace,
filter=filter,
)

formatted_results = []
for match in results.matches:
formatted_results.append({
"id": match.id,
"score": match.score,
"text": match.metadata["text"]
})

logger.success(f"Query completed. Found {len(formatted_results)} results.")
return formatted_results

# # Example usage
# if __name__ == "__main__":
# logger.add("rag_wrapper.log", rotation="500 MB")

# wrapper = PineconeMemory(
# api_key="your-api-key",
# environment="your-environment",
# index_name="your-index-name"
# )

# # Adding documents
# wrapper.add("This is a sample document about artificial intelligence.")
# wrapper.add("Python is a popular programming language for data science.")

# # Querying
# results = wrapper.query("What is AI?")
# for result in results:
# print(f"Score: {result['score']}, Text: {result['text']}")
formatted_results.append(
{
"id": match.id,
"score": match.score,
"metadata": match.metadata,
}
)

processed_results = self.postprocess_function(
formatted_results
)
logger.success(
f"Query completed. Found {len(processed_results)} results."
)
return processed_results

0 comments on commit 515b519

Please sign in to comment.