diff --git a/torch_influence/__init__.py b/torch_influence/__init__.py
index 337ce2a..a39af67 100644
--- a/torch_influence/__init__.py
+++ b/torch_influence/__init__.py
@@ -10,7 +10,8 @@
     "AutogradInfluenceModule",
     "CGInfluenceModule",
     "LiSSAInfluenceModule",
+    "HVPModule"
 ]
 
 from torch_influence.base import BaseInfluenceModule, BaseObjective
-from torch_influence.modules import AutogradInfluenceModule, CGInfluenceModule, LiSSAInfluenceModule
+from torch_influence.modules import AutogradInfluenceModule, CGInfluenceModule, LiSSAInfluenceModule, HVPModule
diff --git a/torch_influence/modules.py b/torch_influence/modules.py
index 27eef6c..d9f1eb1 100644
--- a/torch_influence/modules.py
+++ b/torch_influence/modules.py
@@ -3,9 +3,15 @@
 
 import numpy as np
 import scipy.sparse.linalg as L
+
+import cupy as cp
+import cupyx.scipy.sparse.linalg as L_gpu
+
 import torch
 from torch import nn
 from torch.utils import data
+import gpytorch
+from gpytorch.lazy import LazyTensor
 
 from torch_influence.base import BaseInfluenceModule, BaseObjective
 
@@ -29,6 +35,8 @@ class AutogradInfluenceModule(BaseInfluenceModule):
         check_eigvals: if ``True``, this initializer checks that the damped risk Hessian
             is positive definite, and raises a :mod:`ValueError` if it is not. Otherwise,
             no check is performed.
+        store_as_hessian: if ``True``, damped risk Hessian is stores in object and any access to
+            inverse Hessian will be lazy. Otherwise, inverse Hessian is computed immediately.
 
     Warnings:
         This module scales poorly with the number of model parameters :math:`d`. In
@@ -45,7 +53,8 @@ def __init__(
             test_loader: data.DataLoader,
             device: torch.device,
             damp: float,
-            check_eigvals: bool = False
+            check_eigvals: bool = False,
+            store_as_hessian: bool = False
     ):
         super().__init__(
             model=model,
@@ -55,6 +64,7 @@ def __init__(
             device=device,
         )
 
+        self.store_as_hessian = store_as_hessian
         self.damp = damp
 
         params = self._model_make_functional()
@@ -68,7 +78,7 @@ def f(theta_):
                 self._model_reinsert_params(self._reshape_like_params(theta_))
                 return self.objective.train_loss(self.model, theta_, batch)
 
-            hess_batch = torch.autograd.functional.hessian(f, flat_params).detach()
+            hess_batch = torch.autograd.functional.hessian(f, flat_params).detach().cpu()
             hess = hess + hess_batch * batch_size
 
         with torch.no_grad():
@@ -77,16 +87,32 @@ def f(theta_):
             hess = hess + damp * torch.eye(d, device=hess.device)
 
             if check_eigvals:
-                eigvals = np.linalg.eigvalsh(hess.cpu().numpy())
+                eigvals = torch.linalg.eigvalsh(hess.cpu()).numpy()
                 logging.info("hessian min eigval %f", np.min(eigvals).item())
                 logging.info("hessian max eigval %f", np.max(eigvals).item())
                 if not bool(np.all(eigvals >= 0)):
                     raise ValueError()
 
-            self.inverse_hess = torch.inverse(hess)
+            if self.store_as_hessian:
+                self.hess = hess
+                self.inverse_hess = None
+            else:
+                self.inverse_hess = torch.inverse(hess)
+
+    def get_hessian(self):
+        if not self.store_as_hessian:
+            raise NotImplementedError("To access damped risk Hessian, set store_as_hessian to True.")
+        return self.hess
+
+    def get_inverse_hessian(self):
+        if self.inverse_hess is None:
+            # Lazy inverse computation
+            self.inverse_hess = torch.inverse(self.hess)
+        return self.inverse_hess
 
     def inverse_hvp(self, vec):
-        return self.inverse_hess @ vec
+        inverse_hess = self.get_inverse_hessian()
+        return inverse_hess @ vec
 
 
 class CGInfluenceModule(BaseInfluenceModule):
@@ -122,6 +148,7 @@ def __init__(
             device: torch.device,
             damp: float,
             gnh: bool = False,
+            use_cupy: bool = True,
             **kwargs
     ):
         super().__init__(
@@ -134,31 +161,99 @@ def __init__(
 
         self.damp = damp
         self.gnh = gnh
+        self.use_cupy = use_cupy
         self.cg_kwargs = kwargs
 
     def inverse_hvp(self, vec):
         params = self._model_make_functional()
         flat_params = self._flatten_params_like(params)
 
+        # Know which device tensors should sit on
+        if self.use_cupy:
+            if self.device == "cuda":
+                device_id = 0
+            else:
+                device_id = int(self.device.split(":")[-1])
+
         def hvp_fn(v):
-            v = torch.tensor(v, requires_grad=False, device=self.device, dtype=vec.dtype)
+            # v_ = v.clone().detach().requires_grad_(False)
+            if self.use_cupy:
+                v_ = torch.as_tensor(v, device=self.device)
+            else:
+                v_ = torch.tensor(v, requires_grad=False, device=self.device, dtype=vec.dtype)
+            """
+            was_squeezed = len(v_.shape) > 1 and v_.shape[1] == 1
+            if was_squeezed:
+                v_ = v_.squeeze(1)
+            """
 
             hvp = 0.0
             for batch, batch_size in self._loader_wrapper(train=True):
-                hvp_batch = self._hvp_at_batch(batch, flat_params, vec=v, gnh=self.gnh)
+                hvp_batch = self._hvp_at_batch(batch, flat_params, vec=v_, gnh=self.gnh)
                 hvp = hvp + hvp_batch.detach() * batch_size
             hvp = hvp / len(self.train_loader.dataset)
-            hvp = hvp + self.damp * v
-
+            hvp = hvp + self.damp * v_
+
+            """
+            if was_squeezed:
+                hvp = hvp.unsqueeze(1)
+            """
+
+            if self.use_cupy:
+                # Wrap as cupy array, but create a copy
+                # return cp.asarray(hvp.cpu().numpy())
+                # Get ID that self.device corresponds to
+                with cp.cuda.Device(device_id):
+                    return cp.asarray(hvp.cpu().numpy())
             return hvp.cpu().numpy()
 
         d = vec.shape[0]
-        linop = L.LinearOperator((d, d), matvec=hvp_fn)
-        ihvp = L.cg(A=linop, b=vec.cpu().numpy(), **self.cg_kwargs)[0]
+        if self.use_cupy:
+            """
+            # Wrap hvp_fn in a LinearOperator-like class
+            class HvpLazyTensor(LazyTensor):
+                def __init__(self, hvp_fn, vec_shape):
+                    self.hvp_fn = hvp_fn
+                    self.vec_shape = vec_shape
+                    super(HvpLazyTensor, self).__init__(torch.eye(vec_shape, dtype=vec.dtype))
+
+                def _matmul(self, rhs):
+                    return self.hvp_fn(rhs)
+
+                def _transpose_nonbatch(self):
+                    return self  # Symmetric Hessian
+
+                def _size(self):
+                    return torch.Size((self.vec_shape, self.vec_shape))
+
+            # Create the lazy tensor
+            hvp_lazy_tensor = HvpLazyTensor(hvp_fn, d)
+
+            # rtol: 1e-5
+            # default niters: len(g) * 10
+
+            # Convert vec to a tensor if it isn't already
+            vec_tensor = vec if torch.is_tensor(vec) else torch.tensor(vec, dtype=vec.dtype)
+            print(vec_tensor.shape)
+
+            # Solve using conjugate gradients
+            ihvp = gpytorch.utils.linear_cg(hvp_fn, vec_tensor, **self.cg_kwargs)
+            """
+
+            with cp.cuda.Device(device_id):
+                linop = L_gpu.LinearOperator((d, d), matvec=hvp_fn)
+                vec_ = cp.asarray(vec)
+                ihvp  = L_gpu.cg(A=linop, b=vec_, **self.cg_kwargs)[0]
+        else:
+            # Slow Scipy-based method
+            linop = L.LinearOperator((d, d), matvec=hvp_fn)
+            ihvp = L.cg(A=linop, b=vec.cpu().numpy(), **self.cg_kwargs)[0]
 
         with torch.no_grad():
             self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)
 
+        # if self.use_cupy:
+            # return torch.as_tensor(ihvp, device=self.device)
         return torch.tensor(ihvp, device=self.device)
 
 
@@ -265,3 +360,150 @@ def inverse_hvp(self, vec):
             self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)
 
         return ihvp / self.repeat
+
+
+class HVPModule(BaseInfluenceModule):
+    r"""Basic module for easily computing HPV
+
+    """
+    def __init__(
+            self,
+            model: nn.Module,
+            objective: BaseObjective,
+            train_loader: data.DataLoader,
+            device: torch.device
+    ):
+        super().__init__(
+            model=model,
+            objective=objective,
+            train_loader=train_loader,
+            test_loader=None,
+            device=device,
+        )
+
+    #TODO: Allow user to specify dtype for higher precision
+
+    def hvp(self, vec):
+        params = self._model_make_functional()
+        flat_params = self._flatten_params_like(params)
+
+        hvp = 0.0
+        for batch, batch_size in self._loader_wrapper(train=True):
+            hvp_batch = self._hvp_at_batch(batch, flat_params, vec=vec.to(self.device), gnh=False)
+            hvp = hvp + hvp_batch.detach() * batch_size
+        hvp = hvp / len(self.train_loader.dataset)
+
+        with torch.no_grad():
+            self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)
+
+        return hvp
+
+    def inverse_hvp(self, vec):
+        raise NotImplementedError("Inverse HVP not implemented for HVPModule - this is a heler class for HVP")
+
+
+class ShanksSablonniereModule(BaseInfluenceModule):
+    """
+    Based on Series of Hessian-Vector Products for Tractable Saddle-Free Newton Optimisation of Neural Networks (https://arxiv.org/abs/2310.14901).
+    Compute recursive Shanks transformation of given sequence using Samelson inverse and the epsilon-algorithm with a Sablonniere modifier.
+    """
+    def __init__(
+            self,
+            model: nn.Module,
+            objective: BaseObjective,
+            train_loader: data.DataLoader,
+            device: torch.device,
+            # damp: float,
+            **kwargs
+    ):
+        self.acceleration_order = kwargs.get('acceleration_order', 8)
+        self.initial_scale_factor = kwargs.get('initial_scale_factor', 100)
+        self.num_update_steps = kwargs.get('num_update_steps', 20)
+
+        self.hvp_module = HVPModule(
+            model,
+            objective,
+            train_loader,
+            device=device
+        )
+        # TODO: Consider adding damping
+
+    def compute_epsilon_acceleration(
+        self,
+        source_sequence,
+        num_applications: int=1,):
+        """Compute `num_applications` recursive Shanks transformation of
+        `source_sequence` (preferring later elements) using `Samelson` inverse and the
+        epsilon-algorithm, with Sablonniere modifier.
+        """
+
+        def inverse(vector):
+            """
+            Samelson inverse
+            """
+            return vector / vector.dot(vector)
+
+        epsilon = {}
+        for m, source_m in enumerate(source_sequence):
+            epsilon[m, 0] = source_m.squeeze(1)
+            epsilon[m + 1, -1] = 0
+
+        s = 1
+        m = (len(source_sequence) - 1) - 2 * num_applications
+        initial_m = m
+        while m < len(source_sequence) - 1:
+            while m >= initial_m:
+                # Sablonniere modifier
+                inverse_scaling = np.floor(s / 2) + 1
+
+                epsilon[m, s] = epsilon[m + 1, s - 2] + inverse_scaling * inverse(
+                    epsilon[m + 1, s - 1] - epsilon[m, s - 1]
+                )
+                epsilon.pop((m + 1, s - 2))
+                m -= 1
+                s += 1
+            m += 1
+            s -= 1
+            epsilon.pop((m, s - 1))
+            m = initial_m + s
+            s = 1
+
+        return epsilon[initial_m, 2 * num_applications]
+
+    def inverse_hvp(self, vec):
+        # Detach and clone input
+        vector_cache = vec.detach().clone()
+        update_sum   = vec.detach().clone()
+        coefficient_cache = 1
+
+        cached_update_sums = []
+        if self.acceleration_order > 0 and self.num_update_steps == 2 * self.acceleration_order + 1:
+            cached_update_sums.append(update_sum)
+
+        # Do HessianSeries calculation
+        for update_step in range(1, self.num_update_steps):
+            hessian2_vector_cache = self.hvp_module.hvp(self.hvp_module.hvp(vector_cache))
+
+            if update_step == 1:
+                scale_factor = torch.norm(hessian2_vector_cache, p=2) / torch.norm(vec, p=2)
+                scale_factor = max(scale_factor.item(), self.initial_scale_factor)
+
+            vector_cache = (vector_cache - (1/scale_factor)*hessian2_vector_cache).clone()
+            coefficient_cache *= (2 * update_step - 1) / (2 * update_step)
+            update_sum += coefficient_cache * vector_cache
+
+            if self.acceleration_order > 0 and update_step >= (self.num_update_steps - 2 * self.acceleration_order - 1):
+                cached_update_sums.append(update_sum.clone())
+
+        # Perform series acceleration (Shanks acceleration)
+        if self.acceleration_order > 0:
+            accelerated_sum = self.compute_epsilon_acceleration(
+                cached_update_sums, num_applications=self.acceleration_order
+            )
+            accelerated_sum /= np.sqrt(scale_factor)
+            accelerated_sum = accelerated_sum.unsqueeze(1)
+        
+            return accelerated_sum
+
+        update_sum /= np.sqrt(scale_factor)
+        return update_sum