Skip to content

Commit

Permalink
Improve CI speed and resolve issues of run_quantization_check (kera…
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Jul 8, 2024
1 parent a219e96 commit f9faaf1
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 44 deletions.
15 changes: 9 additions & 6 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,15 @@ def get_config(self):
}

# Add quantization support by utilizing `DTypePolicyMap`
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
if isinstance(self.dtype_policy, keras.dtype_policies.DTypePolicyMap):
config.update({"dtype": self.dtype_policy})
else:
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
return config

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/src/models/bloom/bloom_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
self.embeddings_layer_norm = keras.layers.LayerNormalization(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="token_embedding_layernorm",
name="embedding_layernorm",
)
self.transformer_layers = []
for i in range(num_layers):
Expand Down
3 changes: 0 additions & 3 deletions keras_nlp/src/models/bloom/bloom_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
# TODO: Set to `True`. Error msg: Layer LayerNormalization does not
# have a `quantized_call()` method implemented.
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
22 changes: 13 additions & 9 deletions keras_nlp/src/models/opt/opt_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,16 @@ def __init__(
self.max_sequence_length = max_sequence_length

def get_config(self):
return {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_sequence_length": self.max_sequence_length,
}
)
return config
4 changes: 0 additions & 4 deletions keras_nlp/src/models/opt/opt_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 2),
# TODO: Set to `True`. Error msg: Layer 'token_embedding' expected 1
# variables, but received 0 variables during loading. Expected:
# ['embeddings']
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/src/models/xlnet/xlnet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 2),
run_quantization_check=False, # TODO(hongyu): set to `True`
)

@pytest.mark.large
Expand Down
58 changes: 37 additions & 21 deletions keras_nlp/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from keras import ops
from keras import tree

from keras_nlp.src import layers as keras_nlp_layers
from keras_nlp.src.tokenizers.tokenizer import Tokenizer
from keras_nlp.src.utils.tensor_utils import is_float_dtype

Expand Down Expand Up @@ -336,29 +337,44 @@ def run_precision_test(self, cls, init_kwargs, input_data):
self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)

def run_quantization_test(self, cls, init_kwargs, input_data):
policy = keras.DTypePolicy("float32")
def run_quantization_test(self, instance, cls, init_kwargs, input_data):
def _get_supported_layers(mode):
supported_layers = [keras.layers.Dense, keras.layers.EinsumDense]
if mode == "int8":
supported_layers.append(keras.layers.Embedding)
supported_layers.append(keras_nlp_layers.ReversibleEmbedding)
return supported_layers

for mode in ["int8", "float8"]:
layer = cls(**{**init_kwargs, "dtype": policy})
layer.quantize(mode)
# Try eager call
if isinstance(layer, keras.Model):
_ = layer(input_data)
# Manually configure DTypePolicyMap to avoid intensive computation
# in `Model.quantize`.
policy_map = keras.dtype_policies.DTypePolicyMap("float32")
for layer in instance._flatten_layers():
if type(layer) in _get_supported_layers(mode):
policy_map[layer.path] = keras.dtype_policies.get(
f"{mode}_from_float32"
)
# Instantiate the layer.
model = cls(**{**init_kwargs, "dtype": policy_map})
# Call layer eagerly.
if isinstance(model, keras.Model):
_ = model(input_data)
elif isinstance(input_data, dict):
_ = layer(**input_data)
_ = model(**input_data)
else:
_ = layer(input_data)
# Verify sublayer's dtype policy
for sublayer in layer._flatten_layers():
if type(sublayer) is keras.layers.Dense:
self.assertEqual(
f"{mode}_from_float32", sublayer.dtype_policy.name
)
# Try saving and reloading the model
temp_filepath = os.path.join(self.get_temp_dir(), "layer.keras")
layer.save(temp_filepath)
reloaded_layer = keras.models.load_model(temp_filepath)
self.assertAllClose(layer(input_data), reloaded_layer(input_data))
_ = model(input_data)
# Verify sublayer's dtype policy.
for sublayer in model._flatten_layers():
if type(sublayer) in _get_supported_layers(mode):
self.assertEqual(mode, sublayer.quantization_mode)
# `get_config` roundtrip.
cfg = model.get_config()
revived_model = cls.from_config(cfg)
revived_cfg = revived_model.get_config()
self.assertEqual(cfg, revived_cfg)
# Check weights loading.
weights = model.get_weights()
revived_model.set_weights(weights)

def run_model_saving_test(
self,
Expand Down Expand Up @@ -436,7 +452,7 @@ def run_backbone_test(

# Check quantization.
if run_quantization_check:
self.run_quantization_test(cls, init_kwargs, input_data)
self.run_quantization_test(backbone, cls, init_kwargs, input_data)

def run_task_test(
self,
Expand Down

0 comments on commit f9faaf1

Please sign in to comment.