diff --git a/examples/chatbot_with_streaming.py b/examples/chatbot_with_streaming.py index dd94b81..fce2d8c 100644 --- a/examples/chatbot_with_streaming.py +++ b/examples/chatbot_with_streaming.py @@ -5,6 +5,7 @@ import argparse import logging import os +import readline import sys from mistralai.client import MistralClient @@ -16,49 +17,132 @@ "mistral-medium", ] DEFAULT_MODEL = "mistral-small" +DEFAULT_TEMPERATURE = 0.7 +LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" +COMMAND_LIST = [ + "/new", + "/help", + "/model", + "/system", + "/temperature", + "/config", + "/quit", + "/exit", +] LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" logger = logging.getLogger("chatbot") +def completer(text, state): + buffer = readline.get_line_buffer() + if not buffer.startswith(text): + return None + + options = [command for command in COMMAND_LIST if command.startswith(text)] + if state < len(options): + return options[state] + else: + return None + + +readline.set_completer(completer) +# Remove all delimiters to ensure completion only at the beginning of the line +readline.set_completer_delims("") +# Enable tab completion +readline.parse_and_bind("tab: complete") + + class ChatBot: - def __init__(self, api_key, model, system_message=None): + def __init__( + self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE + ): self.client = MistralClient(api_key=api_key) self.model = model + self.temperature = temperature self.system_message = system_message def opening_instructions(self): - print(""" + print( + """ To chat: type your message and hit enter -To start a new chat: type /new -To exit: type /exit, /quit, or hit CTRL+C -""") +To start a new chat: /new +To switch model: /model +To switch system message: /system +To switch temperature: /temperature +To see current config: /config +To exit: /exit, /quit, or hit CTRL+C +To see this help: /help +""" + ) def new_chat(self): + print("") + print( + f"Starting new chat with model: {self.model}, temperature: {self.temperature}" + ) + print("") self.messages = [] if self.system_message: - self.messages.append(ChatMessage(role="system", content=self.system_message)) + self.messages.append( + ChatMessage(role="system", content=self.system_message) + ) - def check_exit(self, content): - if content.lower().strip() in ["/exit", "/quit"]: - self.exit() + def switch_model(self, input): + model = self.get_arguments(input) + if model in MODEL_LIST: + self.model = model + logger.info(f"Switching model: {model}") + else: + logger.error(f"Invalid model name: {model}") - def check_new_chat(self, content): - if content.lower().strip() in ["/new"]: - print("") - print("Starting new chat...") - print("") + def switch_system_message(self, input): + system_message = self.get_arguments(input) + if system_message: + self.system_message = system_message + logger.info(f"Switching system message: {system_message}") self.new_chat() - return True - return False + else: + logger.error(f"Invalid system message: {system_message}") + + def switch_temperature(self, input): + temperature = self.get_arguments(input) + try: + temperature = float(temperature) + if temperature < 0 or temperature > 1: + raise ValueError + self.temperature = temperature + logger.info(f"Switching temperature: {temperature}") + except ValueError: + logger.error(f"Invalid temperature: {temperature}") + + def show_config(self): + print("") + print(f"Current model: {self.model}") + print(f"Current temperature: {self.temperature}") + print(f"Current system message: {self.system_message}") + print("") + + def collect_user_input(self): + print("") + return input("YOU: ") def run_inference(self, content): + print("") + print("MISTRAL:") + print("") + self.messages.append(ChatMessage(role="user", content=content)) assistant_response = "" + logger.debug( + f"Running inference with model: {self.model}, temperature: {self.temperature}" + ) logger.debug(f"Sending messages: {self.messages}") - for chunk in self.client.chat_stream(model=self.model, messages=self.messages): + for chunk in self.client.chat_stream( + model=self.model, temperature=self.temperature, messages=self.messages + ): response = chunk.choices[0].delta.content if response is not None: print(response, end="", flush=True) @@ -67,24 +151,50 @@ def run_inference(self, content): print("", flush=True) if assistant_response: - self.messages.append(ChatMessage(role="assistant", content=assistant_response)) + self.messages.append( + ChatMessage(role="assistant", content=assistant_response) + ) logger.debug(f"Current messages: {self.messages}") - def start(self): + def get_command(self, input): + return input.split()[0].strip() + def get_arguments(self, input): + try: + return " ".join(input.split()[1:]) + except IndexError: + return "" + + def is_command(self, input): + return self.get_command(input) in COMMAND_LIST + + def execute_command(self, input): + command = self.get_command(input) + if command in ["/exit", "/quit"]: + self.exit() + elif command == "/help": + self.opening_instructions() + elif command == "/new": + self.new_chat() + elif command == "/model": + self.switch_model(input) + elif command == "/system": + self.switch_system_message(input) + elif command == "/temperature": + self.switch_temperature(input) + elif command == "/config": + self.show_config() + + def start(self): self.opening_instructions() self.new_chat() - while True: try: - print("") - content = input("YOU: ") - self.check_exit(content) - if not self.check_new_chat(content): - print("") - print("MISTRAL:") - print("") - self.run_inference(content) + input = self.collect_user_input() + if self.is_command(input): + self.execute_command(input) + else: + self.run_inference(input) except KeyboardInterrupt: self.exit() @@ -95,16 +205,34 @@ def exit(self): if __name__ == "__main__": - - parser = argparse.ArgumentParser(description="A simple chatbot using the Mistral API") - parser.add_argument("--api-key", default=os.environ.get("MISTRAL_API_KEY"), - help="Mistral API key. Defaults to environment variable MISTRAL_API_KEY") - parser.add_argument("-m", "--model", choices=MODEL_LIST, - default=DEFAULT_MODEL, - help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s") - parser.add_argument("-s", "--system-message", - help="Optional system message to prepend.") - parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging") + parser = argparse.ArgumentParser( + description="A simple chatbot using the Mistral API" + ) + parser.add_argument( + "--api-key", + default=os.environ.get("MISTRAL_API_KEY"), + help="Mistral API key. Defaults to environment variable MISTRAL_API_KEY", + ) + parser.add_argument( + "-m", + "--model", + choices=MODEL_LIST, + default=DEFAULT_MODEL, + help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s", + ) + parser.add_argument( + "-s", "--system-message", help="Optional system message to prepend." + ) + parser.add_argument( + "-t", + "--temperature", + type=float, + default=DEFAULT_TEMPERATURE, + help="Optional temperature for chat inference. Defaults to %(default)s", + ) + parser.add_argument( + "-d", "--debug", action="store_true", help="Enable debug logging" + ) args = parser.parse_args() @@ -119,7 +247,11 @@ def exit(self): ch.setFormatter(formatter) logger.addHandler(ch) - logger.debug(f"Starting chatbot with model: {args.model}") + logger.debug( + f"Starting chatbot with model: {args.model}, " + f"temperature: {args.temperature}, " + f"system message: {args.system_message}" + ) - bot = ChatBot(args.api_key, args.model, args.system_message) + bot = ChatBot(args.api_key, args.model, args.system_message, args.temperature) bot.start()