Skip to content

Commit

Permalink
Merge pull request #115 from Whitelisted1/master
Browse files Browse the repository at this point in the history
Utilize model class & add ability to create conversation with specific LLM
  • Loading branch information
Soulter authored Oct 16, 2023
2 parents f040e63 + cf5708d commit f2f8f6f
Showing 1 changed file with 82 additions and 26 deletions.
108 changes: 82 additions & 26 deletions src/hugchat/hugchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,23 @@ def __init__(
self,
id: str = None,
title: str = None,
model: str = None,
model = None,
system_prompt: str = None,
history: list = []
):
'''
Returns a conversation object
'''

self.id = id
self.title = title
self.id: str = id
self.title: str = title
self.model = model
self.system_prompt = system_prompt
self.history = history
self.system_prompt: str = system_prompt
self.history: list = history

def __str__(self) -> str:
return self.id


# For future use -- is not used currently
class model:
def __init__(
self,
Expand All @@ -58,18 +56,18 @@ def __init__(
Returns a model object
'''

self.id: str = id,
self.name: str = name,
self.displayName: str = displayName,
self.id: str = id
self.name: str = name
self.displayName: str = displayName

self.preprompt: str = preprompt,
self.promptExamples: list = promptExamples,
self.websiteUrl: str = websiteUrl,
self.description: str = description,
self.preprompt: str = preprompt
self.promptExamples: list = promptExamples
self.websiteUrl: str = websiteUrl
self.description: str = description

self.datasetName: str = datasetName,
self.datasetUrl: str = datasetUrl,
self.modelUrl: str = modelUrl,
self.datasetName: str = datasetName
self.datasetUrl: str = datasetUrl
self.modelUrl: str = modelUrl
self.parameters: dict = parameters

def __str__(self) -> str:
Expand Down Expand Up @@ -124,10 +122,9 @@ def __init__(
self.llms = self.get_remote_llms()

if type(default_llm) == str:
if default_llm in self.llms:
self.active_model = default_llm
else:
raise Exception(f"Given model is not in llms list. LLM list: {self.llms}")
self.active_model = self.get_llm_from_name(default_llm)
if self.active_model is None:
raise Exception(f"Given model is not in llms list. LLM list: {[model.id for model in self.llms]}")
else:
self.active_model = self.llms[default_llm]

Expand Down Expand Up @@ -204,12 +201,20 @@ def accept_ethics_modal(self):
return True


def new_conversation(self, system_prompt: str = "", switch_to: bool = False) -> str:
def new_conversation(self, modelIndex: int = None, system_prompt: str = "", switch_to: bool = False) -> str:
'''
Create a new conversation. Return the conversation object. You should change the conversation by calling change_conversation() after calling this method.
'''
err_count = 0

if modelIndex == None:
model = self.active_model
else:
if modelIndex < 0 or modelIndex >= len(self.llms):
raise IndexError("Out of range of llm index")

model = self.llms[modelIndex]

# Accept the welcome modal when init.
# 17/5/2023: This is not required anymore.
# if not self.accepted_welcome_modal:
Expand All @@ -225,15 +230,15 @@ def new_conversation(self, system_prompt: str = "", switch_to: bool = False) ->
try:
resp = self.session.post(
self.hf_base_url + "/chat/conversation",
json={"model": self.active_model, "preprompt": system_prompt},
json={"model": model.id, "preprompt": system_prompt if system_prompt != "" else model.preprompt},
headers=_header,
cookies = self.get_cookies()
)

logging.debug(resp.text)
cid = json.loads(resp.text)['conversationId']

c = conversation(id=cid, system_prompt=system_prompt, model=self.active_model)
c = conversation(id=cid, system_prompt=system_prompt, model=model)

self.conversation_list.append(c)
self.__not_summarize_cids.append(cid) # For the 1st chat, the conversation needs to be summarized.
Expand Down Expand Up @@ -394,6 +399,10 @@ def switch_llm(self, index: int) -> bool:
# print(f"Switch LLM {llms[to]} failed. Please submit an issue to https://github.com/Soulter/hugging-chat-api")
# return False

def get_llm_from_name(self, name: str) -> Union[model, None]:
for model in self.llms:
if model.name == name:
return model

# Gives information such as name, websiteUrl, description, displayName, parameters, etc.
# We can use it in the future if we need to get information about models
Expand All @@ -406,11 +415,57 @@ def get_remote_llms(self) -> list:

if r.status_code != 200:
raise Exception(f"Failed to get remote LLMs with status code: {r.status_code}")

data = r.json()["nodes"][0]["data"]
modelsIndices = data[data[0]["models"]]
model_list = []

return_data_from_index = lambda index: None if index == -1 else data[index]

for modelIndex in modelsIndices:
model_data = data[modelIndex]

m = model(
id = return_data_from_index(model_data["id"]),
name = return_data_from_index(model_data["name"]),
displayName = return_data_from_index(model_data["displayName"]),

preprompt = return_data_from_index(model_data["preprompt"]),
# promptExamples = return_data_from_index(model_data["promptExamples"]),
websiteUrl = return_data_from_index(model_data["websiteUrl"]),
description = return_data_from_index(model_data["description"]),

datasetName = return_data_from_index(model_data["datasetName"]),
datasetUrl = return_data_from_index(model_data["datasetUrl"]),
modelUrl = return_data_from_index(model_data["modelUrl"]),
# parameters = return_data_from_index(model_data["parameters"]),
)

prompt_list = return_data_from_index(model_data["promptExamples"])
if prompt_list is not None:
_promptExamples = [return_data_from_index(index) for index in prompt_list]
m.promptExamples = [{ "title": data[prompt["title"]], "prompt": data[prompt["prompt"]] } for prompt in _promptExamples]

indices_parameters_dict = return_data_from_index(model_data["parameters"])
out_parameters_dict = {}
for key in indices_parameters_dict.keys():
value = indices_parameters_dict[key]

if value == -1:
out_parameters_dict[key] = None
continue

if type(data[value]) == list:
out_parameters_dict[key] = [data[index] for index in data[value]]
continue

out_parameters_dict[key] = data[value]

m.parameters = out_parameters_dict

model_list.append(m)

return [data[data[index]["name"]] for index in modelsIndices]
return model_list

def get_remote_conversations(self, replace_conversation_list=True):
'''
Expand Down Expand Up @@ -598,6 +653,7 @@ def _stream_query(
raise ModelOverloadedError(
"Model is overloaded, please try again later or switch to another model."
)
logging.debug(resp.headers)
raise ChatError(f"Failed to parse response: {res}")
if break_label:
break
Expand Down

0 comments on commit f2f8f6f

Please sign in to comment.