Skip to content

Commit

Permalink
Add Gemma 2 (3x sizes) and Llama 3.1 405B
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Sep 9, 2024
1 parent fbf155f commit 7d6466c
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 1 deletion.
85 changes: 85 additions & 0 deletions libs/core/kiln_ai/adapters/ml_model_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,21 @@ class ModelFamily(str, Enum):
llama = "llama"
phi = "phi"
mistral = "mistral"
gemma = "gemma"


class ModelName(str, Enum):
llama_3_1_8b = "llama_3_1_8b"
llama_3_1_70b = "llama_3_1_70b"
llama_3_1_405b = "llama_3_1_405b"
gpt_4o_mini = "gpt_4o_mini"
gpt_4o = "gpt_4o"
phi_3_5 = "phi_3_5"
mistral_large = "mistral_large"
mistral_nemo = "mistral_nemo"
gemma_2_3b = "gemma_2_3b"
gemma_2_9b = "gemma_2_9b"
gemma_2_27b = "gemma_2_27b"


class KilnModelProvider(BaseModel):
Expand Down Expand Up @@ -137,6 +142,34 @@ class KilnModel(BaseModel):
# ),
],
),
# Llama 3.1 405b
KilnModel(
family=ModelFamily.llama,
name=ModelName.llama_3_1_405b,
providers=[
# TODO: bring back when groq does: https://console.groq.com/docs/models
# KilnModelProvider(
# name=ModelProviderName.groq,
# provider_options={"model": "llama-3.1-405b-instruct-v1:0"},
# ),
KilnModelProvider(
name=ModelProviderName.amazon_bedrock,
provider_options={
"model": "meta.llama3-1-405b-instruct-v1:0",
"region_name": "us-west-2", # Llama 3.1 only in west-2
},
),
# TODO: enable once tests update to check if model is available
# KilnModelProvider(
# name=ModelProviderName.ollama,
# provider_options={"model": "llama3.1:405b"},
# ),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "meta-llama/llama-3.1-405b-instruct"},
),
],
),
# Mistral Nemo
KilnModel(
family=ModelFamily.mistral,
Expand Down Expand Up @@ -187,6 +220,58 @@ class KilnModel(BaseModel):
),
],
),
# Gemma 2 1.6b
KilnModel(
family=ModelFamily.gemma,
name=ModelName.gemma_2_3b,
supports_structured_output=False,
providers=[
KilnModelProvider(
name=ModelProviderName.ollama,
provider_options={
"model": "gemma2:2b",
},
),
],
),
# Gemma 2 9b
KilnModel(
family=ModelFamily.gemma,
name=ModelName.gemma_2_9b,
supports_structured_output=False,
providers=[
# TODO: enable once tests update to check if model is available
# KilnModelProvider(
# name=ModelProviderName.ollama,
# provider_options={
# "model": "gemma2:9b",
# },
# ),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "google/gemma-2-9b-it"},
),
],
),
# Gemma 2 27b
KilnModel(
family=ModelFamily.gemma,
name=ModelName.gemma_2_27b,
supports_structured_output=False,
providers=[
# TODO: enable once tests update to check if model is available
# KilnModelProvider(
# name=ModelProviderName.ollama,
# provider_options={
# "model": "gemma2:27b",
# },
# ),
KilnModelProvider(
name=ModelProviderName.openrouter,
provider_options={"model": "google/gemma-2-27b-it"},
),
],
),
]


Expand Down
17 changes: 16 additions & 1 deletion libs/core/kiln_ai/adapters/test_prompt_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ async def test_ollama_phi(tmp_path):
await run_simple_test(tmp_path, "phi_3_5", "ollama")


@pytest.mark.ollama
async def test_ollama_gemma(tmp_path):
# Check if Ollama API is running
if not await ollama_online():
pytest.skip("Ollama API not running. Expect it running on localhost:11434")

await run_simple_test(tmp_path, "gemma_2_3b", "ollama")


@pytest.mark.ollama
async def test_autoselect_provider(tmp_path):
# Check if Ollama API is running
Expand Down Expand Up @@ -109,7 +118,13 @@ def build_test_task(tmp_path: Path):
instruction="If the problem has anything other than addition, subtraction, multiplication, division, and brackets, you will not answer it. Reply instead with 'I'm just a basic calculator, I don't know how to do that'.",
)
r2.save_to_file()
assert len(task.requirements()) == 2
r3 = models.TaskRequirement(
parent=task,
name="Answer format",
instruction="The answer can contain any content about your reasoning, but at the end it should include the final answer in numerals in square brackets. For example if the answer is one hundred, the end of your response should be [100].",
)
r3.save_to_file()
assert len(task.requirements()) == 3
return task


Expand Down

0 comments on commit 7d6466c

Please sign in to comment.