Skip to content

Commit

Permalink
Add embedding caching to reduce inference costs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696234274
  • Loading branch information
xingyousong authored and copybara-github committed Nov 13, 2024
1 parent 5d528d9 commit d5777ea
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 34 deletions.
61 changes: 44 additions & 17 deletions optformer/embed_then_regress/icl_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def setup(self):
self.embedder = self.embedder_factory()

# X, Y, and concatenated X,Y embedders.
self.x_proj = nn.Sequential(
self.xm_proj = nn.Sequential(
[Dense(self.d_model), nn.relu, Dense(self.d_model)]
)
self.y_proj = nn.Sequential(
Expand Down Expand Up @@ -136,20 +136,43 @@ def __call__(
mask: jt.Bool[jax.Array, 'B L'],
deterministic: bool | None = None,
rng: jax.Array | None = None,
) -> tuple[jt.Float[jax.Array, 'B L'], jt.Float[jax.Array, 'B L']]:
embedding_cache: dict[str, jax.Array] | None = None,
) -> tuple[
jt.Float[jax.Array, 'B L'],
jt.Float[jax.Array, 'B L'],
dict[str, jax.Array],
]:
# pylint: disable=invalid-name

B, L, T = x.shape
x = jnp.reshape(x, (B * L, T))
x = self.embed(x) # [B*L, E]
x = jnp.reshape(x, (B, L, -1)) # [B, L, E]

metadata = self.embed(metadata) # [B, E]
metadata = jnp.expand_dims(metadata, axis=1) # [B, 1, E]
metadata = jnp.repeat(metadata, L, axis=1) # [B, L, E]
x = jnp.concatenate((x, metadata), axis=-1) # [B, L, 2E]

xt_emb = self.x_proj(x) # [B, L, D]
L = x.shape[1]

if embedding_cache is None:
x_emb = self.embed(x) # [B, L, E]
metadata_emb = self.embed(metadata) # [B, E]
embedding_cache = {'x': x_emb, 'metadata': metadata_emb}
else:
# Find starting index of target. Raise value error if masks are not all
# same, since dynamic_update_slice wouldn't work.
target_indices = jnp.sum(mask, axis=-1, dtype=jnp.int32)
if not jnp.all(target_indices == target_indices[0]):
raise ValueError('At inference, all masks must be the same in batch.')
target_index = target_indices[0]

# Embed only the new tokens.
target_x = x[:, target_index:, :] # [B=1, target_index, T]
target_x_emb = self.embed(target_x) # [B=1, target_index, E]

x_emb = jax.lax.dynamic_update_slice(
embedding_cache['x'],
target_x_emb,
start_indices=(0, target_index, 0),
)
metadata_emb = embedding_cache['metadata']

metadata_emb = jnp.expand_dims(metadata_emb, axis=1) # [B, 1, E]
metadata_emb = jnp.repeat(metadata_emb, L, axis=1) # [B, L, E]
xm_emb = jnp.concatenate((x_emb, metadata_emb), axis=-1) # [B, L, 2E]

xt_emb = self.xm_proj(xm_emb) # [B, L, D]

# Force 0.0 values for target points using the mask.
y = y * mask # [B, L], element-wise multiplication
Expand All @@ -173,8 +196,12 @@ def __call__(

mean = jnp.squeeze(mean, axis=-1)
std = jnp.squeeze(std, axis=-1)
return mean, std
return mean, std, embedding_cache

@nn.remat # Reduce memory consumption during backward pass.
def embed(self, tokens: jt.Int[jax.Array, 'X T']):
return self.embedder(tokens)
def embed(
self, tokens: jt.Int[jax.Array, '*X T']
) -> jt.Float[jax.Array, '*X E']:
reshaped_tokens = jnp.reshape(tokens, (-1, tokens.shape[-1]))
embeddings = self.embedder(reshaped_tokens) # [-1, E]
return jnp.reshape(embeddings, tokens.shape[:-1] + (embeddings.shape[-1],))
31 changes: 20 additions & 11 deletions optformer/embed_then_regress/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,21 @@ class StatefulICLRegressor:
params: flax_typing.FrozenVariableDict = attrs.field()
vocab: seqio.Vocabulary = attrs.field()

max_trial_length: int = attrs.field(default=300, kw_only=True) # L
max_memory_length: int = attrs.field(default=10000, kw_only=True) # M >> L
max_token_length: int = attrs.field(default=256, kw_only=True) # T

warper: normalization.StatefulWarper = attrs.field(
factory=normalization.default_warper, kw_only=True
)

# Internal state containing tokens.
_all_xt: jt.Int[np.ndarray, 'L T'] = attrs.field(init=False)
_all_yt: jt.Float[np.ndarray, 'L'] = attrs.field(init=False)
# Internal state containing history.
_all_xt: jt.Int[np.ndarray, 'M T'] = attrs.field(init=False)
_all_yt: jt.Float[np.ndarray, 'M'] = attrs.field(init=False)
_mt: jt.Int[np.ndarray, 'T'] = attrs.field(init=False)
_num_prev: int = attrs.field(init=False)
_embedding_cache: dict[str, jax.Array] | None = attrs.field(init=False)

# Jitted function.
_jit_apply: Callable[..., Any] = attrs.field(init=False)

def __attrs_post_init__(self):
Expand All @@ -64,25 +67,29 @@ def predict(self, xs: Sequence[str]) -> tfd.Distribution:
"""Returns prediction in normalized/warped space."""
num_query = len(xs)

temp_xt = np.copy(self._all_xt)
temp_xt[self._num_prev : self._num_prev + num_query] = self._tokenize(xs)
# Use instead of max_trial_length to reduce embedding costs.
max_length = self._num_prev + num_query # L

temp_xt = np.copy(self._all_xt)[:max_length]
temp_xt[self._num_prev :] = self._tokenize(xs)

temp_yt = np.copy(self._all_yt)
temp_yt = np.copy(self._all_yt)[:max_length]
temp_yt = self.warper.warp(temp_yt)

temp_mt = np.copy(self._mt)

mask = np.ones(self.max_trial_length, dtype=bool)
mask = np.ones(max_length, dtype=bool)
mask[self._num_prev :] = False

# Need to add batch dimension to all inputs.
mean, std = self._jit_apply(
mean, std, self._embedding_cache = self._jit_apply(
self.params,
x=np.expand_dims(temp_xt, axis=0), # [B=1, L, T],
y=np.expand_dims(temp_yt, axis=0), # [B=1, L],
metadata=np.expand_dims(temp_mt, axis=0), # [B=1, T],
mask=np.expand_dims(mask, axis=0), # [B=1, L],
deterministic=True,
embedding_cache=self._embedding_cache,
)

mean, std = np.squeeze(mean, axis=0), np.squeeze(std, axis=0)
Expand All @@ -97,6 +104,7 @@ def absorb(self, xs: Sequence[str], ys: Sequence[float]):
self._all_xt[self._num_prev : self._num_prev + num_pts] = self._tokenize(xs)
self._all_yt[self._num_prev : self._num_prev + num_pts] = np.array(ys)
self._num_prev += num_pts
self._embedding_cache = None # Need to recompute historical embeddings.

self.warper.train(self._all_yt[: self._num_prev])

Expand All @@ -105,11 +113,12 @@ def set_metadata(self, metadata: str) -> None:

def reset(self) -> None:
self._all_xt = np.zeros(
(self.max_trial_length, self.max_token_length), dtype=np.int32
(self.max_memory_length, self.max_token_length), dtype=np.int32
)
self._all_yt = np.zeros(self.max_trial_length, dtype=np.float32)
self._all_yt = np.zeros(self.max_memory_length, dtype=np.float32)
self._mt = np.zeros(self.max_token_length, dtype=np.int32)
self._num_prev = 0
self._embedding_cache = None

def _tokenize(self, ss: Sequence[str]) -> jt.Int[np.ndarray, 'S T']:
"""Converts ss (strings) to tokens."""
Expand Down
4 changes: 3 additions & 1 deletion optformer/embed_then_regress/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def loss_fn(
) -> tuple[jax.Array, Mapping[str, Scalar]]:
"""Loss function with metrics."""
# pylint: disable=invalid-name
mean, std = model.apply(params, deterministic=not training, rng=rng, **batch)
mean, std, _ = model.apply(
params, deterministic=not training, rng=rng, **batch
)
nlogprob = -jax.scipy.stats.norm.logpdf(batch['y'], mean, std) # [B, L]

# Only compute loss over target ys. Mask is BxL where True denotes context
Expand Down
6 changes: 1 addition & 5 deletions optformer/embed_then_regress/vizier/designer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@
use_fori=False,
)

default_scoring_function_factory = acq_lib.bayesian_scoring_function_factory(
lambda _: acq_lib.UCB()
)


@attrs.define
class TransformerICLOptDesigner(vza.Designer):
Expand All @@ -57,7 +53,7 @@ class TransformerICLOptDesigner(vza.Designer):
default=default_optimizer_factory, kw_only=True
)
_acq_fn: acq_lib.AcquisitionFunction = attrs.field(
default=acq_lib.UCB(), kw_only=True
factory=acq_lib.UCB, kw_only=True
)

_rng: jax.Array = attrs.field(
Expand Down

0 comments on commit d5777ea

Please sign in to comment.