Skip to content

Commit

Permalink
Other steer functions. Deprecated InterventionModel instantiation fro…
Browse files Browse the repository at this point in the history
…m hooks list; now only csv
  • Loading branch information
AMindToThink committed Jan 22, 2025
1 parent 88d3d4c commit 0543c45
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions lm_eval/models/sae_steered_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def steering_hook_add_scaled_one_hot(
# return sae.decode(activations)
# return activations + steering_coefficient * sae.W_dec[latent_idx]

string_to_steering_function_dict : dict = {'add':steering_hook_add_scaled_one_hot, 'clamp':steering_hook_clamp}

def clamp_sae_feature(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:float) -> Tensor:
"""Clamps a specific latent feature in the SAE activations to a fixed value.
Expand All @@ -59,13 +58,16 @@ def clamp_sae_feature(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:flo
sae_acts[:, :, latent_idx] = value
return sae_acts

string_to_steering_function_dict : dict = {'add':steering_hook_add_scaled_one_hot, 'clamp':clamp_sae_feature}

class InterventionModel(HookedSAETransformer): # Replace with the specific model class
def __init__(self, base_name: str, device: str = "cuda:0", model=None):
trueconfig = loading_from_pretrained.get_pretrained_model_config(
base_name, device=device
)
super().__init__(trueconfig)
self.model = model or HookedSAETransformer.from_pretrained(base_name, device=device)
self.model.use_error_term = True
self.model.eval()
self.device = device # Add device attribute
self.to(device) # Ensure model is on the correct device
Expand Down Expand Up @@ -112,8 +114,8 @@ def get_sae(sae_release, sae_id):
steering_coefficient = float(row["steering_coefficient"])
sae = get_sae(sae_release=sae_release, sae_id=sae_id)
sae.eval()
model.add_sae(sae)
hook_action = row.get("hook_action", "add")

if hook_action == "add":
hook_name = f"{sae.cfg.hook_name}.hook_sae_input" # we aren't actually putting the input through the model
hook = partial(steering_hook_add_scaled_one_hot,
Expand All @@ -124,7 +126,7 @@ def get_sae(sae_release, sae_id):
model.add_hook(hook_name, hook)
elif hook_action == "clamp":
sae.add_hook("hook_sae_acts_post", partial(clamp_sae_feature, latent_idx=latent_idx, value=steering_coefficient))
model.add_sae(sae)

else:
raise ValueError(f"Unknown hook type: {hook_action}")

Expand Down

0 comments on commit 0543c45

Please sign in to comment.