diff --git a/torch_influence/__init__.py b/torch_influence/__init__.py index 337ce2a..a39af67 100644 --- a/torch_influence/__init__.py +++ b/torch_influence/__init__.py @@ -10,7 +10,8 @@ "AutogradInfluenceModule", "CGInfluenceModule", "LiSSAInfluenceModule", + "HVPModule" ] from torch_influence.base import BaseInfluenceModule, BaseObjective -from torch_influence.modules import AutogradInfluenceModule, CGInfluenceModule, LiSSAInfluenceModule +from torch_influence.modules import AutogradInfluenceModule, CGInfluenceModule, LiSSAInfluenceModule, HVPModule diff --git a/torch_influence/modules.py b/torch_influence/modules.py index 27eef6c..d9f1eb1 100644 --- a/torch_influence/modules.py +++ b/torch_influence/modules.py @@ -3,9 +3,15 @@ import numpy as np import scipy.sparse.linalg as L + +import cupy as cp +import cupyx.scipy.sparse.linalg as L_gpu + import torch from torch import nn from torch.utils import data +import gpytorch +from gpytorch.lazy import LazyTensor from torch_influence.base import BaseInfluenceModule, BaseObjective @@ -29,6 +35,8 @@ class AutogradInfluenceModule(BaseInfluenceModule): check_eigvals: if ``True``, this initializer checks that the damped risk Hessian is positive definite, and raises a :mod:`ValueError` if it is not. Otherwise, no check is performed. + store_as_hessian: if ``True``, damped risk Hessian is stores in object and any access to + inverse Hessian will be lazy. Otherwise, inverse Hessian is computed immediately. Warnings: This module scales poorly with the number of model parameters :math:`d`. In @@ -45,7 +53,8 @@ def __init__( test_loader: data.DataLoader, device: torch.device, damp: float, - check_eigvals: bool = False + check_eigvals: bool = False, + store_as_hessian: bool = False ): super().__init__( model=model, @@ -55,6 +64,7 @@ def __init__( device=device, ) + self.store_as_hessian = store_as_hessian self.damp = damp params = self._model_make_functional() @@ -68,7 +78,7 @@ def f(theta_): self._model_reinsert_params(self._reshape_like_params(theta_)) return self.objective.train_loss(self.model, theta_, batch) - hess_batch = torch.autograd.functional.hessian(f, flat_params).detach() + hess_batch = torch.autograd.functional.hessian(f, flat_params).detach().cpu() hess = hess + hess_batch * batch_size with torch.no_grad(): @@ -77,16 +87,32 @@ def f(theta_): hess = hess + damp * torch.eye(d, device=hess.device) if check_eigvals: - eigvals = np.linalg.eigvalsh(hess.cpu().numpy()) + eigvals = torch.linalg.eigvalsh(hess.cpu()).numpy() logging.info("hessian min eigval %f", np.min(eigvals).item()) logging.info("hessian max eigval %f", np.max(eigvals).item()) if not bool(np.all(eigvals >= 0)): raise ValueError() - self.inverse_hess = torch.inverse(hess) + if self.store_as_hessian: + self.hess = hess + self.inverse_hess = None + else: + self.inverse_hess = torch.inverse(hess) + + def get_hessian(self): + if not self.store_as_hessian: + raise NotImplementedError("To access damped risk Hessian, set store_as_hessian to True.") + return self.hess + + def get_inverse_hessian(self): + if self.inverse_hess is None: + # Lazy inverse computation + self.inverse_hess = torch.inverse(self.hess) + return self.inverse_hess def inverse_hvp(self, vec): - return self.inverse_hess @ vec + inverse_hess = self.get_inverse_hessian() + return inverse_hess @ vec class CGInfluenceModule(BaseInfluenceModule): @@ -122,6 +148,7 @@ def __init__( device: torch.device, damp: float, gnh: bool = False, + use_cupy: bool = True, **kwargs ): super().__init__( @@ -134,31 +161,99 @@ def __init__( self.damp = damp self.gnh = gnh + self.use_cupy = use_cupy self.cg_kwargs = kwargs def inverse_hvp(self, vec): params = self._model_make_functional() flat_params = self._flatten_params_like(params) + # Know which device tensors should sit on + if self.use_cupy: + if self.device == "cuda": + device_id = 0 + else: + device_id = int(self.device.split(":")[-1]) + def hvp_fn(v): - v = torch.tensor(v, requires_grad=False, device=self.device, dtype=vec.dtype) + # v_ = v.clone().detach().requires_grad_(False) + if self.use_cupy: + v_ = torch.as_tensor(v, device=self.device) + else: + v_ = torch.tensor(v, requires_grad=False, device=self.device, dtype=vec.dtype) + """ + was_squeezed = len(v_.shape) > 1 and v_.shape[1] == 1 + if was_squeezed: + v_ = v_.squeeze(1) + """ hvp = 0.0 for batch, batch_size in self._loader_wrapper(train=True): - hvp_batch = self._hvp_at_batch(batch, flat_params, vec=v, gnh=self.gnh) + hvp_batch = self._hvp_at_batch(batch, flat_params, vec=v_, gnh=self.gnh) hvp = hvp + hvp_batch.detach() * batch_size hvp = hvp / len(self.train_loader.dataset) - hvp = hvp + self.damp * v - + hvp = hvp + self.damp * v_ + + """ + if was_squeezed: + hvp = hvp.unsqueeze(1) + """ + + if self.use_cupy: + # Wrap as cupy array, but create a copy + # return cp.asarray(hvp.cpu().numpy()) + # Get ID that self.device corresponds to + with cp.cuda.Device(device_id): + return cp.asarray(hvp.cpu().numpy()) return hvp.cpu().numpy() d = vec.shape[0] - linop = L.LinearOperator((d, d), matvec=hvp_fn) - ihvp = L.cg(A=linop, b=vec.cpu().numpy(), **self.cg_kwargs)[0] + if self.use_cupy: + """ + # Wrap hvp_fn in a LinearOperator-like class + class HvpLazyTensor(LazyTensor): + def __init__(self, hvp_fn, vec_shape): + self.hvp_fn = hvp_fn + self.vec_shape = vec_shape + super(HvpLazyTensor, self).__init__(torch.eye(vec_shape, dtype=vec.dtype)) + + def _matmul(self, rhs): + return self.hvp_fn(rhs) + + def _transpose_nonbatch(self): + return self # Symmetric Hessian + + def _size(self): + return torch.Size((self.vec_shape, self.vec_shape)) + + # Create the lazy tensor + hvp_lazy_tensor = HvpLazyTensor(hvp_fn, d) + + # rtol: 1e-5 + # default niters: len(g) * 10 + + # Convert vec to a tensor if it isn't already + vec_tensor = vec if torch.is_tensor(vec) else torch.tensor(vec, dtype=vec.dtype) + print(vec_tensor.shape) + + # Solve using conjugate gradients + ihvp = gpytorch.utils.linear_cg(hvp_fn, vec_tensor, **self.cg_kwargs) + """ + + with cp.cuda.Device(device_id): + linop = L_gpu.LinearOperator((d, d), matvec=hvp_fn) + vec_ = cp.asarray(vec) + ihvp = L_gpu.cg(A=linop, b=vec_, **self.cg_kwargs)[0] + else: + # Slow Scipy-based method + linop = L.LinearOperator((d, d), matvec=hvp_fn) + ihvp = L.cg(A=linop, b=vec.cpu().numpy(), **self.cg_kwargs)[0] with torch.no_grad(): self._model_reinsert_params(self._reshape_like_params(flat_params), register=True) + # if self.use_cupy: + # return torch.as_tensor(ihvp, device=self.device) return torch.tensor(ihvp, device=self.device) @@ -265,3 +360,150 @@ def inverse_hvp(self, vec): self._model_reinsert_params(self._reshape_like_params(flat_params), register=True) return ihvp / self.repeat + + +class HVPModule(BaseInfluenceModule): + r"""Basic module for easily computing HPV + + """ + def __init__( + self, + model: nn.Module, + objective: BaseObjective, + train_loader: data.DataLoader, + device: torch.device + ): + super().__init__( + model=model, + objective=objective, + train_loader=train_loader, + test_loader=None, + device=device, + ) + + #TODO: Allow user to specify dtype for higher precision + + def hvp(self, vec): + params = self._model_make_functional() + flat_params = self._flatten_params_like(params) + + hvp = 0.0 + for batch, batch_size in self._loader_wrapper(train=True): + hvp_batch = self._hvp_at_batch(batch, flat_params, vec=vec.to(self.device), gnh=False) + hvp = hvp + hvp_batch.detach() * batch_size + hvp = hvp / len(self.train_loader.dataset) + + with torch.no_grad(): + self._model_reinsert_params(self._reshape_like_params(flat_params), register=True) + + return hvp + + def inverse_hvp(self, vec): + raise NotImplementedError("Inverse HVP not implemented for HVPModule - this is a heler class for HVP") + + +class ShanksSablonniereModule(BaseInfluenceModule): + """ + Based on Series of Hessian-Vector Products for Tractable Saddle-Free Newton Optimisation of Neural Networks (https://arxiv.org/abs/2310.14901). + Compute recursive Shanks transformation of given sequence using Samelson inverse and the epsilon-algorithm with a Sablonniere modifier. + """ + def __init__( + self, + model: nn.Module, + objective: BaseObjective, + train_loader: data.DataLoader, + device: torch.device, + # damp: float, + **kwargs + ): + self.acceleration_order = kwargs.get('acceleration_order', 8) + self.initial_scale_factor = kwargs.get('initial_scale_factor', 100) + self.num_update_steps = kwargs.get('num_update_steps', 20) + + self.hvp_module = HVPModule( + model, + objective, + train_loader, + device=device + ) + # TODO: Consider adding damping + + def compute_epsilon_acceleration( + self, + source_sequence, + num_applications: int=1,): + """Compute `num_applications` recursive Shanks transformation of + `source_sequence` (preferring later elements) using `Samelson` inverse and the + epsilon-algorithm, with Sablonniere modifier. + """ + + def inverse(vector): + """ + Samelson inverse + """ + return vector / vector.dot(vector) + + epsilon = {} + for m, source_m in enumerate(source_sequence): + epsilon[m, 0] = source_m.squeeze(1) + epsilon[m + 1, -1] = 0 + + s = 1 + m = (len(source_sequence) - 1) - 2 * num_applications + initial_m = m + while m < len(source_sequence) - 1: + while m >= initial_m: + # Sablonniere modifier + inverse_scaling = np.floor(s / 2) + 1 + + epsilon[m, s] = epsilon[m + 1, s - 2] + inverse_scaling * inverse( + epsilon[m + 1, s - 1] - epsilon[m, s - 1] + ) + epsilon.pop((m + 1, s - 2)) + m -= 1 + s += 1 + m += 1 + s -= 1 + epsilon.pop((m, s - 1)) + m = initial_m + s + s = 1 + + return epsilon[initial_m, 2 * num_applications] + + def inverse_hvp(self, vec): + # Detach and clone input + vector_cache = vec.detach().clone() + update_sum = vec.detach().clone() + coefficient_cache = 1 + + cached_update_sums = [] + if self.acceleration_order > 0 and self.num_update_steps == 2 * self.acceleration_order + 1: + cached_update_sums.append(update_sum) + + # Do HessianSeries calculation + for update_step in range(1, self.num_update_steps): + hessian2_vector_cache = self.hvp_module.hvp(self.hvp_module.hvp(vector_cache)) + + if update_step == 1: + scale_factor = torch.norm(hessian2_vector_cache, p=2) / torch.norm(vec, p=2) + scale_factor = max(scale_factor.item(), self.initial_scale_factor) + + vector_cache = (vector_cache - (1/scale_factor)*hessian2_vector_cache).clone() + coefficient_cache *= (2 * update_step - 1) / (2 * update_step) + update_sum += coefficient_cache * vector_cache + + if self.acceleration_order > 0 and update_step >= (self.num_update_steps - 2 * self.acceleration_order - 1): + cached_update_sums.append(update_sum.clone()) + + # Perform series acceleration (Shanks acceleration) + if self.acceleration_order > 0: + accelerated_sum = self.compute_epsilon_acceleration( + cached_update_sums, num_applications=self.acceleration_order + ) + accelerated_sum /= np.sqrt(scale_factor) + accelerated_sum = accelerated_sum.unsqueeze(1) + + return accelerated_sum + + update_sum /= np.sqrt(scale_factor) + return update_sum