Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Change KFAC and KFAC inverse to purely PyTorch #149

Draft
wants to merge 36 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
06ebbf2
[ADD] Minimal linear operator interface for PyTorch
f-dangel Sep 21, 2024
f652a02
[FIX] Linters
f-dangel Sep 21, 2024
9fa1b6a
[ADD] Replicate current base class but inherit from `PyTorchLinearOpe…
f-dangel Sep 21, 2024
62dc997
[DEL] Unused import
f-dangel Sep 21, 2024
59c6369
[ADD] Implement Hessian as `CurvatureLinearOperator`
f-dangel Sep 21, 2024
07144a3
[REF] Combine full and block-diagonal matrix multiply tests
f-dangel Sep 22, 2024
9c538ea
[ADD] Automatically implement `adjoint` for self-adjoint ops
f-dangel Sep 22, 2024
3a8ac95
[ADD] Implement GGN as `CurvatureLinearOperator`
f-dangel Sep 22, 2024
73cdb5a
[FIX] flake8
f-dangel Sep 22, 2024
e8f77e8
[FIX] Increase tolerance to make test pass
f-dangel Sep 22, 2024
8b6fed7
[DEL] Unused import
f-dangel Sep 22, 2024
0ed579f
[ADD] Implement `FisherMCLinearOperator` as `CurvatureLinearOperator`
f-dangel Sep 23, 2024
d72094b
[ADD] Implement Jacobians as `CurvatureLinearOperator`
f-dangel Sep 23, 2024
85e6c6b
[REF] Extract error test with ``MutableMapping`` and ``batch_size_fn=…
f-dangel Sep 23, 2024
c819189
[ADD] Class attribute for checking deterministic batches
f-dangel Sep 23, 2024
d38bb4f
[ADD] Extract getters for `_in_shape` and `_out_shape`
f-dangel Sep 23, 2024
4f546fc
[ADD] Test deterministic batches
f-dangel Sep 23, 2024
87c6eba
[DOC] Add documentation and type annotations
f-dangel Sep 23, 2024
c6fe684
[ADD] Implement empirical Fisher as `CurvatureLinearOperator`
f-dangel Sep 23, 2024
eba5976
Merge branch 'main' into fisher-linop
f-dangel Sep 23, 2024
e678ea2
[REF] Extract testing matrix multiplies in expectation
f-dangel Sep 23, 2024
32523ad
[FIX] Examples
f-dangel Sep 23, 2024
9ff958b
[ADD] Check deterministic batches in MC-Fisher (fixes #75)
f-dangel Sep 23, 2024
76a5190
Merge branch 'jacobian-linop' into fisher-linop
f-dangel Sep 23, 2024
f67ca0f
Merge branch 'ef-linop' into act-hessian-linop
f-dangel Sep 23, 2024
9be72fc
[ADD] Implement activation Hessian as `CurvatureLinearOperator`
f-dangel Sep 23, 2024
a9ea3fc
[ADD] Implement KFAC as `CurvatureLinearOperator`
f-dangel Sep 23, 2024
012cc21
[DEL] Remove original base class
f-dangel Sep 23, 2024
4d0020c
[FIX] Darglint
f-dangel Oct 26, 2024
efaf74a
[DEL] Remove unused imports
f-dangel Oct 26, 2024
712c4c5
[ADD] Declare activation Hessian as self-adjoint
f-dangel Oct 26, 2024
bd90a86
[FIX] Some tests
f-dangel Oct 26, 2024
73add14
[FIX] KFAC tests
f-dangel Oct 26, 2024
f24e11e
[FIX] Implement KFAC inverse as Curvature linop
f-dangel Oct 26, 2024
e422dc9
[FIX] Make tests work on GPU and CPU
f-dangel Oct 27, 2024
7b7f746
[DOC] Remove `torch_matvec` and `torch_matmat`
f-dangel Oct 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
410 changes: 0 additions & 410 deletions curvlinops/_base.py

This file was deleted.

820 changes: 820 additions & 0 deletions curvlinops/_torch_base.py

Large diffs are not rendered by default.

124 changes: 67 additions & 57 deletions curvlinops/experimental/activation_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,28 @@

import contextlib
from types import TracebackType
from typing import Callable, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import (
Callable,
Iterable,
List,
MutableMapping,
Optional,
Set,
Tuple,
Type,
Union,
)

from backpack.hessianfree.hvp import hessian_vector_product
from numpy import ndarray
from torch import Tensor, from_numpy, zeros_like
from torch import Tensor, zeros_like
from torch.autograd import grad
from torch.nn import Module
from torch.utils.hooks import RemovableHandle

from curvlinops._base import _LinearOperator
from curvlinops._torch_base import CurvatureLinearOperator


class ActivationHessianLinearOperator(_LinearOperator):
class ActivationHessianLinearOperator(CurvatureLinearOperator):
r"""Hessian of the loss w.r.t. hidden features in a neural network.

Consider the empirical risk on a single mini-batch
Expand All @@ -38,17 +47,22 @@ class ActivationHessianLinearOperator(_LinearOperator):
\ell(f_{\mathbf{\theta}}(\mathbf{X}), \mathbf{y})

and has dimension :math:`\mathrm{dim}(\mathbf{Z}) = N \mathrm{dim}(\mathbf{z})`.

Attributes:
SELF_ADJOINT: Whether the linear operator is self-adjoint. ``True`` for the
activation Hessian.
"""

SELF_ADJOINT: bool = True

def __init__(
self,
model_func: Module,
loss_func: Callable[[Tensor, Tensor], Tensor],
loss_func: Callable[[Union[MutableMapping, Tensor], Tensor], Tensor],
activation: Tuple[str, str, int],
data: Iterable[Tuple[Tensor, Tensor]],
data: Iterable[Tuple[Union[MutableMapping, Tensor], Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
shape: Optional[Tuple[int, int]] = None,
):
"""Linear operator for the loss Hessian w.r.t. intermediate features.

Expand All @@ -73,8 +87,6 @@ def __init__(
model's forward pass could depend on the order in which mini-batches
are presented (BatchNorm, Dropout). Default: ``True``. This is a
safeguard, only turn it off if you know what you are doing.
shape: Shape of the represented matrix. If ``None``, this dimension will be
inferred at the cost of one forward pass through the model.

Raises:
ValueError: If ``data`` contains more than one batch.
Expand All @@ -92,32 +104,23 @@ def __init__(
>>>
>>> hessian = ActivationHessianLinearOperator( # Hessian w.r.t. ReLU input
... model, loss_func, ("1", "input", 0), data
... )
... ).to_scipy()
>>> hessian.shape # batch size * feature dimension (10 * 3)
(30, 30)
>>>
>>> # The ReLU's input is the first Linear's output, let's check that
>>> hessian2 = ActivationHessianLinearOperator( # Hessian w.r.t. first output
... model, loss_func, ("0", "output", 0), data
... )
... ).to_scipy()
>>> I = eye(hessian.shape[1])
>>> allclose(hessian @ I, hessian2 @ I)
True
"""
self._activation = activation

# Compute shape of activation and ensure there is only one batch
# Ensure there is only one batch
data_iter = iter(data)
X, _ = next(data_iter)
(dev,) = {p.device for p in model_func.parameters()}
X = X.to(dev)
activation_storage = []
with store_activation(model_func, *activation, activation_storage):
model_func(X)
act = activation_storage.pop()
shape = (act.numel(), act.numel())
self._activation_shape = tuple(act.shape)

next(data_iter)
with contextlib.suppress(StopIteration):
next(data_iter)
raise ValueError(f"{self.__class__.__name__} requires a single batch.")
Expand All @@ -129,25 +132,49 @@ def __init__(
data,
progressbar=progressbar,
check_deterministic=check_deterministic,
shape=shape,
)

def _matmat_batch(
self, X: Tensor, y: Tensor, M_list: List[Tensor]
) -> Tuple[Tensor, ...]:
"""Apply the activation Hessian to a matrix.
def _get_out_shape(self) -> List[Tuple[int, ...]]:
"""Return the output shape of the activation Hessian.

Returns:
The dimensions of the activation Hessian's output space.
"""
if not hasattr(self, "_activation_shape"):
X, _ = next(iter(self._data))
dev = self._infer_device()
if isinstance(X, Tensor):
X = X.to(dev)
activation_storage = []
with store_activation(
self._model_func, *self._activation, activation_storage
):
self._model_func(X)
act = activation_storage.pop()

self._activation_shape = tuple(act.shape)

return [self._activation_shape]

def _get_in_shape(self) -> List[Tuple[int, ...]]:
"""Return the input shape of the activation Hessian.

Returns:
The dimensions of the activation Hessian's input space.
"""
return self._get_out_shape()

def _matmat_batch(self, X: Tensor, y: Tensor, M: List[Tensor]) -> List[Tensor]:
"""Apply the activation Hessian to a matrix in tensor list format.

Args:
X: Input to the DNN.
y: Ground truth.
M_list: Matrix to be multiplied with in list format.
Tensors have same shape as trainable model parameters, and an
additional leading axis for the matrix columns.
M: Matrix to be multiplied with in tensor list format.

Returns:
Result of activation Hessian multiplication in list format. Has the same
shape as ``M_list``, i.e. each tensor in the list has the shape of a
parameter and a leading dimension of matrix columns.
shape as ``M``.
"""
activation_storage = []
with store_activation(self._model_func, *self._activation, activation_storage):
Expand All @@ -159,34 +186,17 @@ def _matmat_batch(
grad_activation = grad(loss, activation, create_graph=True)

# collect
result_list = [zeros_like(M) for M in M_list]
HM = [zeros_like(m) for m in M]

num_vectors = M_list[0].shape[0]
(num_vectors,) = {m.shape[-1] for m in M}
for n in range(num_vectors):
out_n_list = hessian_vector_product(
loss, [activation], [M[n] for M in M_list], grad_params=grad_activation
HM_col = hessian_vector_product(
loss, [activation], [m[..., n] for m in M], grad_params=grad_activation
)
for result, out_n in zip(result_list, out_n_list):
result[n].add_(out_n)

return tuple(result_list)
for HM_p, HM_col_p in zip(HM, HM_col):
HM_p[..., n].add_(HM_col_p)

def _preprocess(self, M: ndarray) -> List[Tensor]:
"""Reshape the incoming matrix into the activation shape and convert to PyTorch.

Args:
M: Matrix in NumPy format onto which the linear operator is applied.

Returns:
Matrix in PyTorch format. Has same shape as the activation with an
additional leading axis for the matrix columns.
"""
num_vectors = M.shape[1]
return [
from_numpy(M.T)
.to(self._device)
.reshape(num_vectors, *self._activation_shape)
]
return HM


class store_activation:
Expand Down
62 changes: 28 additions & 34 deletions curvlinops/fisher.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
"""Contains LinearOperator implementation of the (approximate) Fisher."""

from __future__ import annotations
"""Contains LinearOperator implementation of the approximate Fisher."""

from collections.abc import MutableMapping
from math import sqrt
from typing import Callable, Iterable, List, Optional, Tuple, Union

from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
from einops import einsum, rearrange
from numpy import ndarray
from torch import Generator, Tensor, as_tensor, normal, softmax, zeros, zeros_like
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, Parameter
from torch.nn.functional import one_hot

from curvlinops._base import _LinearOperator
from curvlinops._torch_base import CurvatureLinearOperator


class FisherMCLinearOperator(_LinearOperator):
class FisherMCLinearOperator(CurvatureLinearOperator):
r"""Monte-Carlo approximation of the Fisher as SciPy linear operator.

Consider the empirical risk
Expand Down Expand Up @@ -101,13 +98,20 @@ class FisherMCLinearOperator(_LinearOperator):
The linear operator represents a deterministic sample from this MC Fisher estimator.
To generate different samples, you have to create instances with varying random
seed argument.

Attributes:
SELF_ADJOINT: Whether the operator is self-adjoint. ``True`` for the Fisher.
supported_losses: Supported loss functions.
FIXED_DATA_ORDER: Whether the data order must be fix. ``True`` for MC-Fisher.
"""

SELF_ADJOINT: bool = True
FIXED_DATA_ORDER: bool = True
supported_losses = (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss)

def __init__(
self,
model_func: Callable[[Tensor], Tensor],
model_func: Callable[[Union[Tensor, MutableMapping]], Tensor],
loss_func: Union[MSELoss, CrossEntropyLoss],
params: List[Parameter],
data: Iterable[Tuple[Union[Tensor, MutableMapping], Tensor]],
Expand All @@ -118,7 +122,7 @@ def __init__(
num_data: Optional[int] = None,
batch_size_fn: Optional[Callable[[MutableMapping], int]] = None,
):
"""Linear operator for the MC approximation of the Fisher.
"""Linear operator for the Monte-Carlo approximation of the type-I Fisher.

Note:
f(X; θ) denotes a neural network, parameterized by θ, that maps a mini-batch
Expand Down Expand Up @@ -178,16 +182,16 @@ def __init__(
batch_size_fn=batch_size_fn,
)

def _matmat(self, M: ndarray) -> ndarray:
def _matmat(self, M: List[Tensor]) -> List[Tensor]:
"""Multiply the MC-Fisher onto a matrix.

Create and seed the random number generator.

Args:
M: Matrix for multiplication.
M: Matrix for multiplication in tensor list format.

Returns:
Matrix-multiplication result ``mat @ M``.
Matrix-multiplication result ``mat @ M`` in tensor list format.
"""
if self._generator is None or self._generator.device != self._device:
self._generator = Generator(device=self._device)
Expand All @@ -196,21 +200,21 @@ def _matmat(self, M: ndarray) -> ndarray:
return super()._matmat(M)

def _matmat_batch(
self, X: Union[Tensor, MutableMapping], y: Tensor, M_list: List[Tensor]
) -> Tuple[Tensor, ...]:
self, X: Union[Tensor, MutableMapping], y: Tensor, M: List[Tensor]
) -> List[Tensor]:
"""Apply the mini-batch MC-Fisher to a matrix.

Args:
X: Input to the DNN.
y: Ground truth.
M_list: Matrix to be multiplied with in list format.
M: Matrix to be multiplied with in tensor list format.
Tensors have same shape as trainable model parameters, and an
additional leading axis for the matrix columns.
additional trailing axis for the matrix columns.

Returns:
Result of MC-Fisher multiplication in list format. Has the same shape as
``M_list``, i.e. each tensor in the list has the shape of a parameter and a
leading dimension of matrix columns.
Result of MC-Fisher multiplication in tensor list format. Has the same shape
as ``M``, i.e. each tensor in the list has the shape of a parameter and a
trailing dimension of matrix columns.
"""
# compute ∂ℓₙ(yₙₘ)/∂fₙ where fₙ is the prediction for datum n and
# yₙₘ is the m-th sampled label for datum n
Expand Down Expand Up @@ -248,17 +252,17 @@ def _matmat_batch(
)

# Multiply the MC Fisher onto each vector in the input matrix
result_list = [zeros_like(M) for M in M_list]
num_vectors = M_list[0].shape[0]
FM = [zeros_like(m) for m in M]
(num_vectors,) = {m.shape[-1] for m in M}
for v in range(num_vectors):
for idx, ggnvp in enumerate(
for idx, Fm in enumerate(
ggn_vector_product_from_plist(
loss, output, self._params, [M[v] for M in M_list]
loss, output, self._params, [m[..., v] for m in M]
)
):
result_list[idx][v].add_(ggnvp.detach())
FM[idx][..., v].add_(Fm.detach())

return tuple(result_list)
return FM

def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Tensor:
"""Draw would-be gradients ``∇_f log p(·|f)``.
Expand Down Expand Up @@ -326,13 +330,3 @@ def sample_grad_output(self, output: Tensor, num_samples: int, y: Tensor) -> Ten

else:
raise NotImplementedError(f"Supported losses: {self.supported_losses}")

def _adjoint(self) -> FisherMCLinearOperator:
"""Return the linear operator representing the adjoint.

The Fisher MC-approximation is real symmetric, and hence self-adjoint.

Returns:
Self.
"""
return self
Loading
Loading