Skip to content

Commit

Permalink
simplify the embedding process
Browse files Browse the repository at this point in the history
  • Loading branch information
tylertitsworth committed Dec 31, 2023
1 parent 93e7515 commit ba96284
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 43 deletions.
33 changes: 19 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,25 @@ within Kara-Tur.
Choose a new [Filetype Document Loader](https://python.langchain.com/docs/modules/data_connection/document_loaders/) or [App Document Loader](https://python.langchain.com/docs/integrations/document_loaders/) and include those files in your VectorDB.

```python
merged_documents = []
for idx, wiki in enumerate(wikis.keys()):
wikis[wiki] = MWDumpLoader(
encoding="utf-8",
file_path=f"{source}/{wiki}_pages_current.xml",
namespaces=[0],
skip_redirects=True,
stop_on_error=False,
)
wikis[wiki] = wikis[wiki].load()
wikis[wiki] = rename_duplicates(wikis[wiki])
for jdx, doc in enumerate(wikis[wiki]):
wikis[wiki][jdx].metadata["source"] = doc.metadata["source"] + " - " + list(wikis)[idx]
merged_documents.append(wikis[wiki][jdx])
Document = namedtuple("Document", ["page_content", "metadata"])
merged_documents = []
for wiki, _ in mediawikis.items():
# https://python.langchain.com/docs/integrations/document_loaders/mediawikidump
loader = MWDumpLoader(
encoding="utf-8",
file_path=f"{source}/{wiki}_pages_current.xml",
# https://www.mediawiki.org/wiki/Help:Namespaces
namespaces=[0],
skip_redirects=True,
stop_on_error=False,
)
# For each Document provided:
# Modify the source metadata by accounting for duplicates (<name>_n)
# And add the mediawiki title (<name>_n - <wikiname>)
merged_documents.extend(
Document(doc.page_content, {"source": doc.metadata["source"] + f" - {wiki}"})
for doc in rename_duplicates(loader.load())
)
### Insert a new loader
from langchain.document_loaders import TextLoader
myloader = TextLoader("./mydocument.md")
Expand Down
9 changes: 5 additions & 4 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
data_dir: ./data
# Huggingface
embeddings_model: sentence-transformers/all-mpnet-base-v2
# Sources
mediawikis:
# - dnd4e
- dnd5e
# - darksun
# - dragonlance
# - eberron
- dragonlance
- eberron
# - exandria
# - greyhawk
- forgottenrealms
- greyhawk
# - planescape
# - ravenloft
# - spelljammer
# Ollama
model: volo
question: How many eyestalks does a Beholder have?
settings:
num_sources: 4
repeat_penalty: 1.3
temperature: 0.4
top_k: 20
top_p: 0.35
num_sources: 4
# Sources Path
source: ./sources
38 changes: 22 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chainlit.input_widget import Slider, TextInput
from chainlit.playground.config import add_llm_provider
from collections import namedtuple
from langchain.cache import SQLiteCache
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
Expand Down Expand Up @@ -32,7 +33,7 @@ def __init__(self):

for key, val in data.items():
if key == "mediawikis":
self.wikis = {wiki: "" for wiki in data["mediawikis"]}
self.mediawikis = {wiki: "" for wiki in data["mediawikis"]}
else:
setattr(self, key, val)

Expand Down Expand Up @@ -66,37 +67,41 @@ def rename_duplicates(documents):
return documents


def create_vector_db(embeddings_model, source, wikis):
def create_vector_db(data_dir, embeddings_model, source, mediawikis):
if not source:
print("No data sources found")
exit(1)

# https://python.langchain.com/docs/integrations/text_embedding/huggingfacehub
embeddings = HuggingFaceEmbeddings(
model_name=embeddings_model, cache_folder="./model"
)
Document = namedtuple("Document", ["page_content", "metadata"])
merged_documents = []
for idx, wiki in enumerate(wikis.keys()):

for wiki, _ in mediawikis.items():
# https://python.langchain.com/docs/integrations/document_loaders/mediawikidump
wikis[wiki] = MWDumpLoader(
loader = MWDumpLoader(
encoding="utf-8",
file_path=f"{source}/{wiki}_pages_current.xml",
# https://www.mediawiki.org/wiki/Help:Namespaces
namespaces=[0],
skip_redirects=True,
stop_on_error=False,
)
wikis[wiki] = wikis[wiki].load()
wikis[wiki] = rename_duplicates(wikis[wiki])
for jdx, doc in enumerate(wikis[wiki]):
wikis[wiki][jdx].metadata["source"] = doc.metadata["source"] + " - " + list(wikis)[idx]
merged_documents.append(wikis[wiki][jdx])
# For each Document provided:
# Modify the source metadata by accounting for duplicates (<name>_n)
# And add the mediawiki title (<name>_n - <wikiname>)
merged_documents.extend(
Document(doc.page_content, {"source": doc.metadata["source"] + f" - {wiki}"})
for doc in rename_duplicates(loader.load())
)
print(f"Embedding {len(merged_documents)} Pages, this may take a while.")
# https://python.langchain.com/docs/integrations/text_embedding/huggingfacehub
embeddings = HuggingFaceEmbeddings(
model_name=embeddings_model, cache_folder="./model"
)
# https://python.langchain.com/docs/integrations/vectorstores/chroma
vectordb = Chroma.from_documents(
documents=merged_documents,
embedding=embeddings,
persist_directory="data",
persist_directory=data_dir,
)
vectordb.persist()

Expand All @@ -116,7 +121,7 @@ def create_chain(embeddings_model, model):
cache_folder="./model",
model_name=embeddings_model,
)
vectordb = Chroma(persist_directory="data", embedding_function=embeddings)
vectordb = Chroma(persist_directory=wiki.data_dir, embedding_function=embeddings)
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
# https://python.langchain.com/docs/integrations/llms/llm_caching
set_llm_cache(SQLiteCache(database_path="memory/cache.db"))
Expand Down Expand Up @@ -263,9 +268,10 @@ async def setup_agent(settings):

if wiki.args.embed:
create_vector_db(
wiki.data_dir,
wiki.embeddings_model,
wiki.source,
wiki.wikis,
wiki.mediawikis,
)

chain, llm = create_chain(
Expand Down
21 changes: 12 additions & 9 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@

import argparse
import pytest
import shutil


def test_multiwiki():
wiki = MultiWiki()
assert wiki.data_dir == "./data"
assert wiki.embeddings_model == "sentence-transformers/all-mpnet-base-v2"
assert wiki.model == "volo"
assert wiki.question == "How many eyestalks does a Beholder have?"
assert wiki.source == "./sources"
assert wiki.wikis == {
assert wiki.mediawikis == {
# "dnd4e": "",
"dnd5e": "",
# "darksun": "",
# "dragonlance": "",
# "eberron": "",
"dragonlance": "",
"eberron": "",
# "exandria": "",
# "greyhawk": "",
"greyhawk": "",
"forgottenrealms": "",
# "planescape": "",
# "ravenloft": "",
Expand All @@ -41,16 +43,16 @@ def test_multiwiki_set_args():
def test_rename_duplicates():
wiki = MultiWiki()
source = wiki.source
wikis = wiki.wikis
for wiki in wikis.keys():
wikis[wiki] = MWDumpLoader(
mediawikis = wiki.mediawikis
for wiki in mediawikis.keys():
mediawikis[wiki] = MWDumpLoader(
encoding="utf-8",
file_path=f"{source}/{wiki}_pages_current.xml",
namespaces=[0],
skip_redirects=True,
stop_on_error=False,
)
loader_all = MergedDataLoader(loaders=wikis.values())
loader_all = MergedDataLoader(loaders=mediawikis.values())
documents = loader_all.load()

doc_counter = Counter([doc.metadata["source"] for doc in documents])
Expand All @@ -68,8 +70,9 @@ def test_rename_duplicates():
@pytest.mark.embed
def test_create_vector_db():
create_vector_db(
"sentence-transformers/all-mpnet-base-v2", "./sources", {"dnd5e": ""}
"test_data", "sentence-transformers/all-mpnet-base-v2", "./sources", {"dnd5e": ""}
)
shutil.rmtree("test_data")


@pytest.mark.ollama
Expand Down

0 comments on commit ba96284

Please sign in to comment.