Skip to content

Commit

Permalink
use nested attrs instead of dict items
Browse files Browse the repository at this point in the history
  • Loading branch information
tylertitsworth committed Dec 31, 2023
1 parent 66c8197 commit f53dc38
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 41 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ for dump, _ in wiki.mediawikis.items():
# Modify the source metadata by accounting for duplicates (<name>_n)
# And add the mediawiki title (<name>_n - <wikiname>)
merged_documents.extend(
Document(doc.page_content, {"source": doc.metadata["source"] + f" - {wiki}"})
Document(doc.page_content, {"source": doc.metadata["source"] + f" - {dump}"})
for doc in rename_duplicates(loader.load())
)
### Insert a new loader
Expand Down
59 changes: 32 additions & 27 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import namedtuple
from chainlit.input_widget import Slider, TextInput
from chainlit.playground.config import add_llm_provider
from collections import namedtuple
from langchain.cache import SQLiteCache
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
Expand All @@ -19,6 +19,7 @@
import yaml

import torch

if not torch.cuda.is_available():
torch.set_num_threads(torch.get_num_threads() * 2)

Expand All @@ -34,23 +35,30 @@ def __init__(self):
except yaml.YAMLError as e:
print(f"Error reading YAML file: {e}")
exit(1)
self.convert_struct(**data)

for key, val in data.items():
if key == "mediawikis":
self.mediawikis = {wiki: "" for wiki in data["mediawikis"]}
def __getattr__(self, attr):
return self[attr]

def convert_struct(self, **kwargs):
for key, value in kwargs.items():
if isinstance(value, dict):
self.__dict__[key] = self.convert_struct(**value)
else:
setattr(self, key, val)
self.__dict__[key] = value
return self

def set_args(self, args):
self.args = args

def set_chat_settings(self, settings):
# Update global wiki, not local self
global wiki
if not settings:
wiki.inputs = wiki.settings
else:
wiki.inputs = settings
if isinstance(settings, dict):
for key, val in settings.items():
setattr(wiki, key, val)
if settings:
wiki.settings = settings


### Globals
Expand All @@ -77,7 +85,7 @@ def create_vector_db():
Document = namedtuple("Document", ["page_content", "metadata"])
merged_documents = []

for dump, _ in wiki.mediawikis.items():
for dump in wiki.mediawikis:
# https://python.langchain.com/docs/integrations/document_loaders/mediawikidump
loader = MWDumpLoader(
encoding="utf-8",
Expand All @@ -91,7 +99,9 @@ def create_vector_db():
# Modify the source metadata by accounting for duplicates (<name>_n)
# And add the mediawiki title (<name>_n - <wikiname>)
merged_documents.extend(
Document(doc.page_content, {"source": doc.metadata["source"] + f" - {dump}"})
Document(
doc.page_content, {"source": doc.metadata["source"] + f" - {dump}"}
)
for doc in rename_duplicates(loader.load())
)
print(f"Embedding {len(merged_documents)} Pages, this may take a while.")
Expand Down Expand Up @@ -128,26 +138,21 @@ def create_chain():
# https://python.langchain.com/docs/integrations/llms/llm_caching
set_llm_cache(SQLiteCache(database_path="memory/cache.db"))
# https://python.langchain.com/docs/integrations/llms/ollama
# Wiki loses its global scope when used as a parameter in ChatOllama()
# Unsure as to why
inputs = wiki.inputs
model = ChatOllama(
cache=True,
callback_manager=callback_manager,
model=wiki.model,
repeat_penalty=inputs["repeat_penalty"],
temperature=inputs["temperature"],
top_k=inputs["top_k"],
top_p=inputs["top_p"],
repeat_penalty=wiki.repeat_penalty,
temperature=wiki.temperature,
top_k=wiki.top_k,
top_p=wiki.top_p,
)
# https://api.python.langchain.com/en/latest/chains/langchain.chains.conversational_retrieval.base.ConversationalRetrievalChain.html
chain = ConversationalRetrievalChain.from_llm(
chain_type="stuff",
llm=model,
memory=memory,
retriever=vectordb.as_retriever(
search_kwargs={"k": int(wiki.inputs["num_sources"])}
),
retriever=vectordb.as_retriever(search_kwargs={"k": int(wiki.num_sources)}),
return_source_documents=True,
)

Expand All @@ -164,13 +169,13 @@ async def on_chat_start():
TextInput(
id="num_sources",
label="# of Sources",
initial=str(wiki.inputs["num_sources"]),
initial=str(wiki.num_sources),
description="Number of sources returned based on their similarity source. The same source can be returned more than once. (Default: 4)",
),
Slider(
id="temperature",
label="Temperature",
initial=wiki.inputs["temperature"],
initial=wiki.temperature,
min=0,
max=1,
step=0.1,
Expand All @@ -179,7 +184,7 @@ async def on_chat_start():
Slider(
id="repeat_penalty",
label="Repeat Penalty",
initial=wiki.inputs["repeat_penalty"],
initial=wiki.repeat_penalty,
min=0.5,
max=2.5,
step=0.1,
Expand All @@ -188,7 +193,7 @@ async def on_chat_start():
Slider(
id="top_k",
label="Top K",
initial=wiki.inputs["top_k"],
initial=wiki.top_k,
min=0,
max=100,
step=1,
Expand All @@ -197,7 +202,7 @@ async def on_chat_start():
Slider(
id="top_p",
label="Top P",
initial=wiki.inputs["top_p"],
initial=wiki.top_p,
min=0,
max=1,
step=0.1,
Expand All @@ -214,7 +219,7 @@ async def on_chat_start():
is_chat=True,
# Not enough context to LangchainGenericProvider
# https://github.com/Chainlit/chainlit/blob/main/backend/chainlit/playground/providers/langchain.py#L27
inputs=wiki.inputs,
inputs=wiki.settings,
)
)
await cl.ChatSettings(inputs).send()
Expand Down
26 changes: 13 additions & 13 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ def test_multiwiki():
assert wiki.model == "volo"
assert wiki.question == "How many eyestalks does a Beholder have?"
assert wiki.source == "./sources"
assert wiki.mediawikis == {
# "dnd4e": "",
"dnd5e": "",
# "darksun": "",
"dragonlance": "",
"eberron": "",
# "exandria": "",
"greyhawk": "",
"forgottenrealms": "",
# "planescape": "",
# "ravenloft": "",
# "spelljammer": "",
}
assert wiki.mediawikis == [
# "dnd4e",
"dnd5e",
# "darksun",
"dragonlance",
"eberron",
# "exandria",
"forgottenrealms",
"greyhawk",
# "planescape",
# "ravenloft",
# "spelljammer",
]


def test_multiwiki_set_args():
Expand Down

0 comments on commit f53dc38

Please sign in to comment.