-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
AMindToThink
committed
Jan 23, 2025
1 parent
0543c45
commit 8bf726f
Showing
3 changed files
with
126 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import einops | ||
# Andre was working on Matthew's folders, and Matthew didn't want to edit the same doc at the same time. | ||
def steering_hook_projection( | ||
activations,#: Float[Tensor], # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here. | ||
hook: HookPoint, | ||
sae: SAE, | ||
latent_idx: int, | ||
steering_coefficient: float, | ||
) -> Tensor: | ||
""" | ||
Steers the model by finding the projection of each activations, | ||
along the specified feature and adding some multiple of that projection to the activation. | ||
""" | ||
bad_feature = sae.W_dec[latent_idx] # batch, pos, d_in @ d_in, d_embedding -> batch, pos, d_embedding | ||
dot_products = einops.einsum(activations, bad_feature, "batch pos d_embedding, d_embedding -> batch pos") | ||
dot_products /= bad_feature.norm() | ||
|
||
# Calculate the projection of activations onto the feature direction | ||
projection = einops.einsum( | ||
dot_products, | ||
bad_feature, | ||
"batch pos, d_embedding -> batch pos d_embedding" | ||
) | ||
|
||
# Add scaled projection to original activations | ||
return activations + steering_coefficient * projection | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import torch | ||
|
||
def batch_vector_projection(vectors: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Projects each vector in a batch onto a target vector. | ||
Args: | ||
vectors: Tensor of shape (b, p, d) where: | ||
b is the batch size | ||
p is the number of vectors per batch | ||
d is the dimension of each vector | ||
target: Tensor of shape (d,) - the vector to project onto | ||
Returns: | ||
Tensor of shape (b, p, d) containing the projected vectors | ||
Example: | ||
b, p, d = 32, 10, 3 # batch of 32, 10 vectors each, in 3D | ||
vectors = torch.randn(b, p, d) | ||
target = torch.randn(d) | ||
projections = batch_vector_projection(vectors, target) | ||
""" | ||
# Ensure target is unit vector | ||
target = torch.nn.functional.normalize(target, dim=0) | ||
|
||
# Reshape target to (1, 1, d) for broadcasting | ||
target_reshaped = target.view(1, 1, -1) | ||
|
||
# Compute dot product between each vector and target | ||
# Result shape: (b, p, 1) | ||
dot_products = torch.sum(vectors * target_reshaped, dim=-1, keepdim=True) | ||
|
||
# Project each vector onto target | ||
# Multiply dot products by target vector | ||
# Result shape: (b, p, d) | ||
projections = dot_products * target_reshaped | ||
|
||
return projections, dot_products | ||
|
||
# Test function | ||
if __name__ == "__main__": | ||
# Create sample data | ||
batch_size, vectors_per_batch, dim = 2, 3, 4 | ||
vectors = torch.randn(batch_size, vectors_per_batch, dim) | ||
target = torch.randn(dim) | ||
|
||
# Compute projections | ||
projected, dot_products = batch_vector_projection(vectors, target) | ||
|
||
_, zero_dot_products = batch_vector_projection(vectors - projected, target) | ||
assert torch.allclose(zero_dot_products, torch.zeros_like(zero_dot_products), atol=1e-6) | ||
print("Without proj, close to zero") | ||
# Verify shapes | ||
print(f"Input shape: {vectors.shape}") | ||
print(f"Target shape: {target.shape}") | ||
print(f"Output shape: {projected.shape}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters