Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow direct Hessian retrieval #17

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torch_influence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"AutogradInfluenceModule",
"CGInfluenceModule",
"LiSSAInfluenceModule",
"HVPModule"
]

from torch_influence.base import BaseInfluenceModule, BaseObjective
from torch_influence.modules import AutogradInfluenceModule, CGInfluenceModule, LiSSAInfluenceModule
from torch_influence.modules import AutogradInfluenceModule, CGInfluenceModule, LiSSAInfluenceModule, HVPModule
264 changes: 253 additions & 11 deletions torch_influence/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@

import numpy as np
import scipy.sparse.linalg as L

import cupy as cp
import cupyx.scipy.sparse.linalg as L_gpu

import torch
from torch import nn
from torch.utils import data
import gpytorch
from gpytorch.lazy import LazyTensor

from torch_influence.base import BaseInfluenceModule, BaseObjective

Expand All @@ -29,6 +35,8 @@ class AutogradInfluenceModule(BaseInfluenceModule):
check_eigvals: if ``True``, this initializer checks that the damped risk Hessian
is positive definite, and raises a :mod:`ValueError` if it is not. Otherwise,
no check is performed.
store_as_hessian: if ``True``, damped risk Hessian is stores in object and any access to
inverse Hessian will be lazy. Otherwise, inverse Hessian is computed immediately.

Warnings:
This module scales poorly with the number of model parameters :math:`d`. In
Expand All @@ -45,7 +53,8 @@ def __init__(
test_loader: data.DataLoader,
device: torch.device,
damp: float,
check_eigvals: bool = False
check_eigvals: bool = False,
store_as_hessian: bool = False
):
super().__init__(
model=model,
Expand All @@ -55,6 +64,7 @@ def __init__(
device=device,
)

self.store_as_hessian = store_as_hessian
self.damp = damp

params = self._model_make_functional()
Expand All @@ -68,7 +78,7 @@ def f(theta_):
self._model_reinsert_params(self._reshape_like_params(theta_))
return self.objective.train_loss(self.model, theta_, batch)

hess_batch = torch.autograd.functional.hessian(f, flat_params).detach()
hess_batch = torch.autograd.functional.hessian(f, flat_params).detach().cpu()
hess = hess + hess_batch * batch_size

with torch.no_grad():
Expand All @@ -77,16 +87,32 @@ def f(theta_):
hess = hess + damp * torch.eye(d, device=hess.device)

if check_eigvals:
eigvals = np.linalg.eigvalsh(hess.cpu().numpy())
eigvals = torch.linalg.eigvalsh(hess.cpu()).numpy()
logging.info("hessian min eigval %f", np.min(eigvals).item())
logging.info("hessian max eigval %f", np.max(eigvals).item())
if not bool(np.all(eigvals >= 0)):
raise ValueError()

self.inverse_hess = torch.inverse(hess)
if self.store_as_hessian:
self.hess = hess
self.inverse_hess = None
else:
self.inverse_hess = torch.inverse(hess)

def get_hessian(self):
if not self.store_as_hessian:
raise NotImplementedError("To access damped risk Hessian, set store_as_hessian to True.")
return self.hess

def get_inverse_hessian(self):
if self.inverse_hess is None:
# Lazy inverse computation
self.inverse_hess = torch.inverse(self.hess)
return self.inverse_hess

def inverse_hvp(self, vec):
return self.inverse_hess @ vec
inverse_hess = self.get_inverse_hessian()
return inverse_hess @ vec


class CGInfluenceModule(BaseInfluenceModule):
Expand Down Expand Up @@ -122,6 +148,7 @@ def __init__(
device: torch.device,
damp: float,
gnh: bool = False,
use_cupy: bool = True,
**kwargs
):
super().__init__(
Expand All @@ -134,31 +161,99 @@ def __init__(

self.damp = damp
self.gnh = gnh
self.use_cupy = use_cupy
self.cg_kwargs = kwargs

def inverse_hvp(self, vec):
params = self._model_make_functional()
flat_params = self._flatten_params_like(params)

# Know which device tensors should sit on
if self.use_cupy:
if self.device == "cuda":
device_id = 0
else:
device_id = int(self.device.split(":")[-1])

def hvp_fn(v):
v = torch.tensor(v, requires_grad=False, device=self.device, dtype=vec.dtype)
# v_ = v.clone().detach().requires_grad_(False)
if self.use_cupy:
v_ = torch.as_tensor(v, device=self.device)
else:
v_ = torch.tensor(v, requires_grad=False, device=self.device, dtype=vec.dtype)
"""
was_squeezed = len(v_.shape) > 1 and v_.shape[1] == 1
if was_squeezed:
v_ = v_.squeeze(1)
"""

hvp = 0.0
for batch, batch_size in self._loader_wrapper(train=True):
hvp_batch = self._hvp_at_batch(batch, flat_params, vec=v, gnh=self.gnh)
hvp_batch = self._hvp_at_batch(batch, flat_params, vec=v_, gnh=self.gnh)
hvp = hvp + hvp_batch.detach() * batch_size
hvp = hvp / len(self.train_loader.dataset)
hvp = hvp + self.damp * v

hvp = hvp + self.damp * v_

"""
if was_squeezed:
hvp = hvp.unsqueeze(1)
"""

if self.use_cupy:
# Wrap as cupy array, but create a copy
# return cp.asarray(hvp.cpu().numpy())
# Get ID that self.device corresponds to
with cp.cuda.Device(device_id):
return cp.asarray(hvp.cpu().numpy())
return hvp.cpu().numpy()

d = vec.shape[0]
linop = L.LinearOperator((d, d), matvec=hvp_fn)
ihvp = L.cg(A=linop, b=vec.cpu().numpy(), **self.cg_kwargs)[0]
if self.use_cupy:
"""
# Wrap hvp_fn in a LinearOperator-like class
class HvpLazyTensor(LazyTensor):
def __init__(self, hvp_fn, vec_shape):
self.hvp_fn = hvp_fn
self.vec_shape = vec_shape
super(HvpLazyTensor, self).__init__(torch.eye(vec_shape, dtype=vec.dtype))

def _matmul(self, rhs):
return self.hvp_fn(rhs)

def _transpose_nonbatch(self):
return self # Symmetric Hessian

def _size(self):
return torch.Size((self.vec_shape, self.vec_shape))

# Create the lazy tensor
hvp_lazy_tensor = HvpLazyTensor(hvp_fn, d)

# rtol: 1e-5
# default niters: len(g) * 10

# Convert vec to a tensor if it isn't already
vec_tensor = vec if torch.is_tensor(vec) else torch.tensor(vec, dtype=vec.dtype)
print(vec_tensor.shape)

# Solve using conjugate gradients
ihvp = gpytorch.utils.linear_cg(hvp_fn, vec_tensor, **self.cg_kwargs)
"""

with cp.cuda.Device(device_id):
linop = L_gpu.LinearOperator((d, d), matvec=hvp_fn)
vec_ = cp.asarray(vec)
ihvp = L_gpu.cg(A=linop, b=vec_, **self.cg_kwargs)[0]
else:
# Slow Scipy-based method
linop = L.LinearOperator((d, d), matvec=hvp_fn)
ihvp = L.cg(A=linop, b=vec.cpu().numpy(), **self.cg_kwargs)[0]

with torch.no_grad():
self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)

# if self.use_cupy:
# return torch.as_tensor(ihvp, device=self.device)
return torch.tensor(ihvp, device=self.device)


Expand Down Expand Up @@ -265,3 +360,150 @@ def inverse_hvp(self, vec):
self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)

return ihvp / self.repeat


class HVPModule(BaseInfluenceModule):
r"""Basic module for easily computing HPV

"""
def __init__(
self,
model: nn.Module,
objective: BaseObjective,
train_loader: data.DataLoader,
device: torch.device
):
super().__init__(
model=model,
objective=objective,
train_loader=train_loader,
test_loader=None,
device=device,
)

#TODO: Allow user to specify dtype for higher precision

def hvp(self, vec):
params = self._model_make_functional()
flat_params = self._flatten_params_like(params)

hvp = 0.0
for batch, batch_size in self._loader_wrapper(train=True):
hvp_batch = self._hvp_at_batch(batch, flat_params, vec=vec.to(self.device), gnh=False)
hvp = hvp + hvp_batch.detach() * batch_size
hvp = hvp / len(self.train_loader.dataset)

with torch.no_grad():
self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)

return hvp

def inverse_hvp(self, vec):
raise NotImplementedError("Inverse HVP not implemented for HVPModule - this is a heler class for HVP")


class ShanksSablonniereModule(BaseInfluenceModule):
"""
Based on Series of Hessian-Vector Products for Tractable Saddle-Free Newton Optimisation of Neural Networks (https://arxiv.org/abs/2310.14901).
Compute recursive Shanks transformation of given sequence using Samelson inverse and the epsilon-algorithm with a Sablonniere modifier.
"""
def __init__(
self,
model: nn.Module,
objective: BaseObjective,
train_loader: data.DataLoader,
device: torch.device,
# damp: float,
**kwargs
):
self.acceleration_order = kwargs.get('acceleration_order', 8)
self.initial_scale_factor = kwargs.get('initial_scale_factor', 100)
self.num_update_steps = kwargs.get('num_update_steps', 20)

self.hvp_module = HVPModule(
model,
objective,
train_loader,
device=device
)
# TODO: Consider adding damping

def compute_epsilon_acceleration(
self,
source_sequence,
num_applications: int=1,):
"""Compute `num_applications` recursive Shanks transformation of
`source_sequence` (preferring later elements) using `Samelson` inverse and the
epsilon-algorithm, with Sablonniere modifier.
"""

def inverse(vector):
"""
Samelson inverse
"""
return vector / vector.dot(vector)

epsilon = {}
for m, source_m in enumerate(source_sequence):
epsilon[m, 0] = source_m.squeeze(1)
epsilon[m + 1, -1] = 0

s = 1
m = (len(source_sequence) - 1) - 2 * num_applications
initial_m = m
while m < len(source_sequence) - 1:
while m >= initial_m:
# Sablonniere modifier
inverse_scaling = np.floor(s / 2) + 1

epsilon[m, s] = epsilon[m + 1, s - 2] + inverse_scaling * inverse(
epsilon[m + 1, s - 1] - epsilon[m, s - 1]
)
epsilon.pop((m + 1, s - 2))
m -= 1
s += 1
m += 1
s -= 1
epsilon.pop((m, s - 1))
m = initial_m + s
s = 1

return epsilon[initial_m, 2 * num_applications]

def inverse_hvp(self, vec):
# Detach and clone input
vector_cache = vec.detach().clone()
update_sum = vec.detach().clone()
coefficient_cache = 1

cached_update_sums = []
if self.acceleration_order > 0 and self.num_update_steps == 2 * self.acceleration_order + 1:
cached_update_sums.append(update_sum)

# Do HessianSeries calculation
for update_step in range(1, self.num_update_steps):
hessian2_vector_cache = self.hvp_module.hvp(self.hvp_module.hvp(vector_cache))

if update_step == 1:
scale_factor = torch.norm(hessian2_vector_cache, p=2) / torch.norm(vec, p=2)
scale_factor = max(scale_factor.item(), self.initial_scale_factor)

vector_cache = (vector_cache - (1/scale_factor)*hessian2_vector_cache).clone()
coefficient_cache *= (2 * update_step - 1) / (2 * update_step)
update_sum += coefficient_cache * vector_cache

if self.acceleration_order > 0 and update_step >= (self.num_update_steps - 2 * self.acceleration_order - 1):
cached_update_sums.append(update_sum.clone())

# Perform series acceleration (Shanks acceleration)
if self.acceleration_order > 0:
accelerated_sum = self.compute_epsilon_acceleration(
cached_update_sums, num_applications=self.acceleration_order
)
accelerated_sum /= np.sqrt(scale_factor)
accelerated_sum = accelerated_sum.unsqueeze(1)

return accelerated_sum

update_sum /= np.sqrt(scale_factor)
return update_sum