Skip to content

Commit

Permalink
submodule updates
Browse files Browse the repository at this point in the history
  • Loading branch information
AMindToThink committed Jan 25, 2025
1 parent 37d8c96 commit 361e0d1
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 52 deletions.
106 changes: 56 additions & 50 deletions lm_eval/models/InterventionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def clamp_original(sae_acts:Tensor, hook:HookPoint, latent_idx:int, value:float)
Returns:
Tensor: The modified SAE activations with the specified feature clamped
"""

#import pdb;pdb.set_trace()
mask = sae_acts[:, :, latent_idx] > 0 # Create a boolean mask where values are greater than 0
sae_acts[:, :, latent_idx][mask] = value # Replace values conditionally

Expand Down Expand Up @@ -115,6 +115,59 @@ def __init__(self, base_name: str, device: str = "cuda:0", model=None):
self.device = device # Add device attribute
self.to(device) # Ensure model is on the correct device

@classmethod
def from_dataframe(cls, dataframe, base_name:str, device:str='cuda:0'):
model = HookedSAETransformer.from_pretrained(base_name, device=device)
original_saes = model.acts_to_saes
assert original_saes == {} # There shouldn't be any SAEs to start
# Read steering configurations
# Create hooks for each row in the CSV
sae_cache = {}
# original_sae_hooks_cache = {}
def get_sae(sae_release, sae_id):
cache_key = (sae_release, sae_id)
if cache_key not in sae_cache:
sae_cache[cache_key] = SAE.from_pretrained(
sae_release, sae_id, device=str(device)
)[0]
# original_sae_hooks_cache[cache_key] = sae_cache[cache_key]
return sae_cache[cache_key]

for _, row in df.iterrows():
sae_release = row["sae_release"]
sae_id = row["sae_id"]
latent_idx = int(row["latent_idx"])
steering_coefficient = float(row["steering_coefficient"])
sae = get_sae(sae_release=sae_release, sae_id=sae_id)
sae.use_error_term = True
sae.eval()
# Add the SAE to the model after configuring its hooks
model.add_sae(sae)
# First add all hooks to the SAE before adding it to the model
hook_action = row.get("hook_action", "add")
after_activation_fn = f"{sae.cfg.hook_name}.hook_sae_acts_post"
if hook_action == "add":
hook_name = f"{sae.cfg.hook_name}.hook_sae_input" # we aren't actually putting the input through the model
hook = partial(steering_hook_add_scaled_one_hot,
sae=sae,
latent_idx=latent_idx,
steering_coefficient=steering_coefficient,
)
model.add_hook(hook_name, hook)
elif hook_action == "clamp":
#import pdb;pdb.set_trace()
model.add_hook(after_activation_fn, partial(clamp_original, latent_idx=latent_idx, value=steering_coefficient))
elif hook_action == 'print':
model.add_hook(after_activation_fn, print_sae_acts)
elif hook_action == 'debug':
model.add_hook(after_activation_fn, debug_steer)
else:
raise ValueError(f"Unknown hook type: {hook_action}")


# Create and return the model
return cls(base_name=base_name, device=device, model=model)

@classmethod
def from_csv(
cls, csv_path: str, base_name: str, device: str = "cuda:0"
Expand All @@ -135,56 +188,9 @@ def from_csv(
InterventionModel with configured steering hooks
"""
import pandas as pd
model = HookedSAETransformer.from_pretrained(base_name, device=device)
original_saes = model.acts_to_saes
assert original_saes == {} # There shouldn't be any SAEs to start
# Read steering configurations
df = pd.read_csv(csv_path)

# Group SAEs by hook point
hook_groups = df.groupby(['sae_release', 'sae_id'])
sae_cache = {}

for (sae_release, sae_id), group in hook_groups:
# Get or create SAE
cache_key = (sae_release, sae_id)
if cache_key not in sae_cache:
sae = SAE.from_pretrained(sae_release, sae_id, device=device)[0]
sae.use_error_term = True
sae.eval()
sae_cache[cache_key] = sae
else:
sae = sae_cache[cache_key]

# Add SAE after configuring all its hooks
model.add_sae(sae)
# Add all hooks for this SAE
for _, row in group.iterrows():
latent_idx = int(row["latent_idx"])
steering_coefficient = float(row["steering_coefficient"])
hook_action = row.get("hook_action", "add")

if hook_action == "add":
hook_name = f"{sae.cfg.hook_name}.hook_sae_input"
hook = partial(steering_hook_add_scaled_one_hot,
sae=sae,
latent_idx=latent_idx,
steering_coefficient=steering_coefficient)
model.add_hook(hook_name, hook)
elif hook_action == "clamp":
sae.add_hook("hook_sae_acts_post",
partial(clamp_original,
latent_idx=latent_idx,
value=steering_coefficient))
elif hook_action == 'print':
sae.add_hook("hook_sae_acts_post", print_sae_acts)
elif hook_action == 'debug':
sae.add_hook("hook_sae_acts_post", debug_steer)
else:
raise ValueError(f"Unknown hook action: {hook_action}")


return cls(base_name=base_name, device=device, model=model)

return InterventionModel.from_dataframe(dataframe=df, base_name=base_name, device=device)

def forward(self, *args, **kwargs):
# Handle both input_ids and direct tensor inputs
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/models/debug_steer.csv
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ debug,0,0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical,applied s
clamp,12082, -69420.0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical,applied third
clamp,0, 4233030303.0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical,applied fourth
debug,0,0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical,applied fifth
add,12082,10000240.0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical,applied first
add,12082,10000240.0,gemma-scope-2b-pt-res-canonical,layer_20/width_16k/canonical,applied first
2 changes: 1 addition & 1 deletion lm_eval/models/test_im_deleteme.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from InterventionModel import InterventionModel
m = InterventionModel.from_csv(base_name='google/gemma-2-2b', csv_path='/home/cs29824/matthew/sae_jailbreak_unlearning/src/scripts/evaluation/lm-evaluation-harness/lm_eval/models/debug_steer.csv', device='cuda:1')
m.generate('hi')
m.generate('hi, this is a long text because we need the law of large numbers on our side in order to be able to trust that things will be positive a reasonable amount of the time; otherwise I get confused')

0 comments on commit 361e0d1

Please sign in to comment.