Skip to content

Commit

Permalink
simplify global usage and remove embed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tylertitsworth committed Dec 31, 2023
1 parent ba96284 commit 66c8197
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 93 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,12 @@ Choose a new [Filetype Document Loader](https://python.langchain.com/docs/module
```python
Document = namedtuple("Document", ["page_content", "metadata"])
merged_documents = []
for wiki, _ in mediawikis.items():

for dump, _ in wiki.mediawikis.items():
# https://python.langchain.com/docs/integrations/document_loaders/mediawikidump
loader = MWDumpLoader(
encoding="utf-8",
file_path=f"{source}/{wiki}_pages_current.xml",
file_path=f"{wiki.source}/{dump}_pages_current.xml",
# https://www.mediawiki.org/wiki/Help:Namespaces
namespaces=[0],
skip_redirects=True,
Expand Down Expand Up @@ -195,8 +196,6 @@ Access the Chatbot GUI at `http://localhost:8000`.
pip install pytest
# Basic Testing
pytest test/test.py -W ignore::DeprecationWarning
# With Embedding
pytest test/test.py -W ignore::DeprecationWarning --embed
# With Ollama Model Backend
pytest test/test.py -W ignore::DeprecationWarning --ollama
```
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
data_dir: ./data
# Huggingface
embeddings_model: sentence-transformers/all-mpnet-base-v2
introduction: Ah my good fellow!
# Sources
mediawikis:
# - dnd4e
Expand Down
72 changes: 32 additions & 40 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import chainlit as cl
import yaml

import torch
if not torch.cuda.is_available():
torch.set_num_threads(torch.get_num_threads() * 2)


class MultiWiki:
def __init__(self):
Expand All @@ -40,16 +44,18 @@ def __init__(self):
def set_args(self, args):
self.args = args

# https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information
def set_chat_settings(self, settings):
# Update global wiki, not local self
global wiki
if not settings:
self.inputs = wiki.settings
wiki.inputs = wiki.settings
else:
self.inputs = settings
wiki.inputs = settings


### Globals
wiki = MultiWiki()
wiki.set_chat_settings(None)


def rename_duplicates(documents):
Expand All @@ -67,19 +73,15 @@ def rename_duplicates(documents):
return documents


def create_vector_db(data_dir, embeddings_model, source, mediawikis):
if not source:
print("No data sources found")
exit(1)

def create_vector_db():
Document = namedtuple("Document", ["page_content", "metadata"])
merged_documents = []

for wiki, _ in mediawikis.items():
for dump, _ in wiki.mediawikis.items():
# https://python.langchain.com/docs/integrations/document_loaders/mediawikidump
loader = MWDumpLoader(
encoding="utf-8",
file_path=f"{source}/{wiki}_pages_current.xml",
file_path=f"{wiki.source}/{dump}_pages_current.xml",
# https://www.mediawiki.org/wiki/Help:Namespaces
namespaces=[0],
skip_redirects=True,
Expand All @@ -89,24 +91,24 @@ def create_vector_db(data_dir, embeddings_model, source, mediawikis):
# Modify the source metadata by accounting for duplicates (<name>_n)
# And add the mediawiki title (<name>_n - <wikiname>)
merged_documents.extend(
Document(doc.page_content, {"source": doc.metadata["source"] + f" - {wiki}"})
Document(doc.page_content, {"source": doc.metadata["source"] + f" - {dump}"})
for doc in rename_duplicates(loader.load())
)
print(f"Embedding {len(merged_documents)} Pages, this may take a while.")
# https://python.langchain.com/docs/integrations/text_embedding/huggingfacehub
embeddings = HuggingFaceEmbeddings(
model_name=embeddings_model, cache_folder="./model"
model_name=wiki.embeddings_model, cache_folder="./model"
)
# https://python.langchain.com/docs/integrations/vectorstores/chroma
vectordb = Chroma.from_documents(
documents=merged_documents,
embedding=embeddings,
persist_directory=data_dir,
persist_directory=wiki.data_dir,
)
vectordb.persist()


def create_chain(embeddings_model, model):
def create_chain():
# https://python.langchain.com/docs/modules/memory/chat_messages/
message_history = ChatMessageHistory()
# https://python.langchain.com/docs/modules/memory/
Expand All @@ -119,22 +121,24 @@ def create_chain(embeddings_model, model):
# https://python.langchain.com/docs/integrations/text_embedding/huggingfacehub
embeddings = HuggingFaceEmbeddings(
cache_folder="./model",
model_name=embeddings_model,
model_name=wiki.embeddings_model,
)
vectordb = Chroma(persist_directory=wiki.data_dir, embedding_function=embeddings)
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
# https://python.langchain.com/docs/integrations/llms/llm_caching
set_llm_cache(SQLiteCache(database_path="memory/cache.db"))
wiki.set_chat_settings(None)
# https://python.langchain.com/docs/integrations/llms/ollama
# Wiki loses its global scope when used as a parameter in ChatOllama()
# Unsure as to why
inputs = wiki.inputs
model = ChatOllama(
cache=True,
callback_manager=callback_manager,
model=model,
repeat_penalty=wiki.inputs["repeat_penalty"],
temperature=wiki.inputs["temperature"],
top_k=wiki.inputs["top_k"],
top_p=wiki.inputs["top_p"],
model=wiki.model,
repeat_penalty=inputs["repeat_penalty"],
temperature=inputs["temperature"],
top_k=inputs["top_k"],
top_p=inputs["top_p"],
)
# https://api.python.langchain.com/en/latest/chains/langchain.chains.conversational_retrieval.base.ConversationalRetrievalChain.html
chain = ConversationalRetrievalChain.from_llm(
Expand All @@ -147,19 +151,14 @@ def create_chain(embeddings_model, model):
return_source_documents=True,
)

return chain, model
return chain


# https://docs.chainlit.io/integrations/langchain
# https://docs.chainlit.io/examples/qa
@cl.on_chat_start
async def on_chat_start():
chain, llm = create_chain(
wiki.embeddings_model,
wiki.model,
)
wiki.set_chat_settings(None)

chain = create_chain()
# https://docs.chainlit.io/api-reference/chat-settings
inputs = [
TextInput(
Expand Down Expand Up @@ -209,16 +208,17 @@ async def on_chat_start():
# https://docs.chainlit.io/observability-iteration/prompt-playground/llm-providers#langchain-provider
add_llm_provider(
LangchainGenericProvider(
id=llm._llm_type,
id=chain.combine_docs_chain.llm_chain.llm._llm_type,
name="Ollama",
llm=llm,
llm=chain.combine_docs_chain.llm_chain.llm,
is_chat=True,
# Not enough context to LangchainGenericProvider
# https://github.com/Chainlit/chainlit/blob/main/backend/chainlit/playground/providers/langchain.py#L27
inputs=wiki.inputs,
)
)
await cl.ChatSettings(inputs).send()
# await cl.Message(content=wiki.introduction, disable_human_feedback=True).send()

cl.user_session.set("chain", chain)

Expand Down Expand Up @@ -267,17 +267,9 @@ async def setup_agent(settings):
wiki.set_args(parser.parse_args())

if wiki.args.embed:
create_vector_db(
wiki.data_dir,
wiki.embeddings_model,
wiki.source,
wiki.mediawikis,
)
create_vector_db()

chain, llm = create_chain(
wiki.embeddings_model,
wiki.model,
)
chain = create_chain()

if not wiki.question:
print("No Prompt for Chatbot found")
Expand Down
1 change: 1 addition & 0 deletions provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async def create_completion(self, request):
await super().create_completion(request)

self.require_settings(request.prompt.settings)
del request.prompt.settings["num_sources"]

messages = self.create_prompt(request)

Expand Down
5 changes: 1 addition & 4 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import pytest

optional_markers = {
"embed": {"help": "Test VectorDB Embeddings",
"marker-descr": "Enable embed tests",
"skip-reason": "Test only runs with the --{} option."},
"ollama": {"help": "Test ollama backend with Langchain",
"marker-descr": "Enable langchain tests with ollana",
"skip-reason": "Test only runs with the --{} option."},
Expand All @@ -30,4 +27,4 @@ def pytest_collection_modifyitems(config, items):
)
for item in items:
if marker in item.keywords:
item.add_marker(skip_test)
item.add_marker(skip_test)
53 changes: 8 additions & 45 deletions test/test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from collections import Counter
from langchain.document_loaders import MWDumpLoader
from langchain.document_loaders.merge import MergedDataLoader
from main import MultiWiki, create_chain, create_vector_db, rename_duplicates
from main import MultiWiki, create_chain

import argparse
import pytest
import shutil

import torch
torch.set_num_threads(22)

wiki = MultiWiki()


def test_multiwiki():
wiki = MultiWiki()
assert wiki.data_dir == "./data"
assert wiki.embeddings_model == "sentence-transformers/all-mpnet-base-v2"
assert wiki.introduction == "Ah my good fellow!"
assert wiki.model == "volo"
assert wiki.question == "How many eyestalks does a Beholder have?"
assert wiki.source == "./sources"
Expand All @@ -31,54 +32,16 @@ def test_multiwiki():


def test_multiwiki_set_args():
wiki = MultiWiki()
parser = argparse.ArgumentParser()
parser.add_argument("--no-embed", dest="embed", action="store_false")
wiki.set_args(parser.parse_args([]))
print(wiki.args)
assert wiki.args.embed == True


@pytest.mark.embed
def test_rename_duplicates():
wiki = MultiWiki()
source = wiki.source
mediawikis = wiki.mediawikis
for wiki in mediawikis.keys():
mediawikis[wiki] = MWDumpLoader(
encoding="utf-8",
file_path=f"{source}/{wiki}_pages_current.xml",
namespaces=[0],
skip_redirects=True,
stop_on_error=False,
)
loader_all = MergedDataLoader(loaders=mediawikis.values())
documents = loader_all.load()

doc_counter = Counter([doc.metadata["source"] for doc in documents])
duplicates = {source: count for source, count in doc_counter.items() if count > 1}

if len(duplicates) > 1:
documents_renamed = rename_duplicates(documents)
doc_names = [getattr(item, "metadata")["source"] for item in documents_renamed]
for dup in duplicates.items():
for i in range(1, dup[1]):
assert f"{dup[0]}_{i}" in doc_names
assert len(documents) == len(documents_renamed)


@pytest.mark.embed
def test_create_vector_db():
create_vector_db(
"test_data", "sentence-transformers/all-mpnet-base-v2", "./sources", {"dnd5e": ""}
)
shutil.rmtree("test_data")


@pytest.mark.ollama
def test_create_chain():
wiki = MultiWiki()
chain, llm = create_chain(wiki.embeddings_model, wiki.model)
chain = create_chain()
res = chain(wiki.question)
assert res["answer"] != ""
assert res["source_documents"] != []

0 comments on commit 66c8197

Please sign in to comment.