-
Notifications
You must be signed in to change notification settings - Fork 16.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
beefd67
commit 0fee32a
Showing
3 changed files
with
291 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "raw", | ||
"id": "afaf8039", | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "raw" | ||
} | ||
Check failure on line 9 in docs/docs/integrations/chat/goodfire.ipynb GitHub Actions / cd . / make lint #3.9Ruff (I001)
|
||
}, | ||
"source": [ | ||
"---\n", | ||
"sidebar_label: Goodfire\n", | ||
"---" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e49f1e0d", | ||
"metadata": {}, | ||
"source": [ | ||
"# Goodfire\n", | ||
"\n", | ||
"Goodfire is an AI inference platform to run certain Llama models with SAE feature steering. See the [Goodfire docs](https://docs.goodfire.ai/) for more information." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import getpass\n", | ||
"import os\n", | ||
"\n", | ||
"if \"GOODFIRE_API_KEY\" not in os.environ:\n", | ||
" os.environ[\"GOODFIRE_API_KEY\"] = getpass.getpass(\"Enter your Goodfire API key: \")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a38cde65-254d-4219-a441-068766c0d4b5", | ||
"metadata": {}, | ||
"source": [ | ||
"## Instantiation\n", | ||
"\n", | ||
"Now we can instantiate our model object and generate chat completions:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "ValueError", | ||
"evalue": "model must be a Goodfire variant, got <class 'str'>", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[2], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m enthusiasm_variant \u001b[38;5;241m=\u001b[39m goodfire\u001b[38;5;241m.\u001b[39mVariant(MODEL_NAME)\n\u001b[1;32m 13\u001b[0m enthusiasm_variant\u001b[38;5;241m.\u001b[39mset(enthusiasm_feature, \u001b[38;5;241m0.3\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m llm \u001b[38;5;241m=\u001b[39m \u001b[43mChatGoodfire\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mMODEL_NAME\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43mvariant\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menthusiasm_variant\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.6\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# other params...\u001b[39;49;00m\n\u001b[1;32m 21\u001b[0m \u001b[43m)\u001b[49m\n", | ||
"File \u001b[0;32m~/langchain/libs/community/langchain_community/chat_models/goodfire.py:80\u001b[0m, in \u001b[0;36mChatGoodfire.__init__\u001b[0;34m(self, model, goodfire_api_key, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not import goodfire python package. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease install it with `pip install goodfire`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 77\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, goodfire\u001b[38;5;241m.\u001b[39mVariant):\n\u001b[0;32m---> 80\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel must be a Goodfire variant, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(model)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 82\u001b[0m \u001b[38;5;66;03m# Include model in kwargs for parent initialization\u001b[39;00m\n\u001b[1;32m 83\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m model\n", | ||
"\u001b[0;31mValueError\u001b[0m: model must be a Goodfire variant, got <class 'str'>" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from langchain_community.chat_models import ChatGoodfire\n", | ||
"import goodfire\n", | ||
"\n", | ||
"MODEL_NAME = \"meta-llama/Llama-3.3-70B-Instruct\"\n", | ||
"\n", | ||
"goodfire_client = goodfire.Client(api_key=os.environ[\"GOODFIRE_API_KEY\"])\n", | ||
"\n", | ||
"base_variant = goodfire.Variant(MODEL_NAME)\n", | ||
"\n", | ||
"enthusiasm_feature = goodfire_client.features.lookup([55543], base_variant)[55543]\n", | ||
"\n", | ||
"enthusiasm_variant = goodfire.Variant(MODEL_NAME)\n", | ||
"enthusiasm_variant.set(enthusiasm_feature, 0.3)\n", | ||
"\n", | ||
"llm = ChatGoodfire(\n", | ||
" model=enthusiasm_variant,\n", | ||
" temperature=0.6,\n", | ||
" seed=42,\n", | ||
" # other params...\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2b4f3e15", | ||
"metadata": {}, | ||
"source": [ | ||
"## Invocation" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"id": "62e0dbc3", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"AIMessage(content='J\\'ADORE LA PROGRAMMATION! \\n\\n(or in a more casual tone: J\\'ADORE LE CODAGE!)\\n\\nNote: \"J\\'adore\" is a stronger way to say \"I love\" in French, it\\'s more like \"I\\'m crazy about\" or \"I\\'m absolutely passionate about\". If you want to use a more literal translation, you can say: \"J\\'aime la programmation\" which means \"I like programming\".', additional_kwargs={}, response_metadata={}, id='run-d91dd50b-1d6a-4c04-a78c-b1b922c1fc92-0')" | ||
] | ||
}, | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"messages = [\n", | ||
" (\n", | ||
" \"system\",\n", | ||
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n", | ||
" ),\n", | ||
" (\"human\", \"I love programming.\"),\n", | ||
"]\n", | ||
"ai_msg = llm.invoke(messages)\n", | ||
"ai_msg" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "39f7d928", | ||
"metadata": {}, | ||
"source": [ | ||
"Note: The variant can be overridden after instantiation by providing a new variant to the `model` parameter." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"id": "ceac2cb6", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"AIMessage(content=\"J'adore la programmation.\", additional_kwargs={}, response_metadata={}, id='run-b646d8ed-74c3-40a2-8530-7f094060bf23-0')" | ||
] | ||
}, | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"ai_msg = llm.invoke(messages, model=base_variant)\n", | ||
"ai_msg" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"id": "d86145b3-bfef-46e8-b227-4dda5c9c2705", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"J'ADORE LA PROGRAMMATION! \n", | ||
"\n", | ||
"(or in a more casual tone: J'ADORE LE CODAGE!)\n", | ||
"\n", | ||
"Note: \"J'adore\" is a stronger way to say \"I love\" in French, it's more like \"I'm crazy about\" or \"I'm absolutely passionate about\". If you want to use a more literal translation, you can say: \"J'aime la programmation\" which means \"I like programming\".\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(ai_msg.content)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8", | ||
"metadata": {}, | ||
"source": [ | ||
"## Chaining\n", | ||
"\n", | ||
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"AIMessage(content='Ich liebe das Programmieren.', additional_kwargs={}, response_metadata={}, id='run-f77167ac-e9a8-4fc0-9e43-5a4800290324-0')" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"from langchain_core.prompts import ChatPromptTemplate\n", | ||
"\n", | ||
"prompt = ChatPromptTemplate.from_messages(\n", | ||
" [\n", | ||
" (\n", | ||
" \"system\",\n", | ||
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n", | ||
" ),\n", | ||
" (\"human\", \"{input}\"),\n", | ||
" ]\n", | ||
")\n", | ||
"\n", | ||
"chain = prompt | llm\n", | ||
"chain.invoke(\n", | ||
" {\n", | ||
" \"input_language\": \"English\",\n", | ||
" \"output_language\": \"German\",\n", | ||
" \"input\": \"I love programming.\",\n", | ||
" }\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", | ||
"metadata": {}, | ||
"source": [ | ||
"## API reference\n", | ||
"\n", | ||
"For detailed documentation of all ChatGoodfire features and configurations head to the API reference: https://python.langchain.com/api_reference/goodfire/chat_models/langchain_goodfire.chat_models.ChatGoodfire.html\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters