Skip to content

Commit

Permalink
Merge pull request #675 from TransformerLensOrg/dev
Browse files Browse the repository at this point in the history
Release 2.2.2
  • Loading branch information
bryce13950 authored Jul 12, 2024
2 parents 67ed0d6 + b96c1dd commit cec7ed3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
20 changes: 20 additions & 0 deletions tests/integration/test_hooks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import pytest
import torch

Expand Down Expand Up @@ -73,6 +75,24 @@ def test_context_manager_run_with_cache():
model.remove_all_hook_fns(including_permanent=True)


def test_backward_hook_runs_successfully():
c = Counter()

def skip_grad(output_grad: torch.Tensor, hook: Any):
c.inc()
return (output_grad,)

with model.hooks(bwd_hooks=[(embed, skip_grad)]):
assert len(model.hook_dict["hook_embed"].bwd_hooks) == 1
out = model(prompt)
assert c.count == 0
out.sum().backward() # this should run the hook
assert len(model.hook_dict["hook_embed"].bwd_hooks) == 1
assert len(model.hook_dict["hook_embed"].bwd_hooks) == 0
assert c.count == 1
model.remove_all_hook_fns(including_permanent=True)


def test_hook_context_manager_with_permanent_hook():
c = Counter()
model.add_perma_hook(embed, c.inc)
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/hook_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def full_hook(
_internal_hooks = self._forward_hooks
visible_hooks = self.fwd_hooks
elif dir == "bwd":
pt_handle = self.register_backward_hook(full_hook)
pt_handle = self.register_full_backward_hook(full_hook)
_internal_hooks = self._backward_hooks
visible_hooks = self.bwd_hooks
else:
Expand Down

0 comments on commit cec7ed3

Please sign in to comment.