Skip to content

Commit

Permalink
Merge pull request #30 from thehunmonkgroup/chatbot-example-improvements
Browse files Browse the repository at this point in the history
Chatbot example improvements
  • Loading branch information
Bam4d authored Dec 20, 2023
2 parents 8cfcaa1 + f970fe0 commit ef64c3e
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions examples/chatbot_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,45 @@
DEFAULT_MODEL = "mistral-small"
DEFAULT_TEMPERATURE = 0.7
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
COMMAND_LIST = [
"/new",
"/help",
"/model",
"/system",
"/temperature",
"/config",
"/quit",
"/exit",
]
# A dictionary of all commands and their arguments, used for tab completion.
COMMAND_LIST = {
"/new": {},
"/help": {},
"/model": {model: {} for model in MODEL_LIST}, # Nested completions for models
"/system": {},
"/temperature": {},
"/config": {},
"/quit": {},
"/exit": {},
}

LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"

logger = logging.getLogger("chatbot")


def find_completions(command_dict, parts):
if not parts:
return command_dict.keys()
if parts[0] in command_dict:
return find_completions(command_dict[parts[0]], parts[1:])
else:
return [cmd for cmd in command_dict if cmd.startswith(parts[0])]


def completer(text, state):
buffer = readline.get_line_buffer()
if not buffer.startswith(text):
return None
line_parts = buffer.lstrip().split(" ")
options = find_completions(COMMAND_LIST, line_parts[:-1])

options = [command for command in COMMAND_LIST if command.startswith(text)]
if state < len(options):
return options[state]
else:
try:
return [option for option in options if option.startswith(line_parts[-1])][state]
except IndexError:
return None


readline.set_completer(completer)
# Remove all delimiters to ensure completion only at the beginning of the line
readline.set_completer_delims("")
readline.set_completer_delims(" ")
# Enable tab completion
readline.parse_and_bind("tab: complete")

Expand All @@ -58,6 +66,8 @@ class ChatBot:
def __init__(
self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE
):
if not api_key:
raise ValueError("An API key must be provided to use the Mistral API.")
self.client = MistralClient(api_key=api_key)
self.model = model
self.temperature = temperature
Expand Down Expand Up @@ -253,5 +263,9 @@ def exit(self):
f"system message: {args.system_message}"
)

bot = ChatBot(args.api_key, args.model, args.system_message, args.temperature)
bot.start()
try:
bot = ChatBot(args.api_key, args.model, args.system_message, args.temperature)
bot.start()
except Exception as e:
logger.error(e)
sys.exit(1)

0 comments on commit ef64c3e

Please sign in to comment.