-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add example for OPT model with distribution. #1727
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
""" | ||
Title: OPT2 Text Generation with KerasNLP and Keras Distribution API | ||
Author: Qianli (Scott) Zhu | ||
Date created: 01/11/2024 | ||
Last modified: 01/11/2024 | ||
Description: Use OPT2 model and Keras distribution API to do text generation. | ||
Accelerator: GPU | ||
""" | ||
|
||
""" | ||
In this tutorial, you will learn to use [KerasNLP](https://keras.io/keras_nlp/) | ||
to load a pre-trained Large Language Model (LLM) - [OPT-2 model](https://arxiv.org/abs/2205.01068) | ||
(originally invented by Meta), finetune and generate with a distribute hardware | ||
setting. | ||
""" | ||
|
||
""" | ||
## Before we begin | ||
|
||
Colab offers different kinds of runtimes. Make sure to go to **Runtime -> | ||
Change runtime type** and choose the GPU Hardware Accelerator runtime | ||
(which should have >12G host RAM and ~16G GPU RAM) since you will finetune the | ||
OPT-2 model. Running this tutorial on CPU runtime will take hours. | ||
|
||
Also note that this example was originally created with 8 V100 GPUs, explicitly | ||
to simulate how to do inference of large model with limited hardwares. | ||
""" | ||
|
||
""" | ||
## Install KerasNLP, Choose Backend and Import Dependencies | ||
|
||
This examples uses the latest distribution API from [Keras](https://keras.io/keras/). | ||
The API is currently supporting JAX backend, and we are adding more backends | ||
support in the coming future. | ||
""" | ||
|
||
import os | ||
|
||
# This will allow JAX to scale more to fully leverage all the available GPU memory. | ||
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" | ||
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" | ||
|
||
import jax | ||
|
||
# We have 8 V100 GPUs, and each of which has 16G of GPU memory. | ||
# It will not be enough memory on a single device to host all the model weights | ||
# and optimizer state. | ||
# We are going to show case how to distribute the large model weights, so that | ||
# a popular LLM model (7B param) can be finetuned on a previous generation of | ||
# hardware. | ||
print(jax.devices()) | ||
|
||
os.environ['KERAS_BACKEND'] = "jax" | ||
|
||
import keras | ||
print(keras.version()) | ||
print(keras.backend.backend()) | ||
|
||
keras.mixed_precision.set_global_policy("mixed_float16") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a comment about using mixed precision. |
||
|
||
import keras_nlp | ||
|
||
|
||
""" | ||
## Introduction to KerasNLP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest focusing the example purely on the distribution aspects, so we can replace the KerasNLP intro with ~1 sentence. Meanwhile maybe we could flesh out the distribution part, e.g. include fine-tuning or other inference performance considerations. |
||
|
||
Large Language Models are complex to build and expensive to train from scratch. | ||
Luckily there are pretrained LLMs available for use right away. [KerasNLP](https://keras.io/keras_nlp/) | ||
provides a large number of pre-trained checkpoints that allow you to experiment | ||
with SOTA models without needing to train them yourself. | ||
|
||
KerasNLP is a natural language processing library that supports users through | ||
their entire development cycle. KerasNLP offers both pretrained models and | ||
modularized building blocks, so developers could easily reuse pretrained models | ||
or stack their own LLM. | ||
|
||
In a nutshell, for generative LLM, KerasNLP offers: | ||
|
||
- Pretrained models with `generate()` method, e.g., `keras_nlp.models.OPTCausalLM`. | ||
- Sampler class that implements generation algorithms such as Top-K, Beam and | ||
contrastive search. These samplers can be used to generate text with | ||
custom models. | ||
""" | ||
|
||
def create_opt_model(model_spec): | ||
opt_model = keras_nlp.models.OPTCausalLM.from_preset(model_spec) | ||
opt_model.summary() | ||
return opt_model | ||
|
||
""" | ||
we are going to first try to create the 7B model without any distribution, | ||
and it will error out with a OOM message from JAX. The 7B param would take | ||
about 28G GPU memory, and the per-GPU memory limit 16G. This doesn't even | ||
count other items like optimizer states, as well as forward and backward path. | ||
""" | ||
# model_spec = 'opt_6.7b_en' | ||
# langauge_model = create_opt_model(model_spec) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be uncomented? |
||
|
||
""" | ||
Now let's create a new with distributions. In Keras 3, we introduce a new | ||
unified distribution API that allow you to do data and model parallel | ||
trainings. You can find more details of the API in https://keras.io/api/distribution/. | ||
""" | ||
|
||
# Create a 2D mesh for model parallel, change the mesh shape to tune the | ||
# ratio of data/model parallelism | ||
_BATCH_DIM_NAME = "batch" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for leading underscores |
||
_MODEL_DIM_NAME = "model" | ||
|
||
# Create mesh with (1, 8) shape so that the weights are sharded across all 8 | ||
# GPUs. | ||
mesh = keras.distribution.DeviceMesh( | ||
(1, 8), | ||
[_BATCH_DIM_NAME, _MODEL_DIM_NAME], | ||
devices=keras.distribution.list_devices()) | ||
|
||
""" | ||
The following code specifies how we would like to distribute the model weights. | ||
The layout map is a dict like object, which maps the string key to a Layout. | ||
The string key is used to indentify the variables in the Keras model, and the | ||
corresponding Layout sharding will be applied to the weights. Note that the | ||
key is like a regex, so it can be applied to both variables and its related | ||
optimizer states. | ||
|
||
You can find more details about the Layout Map in https://keras.io/api/distribution/layout_map/#layoutmap-class. | ||
""" | ||
unshard_dim = None | ||
model_dim = _MODEL_DIM_NAME | ||
|
||
layout_map = keras.distribution.LayoutMap(mesh) | ||
|
||
layout_map[r"embeddings.*"] = (unshard_dim, model_dim) | ||
|
||
# Transformer block sharding | ||
layout_map[r"self_attention.*(query|key|value).*kernel.*"] = ( | ||
unshard_dim, unshard_dim, model_dim) | ||
layout_map[r"self_attention.*(query|key|value).*bias.*"] = ( | ||
model_dim, unshard_dim) | ||
layout_map[r"self_attention.*attention_output.*kernel.*"] = ( | ||
unshard_dim, model_dim, unshard_dim) | ||
layout_map[r"intermediate_dense.*kernel.*"] = ( | ||
unshard_dim, model_dim) | ||
layout_map[r"intermediate_dense.*bias.*"] = ( | ||
model_dim,) | ||
layout_map[r"output_dense.*kernel.*"] = (model_dim, unshard_dim) | ||
layout_map[r"output_dense.*bias.*"] = (unshard_dim,) | ||
|
||
|
||
""" | ||
Next we will create a global distribut setting, and all the variables/data | ||
created afterwards will be distributed according to this setting. | ||
|
||
There is also a scope based API available with `model_parallel.scope()`. | ||
""" | ||
model_parallel = keras.distribution.ModelParallel( | ||
mesh, layout_map, batch_dim_name=_BATCH_DIM_NAME) | ||
keras.distribution.set_distribution(model_parallel) | ||
|
||
""" | ||
Let's create the 2.7B model here, and with the model weights and forward path, | ||
it won't be able to fit into GPU memory without any distribution. | ||
""" | ||
# Other avaiable model_spec are 'opt_125m_en', 'opt_1.3b_en' and 'opt_6.7b_en' | ||
model_spec = 'opt_2.7b_en' | ||
large_model = create_opt_model(model_spec) | ||
|
||
""" | ||
Inference | ||
|
||
Note that the first run will take long time, since JAX need to compile the | ||
generate function with XLA. The follow up runs will be much faster. | ||
""" | ||
prompt = "What is machine learning?" | ||
print(large_model.generate(prompt)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a second prompt, possible with some |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please group the imports at the top.