Skip to content

Commit

Permalink
Merge branch 'keras-team:master' into kaggle-upload
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat authored Mar 19, 2024
2 parents 7b18970 + 4511580 commit e578789
Show file tree
Hide file tree
Showing 9 changed files with 608 additions and 63 deletions.
3 changes: 0 additions & 3 deletions .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@ if [[ -z "${KAGGLE_USERNAME}" ]]; then
fi

set -x

cd "${KOKORO_ROOT}/"

sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1

PYTHON_BINARY="/usr/bin/python3.9"

"${PYTHON_BINARY}" -m venv venv
Expand Down
47 changes: 29 additions & 18 deletions keras_nlp/layers/modeling/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,42 @@ def __init__(
self.built = True

def call(self, inputs, start_index=0):
inputs = ops.moveaxis(
inputs, (self.feature_axis, self.sequence_axis), (-1, 1)
)
cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, start_index)
return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)
output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)
return ops.moveaxis(
output, (-1, 1), (self.feature_axis, self.sequence_axis)
)

def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
x1, x2 = ops.split(tensor, 2, axis=self.feature_axis)
half_rot_tensor = ops.concatenate((-x2, x1), axis=self.feature_axis)
x1, x2 = ops.split(tensor, 2, axis=-1)
# Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA
# compilation on jax. We should be able to remove this once the
# following PR is in all jax releases we care about:
# https://github.com/openxla/xla/pull/7875
half_rot_tensor = ops.stack((-x2, x1), axis=-2)
half_rot_tensor = ops.reshape(half_rot_tensor, ops.shape(tensor))
return (tensor * cos_emb) + (half_rot_tensor * sin_emb)

def _compute_cos_sin_embedding(self, inputs, start_index=0):
def get_axis(axis):
return axis if axis > 0 else len(inputs.shape) + axis
start_index = ops.cast(start_index, dtype="float32")

feature_axis = get_axis(self.feature_axis)
sequence_axis = get_axis(self.sequence_axis)
feature_axis = len(inputs.shape) - 1
sequence_axis = 1

rotary_dim = ops.shape(inputs)[feature_axis]
inverse_freq = self._get_inverse_freq(rotary_dim)

seq_len = ops.shape(inputs)[self.sequence_axis]
tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index
seq_len = ops.shape(inputs)[sequence_axis]
tensor = ops.arange(seq_len, dtype="float32") + start_index

tensor = ops.cast(tensor, dtype=inverse_freq.dtype)
freq = ops.einsum("i,j->ij", tensor, inverse_freq)
embedding = ops.concatenate((freq, freq), axis=-1)
embedding = ops.stack((freq, freq), axis=-2)
embedding = ops.reshape(
embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2)
)

# Reshape the embedding to be broadcastable with input shape.
if feature_axis < sequence_axis:
Expand All @@ -117,17 +129,16 @@ def get_axis(axis):
if axis != sequence_axis and axis != feature_axis:
embedding = ops.expand_dims(embedding, axis)

return ops.cos(embedding), ops.sin(embedding)
cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype)
sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype)
return cos_emb, sin_emb

def _get_inverse_freq(self, rotary_dim):
freq_range = ops.arange(0, rotary_dim, 2)
freq_range = ops.cast(freq_range, self.compute_dtype)
freq_range = freq_range / ops.cast(
self.scaling_factor, self.compute_dtype
)
freq_range = ops.arange(0, rotary_dim, 2, dtype="float32")
freq_range = freq_range / ops.cast(self.scaling_factor, "float32")
inverse_freq = 1.0 / (
self.max_wavelength
** (freq_range / ops.cast(rotary_dim, self.compute_dtype))
** (freq_range / ops.cast(rotary_dim, "float32"))
)
return inverse_freq

Expand Down
43 changes: 17 additions & 26 deletions keras_nlp/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_nlp.utils.keras_utils import clone_initializer


Expand Down Expand Up @@ -87,28 +88,23 @@ def build(self, inputs_shape):
(None, None, self.num_query_heads, self.head_dim)
)
self.softmax = keras.layers.Softmax(dtype="float32")

self.rope_layer = RotaryEmbedding(
max_wavelength=10_000.0, dtype=self.dtype_policy
)

self.built = True

def _apply_rope(self, x, positions):
def _apply_rope(self, x, start_index):
"""Rope rotate q or k."""
# TODO: refactor to use RotaryEmbedding layer?
max_wavelength = 10000
x_shape = ops.shape(x)
freq_exponents = (2.0 / x_shape[-1]) * ops.arange(
x_shape[-1] // 2, dtype="float32"
x = self.rope_layer(x, start_index=start_index)
# Gemma uses a different layout for positional embeddings.
# The transformation below ensures the embeddings are numerically
# equivalent to the original gemma implementation.
x = ops.reshape(
ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x)
)
timescale = max_wavelength**freq_exponents
radians = positions[..., None] / timescale[None, None, :]
radians = radians[..., None, :]
sin = ops.cast(ops.sin(radians), self.compute_dtype)
cos = ops.cast(ops.cos(radians), self.compute_dtype)
x1, x2 = ops.split(x, 2, axis=-1)
# Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA
# compilation on jax. We should be able to remove this once the
# following PR is in all jax releases we care about:
# https://github.com/openxla/xla/pull/7875
output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
return ops.reshape(output, x_shape)
return x

def _compute_attention(
self,
Expand Down Expand Up @@ -155,27 +151,22 @@ def call(
cache_update_index=0,
training=False,
):
seq_len = ops.shape(x)[1]
start_index = cache_update_index
positions = ops.arange(seq_len, dtype="float32")

positions = positions + ops.cast(start_index, "float32")
query = self.query_dense(x)
query = self._apply_rope(query, positions)
query = self._apply_rope(query, cache_update_index)

if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
key_update = self.key_dense(x)
key_update = self._apply_rope(key_update, positions)
key_update = self._apply_rope(key_update, cache_update_index)
value_update = self.value_dense(x)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
key = self.key_dense(x)
key = self._apply_rope(key, positions)
key = self._apply_rope(key, cache_update_index)
value = self.value_dense(x)

attention_vec = self._compute_attention(
Expand Down
31 changes: 23 additions & 8 deletions keras_nlp/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ def presets(cls):
return copy.deepcopy(backbone_presets)

@staticmethod
def get_layout_map(device_mesh, model_parallel_dim_name="model"):
def get_layout_map(
device_mesh,
model_parallel_dim_name="model",
data_parallel_dim_name="batch",
):
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.
The returned `LayoutMap` contains the sharding spec for the gemma
Expand All @@ -221,6 +225,8 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"):
distribution.
model_parallel_dim_name: The axis name of the device mesh, where
the weights should be partition on.
data_parallel_dim_name: The axis name of the device mesh, where
the data should be partition on.
Return:
`keras.distribution.LayoutMap` that contains the sharding spec
of all the model weights.
Expand Down Expand Up @@ -248,21 +254,30 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"):
f"{model_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
if data_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{data_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
# Note that it is possible to further config the mesh to be 3D, eg
# (data, seq, model). We leave it as 2D for now for simplicity.
data_dim = data_parallel_dim_name
model_dim = model_parallel_dim_name
# The sharding is partition for the hidden_dim of the model.
# The sharding config is based on the Gemma team training config.
# See https://arxiv.org/abs/2403.08295
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, model_dim)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
None,
model_dim,
data_dim,
None,
)
layout_map["decoder_block.*attention_output.*kernel"] = (
None,
None,
model_dim,
None,
data_dim,
)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim)

return layout_map
24 changes: 16 additions & 8 deletions keras_nlp/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,34 @@ def test_distribution(self):

for w in model.weights:
if "token_embedding/embeddings" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)
if "attention/query/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "attention/key/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "attention/value/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, "model", None)
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "attention/attention_output/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, "model")
tuple(w.value.sharding.spec), ("model", None, "batch")
)
if "ffw_gating/kernel" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "ffw_gating_2/kernel" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), ("model", None))
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "ffw_linearl" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, "model"))
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)
Loading

0 comments on commit e578789

Please sign in to comment.