From 66c819734b13bb379a7f2bcf1937b66a4e48b73b Mon Sep 17 00:00:00 2001 From: tylertitsworth Date: Sun, 31 Dec 2023 11:01:34 -0800 Subject: [PATCH] simplify global usage and remove embed tests --- README.md | 7 ++--- config.yaml | 1 + main.py | 72 +++++++++++++++++++++--------------------------- provider.py | 1 + test/conftest.py | 5 +--- test/test.py | 53 ++++++----------------------------- 6 files changed, 46 insertions(+), 93 deletions(-) diff --git a/README.md b/README.md index 16522c4..b5d1d45 100644 --- a/README.md +++ b/README.md @@ -155,11 +155,12 @@ Choose a new [Filetype Document Loader](https://python.langchain.com/docs/module ```python Document = namedtuple("Document", ["page_content", "metadata"]) merged_documents = [] -for wiki, _ in mediawikis.items(): + +for dump, _ in wiki.mediawikis.items(): # https://python.langchain.com/docs/integrations/document_loaders/mediawikidump loader = MWDumpLoader( encoding="utf-8", - file_path=f"{source}/{wiki}_pages_current.xml", + file_path=f"{wiki.source}/{dump}_pages_current.xml", # https://www.mediawiki.org/wiki/Help:Namespaces namespaces=[0], skip_redirects=True, @@ -195,8 +196,6 @@ Access the Chatbot GUI at `http://localhost:8000`. pip install pytest # Basic Testing pytest test/test.py -W ignore::DeprecationWarning -# With Embedding -pytest test/test.py -W ignore::DeprecationWarning --embed # With Ollama Model Backend pytest test/test.py -W ignore::DeprecationWarning --ollama ``` diff --git a/config.yaml b/config.yaml index 81edf42..e7a9ab8 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,7 @@ data_dir: ./data # Huggingface embeddings_model: sentence-transformers/all-mpnet-base-v2 +introduction: Ah my good fellow! # Sources mediawikis: # - dnd4e diff --git a/main.py b/main.py index 4182e61..7dc6fa8 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,10 @@ import chainlit as cl import yaml +import torch +if not torch.cuda.is_available(): + torch.set_num_threads(torch.get_num_threads() * 2) + class MultiWiki: def __init__(self): @@ -40,16 +44,18 @@ def __init__(self): def set_args(self, args): self.args = args - # https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information def set_chat_settings(self, settings): + # Update global wiki, not local self + global wiki if not settings: - self.inputs = wiki.settings + wiki.inputs = wiki.settings else: - self.inputs = settings + wiki.inputs = settings ### Globals wiki = MultiWiki() +wiki.set_chat_settings(None) def rename_duplicates(documents): @@ -67,19 +73,15 @@ def rename_duplicates(documents): return documents -def create_vector_db(data_dir, embeddings_model, source, mediawikis): - if not source: - print("No data sources found") - exit(1) - +def create_vector_db(): Document = namedtuple("Document", ["page_content", "metadata"]) merged_documents = [] - for wiki, _ in mediawikis.items(): + for dump, _ in wiki.mediawikis.items(): # https://python.langchain.com/docs/integrations/document_loaders/mediawikidump loader = MWDumpLoader( encoding="utf-8", - file_path=f"{source}/{wiki}_pages_current.xml", + file_path=f"{wiki.source}/{dump}_pages_current.xml", # https://www.mediawiki.org/wiki/Help:Namespaces namespaces=[0], skip_redirects=True, @@ -89,24 +91,24 @@ def create_vector_db(data_dir, embeddings_model, source, mediawikis): # Modify the source metadata by accounting for duplicates (_n) # And add the mediawiki title (_n - ) merged_documents.extend( - Document(doc.page_content, {"source": doc.metadata["source"] + f" - {wiki}"}) + Document(doc.page_content, {"source": doc.metadata["source"] + f" - {dump}"}) for doc in rename_duplicates(loader.load()) ) print(f"Embedding {len(merged_documents)} Pages, this may take a while.") # https://python.langchain.com/docs/integrations/text_embedding/huggingfacehub embeddings = HuggingFaceEmbeddings( - model_name=embeddings_model, cache_folder="./model" + model_name=wiki.embeddings_model, cache_folder="./model" ) # https://python.langchain.com/docs/integrations/vectorstores/chroma vectordb = Chroma.from_documents( documents=merged_documents, embedding=embeddings, - persist_directory=data_dir, + persist_directory=wiki.data_dir, ) vectordb.persist() -def create_chain(embeddings_model, model): +def create_chain(): # https://python.langchain.com/docs/modules/memory/chat_messages/ message_history = ChatMessageHistory() # https://python.langchain.com/docs/modules/memory/ @@ -119,22 +121,24 @@ def create_chain(embeddings_model, model): # https://python.langchain.com/docs/integrations/text_embedding/huggingfacehub embeddings = HuggingFaceEmbeddings( cache_folder="./model", - model_name=embeddings_model, + model_name=wiki.embeddings_model, ) vectordb = Chroma(persist_directory=wiki.data_dir, embedding_function=embeddings) callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) # https://python.langchain.com/docs/integrations/llms/llm_caching set_llm_cache(SQLiteCache(database_path="memory/cache.db")) - wiki.set_chat_settings(None) # https://python.langchain.com/docs/integrations/llms/ollama + # Wiki loses its global scope when used as a parameter in ChatOllama() + # Unsure as to why + inputs = wiki.inputs model = ChatOllama( cache=True, callback_manager=callback_manager, - model=model, - repeat_penalty=wiki.inputs["repeat_penalty"], - temperature=wiki.inputs["temperature"], - top_k=wiki.inputs["top_k"], - top_p=wiki.inputs["top_p"], + model=wiki.model, + repeat_penalty=inputs["repeat_penalty"], + temperature=inputs["temperature"], + top_k=inputs["top_k"], + top_p=inputs["top_p"], ) # https://api.python.langchain.com/en/latest/chains/langchain.chains.conversational_retrieval.base.ConversationalRetrievalChain.html chain = ConversationalRetrievalChain.from_llm( @@ -147,19 +151,14 @@ def create_chain(embeddings_model, model): return_source_documents=True, ) - return chain, model + return chain # https://docs.chainlit.io/integrations/langchain # https://docs.chainlit.io/examples/qa @cl.on_chat_start async def on_chat_start(): - chain, llm = create_chain( - wiki.embeddings_model, - wiki.model, - ) - wiki.set_chat_settings(None) - + chain = create_chain() # https://docs.chainlit.io/api-reference/chat-settings inputs = [ TextInput( @@ -209,9 +208,9 @@ async def on_chat_start(): # https://docs.chainlit.io/observability-iteration/prompt-playground/llm-providers#langchain-provider add_llm_provider( LangchainGenericProvider( - id=llm._llm_type, + id=chain.combine_docs_chain.llm_chain.llm._llm_type, name="Ollama", - llm=llm, + llm=chain.combine_docs_chain.llm_chain.llm, is_chat=True, # Not enough context to LangchainGenericProvider # https://github.com/Chainlit/chainlit/blob/main/backend/chainlit/playground/providers/langchain.py#L27 @@ -219,6 +218,7 @@ async def on_chat_start(): ) ) await cl.ChatSettings(inputs).send() + # await cl.Message(content=wiki.introduction, disable_human_feedback=True).send() cl.user_session.set("chain", chain) @@ -267,17 +267,9 @@ async def setup_agent(settings): wiki.set_args(parser.parse_args()) if wiki.args.embed: - create_vector_db( - wiki.data_dir, - wiki.embeddings_model, - wiki.source, - wiki.mediawikis, - ) + create_vector_db() - chain, llm = create_chain( - wiki.embeddings_model, - wiki.model, - ) + chain = create_chain() if not wiki.question: print("No Prompt for Chatbot found") diff --git a/provider.py b/provider.py index 46089f8..af18fab 100644 --- a/provider.py +++ b/provider.py @@ -65,6 +65,7 @@ async def create_completion(self, request): await super().create_completion(request) self.require_settings(request.prompt.settings) + del request.prompt.settings["num_sources"] messages = self.create_prompt(request) diff --git a/test/conftest.py b/test/conftest.py index 9c7083e..fee67e9 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,9 +1,6 @@ import pytest optional_markers = { - "embed": {"help": "Test VectorDB Embeddings", - "marker-descr": "Enable embed tests", - "skip-reason": "Test only runs with the --{} option."}, "ollama": {"help": "Test ollama backend with Langchain", "marker-descr": "Enable langchain tests with ollana", "skip-reason": "Test only runs with the --{} option."}, @@ -30,4 +27,4 @@ def pytest_collection_modifyitems(config, items): ) for item in items: if marker in item.keywords: - item.add_marker(skip_test) \ No newline at end of file + item.add_marker(skip_test) diff --git a/test/test.py b/test/test.py index 4d0a065..e171951 100644 --- a/test/test.py +++ b/test/test.py @@ -1,17 +1,18 @@ -from collections import Counter -from langchain.document_loaders import MWDumpLoader -from langchain.document_loaders.merge import MergedDataLoader -from main import MultiWiki, create_chain, create_vector_db, rename_duplicates +from main import MultiWiki, create_chain import argparse import pytest -import shutil + +import torch +torch.set_num_threads(22) + +wiki = MultiWiki() def test_multiwiki(): - wiki = MultiWiki() assert wiki.data_dir == "./data" assert wiki.embeddings_model == "sentence-transformers/all-mpnet-base-v2" + assert wiki.introduction == "Ah my good fellow!" assert wiki.model == "volo" assert wiki.question == "How many eyestalks does a Beholder have?" assert wiki.source == "./sources" @@ -31,7 +32,6 @@ def test_multiwiki(): def test_multiwiki_set_args(): - wiki = MultiWiki() parser = argparse.ArgumentParser() parser.add_argument("--no-embed", dest="embed", action="store_false") wiki.set_args(parser.parse_args([])) @@ -39,46 +39,9 @@ def test_multiwiki_set_args(): assert wiki.args.embed == True -@pytest.mark.embed -def test_rename_duplicates(): - wiki = MultiWiki() - source = wiki.source - mediawikis = wiki.mediawikis - for wiki in mediawikis.keys(): - mediawikis[wiki] = MWDumpLoader( - encoding="utf-8", - file_path=f"{source}/{wiki}_pages_current.xml", - namespaces=[0], - skip_redirects=True, - stop_on_error=False, - ) - loader_all = MergedDataLoader(loaders=mediawikis.values()) - documents = loader_all.load() - - doc_counter = Counter([doc.metadata["source"] for doc in documents]) - duplicates = {source: count for source, count in doc_counter.items() if count > 1} - - if len(duplicates) > 1: - documents_renamed = rename_duplicates(documents) - doc_names = [getattr(item, "metadata")["source"] for item in documents_renamed] - for dup in duplicates.items(): - for i in range(1, dup[1]): - assert f"{dup[0]}_{i}" in doc_names - assert len(documents) == len(documents_renamed) - - -@pytest.mark.embed -def test_create_vector_db(): - create_vector_db( - "test_data", "sentence-transformers/all-mpnet-base-v2", "./sources", {"dnd5e": ""} - ) - shutil.rmtree("test_data") - - @pytest.mark.ollama def test_create_chain(): - wiki = MultiWiki() - chain, llm = create_chain(wiki.embeddings_model, wiki.model) + chain = create_chain() res = chain(wiki.question) assert res["answer"] != "" assert res["source_documents"] != []