Skip to content

Commit

Permalink
Other hook functions
Browse files Browse the repository at this point in the history
  • Loading branch information
AMindToThink committed Jan 23, 2025
1 parent 0543c45 commit 8bf726f
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
27 changes: 27 additions & 0 deletions lm_eval/models/add_to_sae_steered_beta_then_delete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import einops
# Andre was working on Matthew's folders, and Matthew didn't want to edit the same doc at the same time.
def steering_hook_projection(
activations,#: Float[Tensor], # Float[Tensor, "batch pos d_in"], Either jaxtyping or lm-evaluation-harness' precommit git script hate a type hint here.
hook: HookPoint,
sae: SAE,
latent_idx: int,
steering_coefficient: float,
) -> Tensor:
"""
Steers the model by finding the projection of each activations,
along the specified feature and adding some multiple of that projection to the activation.
"""
bad_feature = sae.W_dec[latent_idx] # batch, pos, d_in @ d_in, d_embedding -> batch, pos, d_embedding
dot_products = einops.einsum(activations, bad_feature, "batch pos d_embedding, d_embedding -> batch pos")
dot_products /= bad_feature.norm()

# Calculate the projection of activations onto the feature direction
projection = einops.einsum(
dot_products,
bad_feature,
"batch pos, d_embedding -> batch pos d_embedding"
)

# Add scaled projection to original activations
return activations + steering_coefficient * projection

57 changes: 57 additions & 0 deletions lm_eval/models/projection_deleteme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch

def batch_vector_projection(vectors: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Projects each vector in a batch onto a target vector.
Args:
vectors: Tensor of shape (b, p, d) where:
b is the batch size
p is the number of vectors per batch
d is the dimension of each vector
target: Tensor of shape (d,) - the vector to project onto
Returns:
Tensor of shape (b, p, d) containing the projected vectors
Example:
b, p, d = 32, 10, 3 # batch of 32, 10 vectors each, in 3D
vectors = torch.randn(b, p, d)
target = torch.randn(d)
projections = batch_vector_projection(vectors, target)
"""
# Ensure target is unit vector
target = torch.nn.functional.normalize(target, dim=0)

# Reshape target to (1, 1, d) for broadcasting
target_reshaped = target.view(1, 1, -1)

# Compute dot product between each vector and target
# Result shape: (b, p, 1)
dot_products = torch.sum(vectors * target_reshaped, dim=-1, keepdim=True)

# Project each vector onto target
# Multiply dot products by target vector
# Result shape: (b, p, d)
projections = dot_products * target_reshaped

return projections, dot_products

# Test function
if __name__ == "__main__":
# Create sample data
batch_size, vectors_per_batch, dim = 2, 3, 4
vectors = torch.randn(batch_size, vectors_per_batch, dim)
target = torch.randn(dim)

# Compute projections
projected, dot_products = batch_vector_projection(vectors, target)

_, zero_dot_products = batch_vector_projection(vectors - projected, target)
assert torch.allclose(zero_dot_products, torch.zeros_like(zero_dot_products), atol=1e-6)
print("Without proj, close to zero")
# Verify shapes
print(f"Input shape: {vectors.shape}")
print(f"Target shape: {target.shape}")
print(f"Output shape: {projected.shape}")

44 changes: 42 additions & 2 deletions lm_eval/models/sae_steered_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,45 @@ def clamp_sae_feature(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:flo
Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""

sae_acts[:, :, latent_idx] = value

return sae_acts

def clamp_original(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:float) -> Tensor:
"""Clamps a specific latent feature in the SAE activations to a fixed value.
Args:
sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
hook (HookPoint): The transformer-lens hook point
latent_idx (int): Index of the latent feature to clamp
value (float): Value to clamp the feature to
Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""

mask = sae_acts[:, :, latent_idx] > 0 # Create a boolean mask where values are greater than 0
sae_acts[:, :, latent_idx][mask] = value # Replace values conditionally

return sae_acts

def print_sae_acts(sae_acts:Tensor, hook:HookPoint) -> Tensor:
"""Clamps a specific latent feature in the SAE activations to a fixed value.
Args:
sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
hook (HookPoint): The transformer-lens hook point
latent_idx (int): Index of the latent feature to clamp
value (float): Value to clamp the feature to
Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""
print(40*"----")
print(f"This is the latent activations of {hook.name}")
print(sae_acts.shape)
print(torch.all(sae_acts > 0))
return sae_acts

string_to_steering_function_dict : dict = {'add':steering_hook_add_scaled_one_hot, 'clamp':clamp_sae_feature}
Expand Down Expand Up @@ -113,6 +151,7 @@ def get_sae(sae_release, sae_id):
latent_idx = int(row["latent_idx"])
steering_coefficient = float(row["steering_coefficient"])
sae = get_sae(sae_release=sae_release, sae_id=sae_id)
sae.use_error_term = True
sae.eval()
model.add_sae(sae)
hook_action = row.get("hook_action", "add")
Expand All @@ -125,8 +164,9 @@ def get_sae(sae_release, sae_id):
)
model.add_hook(hook_name, hook)
elif hook_action == "clamp":
sae.add_hook("hook_sae_acts_post", partial(clamp_sae_feature, latent_idx=latent_idx, value=steering_coefficient))

sae.add_hook("hook_sae_acts_post", partial(clamp_original, latent_idx=latent_idx, value=steering_coefficient))
elif hook_action == 'print':
sae.add_hook("hook_sae_acts_post", partial(print_sae_acts))
else:
raise ValueError(f"Unknown hook type: {hook_action}")

Expand Down

0 comments on commit 8bf726f

Please sign in to comment.