Skip to content

Commit

Permalink
Support erasing up to a specified worst-case covariance difference or…
Browse files Browse the repository at this point in the history
… intervention rank
  • Loading branch information
luciaquirke committed Dec 16, 2024
1 parent 6449386 commit a3bcf93
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 25 deletions.
79 changes: 54 additions & 25 deletions concept_erasure/alf_qleace.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Literal

import torch
from torch import Tensor
Expand All @@ -8,8 +7,6 @@
from .groupby import groupby
from .shrinkage import optimal_linear_shrinkage

ErasureMethod = Literal["leace", "orth"]


@dataclass(frozen=True)
class AlfQLeaceEraser:
Expand Down Expand Up @@ -70,21 +67,20 @@ def __call__(self, x: Tensor) -> Tensor:

def to(self, device: torch.device | str) -> "AlfQLeaceEraser":
"""Move eraser to a new device."""
return AlfQLeaceEraser(
self.proj_left.to(device),
self.proj_right.to(device),
self.bias.to(device) if self.bias is not None else None,
self.alf_qleace_vecs.to(device),
)
self.proj_left = self.proj_left.to(device)
self.proj_right = self.proj_right.to(device)
self.bias = self.bias.to(device) if self.bias is not None else None
self.alf_qleace_vecs = self.alf_qleace_vecs.to(device)

return self


class AlfQLeaceFitter:
"""Fits LEACE plus a linear transform that surgically erases the direction of
maximum covariance from a representation.
This class implements Least-squares Concept Erasure (LEACE) from
https://arxiv.org/abs/2306.03819. You can also use a slightly simpler orthogonal
projection-based method by setting `method="orth"`.
https://arxiv.org/abs/2306.03819.
This class stores all the covariance statistics needed to compute the QLEACE eraser.
This allows the statistics to be updated incrementally with `update()`.
Expand All @@ -99,7 +95,7 @@ class AlfQLeaceFitter:
sigma_xz_: Tensor
"""Unnormalized cross-covariance matrix X^T Z."""

sigma_xx_: Tensor | None
sigma_xx_: Tensor
"""Unnormalized covariance matrix X^T X."""

sigma_xx_z_: Tensor
Expand All @@ -121,14 +117,15 @@ def __init__(
self,
x_dim: int,
z_dim: int,
method: ErasureMethod = "leace",
*,
affine: bool = True,
constrain_cov_trace: bool = True,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
shrinkage: bool = True,
svd_tol: float = 0.01,
max_rank: int | None = None,
target_erasure: float = 0.5,
):
"""Initialize a `LeaceFitter`.
Expand All @@ -150,6 +147,9 @@ def __init__(
phase where we compute the pseudoinverse of the projected covariance
matrix. Higher values are more numerically stable and result in less
damage to the representation, but may leave trace correlations intact.
target_erasure: Fraction of the worst-case covariance difference between
classes to erase. Higher values result in higher rank interventions.
max_rank: Maximum rank of the intervention.
"""
super().__init__()

Expand All @@ -158,9 +158,9 @@ def __init__(

self.affine = affine
self.constrain_cov_trace = constrain_cov_trace
self.method = method
self.shrinkage = shrinkage

self.target_erasure = target_erasure
self.max_rank = max_rank
assert svd_tol > 0.0, "`svd_tol` must be positive for numerical stability."
self.svd_tol = svd_tol

Expand Down Expand Up @@ -202,10 +202,8 @@ def update(self, x: Tensor, z: Tensor) -> "AlfQLeaceFitter":
self.global_mean_x += delta_x.sum(dim=0) / self.global_n
delta_x2 = x - self.global_mean_x

# Update the covariance matrix of X if needed (for LEACE)
if self.method == "leace":
assert self.sigma_xx_ is not None
self.sigma_xx_.addmm_(delta_x.mH, delta_x2)
# Update the covariance matrix of X for LEACE
self.sigma_xx_.addmm_(delta_x.mH, delta_x2)

z = z.reshape(n, -1).type_as(x)
assert z.shape[-1] == c, f"Unexpected number of classes {z.shape[-1]}"
Expand Down Expand Up @@ -239,7 +237,6 @@ def update_single(self, x: Tensor, z: int) -> "AlfQLeaceFitter":
@cached_property
def eraser(self) -> AlfQLeaceEraser:
"""Erasure function lazily computed given the current statistics."""
n_dims = 10
eye = torch.eye(
self.x_dim, device=self.global_mean_x.device, dtype=self.global_mean_x.dtype
)
Expand Down Expand Up @@ -318,17 +315,22 @@ def eraser(self) -> AlfQLeaceEraser:
transformed_sigma_xx_z_ = torch.stack(
[P @ self.sigma_xx_z_[i] @ P for i in range(self.z_dim)]
)
base_cov_diff_norm = (
(transformed_sigma_xx_z_ - transformed_sigma_xx_z_.mean(dim=0))
.norm(dim=(1, 2))
.max()
)

# Erase to max_rank principal directions to minimize the worst-case
# covariance difference between classes
principal_directions = []
for _ in range(n_dims):
max_rank = self.max_rank or transformed_sigma_xx_z_.flatten(1).shape[1]
for i in range(max_rank):
# Compute the class conditional covariance differences from the mean
mean_sigma_xx_z = transformed_sigma_xx_z_.mean(dim=0)
sigma_xx_z_diffs = transformed_sigma_xx_z_ - mean_sigma_xx_z

batch_svd = torch.vmap(
lambda x: torch.svd_lowrank(x, q=1, niter=10), randomness="different"
)
U, S, Vh = batch_svd(sigma_xx_z_diffs)
U, S, Vh = torch.svd_lowrank(sigma_xx_z_diffs, q=1, niter=10)

max_idx = torch.argmax(S.squeeze())
principal_directions.append(U.squeeze()[max_idx])
Expand All @@ -341,6 +343,19 @@ def eraser(self) -> AlfQLeaceEraser:
[proj_qleace @ sigma @ proj_qleace for sigma in transformed_sigma_xx_z_]
)

current_cov_diff_norm = (
(transformed_sigma_xx_z_ - transformed_sigma_xx_z_.mean(dim=0))
.norm(dim=(1, 2))
.max()
)

if current_cov_diff_norm < (1 - self.target_erasure) * base_cov_diff_norm:
print(
f"Found rank {i + 1} intervention to reduce worst-case covariance\
difference norm by {self.target_erasure:.0%}"
)
break

return AlfQLeaceEraser(
proj_left,
proj_right,
Expand Down Expand Up @@ -374,3 +389,17 @@ def sigma_xz(self) -> Tensor:
"""The cross-covariance matrix."""
assert self.global_n > 1, "Call update() with labels before accessing sigma_xz"
return self.sigma_xz_ / (self.global_n - 1)

def to(self, device: torch.device | str) -> "AlfQLeaceFitter":
"""Move fitter to a new device."""
self.global_mean_x = self.global_mean_x.to(device)
self.global_mean_z = self.global_mean_z.to(device)
self.global_n = self.global_n.to(device)

self.sigma_xz_ = self.sigma_xz_.to(device)
self.sigma_xx_ = self.sigma_xx_.to(device)
self.mean_x = self.mean_x.to(device)
self.n = self.n.to(device)
self.sigma_xx_z_ = self.sigma_xx_z_.to(device)

return self
8 changes: 8 additions & 0 deletions concept_erasure/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,11 @@ def sigma_xx(self) -> Tensor:
# Just apply Bessel's correction
else:
return S_hat / (n - 1)

def to(self, device: torch.device | str) -> "QuadraticFitter":
"""Move fitter to a new device."""
self.mean_x = self.mean_x.to(device)
self.n = self.n.to(device)
self.sigma_xx_ = self.sigma_xx_.to(device)

return self

0 comments on commit a3bcf93

Please sign in to comment.