Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SliceOp; Add Phi-3 & Qwen-2 #435

Merged
merged 7 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions examples/LLMs/prune_llama.py → examples/LLMs/prune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,15 @@ def pattern_match(patterns, source_list):
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name):
def get_llm(model_name, max_seq_len=None):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)

model.seqlen = model.config.max_position_embeddings
model.seqlen = max(max_seq_len, model.config.max_position_embeddings) if max_seq_len is not None else model.config.max_position_embeddings
# avoid OOM, feel free to change this
return model

def main():
Expand All @@ -266,6 +267,7 @@ def main():
parser.add_argument('--save', type=str, default=None, help='Path to save results.')
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
parser.add_argument("--eval_zero_shot", action="store_true")
parser.add_argument("--max_seq_len", type=int, default=None)
args = parser.parse_args()

# Setting seeds for reproducibility
Expand All @@ -274,7 +276,7 @@ def main():

model_name = args.model.split("/")[-1]
print(f"loading llm model {args.model}")
model = get_llm(args.model)
model = get_llm(args.model, max_seq_len=args.max_seq_len)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
device = torch.device("cuda:0")
Expand All @@ -291,11 +293,22 @@ def main():
inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
import torch_pruning as tp
num_heads = {}
out_channel_groups = {}
seperate_qkv = False
for name, m in model.named_modules():
if name.endswith("self_attn"):
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads
if hasattr(m, "q_proj"):
seperate_qkv = True
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads
elif hasattr(m, "qkv_proj"):
seperate_qkv = False
num_heads[m.qkv_proj] = model.config.num_attention_heads
if name.endswith('mlp'):
if hasattr(m, "gate_up_proj"):
out_channel_groups[m.gate_up_proj] = 2

_is_gqa = model.config.num_attention_heads != model.config.num_key_value_heads
head_pruning_ratio = args.pruning_ratio
hidden_size_pruning_ratio = args.pruning_ratio
Expand All @@ -311,33 +324,56 @@ def main():
prune_num_heads=True,
prune_head_dims=False, # we do not prune head dims so that we don't need to prune the ROPE
head_pruning_ratio=head_pruning_ratio,
out_channel_groups=out_channel_groups,
round_to=4,
)


#with torch.no_grad():
# with importance.compute_importance(model):
# calibration_data = "We recommend at least a 1TB hard drive for 4 channels, more if you plan on using 8MP \/ 4K cameras.\nDahua's Lite Series network video recorders offer excellent performance and high recording quality for IP video surveillance applications. For applications where details are critical for identification, this professional NVR provides a powerful processor with up to 4K resolution. Additionally, the NVR features a mouse shortcut operation menu, remote management and control, center storage, edge storage, and back up storage."
# calibration_data = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
# _ = model(calibration_data)
pruner.step()

#group = pruner.DG.get_pruning_group(model.model.layers[31].mlp.gate_up_proj, tp.prune_linear_out_channels, idxs=list(range(16384)))
#print(group)

for g in pruner.step(interactive=True):
print(g)
g.prune()

# Update model attributes
model.config.hidden_size = model.lm_head.in_features
for name, m in model.named_modules():
if name.endswith("self_attn"):
m.hidden_size = m.q_proj.out_features
if seperate_qkv:
m.hidden_size = m.q_proj.out_features
else:
m.hidden_size = m.qkv_proj.out_features // 3
m.num_heads = m.hidden_size // m.head_dim
model.config.num_attention_heads = m.num_heads
#m.head_dim = m.q_proj.out_features // m.num_heads
if not _is_gqa:
m.num_key_value_heads = m.num_heads
m.num_key_value_groups = m.num_heads // m.num_key_value_heads
elif name.endswith("mlp"):
model.config.intermediate_size = m.gate_proj.out_features
if hasattr(m, "gate_proj"):
m.hidden_size = m.gate_proj.out_features
elif hasattr(m, "gate_up_proj"):
m.hidden_size = m.gate_up_proj.in_features
else:
raise ValueError("Unknown mlp layer")

if not _is_gqa:
model.config.num_key_value_heads = model.config.num_attention_heads
print("----------------- After Pruning -----------------")
print(model)
print(model.config)


del pruner
torch.cuda.empty_cache()
model.eval()
num_params = sum(p.numel() for p in model.parameters())
print(f"num_params {num_params}")
ppl_test = eval_ppl(args, model, tokenizer, device)
Expand Down
231 changes: 228 additions & 3 deletions examples/LLMs/readme.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
# Prune Large Language Models
# Pruning Large Language Models

This example provides a minimal example of pruning large language models with magnitude-based pruning. We use the `transformers` library to load the model and the `datasets` library to evaluate the Perplexity with `Wikitext2`. **For more comprehensive examples of Gradient-based pruning or finetuning, please refer to [LLM-Pruner](https://github.com/horseee/LLM-Pruner)**.

This script has been tested with the following models:

1. [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
2. [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
3. [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
4. [Qwen/Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B)



## 0. Requirements

```bash
Expand All @@ -13,7 +22,7 @@ pip install transformers datasets
### Llama-3 8B

```bash
python prune_llama.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
python prune_llm.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
```

<details>
Expand Down Expand Up @@ -120,7 +129,7 @@ wikitext perplexity 552648.25
### Llama-2 7B

```bash
python prune_llama.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
python prune_llm.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
```


Expand Down Expand Up @@ -224,3 +233,219 @@ wikitext perplexity 8479.0673828125
</details>


### microsoft/Phi-3-mini-4k-instruct

```bash
python prune_llm.py --model microsoft/Phi-3-mini-4k-instruct --pruning_ratio 0.5
```


<details>
<summary>Output:</summary>

```
----------------- Before Pruning -----------------
Phi3ForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 3072, padding_idx=32000)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=3072, out_features=3072, bias=False)
(qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): Phi3MLP(
(gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
(down_proj): Linear(in_features=8192, out_features=3072, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm()
(resid_attn_dropout): Dropout(p=0.0, inplace=False)
(resid_mlp_dropout): Dropout(p=0.0, inplace=False)
(post_attention_layernorm): Phi3RMSNorm()
)
)
(norm): Phi3RMSNorm()
)
(lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)
----------------- After Pruning -----------------
Token indices sequence length is longer than the specified maximum sequence length for this model (2824490 > 4096). Running this sequence through the model will result in indexing errors
Phi3ForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 1536, padding_idx=32000)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=1536, out_features=1536, bias=False)
(qkv_proj): Linear(in_features=1536, out_features=4608, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): Phi3MLP(
(gate_up_proj): Linear(in_features=1536, out_features=8192, bias=False)
(down_proj): Linear(in_features=4096, out_features=1536, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm()
(resid_attn_dropout): Dropout(p=0.0, inplace=False)
(resid_mlp_dropout): Dropout(p=0.0, inplace=False)
(post_attention_layernorm): Phi3RMSNorm()
)
)
(norm): Phi3RMSNorm()
)
(lm_head): Linear(in_features=1536, out_features=32064, bias=False)
)
Phi3Config {
"_name_or_path": "microsoft/Phi-3-mini-4k-instruct",
"architectures": [
"Phi3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "microsoft/Phi-3-mini-4k-instruct--configuration_phi3.Phi3Config",
"AutoModelForCausalLM": "microsoft/Phi-3-mini-4k-instruct--modeling_phi3.Phi3ForCausalLM"
},
"bos_token_id": 1,
"embd_pdrop": 0.0,
"eos_token_id": 32000,
"hidden_act": "silu",
"hidden_size": 1536,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 4096,
"model_type": "phi3",
"num_attention_heads": 16,
"num_hidden_layers": 32,
"num_key_value_heads": 16,
"original_max_position_embeddings": 4096,
"pad_token_id": 32000,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000.0,
"sliding_window": 2047,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.36.2",
"use_cache": true,
"vocab_size": 32064
}

num_params 1004570112
evaluating on wikitext2
nsamples 83
sample 0
sample 50
wikitext perplexity 92795.3984375
```

</details>

### Qwen/Qwen2-7B

```bash
python prune_llm.py --model Qwen/Qwen2-7B --pruning_ratio 0.5
```


<details>
<summary>Output:</summary>

```
----------------- Before Pruning -----------------
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 3584)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=3584, out_features=3584, bias=True)
(k_proj): Linear(in_features=3584, out_features=512, bias=True)
(v_proj): Linear(in_features=3584, out_features=512, bias=True)
(o_proj): Linear(in_features=3584, out_features=3584, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
(up_proj): Linear(in_features=3584, out_features=18944, bias=False)
(down_proj): Linear(in_features=18944, out_features=3584, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((3584,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=3584, out_features=152064, bias=False)
)
----------------- After Pruning -----------------
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 1792)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=1792, out_features=2048, bias=True)
(k_proj): Linear(in_features=1792, out_features=512, bias=True)
(v_proj): Linear(in_features=1792, out_features=512, bias=True)
(o_proj): Linear(in_features=2048, out_features=1792, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=1792, out_features=9472, bias=False)
(up_proj): Linear(in_features=1792, out_features=9472, bias=False)
(down_proj): Linear(in_features=9472, out_features=1792, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((1792,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((1792,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((1792,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=1792, out_features=152064, bias=False)
)
Qwen2Config {
"_attn_implementation_autoset": true,
"_name_or_path": "Qwen/Qwen2-7B",
"architectures": [
"Qwen2ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
"hidden_act": "silu",
"hidden_size": 1792,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 131072,
"max_window_layers": 28,
"model_type": "qwen2",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.46.2",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 152064
}

num_params 2227887872
```

</details>

2 changes: 0 additions & 2 deletions examples/torchvision_models/torchvision_global_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
)
from torchvision.models.detection.fcos import fcos_resnet50_fpn
from torchvision.models.detection.keypoint_rcnn import keypointrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn_v2
from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2
from torchvision.models.alexnet import alexnet

from torchvision.models.vision_transformer import (
Expand Down
2 changes: 0 additions & 2 deletions examples/torchvision_models/torchvision_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
)
from torchvision.models.detection.fcos import fcos_resnet50_fpn
from torchvision.models.detection.keypoint_rcnn import keypointrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn_v2
from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2
from torchvision.models.alexnet import alexnet

from torchvision.models.vision_transformer import (
Expand Down
Loading
Loading