diff --git a/daam/trace.py b/daam/trace.py index 6932a4a..54b9f71 100644 --- a/daam/trace.py +++ b/daam/trace.py @@ -279,8 +279,12 @@ def __call__( return hidden_states def _hook_impl(self): + self.original_processor = self.module.processor self.module.set_processor(self) + def _unhook_impl(self): + self.module.set_processor(self.original_processor) + @property def num_heat_maps(self): return len(next(iter(self.heat_maps.values())))