Skip to content

Commit

Permalink
Allow metadata embedding to be optional.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696652400
  • Loading branch information
xingyousong authored and copybara-github committed Nov 14, 2024
1 parent e76d7d4 commit e9e98b3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 33 deletions.
1 change: 1 addition & 0 deletions optformer/embed_then_regress/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ModelConfig:
nhead: int = 16
dropout: float = 0.1
num_layers: int = 8
use_metadata: bool = True
std_transform: str = 'exp'

def create_model(
Expand Down
65 changes: 35 additions & 30 deletions optformer/embed_then_regress/icl_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ def __call__(
return x


class EmbeddingCache(dict[str, AnyTensor]):

def get_or_set(self, key: str, fn: Callable[[], AnyTensor]):
value = self.get(key)
if value is None:
value = fn()
self[key] = value
return value


class ICLTransformer(nn.Module):
"""ICL Transformer model for regression."""

Expand All @@ -94,6 +104,7 @@ class ICLTransformer(nn.Module):
nhead: int # H
dropout: float
num_layers: int
use_metadata: bool
std_transform_fn: Callable[[AnyTensor], AnyTensor]
embedder_factory: Callable[[], nn.Module] # __call__: [B, T] -> [B, D]

Expand All @@ -102,7 +113,7 @@ def setup(self):
self.embedder = self.embedder_factory()

# X, Y, and concatenated X,Y embedders.
self.xm_proj = nn.Sequential(
self.x_proj = nn.Sequential(
[Dense(self.d_model), nn.relu, Dense(self.d_model)]
)
self.y_proj = nn.Sequential(
Expand Down Expand Up @@ -132,7 +143,6 @@ def __call__(
self,
x_emb: jt.Float[jax.Array, 'B L E'],
y: jt.Float[jax.Array, 'B L'],
metadata_emb: jt.Float[jax.Array, 'B E'],
mask: jt.Bool[jax.Array, 'B L'],
deterministic: bool | None = None,
rng: jax.Array | None = None,
Expand Down Expand Up @@ -178,16 +188,16 @@ def fit(
rng: jax.Array | None = None,
) -> tuple[jt.Float[jax.Array, 'B L'], jt.Float[jax.Array, 'B L']]:
"""For training / eval loss metrics only."""
# pylint: disable=invalid-name
L = x.shape[1]

x_emb = self.embed(x) # [B, L, E]
metadata_emb = self.embed(metadata) # [B, E]

metadata_emb = jnp.expand_dims(metadata_emb, axis=1) # [B, 1, E]
metadata_emb = jnp.repeat(metadata_emb, L, axis=1) # [B, L, E]
xm_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [B, L, 2E]
return self.__call__(xm_emb, y, metadata_emb, mask, deterministic, rng)
if self.use_metadata:
L = x_emb.shape[1] # pylint: disable=invalid-name
metadata_emb = self.embed(metadata) # [B, E]
metadata_emb = jnp.expand_dims(metadata_emb, axis=1) # [B, 1, E]
metadata_emb = jnp.repeat(metadata_emb, L, axis=1) # [B, L, E]
x_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [B, L, 2E]

return self.__call__(x_emb, y, mask, deterministic, rng)

def infer(
self,
Expand All @@ -196,49 +206,44 @@ def infer(
x_targ: jt.Int[jax.Array, 'Q T'], # Q is fixed to avoid re-jitting.
metadata: jt.Int[jax.Array, 'T'],
mask: jt.Bool[jax.Array, 'L'],
cache: dict[str, jax.Array] | None = None, # For caching embeddings.
cache: EmbeddingCache | None = None, # For caching embeddings.
) -> tuple[
jt.Float[jax.Array, 'L'],
jt.Float[jax.Array, 'L'],
dict[str, jax.Array],
]:
"""Friendly for inference, no batch dimension."""
if cache is None:
x_padded_emb = self.embed(x_padded) # [L, E]
metadata_emb = self.embed(metadata) # [E]
cache = {'x_padded_emb': x_padded_emb, 'metadata_emb': metadata_emb}
else:
x_padded_emb = cache['x_padded_emb'] # [L, E]
metadata_emb = cache['metadata_emb'] # [E]
cache = EmbeddingCache()

# [L, E]
x_pad_emb = cache.get_or_set('x_pad_emb', lambda: self.embed(x_padded))
x_targ_emb = self.embed(x_targ) # [Q, E]

L, E = x_padded_emb.shape # pylint: disable=invalid-name

target_index = jnp.sum(mask, dtype=jnp.int32) # [1]
L, E = x_pad_emb.shape # pylint: disable=invalid-name

# Combine target and historical (padded) embeddings.
target_index = jnp.sum(mask, dtype=jnp.int32) # [1]
padded_target_emb = jnp.zeros((L, E), dtype=x_targ_emb.dtype)
padded_target_emb = jax.lax.dynamic_update_slice_in_dim(
padded_target_emb, x_targ_emb, start_index=target_index, axis=0
)
w_mask = jnp.expand_dims(mask, axis=-1) # [L, 1]
x_emb = x_padded_emb * w_mask + padded_target_emb * (1 - w_mask)
x_emb = x_pad_emb * w_mask + padded_target_emb * (1 - w_mask)

# Attach metadata embeddings too.
metadata_emb = jnp.expand_dims(metadata_emb, axis=0) # [1, E]
metadata_emb = jnp.repeat(metadata_emb, L, axis=0) # [L, E]
xm_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [L, 2E]
if self.use_metadata: # Attach metadata embeddings too.
metadata_emb = cache.get_or_set(
'metadata_emb', lambda: self.embed(metadata)
)
metadata_emb = jnp.expand_dims(metadata_emb, axis=0) # [1, E]
metadata_emb = jnp.repeat(metadata_emb, L, axis=0) # [L, E]
x_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [L, 2E]

# TODO: Are these batch=1 expands necessary?
mean, std = self.__call__(
x_emb=jnp.expand_dims(xm_emb, axis=0),
x_emb=jnp.expand_dims(x_emb, axis=0),
y=jnp.expand_dims(y_padded, axis=0),
metadata_emb=jnp.expand_dims(metadata_emb, axis=0),
mask=jnp.expand_dims(mask, axis=0),
deterministic=True,
)

return jnp.squeeze(mean, axis=0), jnp.squeeze(std, axis=0), cache

@nn.remat # Reduce memory consumption during backward pass.
Expand Down
8 changes: 5 additions & 3 deletions optformer/embed_then_regress/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

tfd = tfp.distributions

EmbeddingCache = icl_transformer.EmbeddingCache


# TODO: Maybe refactor omnipred2 regressor base class.
@attrs.define
Expand All @@ -54,7 +56,7 @@ class StatefulICLRegressor:
_all_yt: jt.Float[np.ndarray, 'L'] = attrs.field(init=False)
_mt: jt.Int[np.ndarray, 'T'] = attrs.field(init=False)
_num_prev: int = attrs.field(init=False)
_cache: dict[str, jax.Array] | None = attrs.field(init=False)
_cache: EmbeddingCache = attrs.field(init=False)

# Jitted function.
_jit_apply: Callable[..., Any] = attrs.field(init=False)
Expand Down Expand Up @@ -95,7 +97,7 @@ def absorb(self, xs: Sequence[str], ys: Sequence[float]):
self._all_xt[self._num_prev : self._num_prev + num_pts] = self._tokenize(xs)
self._all_yt[self._num_prev : self._num_prev + num_pts] = np.array(ys)
self._num_prev += num_pts
self._cache = None # Need to recompute historical embeddings.
self._cache = EmbeddingCache() # Need to recompute historical embeddings.

self.warper.train(self._all_yt[: self._num_prev])

Expand All @@ -109,7 +111,7 @@ def reset(self) -> None:
self._all_yt = np.zeros(self.max_trial_length, dtype=np.float32)
self._mt = np.zeros(self.max_token_length, dtype=np.int32)
self._num_prev = 0
self._cache = None
self._cache = EmbeddingCache()

def _tokenize(self, ss: Sequence[str]) -> jt.Int[np.ndarray, 'S T']:
"""Converts ss (strings) to tokens."""
Expand Down

0 comments on commit e9e98b3

Please sign in to comment.