Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Key-Value Caching #383

Closed
wants to merge 3 commits into from
Closed

Conversation

UFO-101
Copy link
Contributor

@UFO-101 UFO-101 commented Sep 18, 2023

Description

Commit 1:

  • Support passing in multiple tokens when using past_kv_cache.
  • Add tests for past_kv_cache.
  • Add documentation for past_kv_cache.
  • Fix type hints for some components that assume left_attention_mask has same number of tokens as input. This was previously unnoticed because there were no tests that covered past_kv_cache.

Commit 2:

  • Support freezing key-value caches.

Commit 3:

  • Integrate past_left_attention_mask into HookedTransformerKeyValueCache so that it doesn't need to managed manually. Remove from HookedTransformer.forward().

Motivation for allowing multiple tokens to run with key value cache

In ACDC we run the same prompt many times. Patching only affects token positions after the point where the clean and corrupt prompts differ. We want to run the first part of the prompt that is identical between clean and corrupt, freeze the cache, then pass in only the tokens after the point of divergence for our patched runs.

Breaking change

Commit 3 removes past_left_attention_mask from HookedTransformer.forward() because the left_attention_mask of previous inputs is stored automatically by HookedTransformerKeyValueCache. I think this won't break many people's code as this argument was only added 2 weeks ago in #344. And the fix should be quite trivial as it only requires deleting this input. (See changes to tests in commit 3 for an example).

Overall I think the benefits are worth the cost as this makes caching easier to do and generally reduces complexity. I can't think of any case where someone would want to pass in a past_left_attention_mask that doesn't match past_kv_cache.

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

- Add tests for past_kv_cache.
- Add documentation for past_kv_cache.
- Fix type hints for some components that assume left_attention_mask has same number of tokens as input. This was previously unnoticed because there were no tests that covered past_kv_cache.
@UFO-101 UFO-101 changed the title Support running multiple tokens with KV cache. Support freezing KV cache. Improve Key-Value Caching Sep 18, 2023
…e so that it doesn't need to managed manually. Remove from HookedTransformer forward().
@UFO-101
Copy link
Contributor Author

UFO-101 commented Sep 19, 2023

Sorry actually this isn't ready. I just realized there's an edge case when passing in a left padded input while using the cache.

@UFO-101 UFO-101 closed this Sep 26, 2023
@UFO-101
Copy link
Contributor Author

UFO-101 commented Sep 26, 2023

Superseded by #386

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant