Skip to content

Commit

Permalink
Merge pull request #1 from shawnzxf/neuron-2-nxd
Browse files Browse the repository at this point in the history
Support DBRX for NxD integration with vLLM
  • Loading branch information
shawnzxf authored Jul 16, 2024
2 parents 8c51a8e + 2db64e2 commit 5da48c0
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 5 deletions.
36 changes: 36 additions & 0 deletions examples/offline_inference_neuron_dbrx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from vllm import LLM, SamplingParams


if __name__ == "__main__":
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(top_k=1)

# Create an LLM.
llm = LLM(
model="databricks/dbrx-instruct",
tensor_parallel_size=32,
max_num_seqs=4,
# The max_model_len and block_size arguments are required to be same as max sequence length,
# when targeting neuron device. Currently, this is a known limitation in continuous batching
# support in neuronx-distributed.
max_model_len=64,
block_size=64,
dtype="bfloat16",
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection, or explicitly assigned.
device="neuron")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
6 changes: 5 additions & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
Expand Down Expand Up @@ -60,7 +61,10 @@
}

# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
_NEURON_SUPPORTED_MODELS = {
"LlamaForCausalLM": "neuron.llama",
"DbrxForCausalLM": "neuron.dbrx",
}


class ModelRegistry:
Expand Down
152 changes: 152 additions & 0 deletions vllm/model_executor/models/neuron/dbrx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Inference-only DBRX model compatible with HuggingFace weights."""
import os
from typing import List, Optional, Tuple

import torch
from torch import nn
from transformers import DbrxConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
import neuronx_distributed as nxd


KVCache = Tuple[torch.Tensor, torch.Tensor]


class DbrxForCausalLM(nn.Module):

def __init__(
self,
config: DbrxConfig,
linear_method=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = None
self.sampler = Sampler(config.vocab_size)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:


with torch.inference_mode():
block_size = self.model.config.n_positions
if input_metadata.is_prompt:
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
else:
seq_ids = input_metadata.block_tables


output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=seq_ids.flatten() - 1)
return output.logits[:, -1, :]

def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(None,
hidden_states, sampling_metadata)
return next_tokens

def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
# Need to add path of NeuronxDistributed/examples/inference to the PYTHONPATH for a successful import
from dbrx.neuron_modeling_dbrx import NeuronDbrxForCausalLM, NeuronDbrxConfig, NeuronDbrxModel, preshard_hook_fn
from neuronx_distributed.parallel_layers.checkpointing import _invoke_preshard_hook
from transformers import DbrxForCausalLM as DbrxForCausalLMHF

config = NeuronDbrxConfig.from_pretrained(model_name_or_path)
config.tp_degree = kwargs["tp_degree"]
config.max_batch_size = kwargs["batch_size"]
config.torch_dtype = kwargs["amp"]
config.n_positions = kwargs["n_positions"][-1]
config.buckets = [config.n_positions]
config.tkg_batch_size = kwargs["batch_size"]
config.ctx_batch_size = 1
config.attn_cls = 'NeuronLlamaAttention'
config.padding_side = "right"
config.is_continuous_batching = True
config.do_sample = True
config.top_k = 1
config.quantized = False

print(config)

if os.environ.get("NXD_DEBUG", None):
from imp import reload
import logging

reload(logging)
logging.basicConfig(level=logging.DEBUG)

# need to save to local if the model path doesn't exist
if not os.path.exists(model_name_or_path):

model = DbrxForCausalLMHF.from_pretrained(model_name_or_path)

saved_path = os.path.join("local-models", model_name_or_path)
model.save_pretrained(saved_path)

model_name_or_path = saved_path

cpu_mode = os.environ.get("NXD_CPU", None)
if cpu_mode is not None:
config.tp_degree = 1

self.init_ditributed_env()
dbrx_model = NeuronDbrxModel(config)
state_dict = NeuronDbrxForCausalLM.get_state_dict(model_name_or_path, config)
_invoke_preshard_hook(dbrx_model, state_dict)
dbrx_model.load_state_dict(state_dict, strict=False)

config.torch_dtype = torch.float32

self.model = NeuronDbrxForCausalLM("", config)
config.batch_size = config.ctx_batch_size
config.n_active_tokens = config.n_positions
dbrx_model_ctx = NeuronDbrxModel.from_pretrained(None, config=config, state_dict=state_dict)

config.batch_size = config.tkg_batch_size
config.n_active_tokens = 1
dbrx_model_tkg = NeuronDbrxModel.from_pretrained(None, config=config, state_dict=state_dict)

self.model.context_encoding_model.model = dbrx_model_ctx
self.model.token_generation_model.model = dbrx_model_tkg
else:
self.model = NeuronDbrxForCausalLM.from_pretrained(model_name_or_path, config)
self.model.to_neuron()


def init_ditributed_env(self):
"""
Initialize a simple neuronx distributed (Tensor Parallelism) environment, where there TP degree is 1.
This function is just for running NeuronxDistributed models on CPU to validate correctness.
"""
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "2024"

if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="xla")

nxd.parallel_layers.parallel_state.destroy_model_parallel()
nxd.parallel_layers.parallel_state.initialize_model_parallel(tensor_model_parallel_size=1)
5 changes: 1 addition & 4 deletions vllm/model_executor/models/neuron/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from torch import nn
from transformers import LlamaConfig, LlamaForCausalLM
from transformers import LlamaConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.sampler import Sampler
Expand All @@ -14,9 +14,6 @@
KVCache = Tuple[torch.Tensor, torch.Tensor]





class LlamaForCausalLM(nn.Module):

def __init__(
Expand Down

0 comments on commit 5da48c0

Please sign in to comment.