Skip to content

Commit

Permalink
Missing torch import
Browse files Browse the repository at this point in the history
  • Loading branch information
AMindToThink committed Jan 25, 2025
1 parent 361e0d1 commit f476b70
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions lm_eval/models/InterventionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, base_name: str, device: str = "cuda:0", model=None):
self.to(device) # Ensure model is on the correct device

@classmethod
def from_dataframe(cls, dataframe, base_name:str, device:str='cuda:0'):
def from_dataframe(cls, df, base_name:str, device:str='cuda:0'):
model = HookedSAETransformer.from_pretrained(base_name, device=device)
original_saes = model.acts_to_saes
assert original_saes == {} # There shouldn't be any SAEs to start
Expand Down Expand Up @@ -190,7 +190,7 @@ def from_csv(
import pandas as pd
df = pd.read_csv(csv_path)

return InterventionModel.from_dataframe(dataframe=df, base_name=base_name, device=device)
return InterventionModel.from_dataframe(df=df, base_name=base_name, device=device)

def forward(self, *args, **kwargs):
# Handle both input_ids and direct tensor inputs
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/models/sae_steered_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
from lm_eval.models.InterventionModel import InterventionModel

import torch


@register_model("sae_steered_beta")
Expand Down

0 comments on commit f476b70

Please sign in to comment.