Skip to content

Commit

Permalink
Update Patch v0.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
seungahdev committed Jul 26, 2024
1 parent a2fedf0 commit db40b30
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 26 deletions.
35 changes: 16 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Copyright (c) 2024-present, FriendliAI Inc. All rights reserved.
-->

<h2><p align="center">Friendli Model Optimizer (FMO) for Supercharge Generative AI Serving 🚀</p></h2>
<h2><p align="center">Friendli Model Optimizer (FMO) for supercharging generative AI serving 🚀</p></h2>

<p align="center">
<a href="https://github.com/friendliai/friendli-model-optimizer/actions/workflows/ci.yaml">
Expand All @@ -24,20 +24,20 @@ Copyright (c) 2024-present, FriendliAI Inc. All rights reserved.


# Overview
FMO is a tool that provides model optimizations for efficient generative AI serving with [Friendli Engine](https://friendli.ai/solutions/engine/).
It provides features to improve generative AI serving performance without compromising task accuracy.
Friendli Model Optimizer (FMO) is a tool that provides model optimizations for efficient generative AI serving with [Friendli Engine](https://friendli.ai/solutions/engine/).
The optimizations improve generative AI serving performance without compromising task accuracy.

FMO is designed to work with Huggingface pretrained model, which can be loaded using ['PreTrainedModel.from_pretrained()'](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained).
FMO is designed to work with Hugging Face pretrained models, which can be loaded using ['PreTrainedModel.from_pretrained()'](https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained).

> [!NOTE]
> The Huggingface model architectures that can be optimized with FMO is specified in [Supported Features & Model Architecture](#supported-features--model-architecture).
> The list of Hugging Face model architectures that can be optimized with FMO is specified in [Supported Features & Model Architecture](#supported-features--model-architecture).

# Table of Contents
- [Quick Installation](#quick-installation)
- [Supported Features & Model Architecture](#supported-features--model-architecture)
- [User Guides](#user-guides)
- [How to Serve Optimized Model?](#how-to-serve-optimized-model-with-frinedli-engine)
- [Serving an Optimized Model](#how-to-serve-an-optimized-model-with-friendli-engine)


# Quick Installation
Expand All @@ -47,7 +47,7 @@ pip install friendli-model-optimizer


# Supported Features & Model Architecture
FMO currently supports the following PTQ(Post-Training Quantization) techniques:
FMO currently supports the following PTQ (Post-Training Quantization) techniques:

## FP8

Expand All @@ -59,9 +59,9 @@ This leads to increased throughput and reduced latency while maintaining high ou
> FP8 is only supported by NVIDIA Ada, Hopper, and Blackwell GPU architectures.
> [!NOTE]
> For now, we only support E4M3 (4-bit exponent and 3-bit mantissa) encoding format.
> For now, we only support the E4M3 (4-bit exponent and 3-bit mantissa) encoding format.
### Supported Model Architecutre for FP8 Quantization
### Supported Model Architectures for FP8 Quantization
- `LlamaForCausalLM`
- `MistralForcausalLM`
- `CohereForCausalLM`
Expand All @@ -77,9 +77,9 @@ This leads to increased throughput and reduced latency while maintaining high ou

INT8 Quantization represents weights and activations using the INT8 format with acceptable accuracy drops.
Friendli Engine enables dynamic activation scaling, where scales are computed on the fly during runtime.
Thus, FMO only quantize weight, and Friendli Engine will load quantized weight.
Thus, FMO only quantizes model weights, and Friendli Engine will load the quantized weights.

### Supported Model Architecutre for INT8 Quantization
### Supported Model Architectures for INT8 Quantization
- `LlamaForCausalLM`
- `MistralForcausalLM`
- `CohereForCausalLM`
Expand All @@ -100,14 +100,11 @@ fmo quantize \
The command line arguments means :
- **`model-name-or-path`**: Hugging Face pretrained model name or directory path of the saved model checkpoint.
- **`output-dir`**: Directory path to save the quantized checkpoint and related configurations.
- **`mode`**: Qantization techniques to apply. You can use `fp8`, `int8`.
- **`mode`**: Quantization techniques to apply. You can use `fp8`, `int8`.
- **`device`**: Device to run the quantization process. Defaults to "cuda:0".
- **`offload`**: When enabled, this option significantly reduces GPU memory usage by offloading model layers onto CPU RAM. Defaults to true.

> [!TIP]
> If you want to use more advanced quantization options(e.g., calibration dataset), Please checkout our [official documentations](https://docs.friendli.ai/guides/container/running_friendli_container/quantization).
## Example: Run FP8 uantization with Meta-Llama-3-8B-Instruct
## Example: Run FP8 quantization with Meta-Llama-3-8B-Instruct
```bash
export MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
export OUTPUT_DIR="./"
Expand All @@ -120,6 +117,6 @@ fmo quantize \
--offload
```

# How to serve optimized model with Frinedli Engine?
If your optimized model is ready, now, you can serve the model with Friendli Engine.\
Please checkout our [official documentations](https://docs.friendli.ai/guides/container/running_friendli_container/quantization) to learn more!
# How to serve an optimized model with Friendli Engine?
Once your optimized model is ready, you can serve the model with Friendli Engine.\
Please check out our [official documentation](https://docs.friendli.ai/guides/container/running_friendli_container) to learn more!
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "friendli-model-optimizer"
version = "0.2.0"
version = "0.2.1"
authors = [
{ name = "FriendliAI teams", email = "[email protected]" },
]
Expand Down
11 changes: 5 additions & 6 deletions src/fmo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import typer

from fmo.utils.dataset import get_tokenizer, safe_load_datasets
from fmo.utils.dataset import get_calib_dataloader, get_tokenizer, safe_load_datasets
from fmo.utils.format import secho_error_and_exit
from fmo.utils.version import get_installed_version

Expand Down Expand Up @@ -93,9 +93,9 @@ def quantize(
"--dataset-target-column-name",
help=("Huggingface dataset column name for gathering sample activations."),
),
dataset_num_sample: int = typer.Option(
dataset_num_samples: int = typer.Option(
128,
"--dataset-num-sample",
"--dataset-num-samples",
help=("The number of samples for gathering sample activations."),
),
dataset_max_length: int = typer.Option(
Expand All @@ -119,7 +119,7 @@ def quantize(
):
"""Quantize huggingface's model."""
# pylint: disable=too-many-locals, import-outside-toplevel
from fmo_core import get_calib_dataloader, quantize # type: ignore
from fmo_core import quantize # type: ignore

# pylint: enable=import-outside-toplevel

Expand All @@ -135,11 +135,10 @@ def quantize(
model_name_or_path=model_name_or_path, cache_dir=cache_dir
)
calib_dataloader = get_calib_dataloader(
mode=mode,
dataset=dataset,
lookup_column_name=dataset_target_column_name,
max_length=dataset_max_length,
num_sample=dataset_num_sample,
num_samples=dataset_num_samples,
batch_size=dataset_batch_size,
seed=seed,
tokenizer=tokenizer,
Expand Down
59 changes: 59 additions & 0 deletions src/fmo/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from typing import Optional

import datasets # type: ignore
import torch
from fmo_core import NotSupportedError, QuantizationError # type: ignore
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, PreTrainedTokenizer # type: ignore


Expand Down Expand Up @@ -74,3 +76,60 @@ def get_tokenizer(
tokenizer.pad_token = tokenizer.eos_token

return tokenizer


def get_calib_dataloader( # pylint: disable=too-many-arguments
tokenizer: PreTrainedTokenizer,
dataset: datasets.Dataset,
lookup_column_name: str,
seed: Optional[int] = 42,
max_length: Optional[int] = 512,
num_samples: Optional[int] = 512,
batch_size: Optional[int] = 32,
) -> DataLoader:
"""Return Calibration DataLoader."""
try:
dataset = dataset.shuffle(seed=seed).select(range(num_samples * 2)) # type: ignore
encoded_ds_w_special_tokens = tokenizer(
dataset[lookup_column_name][:num_samples],
return_tensors="pt",
truncation=True,
padding=True,
max_length=max_length,
add_special_tokens=True,
).input_ids
encoded_ds_wo_special_tokens = tokenizer(
dataset[lookup_column_name][num_samples:],
return_tensors="pt",
truncation=True,
padding=True,
max_length=max_length,
add_special_tokens=False,
).input_ids

max_length_diff = (
encoded_ds_w_special_tokens.shape[-1]
- encoded_ds_wo_special_tokens.shape[-1]
)
if max_length_diff > 0:
padded_tokens = torch.full(
(encoded_ds_wo_special_tokens.shape[0], max_length_diff),
tokenizer.pad_token_id,
)
encoded_ds_wo_special_tokens = torch.cat(
[encoded_ds_wo_special_tokens, padded_tokens], dim=1
)
assert (
encoded_ds_w_special_tokens.shape[-1]
== encoded_ds_wo_special_tokens.shape[-1]
)
encoded_dataset = torch.cat(
[encoded_ds_w_special_tokens, encoded_ds_wo_special_tokens], dim=0
)

except KeyError as exc:
raise NotSupportedError(
f"`{lookup_column_name}` is not valid column name in given dataset."
) from exc

return DataLoader(encoded_dataset, batch_size=batch_size) # type: ignore

0 comments on commit db40b30

Please sign in to comment.