Skip to content

Commit

Permalink
Few improvements on chatbot_with_streaming example
Browse files Browse the repository at this point in the history
Colored user and assistant prefixes

Black formatting on script

Add temperature parameter
  • Loading branch information
lionelchg committed Dec 14, 2023
1 parent be8bfb5 commit 7bce47b
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions examples/chatbot_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,27 @@


class ChatBot:
def __init__(self, api_key, model, system_message=None):
def __init__(self, api_key, model, temperature, system_message=None):
self.client = MistralClient(api_key=api_key)
self.model = model
self.temperature = temperature
self.system_message = system_message

def opening_instructions(self):
print("""
print(
"""
To chat: type your message and hit enter
To start a new chat: type /new
To exit: type /exit, /quit, or hit CTRL+C
""")
"""
)

def new_chat(self):
self.messages = []
if self.system_message:
self.messages.append(ChatMessage(role="system", content=self.system_message))
self.messages.append(
ChatMessage(role="system", content=self.system_message)
)

def check_exit(self, content):
if content.lower().strip() in ["/exit", "/quit"]:
Expand Down Expand Up @@ -67,23 +72,20 @@ def run_inference(self, content):
print("", flush=True)

if assistant_response:
self.messages.append(ChatMessage(role="assistant", content=assistant_response))
self.messages.append(
ChatMessage(role="assistant", content=assistant_response)
)
logger.debug(f"Current messages: {self.messages}")

def start(self):

self.opening_instructions()
self.new_chat()

while True:
try:
print("")
content = input("YOU: ")
content = input(f"\033[38;2;50;168;82mUser: \033[0m")
self.check_exit(content)
if not self.check_new_chat(content):
print("")
print("MISTRAL:")
print("")
print(f"\033[38;2;253;112;0m{self.model}: \033[0m", end="")
self.run_inference(content)

except KeyboardInterrupt:
Expand All @@ -95,16 +97,33 @@ def exit(self):


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="A simple chatbot using the Mistral API")
parser.add_argument("--api-key", default=os.environ.get("MISTRAL_API_KEY"),
help="Mistral API key. Defaults to environment variable MISTRAL_API_KEY")
parser.add_argument("-m", "--model", choices=MODEL_LIST,
default=DEFAULT_MODEL,
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s")
parser.add_argument("-s", "--system-message",
help="Optional system message to prepend.")
parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging")
parser = argparse.ArgumentParser(
description="A simple chatbot using the Mistral API"
)
parser.add_argument(
"--api-key",
default=os.environ.get("MISTRAL_API_KEY"),
help="Mistral API key. Defaults to environment variable MISTRAL_API_KEY",
)
parser.add_argument(
"-m",
"--model",
choices=MODEL_LIST,
default=DEFAULT_MODEL,
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s",
)
parser.add_argument(
"-t",
"--temperature",
default=0.7,
help="Temperature of the model. Defaults to 0.7.",
)
parser.add_argument(
"-s", "--system-message", help="Optional system message to prepend."
)
parser.add_argument(
"-d", "--debug", action="store_true", help="Enable debug logging"
)

args = parser.parse_args()

Expand All @@ -121,5 +140,5 @@ def exit(self):

logger.debug(f"Starting chatbot with model: {args.model}")

bot = ChatBot(args.api_key, args.model, args.system_message)
bot = ChatBot(args.api_key, args.model, args.temperature, args.system_message)
bot.start()

0 comments on commit 7bce47b

Please sign in to comment.