Skip to content

Commit

Permalink
Simplify constructor and add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
keenanpepper committed Jan 26, 2025
1 parent beefd67 commit 0fee32a
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 18 deletions.
266 changes: 266 additions & 0 deletions docs/docs/integrations/chat/goodfire.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
{
"cells": [
{
"cell_type": "raw",
"id": "afaf8039",
"metadata": {
"vscode": {
"languageId": "raw"
}

Check failure on line 9 in docs/docs/integrations/chat/goodfire.ipynb

View workflow job for this annotation

GitHub Actions / cd . / make lint #3.9

Ruff (I001)

docs/docs/integrations/chat/goodfire.ipynb:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 9 in docs/docs/integrations/chat/goodfire.ipynb

View workflow job for this annotation

GitHub Actions / cd . / make lint #3.12

Ruff (I001)

docs/docs/integrations/chat/goodfire.ipynb:1:1: I001 Import block is un-sorted or un-formatted
},
"source": [
"---\n",
"sidebar_label: Goodfire\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "e49f1e0d",
"metadata": {},
"source": [
"# Goodfire\n",
"\n",
"Goodfire is an AI inference platform to run certain Llama models with SAE feature steering. See the [Goodfire docs](https://docs.goodfire.ai/) for more information."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"if \"GOODFIRE_API_KEY\" not in os.environ:\n",
" os.environ[\"GOODFIRE_API_KEY\"] = getpass.getpass(\"Enter your Goodfire API key: \")"
]
},
{
"cell_type": "markdown",
"id": "a38cde65-254d-4219-a441-068766c0d4b5",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "model must be a Goodfire variant, got <class 'str'>",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m enthusiasm_variant \u001b[38;5;241m=\u001b[39m goodfire\u001b[38;5;241m.\u001b[39mVariant(MODEL_NAME)\n\u001b[1;32m 13\u001b[0m enthusiasm_variant\u001b[38;5;241m.\u001b[39mset(enthusiasm_feature, \u001b[38;5;241m0.3\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m llm \u001b[38;5;241m=\u001b[39m \u001b[43mChatGoodfire\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mMODEL_NAME\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43mvariant\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menthusiasm_variant\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.6\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# other params...\u001b[39;49;00m\n\u001b[1;32m 21\u001b[0m \u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/langchain/libs/community/langchain_community/chat_models/goodfire.py:80\u001b[0m, in \u001b[0;36mChatGoodfire.__init__\u001b[0;34m(self, model, goodfire_api_key, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not import goodfire python package. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease install it with `pip install goodfire`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 77\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, goodfire\u001b[38;5;241m.\u001b[39mVariant):\n\u001b[0;32m---> 80\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel must be a Goodfire variant, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(model)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 82\u001b[0m \u001b[38;5;66;03m# Include model in kwargs for parent initialization\u001b[39;00m\n\u001b[1;32m 83\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m model\n",
"\u001b[0;31mValueError\u001b[0m: model must be a Goodfire variant, got <class 'str'>"
]
}
],
"source": [
"from langchain_community.chat_models import ChatGoodfire\n",
"import goodfire\n",
"\n",
"MODEL_NAME = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
"\n",
"goodfire_client = goodfire.Client(api_key=os.environ[\"GOODFIRE_API_KEY\"])\n",
"\n",
"base_variant = goodfire.Variant(MODEL_NAME)\n",
"\n",
"enthusiasm_feature = goodfire_client.features.lookup([55543], base_variant)[55543]\n",
"\n",
"enthusiasm_variant = goodfire.Variant(MODEL_NAME)\n",
"enthusiasm_variant.set(enthusiasm_feature, 0.3)\n",
"\n",
"llm = ChatGoodfire(\n",
" model=enthusiasm_variant,\n",
" temperature=0.6,\n",
" seed=42,\n",
" # other params...\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2b4f3e15",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "62e0dbc3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='J\\'ADORE LA PROGRAMMATION! \\n\\n(or in a more casual tone: J\\'ADORE LE CODAGE!)\\n\\nNote: \"J\\'adore\" is a stronger way to say \"I love\" in French, it\\'s more like \"I\\'m crazy about\" or \"I\\'m absolutely passionate about\". If you want to use a more literal translation, you can say: \"J\\'aime la programmation\" which means \"I like programming\".', additional_kwargs={}, response_metadata={}, id='run-d91dd50b-1d6a-4c04-a78c-b1b922c1fc92-0')"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
},
{
"cell_type": "markdown",
"id": "39f7d928",
"metadata": {},
"source": [
"Note: The variant can be overridden after instantiation by providing a new variant to the `model` parameter."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "ceac2cb6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"J'adore la programmation.\", additional_kwargs={}, response_metadata={}, id='run-b646d8ed-74c3-40a2-8530-7f094060bf23-0')"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ai_msg = llm.invoke(messages, model=base_variant)\n",
"ai_msg"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d86145b3-bfef-46e8-b227-4dda5c9c2705",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"J'ADORE LA PROGRAMMATION! \n",
"\n",
"(or in a more casual tone: J'ADORE LE CODAGE!)\n",
"\n",
"Note: \"J'adore\" is a stronger way to say \"I love\" in French, it's more like \"I'm crazy about\" or \"I'm absolutely passionate about\". If you want to use a more literal translation, you can say: \"J'aime la programmation\" which means \"I like programming\".\n"
]
}
],
"source": [
"print(ai_msg.content)"
]
},
{
"cell_type": "markdown",
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe das Programmieren.', additional_kwargs={}, response_metadata={}, id='run-f77167ac-e9a8-4fc0-9e43-5a4800290324-0')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatGoodfire features and configurations head to the API reference: https://python.langchain.com/api_reference/goodfire/chat_models/langchain_goodfire.chat_models.ChatGoodfire.html\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
21 changes: 9 additions & 12 deletions libs/community/langchain_community/chat_models/goodfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ChatGoodfire(BaseChatModel):
goodfire_api_key: SecretStr = Field(default=SecretStr(""))
sync_client: Any = Field(default=None)
async_client: Any = Field(default=None)
variant: Any # Changed type hint since we can't import goodfire at module level
model: Any # Changed type hint since we can't import goodfire at module level

@property
def _llm_type(self) -> str:
Expand All @@ -57,19 +57,16 @@ def lc_secrets(self) -> Dict[str, str]:

def __init__(
self,
model: str, # Changed from SUPPORTED_MODELS since we can't import it
model: Any,
goodfire_api_key: Optional[str] = None,
variant: Optional[Any] = None,
**kwargs: Any,
):
"""Initialize the Goodfire chat model.
Args:
model: The model to use, must be one of the supported models.
model: The Goodfire variant to use.
goodfire_api_key: The API key to use. If None, will look for
GOODFIRE_API_KEY env var.
variant: Optional variant to use. If not provided, will be created
from the model parameter.
"""
try:
import goodfire
Expand All @@ -79,11 +76,11 @@ def __init__(
"Please install it with `pip install goodfire`."
) from e

# Create variant first
variant_instance = variant or goodfire.Variant(model)
if not isinstance(model, goodfire.Variant):
raise ValueError(f"model must be a Goodfire variant, got {type(model)}")

# Include variant in kwargs for parent initialization
kwargs["variant"] = variant_instance
# Include model in kwargs for parent initialization
kwargs["model"] = model

# Initialize parent class
super().__init__(**kwargs)
Expand Down Expand Up @@ -136,7 +133,7 @@ def _generate(
if "model" in kwargs:
model = kwargs.pop("model")
else:
model = self.variant
model = self.model

goodfire_response = self.sync_client.chat.completions.create(
messages=format_for_goodfire(messages),
Expand Down Expand Up @@ -167,7 +164,7 @@ async def _agenerate(
if "model" in kwargs:
model = kwargs.pop("model")
else:
model = self.variant
model = self.model

goodfire_response = await self.async_client.chat.completions.create(
messages=format_for_goodfire(messages),
Expand Down
22 changes: 16 additions & 6 deletions libs/community/tests/unit_tests/chat_models/test_goodfire.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test Goodfire Chat API wrapper."""

import os
from typing import List
from typing import Any, List

import pytest
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
Expand All @@ -14,7 +14,16 @@

os.environ["GOODFIRE_API_KEY"] = "test_key"

VALID_MODEL: str = "meta-llama/Llama-3.3-70B-Instruct"

def get_valid_variant() -> Any:
try:
import goodfire
except ImportError as e:
raise ImportError(
"Could not import goodfire python package. "
"Please install it with `pip install goodfire`."
) from e
return goodfire.Variant("meta-llama/Llama-3.3-70B-Instruct")


@pytest.mark.requires("goodfire")
Expand All @@ -26,9 +35,10 @@ def test_goodfire_model_param() -> None:
"Could not import goodfire python package. "
"Please install it with `pip install goodfire`."
) from e
llm = ChatGoodfire(model=VALID_MODEL)
assert isinstance(llm.variant, goodfire.Variant)
assert llm.variant.base_model == VALID_MODEL
base_variant = get_valid_variant()
llm = ChatGoodfire(model=base_variant)
assert isinstance(llm.model, goodfire.Variant)
assert llm.model.base_model == base_variant.base_model


@pytest.mark.requires("goodfire")
Expand All @@ -41,7 +51,7 @@ def test_goodfire_initialization() -> None:
"Could not import goodfire python package. "
"Please install it with `pip install goodfire`."
) from e
llm = ChatGoodfire(model=VALID_MODEL, goodfire_api_key="test_key")
llm = ChatGoodfire(model=get_valid_variant(), goodfire_api_key="test_key")
assert llm.goodfire_api_key.get_secret_value() == "test_key"
assert isinstance(llm.sync_client, goodfire.Client)
assert isinstance(llm.async_client, goodfire.AsyncClient)
Expand Down

0 comments on commit 0fee32a

Please sign in to comment.