Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix caching #145

Merged
merged 1 commit into from
Nov 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions optformer/embed_then_regress/icl_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

"""Transformer model for ICL regression."""

import dataclasses
import functools
from typing import Callable
from flax import linen as nn
from flax import struct
import jax
import jax.numpy as jnp
import jaxtyping as jt
Expand Down Expand Up @@ -86,14 +88,12 @@ def __call__(
return x


class EmbeddingCache(dict[str, AnyTensor]):
@struct.dataclass
class EmbeddingCache:
"""Cache for storing previously computed embeddings."""

def get_or_set(self, key: str, fn: Callable[[], AnyTensor]):
value = self.get(key)
if value is None:
value = fn()
self.update({key: value})
return value
x_emb: jt.Float[jax.Array, 'L E'] | None = None
metadata_emb: jt.Float[jax.Array, 'E'] | None = None


class ICLTransformer(nn.Module):
Expand Down Expand Up @@ -206,36 +206,33 @@ def infer(
x_targ: jt.Int[jax.Array, 'Q T'], # Q is fixed to avoid re-jitting.
metadata: jt.Int[jax.Array, 'T'],
mask: jt.Bool[jax.Array, 'L'],
cache: EmbeddingCache | None = None, # For caching embeddings.
cache: EmbeddingCache, # For caching embeddings.
) -> tuple[
jt.Float[jax.Array, 'L'],
jt.Float[jax.Array, 'L'],
dict[str, jax.Array],
EmbeddingCache,
]:
"""Friendly for inference, no batch dimension."""
if cache is None:
cache = EmbeddingCache()

# [L, E]
x_pad_emb = cache.get_or_set('x_pad_emb', lambda: self.embed(x_padded))
if cache.x_emb is None:
cache = dataclasses.replace(cache, x_emb=self.embed(x_padded))
x_pad_emb = cache.x_emb # [L, E]
x_targ_emb = self.embed(x_targ) # [Q, E]
L, E = x_pad_emb.shape # pylint: disable=invalid-name

# Combine target and historical (padded) embeddings.
target_index = jnp.sum(mask, dtype=jnp.int32) # [1]
padded_target_emb = jnp.zeros((L, E), dtype=x_targ_emb.dtype)
padded_target_emb = jnp.zeros_like(x_pad_emb)
padded_target_emb = jax.lax.dynamic_update_slice_in_dim(
padded_target_emb, x_targ_emb, start_index=target_index, axis=0
)
w_mask = jnp.expand_dims(mask, axis=-1) # [L, 1]
x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask)
x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask) # [L, E]

if self.use_metadata: # Attach metadata embeddings too.
metadata_emb = cache.get_or_set(
'metadata_emb', lambda: self.embed(metadata)
)
if cache.metadata_emb is None:
cache = dataclasses.replace(cache, metadata_emb=self.embed(metadata))
metadata_emb = cache.metadata_emb # [E]
metadata_emb = jnp.expand_dims(metadata_emb, axis=0) # [1, E]
metadata_emb = jnp.repeat(metadata_emb, L, axis=0) # [L, E]
metadata_emb = jnp.repeat(metadata_emb, x_emb.shape[0], axis=0) # [L, E]
x_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [L, 2E]

mean, std = self.__call__(
Expand Down
Loading