Skip to content

Commit

Permalink
Bert demo ci (#556)
Browse files Browse the repository at this point in the history
* revised demo testing to check all demos

* separated demos

* changed demo test order

* rearranged test order

* updated attribution patching to run differnt code in github

* rearranged tests

* updated header

* updated grokking demo

* updated bert for testing

* updated bert demo

* ran cells

* removed github check

* removed cells to skip

* ignored output of loading cells

* removed other tests
  • Loading branch information
bryce13950 authored Apr 26, 2024
1 parent 6cd64d5 commit 1139caf
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 46 deletions.
2 changes: 1 addition & 1 deletion demos/Attribution_Patching_Demo.ipynb

Large diffs are not rendered by default.

77 changes: 53 additions & 24 deletions demos/BERT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,45 +29,70 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running as a Jupyter notebook - intended for development only!\n"
"Running as a Jupyter notebook - intended for development only!\n",
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:26: DeprecationWarning:\n",
"\n",
"`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
"\n",
"/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_39188/4022418010.py:27: DeprecationWarning:\n",
"\n",
"`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
"\n"
]
}
],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"import os\n",
"\n",
"# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
"DEVELOPMENT_MODE = False\n",
"IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
"try:\n",
" import google.colab\n",
"\n",
" IN_COLAB = True\n",
" print(\"Running as a Colab notebook\")\n",
" %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n",
" %pip install circuitsvis\n",
" \n",
"\n",
" # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
" # # Install another version of node that makes PySvelte work way faster\n",
" # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
" # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
"except:\n",
" IN_COLAB = False\n",
"\n",
"if not IN_GITHUB and not IN_COLAB:\n",
" print(\"Running as a Jupyter notebook - intended for development only!\")\n",
" from IPython import get_ipython\n",
"\n",
" ipython = get_ipython()\n",
" # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
" ipython.magic(\"load_ext autoreload\")\n",
" ipython.magic(\"autoreload 2\")"
" ipython.magic(\"autoreload 2\")\n",
"\n",
"if IN_COLAB:\n",
" %pip install transformer_lens\n",
" %pip install circuitsvis"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -81,6 +106,7 @@
"source": [
"# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
"import plotly.io as pio\n",
"\n",
"if IN_COLAB or not DEVELOPMENT_MODE:\n",
" pio.renderers.default = \"colab\"\n",
"else:\n",
Expand All @@ -90,40 +116,41 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div id=\"circuits-vis-3dab7238-6dd6\" style=\"margin: 15px 0;\"/>\n",
"<div id=\"circuits-vis-8c91db10-74f4\" style=\"margin: 15px 0;\"/>\n",
" <script crossorigin type=\"module\">\n",
" import { render, Hello } from \"https://unpkg.com/circuitsvis@1.39.1/dist/cdn/esm.js\";\n",
" import { render, Hello } from \"https://unpkg.com/circuitsvis@1.43.2/dist/cdn/esm.js\";\n",
" render(\n",
" \"circuits-vis-3dab7238-6dd6\",\n",
" \"circuits-vis-8c91db10-74f4\",\n",
" Hello,\n",
" {\"name\": \"Neel\"}\n",
" )\n",
" </script>"
],
"text/plain": [
"<circuitsvis.utils.render.RenderedHTML at 0x1090aa4d0>"
"<circuitsvis.utils.render.RenderedHTML at 0x13a9760d0>"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import circuitsvis as cv\n",
"\n",
"# Testing that the library works\n",
"cv.examples.hello(\"Neel\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -137,16 +164,16 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.autograd.grad_mode.set_grad_enabled at 0x104e56b60>"
"<torch.autograd.grad_mode.set_grad_enabled at 0x2a285a790>"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -167,26 +194,28 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:HookedEncoder is still in beta. Please be aware that model preprocessing (e.g. LayerNorm folding) is not yet supported and backward compatibility is not guaranteed.\n"
"WARNING:root:Support for BERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n",
"If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Moving model to device: cpu\n",
"Moving model to device: mps\n",
"Loaded pretrained model bert-base-cased into HookedTransformer\n"
]
}
],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
]
Expand All @@ -201,7 +230,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -213,7 +242,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -230,7 +259,7 @@
"prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n",
"\n",
"print(f\"Prompt: {prompt}\")\n",
"print(f\"Prediction: \\\"{prediction}\\\"\")"
"print(f'Prediction: \"{prediction}\"')"
]
},
{
Expand Down Expand Up @@ -258,7 +287,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.11.8"
},
"orig_nbformat": 4
},
Expand Down
44 changes: 26 additions & 18 deletions demos/Grokking_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@
],
"source": [
"# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
"import os\n",
"\n",
"DEVELOPMENT_MODE = True\n",
"IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
"try:\n",
" import google.colab\n",
" IN_COLAB = True\n",
" print(\"Running as a Colab notebook\")\n",
" %pip install transformer-lens\n",
" %pip install circuitsvis\n",
" \n",
" # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
" # # Install another version of node that makes PySvelte work way faster\n",
Expand All @@ -73,7 +74,11 @@
" ipython = get_ipython()\n",
" # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
" ipython.magic(\"load_ext autoreload\")\n",
" ipython.magic(\"autoreload 2\")"
" ipython.magic(\"autoreload 2\")\n",
" \n",
"if IN_COLAB or IN_GITHUB:\n",
" %pip install transformer_lens\n",
" %pip install circuitsvis"
]
},
{
Expand Down Expand Up @@ -154,7 +159,10 @@
" HookedRootModule,\n",
" HookPoint,\n",
") # Hooking utilities\n",
"from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"
"from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n",
"\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
]
},
{
Expand Down Expand Up @@ -281,7 +289,7 @@
}
],
"source": [
"dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).cuda()\n",
"dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)\n",
"print(dataset[:5])\n",
"print(dataset.shape)"
]
Expand Down Expand Up @@ -386,7 +394,7 @@
" d_vocab_out=p,\n",
" n_ctx=3,\n",
" init_weights=True,\n",
" device=\"cuda\",\n",
" device=device,\n",
" seed = 999,\n",
")"
]
Expand Down Expand Up @@ -1645,7 +1653,7 @@
" fourier_basis_names.append(f\"Sin {freq}\")\n",
" fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))\n",
" fourier_basis_names.append(f\"Cos {freq}\")\n",
"fourier_basis = torch.stack(fourier_basis, dim=0).cuda()\n",
"fourier_basis = torch.stack(fourier_basis, dim=0).to(device)\n",
"fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)\n",
"imshow(fourier_basis, xaxis=\"Input\", yaxis=\"Component\", y=fourier_basis_names)"
]
Expand Down Expand Up @@ -2394,7 +2402,7 @@
}
],
"source": [
"neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()\n",
"neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).to(device)\n",
"for freq in range(0, p//2):\n",
" for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n",
" for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n",
Expand Down Expand Up @@ -2993,7 +3001,7 @@
" a = torch.arange(p)[:, None, None]\n",
" b = torch.arange(p)[None, :, None]\n",
" c = torch.arange(p)[None, None, :]\n",
" cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n",
" cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)\n",
" cube_predicted_logits /= cube_predicted_logits.norm()\n",
" coses[freq] = cube_predicted_logits"
]
Expand Down Expand Up @@ -3124,7 +3132,7 @@
" a = torch.arange(p)[:, None, None]\n",
" b = torch.arange(p)[None, :, None]\n",
" c = torch.arange(p)[None, None, :]\n",
" cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n",
" cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)\n",
" cube_predicted_logits /= cube_predicted_logits.norm()\n",
" cos_cube.append(cube_predicted_logits)\n",
"cos_cube = torch.stack(cos_cube, dim=0)\n",
Expand Down Expand Up @@ -3486,11 +3494,11 @@
"a = torch.arange(p)[:, None]\n",
"b = torch.arange(p)[None, :]\n",
"for freq in key_freqs:\n",
" cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
" cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" cos_apb_vec /= cos_apb_vec.norm()\n",
" cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n",
" sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
" sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" sin_apb_vec /= sin_apb_vec.norm()\n",
" sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n",
Expand Down Expand Up @@ -3555,11 +3563,11 @@
" a = torch.arange(p)[:, None]\n",
" b = torch.arange(p)[None, :]\n",
" for freq in key_freqs:\n",
" cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
" cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" cos_apb_vec /= cos_apb_vec.norm()\n",
" cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n",
" sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
" sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" sin_apb_vec /= sin_apb_vec.norm()\n",
" sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n",
Expand Down Expand Up @@ -3718,11 +3726,11 @@
"a = torch.arange(p)[:, None]\n",
"b = torch.arange(p)[None, :]\n",
"for freq in key_freqs:\n",
" cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
" cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" cos_apb_vec /= cos_apb_vec.norm()\n",
" cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n",
" sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
" sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" sin_apb_vec /= sin_apb_vec.norm()\n",
" sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n",
Expand Down Expand Up @@ -3765,11 +3773,11 @@
" a = torch.arange(p)[:, None]\n",
" b = torch.arange(p)[None, :]\n",
" for freq in key_freqs:\n",
" cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
" cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" cos_apb_vec /= cos_apb_vec.norm()\n",
" cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n",
" sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n",
" sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n",
" sin_apb_vec /= sin_apb_vec.norm()\n",
" sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n",
" approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n",
Expand Down
3 changes: 1 addition & 2 deletions demos/Main_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@
" ip.extension_manager.load('autoreload')\n",
" %autoreload 2\n",
" \n",
"IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
"IN_GITHUB = True\n"
"IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ docstring-test:
poetry run pytest transformer_lens/

notebook-test:
poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb
poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb
poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb
poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb

test:
make unit-test
Expand Down

0 comments on commit 1139caf

Please sign in to comment.