Skip to content

Commit

Permalink
docs: update mistral models in zoo
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Oct 4, 2024
1 parent 746b67f commit b5e4556
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions examples/4_engines_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,13 @@
from kani.engines.huggingface.llama2 import LlamaEngine
engine = LlamaEngine(model_id="meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) # log in with huggingface-cli

# ---- Mixtral-8x22B (Hugging Face) ----
# ---- Mistral Small/Large (Hugging Face) ----
from kani.engines.huggingface import HuggingEngine
from kani.prompts.impl.mistral import MISTRAL_V3_PIPELINE, MixtralFunctionCallingAdapter
model = HuggingEngine(
model_id="mistralai/Mixtral-8x22B-Instruct-v0.1",
prompt_pipeline=MISTRAL_V3_PIPELINE,
model_load_kwargs={"device_map": "auto", "torch_dtype": torch.bfloat16},
)

# to enable function calling:
# NOTE: as of May 2024, the huggingface implementation of Mixtral-8x22B function calling is broken:
# https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/discussions/27
# this comment will be removed once it is fixed - until then, you should use another backend
engine = MixtralFunctionCallingAdapter(model)
from kani.prompts.impl.mistral import MISTRAL_V3_PIPELINE, MistralFunctionCallingAdapter
# small (22B): mistralai/Mistral-Small-Instruct-2409
# large (123B): mistralai/Mistral-Large-Instruct-2407
model = HuggingEngine(model_id="mistralai/Mistral-Small-Instruct-2409", prompt_pipeline=MISTRAL_V3_PIPELINE)
engine = MistralFunctionCallingAdapter(model)

# ---- Mistral-7B (Hugging Face) ----
# v0.3 (supports function calling):
Expand All @@ -71,11 +64,8 @@
# Also use the MISTRAL_V1_PIPELINE for Mixtral-8x7B (i.e. mistralai/Mixtral-8x7B-Instruct-v0.1).

# ---- Command R (Hugging Face) ----
import torch
from kani.engines.huggingface.cohere import CommandREngine
engine = CommandREngine(
model_id="CohereForAI/c4ai-command-r-v01", model_load_kwargs={"device_map": "auto", "torch_dtype": torch.float16}
)
engine = CommandREngine(model_id="CohereForAI/c4ai-command-r-v01")

# ---- Gemma (Hugging Face) ----
from kani.engines.huggingface import HuggingEngine
Expand Down

0 comments on commit b5e4556

Please sign in to comment.