Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding torchao apis to gpt-fast #208

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ To run with int4, just pass the int4 checkpoint to generate.py.
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
```

### TorchAO Quantization APIs
There are also options to use TorchAO apis with quantize.py using the torchao-int4, torchao-int8 and torchao-int4-hqq options
To generate this version of the model
```bash
# Spits out model at checkpoints/$MODEL_REPO/model_torchao-int4.pth
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode torchao-int4-hqq --groupsize 32
```
In addition to adding the hqq option for int4 quantization, the primary difference between the TorchAO quantization apis and the gpt-fast ones are that the checkpoints saved using the TorchAO apis
can be loaded directly, rather than requiring


## Speculative Sampling
To generate with speculative sampling (DRAFT_MODEL_REPO should point to a smaller model compared with MODEL_REPO).

Expand Down
6 changes: 5 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def __init__(
tokenizer,
max_seq_length: Optional[int]=None,
):
super().__init__()
try:
super().__init__()
except TypeError:
# lm_eval 0.4.2 removed the default init
super().__init__("gpt2", device="cuda")
self._model = model
self._tokenizer = tokenizer
self._device = torch.device('cuda')
Expand Down
9 changes: 6 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,21 +221,24 @@ def _load_model(checkpoint_path, device, precision, use_tp):
with torch.device('meta'):
model = Transformer.from_name(checkpoint_path.parent.name)

if "int8" in str(checkpoint_path):
# don't have to transform the model when using torchao apis
is_torchao = 'torchao-' in str(checkpoint_path)

if "int8" in str(checkpoint_path) and not is_torchao:
print("Using int8 weight-only quantization!")
from quantize import WeightOnlyInt8QuantHandler
simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()

if "int4" in str(checkpoint_path):
if "int4" in str(checkpoint_path) and not is_torchao:
print("Using int4 weight-only quantization!")
path_comps = checkpoint_path.name.split(".")
groupsize = int(path_comps[-2][1:])
from quantize import WeightOnlyInt4QuantHandler
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime()

checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=not is_torchao)
if "model" in checkpoint and "stories" in str(checkpoint_path):
checkpoint = checkpoint["model"]
model.load_state_dict(checkpoint, assign=True)
Expand Down
34 changes: 21 additions & 13 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,22 +554,33 @@ def quantize(
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)

if mode == 'int8':
dir_name = checkpoint_path.parent
base_name = checkpoint_path.name

if 'torchao-int4' in mode:
import torchao
from torchao.quantization import (quantize_, int4_weight_only)
use_hqq = 'hqq' in mode
print(f"Quantizing model weights for int4 weight-only symmetric per-channel quantization {'with hqq' if use_hqq else ''}")
quantize_(model, int4_weight_only(group_size=groupsize, use_hqq=use_hqq), device='cuda')
quantized_state_dict = model.state_dict()
new_base_name = base_name.replace('.pth', f'{label}{mode}.pth')
elif 'torchao-int8' in mode:
import torchao
from torchao.quantization import (quantize_, int8_weight_only)
print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
quantize_(model, int8_weight_only())
quantized_state_dict = model.state_dict()
new_base_name = base_name.replace('.pth', f'{label}{mode}.pth')
elif mode == 'int8':
print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
quant_handler = WeightOnlyInt8QuantHandler(model)
quantized_state_dict = quant_handler.create_quantized_state_dict()

dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f'{label}int8.pth')

elif mode == 'int4':
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
quantized_state_dict = quant_handler.create_quantized_state_dict()

dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")

elif mode == 'int4-gptq':
Expand All @@ -590,12 +601,9 @@ def quantize(
calibration_seq_length,
pad_calibration_inputs
)

dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
else:
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq, torchao-int4, torchao-int8, torchao-int4-hqq]")

quantize_path = dir_name / new_base_name
print(f"Writing quantized weights to {quantize_path}")
Expand All @@ -608,7 +616,7 @@ def quantize(
import argparse
parser = argparse.ArgumentParser(description='Quantize a model.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq', 'torchao-int4', 'torchao-int8', 'torchao-int4-hqq'], help='type of quantization to perform')
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
Expand Down