diff --git a/optformer/embed_then_regress/icl_transformer.py b/optformer/embed_then_regress/icl_transformer.py index 43b082c..6a04eba 100644 --- a/optformer/embed_then_regress/icl_transformer.py +++ b/optformer/embed_then_regress/icl_transformer.py @@ -14,9 +14,11 @@ """Transformer model for ICL regression.""" +import dataclasses import functools from typing import Callable from flax import linen as nn +from flax import struct import jax import jax.numpy as jnp import jaxtyping as jt @@ -86,14 +88,12 @@ def __call__( return x -class EmbeddingCache(dict[str, AnyTensor]): +@struct.dataclass +class EmbeddingCache: + """Cache for storing previously computed embeddings.""" - def get_or_set(self, key: str, fn: Callable[[], AnyTensor]): - value = self.get(key) - if value is None: - value = fn() - self.update({key: value}) - return value + x_emb: jt.Float[jax.Array, 'L E'] | None = None + metadata_emb: jt.Float[jax.Array, 'E'] | None = None class ICLTransformer(nn.Module): @@ -206,36 +206,33 @@ def infer( x_targ: jt.Int[jax.Array, 'Q T'], # Q is fixed to avoid re-jitting. metadata: jt.Int[jax.Array, 'T'], mask: jt.Bool[jax.Array, 'L'], - cache: EmbeddingCache | None = None, # For caching embeddings. + cache: EmbeddingCache, # For caching embeddings. ) -> tuple[ jt.Float[jax.Array, 'L'], jt.Float[jax.Array, 'L'], - dict[str, jax.Array], + EmbeddingCache, ]: """Friendly for inference, no batch dimension.""" - if cache is None: - cache = EmbeddingCache() - - # [L, E] - x_pad_emb = cache.get_or_set('x_pad_emb', lambda: self.embed(x_padded)) + if cache.x_emb is None: + cache = dataclasses.replace(cache, x_emb=self.embed(x_padded)) + x_pad_emb = cache.x_emb # [L, E] x_targ_emb = self.embed(x_targ) # [Q, E] - L, E = x_pad_emb.shape # pylint: disable=invalid-name # Combine target and historical (padded) embeddings. target_index = jnp.sum(mask, dtype=jnp.int32) # [1] - padded_target_emb = jnp.zeros((L, E), dtype=x_targ_emb.dtype) + padded_target_emb = jnp.zeros_like(x_pad_emb) padded_target_emb = jax.lax.dynamic_update_slice_in_dim( padded_target_emb, x_targ_emb, start_index=target_index, axis=0 ) w_mask = jnp.expand_dims(mask, axis=-1) # [L, 1] - x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask) + x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask) # [L, E] if self.use_metadata: # Attach metadata embeddings too. - metadata_emb = cache.get_or_set( - 'metadata_emb', lambda: self.embed(metadata) - ) + if cache.metadata_emb is None: + cache = dataclasses.replace(cache, metadata_emb=self.embed(metadata)) + metadata_emb = cache.metadata_emb # [E] metadata_emb = jnp.expand_dims(metadata_emb, axis=0) # [1, E] - metadata_emb = jnp.repeat(metadata_emb, L, axis=0) # [L, E] + metadata_emb = jnp.repeat(metadata_emb, x_emb.shape[0], axis=0) # [L, E] x_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [L, 2E] mean, std = self.__call__(