Skip to content

Commit

Permalink
Fix caching
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697005664
  • Loading branch information
xingyousong authored and copybara-github committed Nov 15, 2024
1 parent 72ab8db commit 74ee180
Showing 1 changed file with 18 additions and 21 deletions.
39 changes: 18 additions & 21 deletions optformer/embed_then_regress/icl_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

"""Transformer model for ICL regression."""

import dataclasses
import functools
from typing import Callable
from flax import linen as nn
from flax import struct
import jax
import jax.numpy as jnp
import jaxtyping as jt
Expand Down Expand Up @@ -86,14 +88,12 @@ def __call__(
return x


class EmbeddingCache(dict[str, AnyTensor]):
@struct.dataclass
class EmbeddingCache:
"""Cache for storing previously computed embeddings."""

def get_or_set(self, key: str, fn: Callable[[], AnyTensor]):
value = self.get(key)
if value is None:
value = fn()
self.update({key: value})
return value
x_emb: jt.Float[jax.Array, 'L E'] | None = None
metadata_emb: jt.Float[jax.Array, 'E'] | None = None


class ICLTransformer(nn.Module):
Expand Down Expand Up @@ -206,36 +206,33 @@ def infer(
x_targ: jt.Int[jax.Array, 'Q T'], # Q is fixed to avoid re-jitting.
metadata: jt.Int[jax.Array, 'T'],
mask: jt.Bool[jax.Array, 'L'],
cache: EmbeddingCache | None = None, # For caching embeddings.
cache: EmbeddingCache, # For caching embeddings.
) -> tuple[
jt.Float[jax.Array, 'L'],
jt.Float[jax.Array, 'L'],
dict[str, jax.Array],
EmbeddingCache,
]:
"""Friendly for inference, no batch dimension."""
if cache is None:
cache = EmbeddingCache()

# [L, E]
x_pad_emb = cache.get_or_set('x_pad_emb', lambda: self.embed(x_padded))
if cache.x_emb is None:
cache = dataclasses.replace(cache, x_emb=self.embed(x_padded))
x_pad_emb = cache.x_emb # [L, E]
x_targ_emb = self.embed(x_targ) # [Q, E]
L, E = x_pad_emb.shape # pylint: disable=invalid-name

# Combine target and historical (padded) embeddings.
target_index = jnp.sum(mask, dtype=jnp.int32) # [1]
padded_target_emb = jnp.zeros((L, E), dtype=x_targ_emb.dtype)
padded_target_emb = jnp.zeros_like(x_pad_emb)
padded_target_emb = jax.lax.dynamic_update_slice_in_dim(
padded_target_emb, x_targ_emb, start_index=target_index, axis=0
)
w_mask = jnp.expand_dims(mask, axis=-1) # [L, 1]
x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask)
x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask) # [L, E]

if self.use_metadata: # Attach metadata embeddings too.
metadata_emb = cache.get_or_set(
'metadata_emb', lambda: self.embed(metadata)
)
if cache.metadata_emb is None:
cache = dataclasses.replace(cache, metadata_emb=self.embed(metadata))
metadata_emb = cache.metadata_emb # [E]
metadata_emb = jnp.expand_dims(metadata_emb, axis=0) # [1, E]
metadata_emb = jnp.repeat(metadata_emb, L, axis=0) # [L, E]
metadata_emb = jnp.repeat(metadata_emb, x_emb.shape[0], axis=0) # [L, E]
x_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [L, 2E]

mean, std = self.__call__(
Expand Down

0 comments on commit 74ee180

Please sign in to comment.