Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Phi family of models #120

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/optimum/nvidia/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_quantization_config(
exclude_modules = qconfig.get("module_to_not_convert", [])

return mode, TensorRTQuantizationConfig(
quantization_algo=quant_method,
quant_algo=quant_method,
kv_cache_quant_algo=None,
group_size=group_size,
has_zero_point=has_zero_point,
Expand Down
40 changes: 38 additions & 2 deletions src/optimum/nvidia/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Expand Down Expand Up @@ -46,7 +47,11 @@
from optimum.nvidia.builder.config import EngineConfigBuilder
from optimum.nvidia.quantization import AutoQuantizationConfig
from optimum.nvidia.quantization.ammo import AmmoQuantizer
from optimum.nvidia.utils import get_user_agent, maybe_offload_weights_to_cpu
from optimum.nvidia.utils import (
get_user_agent,
iter_safetensors,
maybe_offload_weights_to_cpu,
)
from optimum.nvidia.utils.nvml import get_max_memory


Expand Down Expand Up @@ -119,6 +124,9 @@ def find_prebuilt_engines(root: Path) -> Tuple[List[Path], List[Path]]:
class SupportsTensorrtConversion(Protocol):
MODEL_CONFIG: Type[TensorRTConfig]
HF_LIBRARY_TARGET_MODEL_CLASS: Type[ModelHubMixin]
HF_CHECKPOINT_TRANSFORM_FN: Optional[
Callable[[Dict[str, np.array]], Dict[str, np.array]]
] = None
TRT_LLM_TARGET_MODEL_CLASS: Type[PretrainedModel]

@staticmethod
Expand Down Expand Up @@ -157,7 +165,15 @@ def convert_and_build(
engines_folder.mkdir(exist_ok=True, parents=True)

# Retrieve configuration
config = AutoConfig.for_model(**hf_model_config)
if hf_model_config["model_type"] == "phi3":
from transformers import PhiConfig

config = PhiConfig.from_dict(hf_model_config)
config.attention_bias = False
config.model_type == "llama"
else:
config = AutoConfig.for_model(**hf_model_config)

if "torch_dtype" in model_kwargs:
config.torch_dtype = model_kwargs["torch_dtype"]

Expand All @@ -166,6 +182,9 @@ def convert_and_build(
cls, config, config_class=config_class, **model_kwargs
)

if config.architectures[0] == "Phi3ForCausalLM":
model_config.architecture = "LlamaForCausalLM"

# We now have a TRTLLM compatible config, so let's feed it to the target TRTLLM model to create a checkpoint
LOGGER.debug("Allocating TRTLLM model to build the checkpoint")
model = cls.TRT_LLM_TARGET_MODEL_CLASS.from_config(model_config)
Expand All @@ -183,6 +202,21 @@ def convert_and_build(

max_memory = get_max_memory()

if cls.HF_CHECKPOINT_TRANSFORM_FN:
LOGGER.info("Detected checkpoint transformation function")

transformed_state = {}
for state in iter_safetensors(local_path):
transformed_state |= cls.HF_CHECKPOINT_TRANSFORM_FN(state)

if transformed_state:
LOGGER.info(
f"Loading the model from a transformed state_dict "
f"({config.model_type} -> {cls.HF_LIBRARY_TARGET_MODEL_CLASS.config_class.model_type})"
)
hf_model = cls.HF_LIBRARY_TARGET_MODEL_CLASS(config)
hf_model.load_state_dict(transformed_state)

if hf_model is None:
LOGGER.debug(
f"Loading weights from {local_path} into the model ({cls.HF_LIBRARY_TARGET_MODEL_CLASS.__name__})"
Expand Down Expand Up @@ -383,6 +417,8 @@ def _from_pretrained(
except OSError:
generation_config = None

if config["model_type"] == "phi3":
config["model_type"] = "llama"
transformers_config = AutoConfig.for_model(**config)
model = cls(
engines_folders,
Expand Down
3 changes: 3 additions & 0 deletions src/optimum/nvidia/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
from .gemma import GemmaForCausalLM
from .llama import LlamaForCausalLM
from .mistral import MistralForCausalLM
from .phi import Phi3ForCausalLM, PhiForCausalLM


_SUPPORTED_MODEL_CLASS = {
"llama": LlamaForCausalLM,
"mistral": MistralForCausalLM,
"gemma": GemmaForCausalLM,
"phi": PhiForCausalLM,
"phi3": Phi3ForCausalLM,
}


Expand Down
217 changes: 217 additions & 0 deletions src/optimum/nvidia/models/phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from typing import Dict

import numpy as np
import torch
from tensorrt_llm.models import PretrainedConfig, PretrainedModel
from tensorrt_llm.models.llama.model import LLaMAForCausalLM as TrtLlamaForCausalLM
from tensorrt_llm.models.llama.weight import load_from_hf_llama
from tensorrt_llm.models.phi.convert import convert_hf_weights
from tensorrt_llm.models.phi.model import PhiForCausalLM as TrtPhiForCausalLM
from tensorrt_llm.plugin import PluginConfig
from transformers import LlamaForCausalLM as TransformersLlamaForCausalLM
from transformers import PhiForCausalLM as TransformersPhiForCausalLM
from transformers import PretrainedConfig as TransformersPretrainedConfig
from transformers import PreTrainedModel as TransformersPretrainedModel

from optimum.nvidia import TensorRTConfig
from optimum.nvidia.config import dtype_to_str
from optimum.nvidia.hub import HuggingFaceHubModel
from optimum.nvidia.runtime import CausalLM


LOGGER = getLogger(__name__)


class PhiConfig(TensorRTConfig):
r"""
This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.

Configuration objects inherit from [`TensorRTConfig`] and can be used to control the model outputs. Read the
documentation from [`TensorRTConfig`] for more information.
"""

@staticmethod
def from_config(config: TransformersPretrainedConfig) -> "TensorRTConfig":
# Retrieve the quantization from the transformers config (if provided)
_, qconfig = TensorRTConfig.get_quantization_config(config)

trt_config = PhiConfig(
architecture=config.architectures[0],
dtype=dtype_to_str(config.torch_dtype),
logits_dtype="float32",
vocab_size=config.vocab_size,
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
hidden_act=config.hidden_act,
intermediate_size=config.intermediate_size,
norm_epsilon=config.layer_norm_eps,
position_embedding_type="rope_gpt_neox",
partial_rotary_factor=config.partial_rotary_factor,
rope_theta=config.rope_theta,
world_size=1,
tp_size=1,
pp_size=1,
use_prompt_tuning=False,
use_parallel_embedding=False,
embedding_sharding_dim=0,
share_embedding_table=False,
max_lora_rank=64,
head_size=config.hidden_size / config.num_attention_heads,
quantization=qconfig,
)

trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8)

return trt_config

def get_plugins_config(self) -> PluginConfig:
config = super().get_plugins_config()
config.moe_plugin = "disable" # TODO : Mixtral?
config.bert_attention_plugin = "disable"
config.gpt_attention_plugin = self.dtype
config.gemm_plugin = self.dtype

return config

@staticmethod
def supports_strong_typing() -> bool:
return False


class Phi3Config(TensorRTConfig):
r"""
This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate an Phi3
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Phi-3.

Configuration objects inherit from [`TensorRTConfig`] and can be used to control the model outputs. Read the
documentation from [`TensorRTConfig`] for more information.
"""

@staticmethod
def from_config(config: TransformersPretrainedConfig) -> "TensorRTConfig":
# Retrieve the quantization from the transformers config (if provided)
_, qconfig = TensorRTConfig.get_quantization_config(config)

trt_config = PhiConfig(
architecture=config.architectures[0],
dtype=dtype_to_str(config.torch_dtype),
logits_dtype="float32",
vocab_size=config.vocab_size,
max_position_embeddings=config.max_position_embeddings,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
hidden_act=config.hidden_act,
intermediate_size=config.intermediate_size,
norm_epsilon=config.rms_norm_eps,
position_embedding_type="rope_gpt_neox",
partial_rotary_factor=config.partial_rotary_factor,
rope_theta=config.rope_theta,
world_size=1,
tp_size=1,
pp_size=1,
use_prompt_tuning=False,
use_parallel_embedding=False,
embedding_sharding_dim=0,
share_embedding_table=False,
max_lora_rank=64,
head_size=config.hidden_size / config.num_attention_heads,
quantization=qconfig,
)

trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8)

return trt_config

def get_plugins_config(self) -> PluginConfig:
config = super().get_plugins_config()
config.moe_plugin = "disable" # TODO : Mixtral?
config.bert_attention_plugin = "disable"
config.gpt_attention_plugin = self.dtype
config.gemm_plugin = self.dtype

return config

@staticmethod
def supports_strong_typing() -> bool:
return False


class PhiForCausalLM(CausalLM, HuggingFaceHubModel):
MODEL_CONFIG = PhiConfig
HF_LIBRARY_TARGET_MODEL_CLASS = TransformersPhiForCausalLM
TRT_LLM_TARGET_MODEL_CLASS = TrtPhiForCausalLM

@staticmethod
def convert_weights(
target: PretrainedModel,
source: TransformersPretrainedModel,
config: PretrainedConfig,
) -> Dict[str, np.ndarray]:
if config.quant_mode.has_any_quant():
raise NotImplementedError("Quantization is not supported yet.")

return {
name: tensor.numpy()
for name, tensor in convert_hf_weights(source, config.dtype).items()
}


def transform_checkpoint_to_llama(
params: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
state_dict = {}
for name, tensor in params.items():
if "qkv_proj" in name:
q, k, v = torch.chunk(tensor, 3, dim=0)
state_dict[name.replace("qkv_proj", "q_proj")] = q
state_dict[name.replace("qkv_proj", "k_proj")] = k
state_dict[name.replace("qkv_proj", "v_proj")] = v
elif "gate_up_proj" in name:
gate, up = torch.chunk(tensor, 2, dim=0)
state_dict[name.replace("gate_up_proj", "gate_proj")] = gate
state_dict[name.replace("gate_up_proj", "up_proj")] = up
else:
state_dict[name] = tensor
return state_dict


class Phi3ForCausalLM(CausalLM, HuggingFaceHubModel):
MODEL_CONFIG = Phi3Config
HF_LIBRARY_TARGET_MODEL_CLASS = TransformersLlamaForCausalLM
TRT_LLM_TARGET_MODEL_CLASS = TrtLlamaForCausalLM
HF_CHECKPOINT_TRANSFORM_FN = transform_checkpoint_to_llama

@staticmethod
def convert_weights(
target: PretrainedModel,
source: TransformersPretrainedModel,
config: PretrainedConfig,
) -> Dict[str, torch.Tensor]:
return load_from_hf_llama(target, source, config.mapping, config.dtype)
1 change: 1 addition & 0 deletions src/optimum/nvidia/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .nvml import has_float8_support
from .offload import maybe_offload_weights_to_cpu
from .onnx import to_onnx
from .safetensors import iter_safetensors


def rgetattr(obj, attr):
Expand Down
19 changes: 19 additions & 0 deletions src/optimum/nvidia/utils/safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import json
from pathlib import Path
from typing import Dict, Generator

import torch
from safetensors.torch import load
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME


def iter_safetensors(path: Path) -> Generator[Dict[str, torch.Tensor], None, None]:
if (path / SAFE_WEIGHTS_INDEX_NAME).exists():
with open(path / SAFE_WEIGHTS_INDEX_NAME) as index_f:
indexes = json.load(index_f)
for file in set(indexes["weight_map"].values()):
with open(path / file, "rb") as shard_f:
yield load(shard_f.read())
else:
with open(path / SAFE_WEIGHTS_NAME, "rb") as safetensors_f:
yield load(safetensors_f.read())
Loading
Loading