Skip to content

Commit

Permalink
Merge pull request #435 from VainF/v2.0
Browse files Browse the repository at this point in the history
Add SliceOp; Add Phi-3 & Qwen-2
  • Loading branch information
VainF authored Nov 17, 2024
2 parents 224d7f8 + d74742e commit 6dc7e57
Show file tree
Hide file tree
Showing 10 changed files with 380 additions and 25 deletions.
54 changes: 45 additions & 9 deletions examples/LLMs/prune_llama.py → examples/LLMs/prune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,15 @@ def pattern_match(patterns, source_list):
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name):
def get_llm(model_name, max_seq_len=None):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)

model.seqlen = model.config.max_position_embeddings
model.seqlen = max(max_seq_len, model.config.max_position_embeddings) if max_seq_len is not None else model.config.max_position_embeddings
# avoid OOM, feel free to change this
return model

def main():
Expand All @@ -266,6 +267,7 @@ def main():
parser.add_argument('--save', type=str, default=None, help='Path to save results.')
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
parser.add_argument("--eval_zero_shot", action="store_true")
parser.add_argument("--max_seq_len", type=int, default=None)
args = parser.parse_args()

# Setting seeds for reproducibility
Expand All @@ -274,7 +276,7 @@ def main():

model_name = args.model.split("/")[-1]
print(f"loading llm model {args.model}")
model = get_llm(args.model)
model = get_llm(args.model, max_seq_len=args.max_seq_len)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
device = torch.device("cuda:0")
Expand All @@ -291,11 +293,22 @@ def main():
inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
import torch_pruning as tp
num_heads = {}
out_channel_groups = {}
seperate_qkv = False
for name, m in model.named_modules():
if name.endswith("self_attn"):
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads
if hasattr(m, "q_proj"):
seperate_qkv = True
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads
elif hasattr(m, "qkv_proj"):
seperate_qkv = False
num_heads[m.qkv_proj] = model.config.num_attention_heads
if name.endswith('mlp'):
if hasattr(m, "gate_up_proj"):
out_channel_groups[m.gate_up_proj] = 2

_is_gqa = model.config.num_attention_heads != model.config.num_key_value_heads
head_pruning_ratio = args.pruning_ratio
hidden_size_pruning_ratio = args.pruning_ratio
Expand All @@ -311,33 +324,56 @@ def main():
prune_num_heads=True,
prune_head_dims=False, # we do not prune head dims so that we don't need to prune the ROPE
head_pruning_ratio=head_pruning_ratio,
out_channel_groups=out_channel_groups,
round_to=4,
)


#with torch.no_grad():
# with importance.compute_importance(model):
# calibration_data = "We recommend at least a 1TB hard drive for 4 channels, more if you plan on using 8MP \/ 4K cameras.\nDahua's Lite Series network video recorders offer excellent performance and high recording quality for IP video surveillance applications. For applications where details are critical for identification, this professional NVR provides a powerful processor with up to 4K resolution. Additionally, the NVR features a mouse shortcut operation menu, remote management and control, center storage, edge storage, and back up storage."
# calibration_data = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
# _ = model(calibration_data)
pruner.step()

#group = pruner.DG.get_pruning_group(model.model.layers[31].mlp.gate_up_proj, tp.prune_linear_out_channels, idxs=list(range(16384)))
#print(group)

for g in pruner.step(interactive=True):
print(g)
g.prune()

# Update model attributes
model.config.hidden_size = model.lm_head.in_features
for name, m in model.named_modules():
if name.endswith("self_attn"):
m.hidden_size = m.q_proj.out_features
if seperate_qkv:
m.hidden_size = m.q_proj.out_features
else:
m.hidden_size = m.qkv_proj.out_features // 3
m.num_heads = m.hidden_size // m.head_dim
model.config.num_attention_heads = m.num_heads
#m.head_dim = m.q_proj.out_features // m.num_heads
if not _is_gqa:
m.num_key_value_heads = m.num_heads
m.num_key_value_groups = m.num_heads // m.num_key_value_heads
elif name.endswith("mlp"):
model.config.intermediate_size = m.gate_proj.out_features
if hasattr(m, "gate_proj"):
m.hidden_size = m.gate_proj.out_features
elif hasattr(m, "gate_up_proj"):
m.hidden_size = m.gate_up_proj.in_features
else:
raise ValueError("Unknown mlp layer")

if not _is_gqa:
model.config.num_key_value_heads = model.config.num_attention_heads
print("----------------- After Pruning -----------------")
print(model)
print(model.config)


del pruner
torch.cuda.empty_cache()
model.eval()
num_params = sum(p.numel() for p in model.parameters())
print(f"num_params {num_params}")
ppl_test = eval_ppl(args, model, tokenizer, device)
Expand Down
231 changes: 228 additions & 3 deletions examples/LLMs/readme.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
# Prune Large Language Models
# Pruning Large Language Models

This example provides a minimal example of pruning large language models with magnitude-based pruning. We use the `transformers` library to load the model and the `datasets` library to evaluate the Perplexity with `Wikitext2`. **For more comprehensive examples of Gradient-based pruning or finetuning, please refer to [LLM-Pruner](https://github.com/horseee/LLM-Pruner)**.

This script has been tested with the following models:

1. [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
2. [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
3. [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
4. [Qwen/Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B)



## 0. Requirements

```bash
Expand All @@ -13,7 +22,7 @@ pip install transformers datasets
### Llama-3 8B

```bash
python prune_llama.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
python prune_llm.py --model meta-llama/Meta-Llama-3-8B --pruning_ratio 0.5
```

<details>
Expand Down Expand Up @@ -120,7 +129,7 @@ wikitext perplexity 552648.25
### Llama-2 7B

```bash
python prune_llama.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
python prune_llm.py --model meta-llama/Llama-2-7b-hf --pruning_ratio 0.5
```


Expand Down Expand Up @@ -224,3 +233,219 @@ wikitext perplexity 8479.0673828125
</details>


### microsoft/Phi-3-mini-4k-instruct

```bash
python prune_llm.py --model microsoft/Phi-3-mini-4k-instruct --pruning_ratio 0.5
```


<details>
<summary>Output:</summary>

```
----------------- Before Pruning -----------------
Phi3ForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 3072, padding_idx=32000)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=3072, out_features=3072, bias=False)
(qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): Phi3MLP(
(gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
(down_proj): Linear(in_features=8192, out_features=3072, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm()
(resid_attn_dropout): Dropout(p=0.0, inplace=False)
(resid_mlp_dropout): Dropout(p=0.0, inplace=False)
(post_attention_layernorm): Phi3RMSNorm()
)
)
(norm): Phi3RMSNorm()
)
(lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)
----------------- After Pruning -----------------
Token indices sequence length is longer than the specified maximum sequence length for this model (2824490 > 4096). Running this sequence through the model will result in indexing errors
Phi3ForCausalLM(
(model): Phi3Model(
(embed_tokens): Embedding(32064, 1536, padding_idx=32000)
(embed_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-31): 32 x Phi3DecoderLayer(
(self_attn): Phi3Attention(
(o_proj): Linear(in_features=1536, out_features=1536, bias=False)
(qkv_proj): Linear(in_features=1536, out_features=4608, bias=False)
(rotary_emb): Phi3RotaryEmbedding()
)
(mlp): Phi3MLP(
(gate_up_proj): Linear(in_features=1536, out_features=8192, bias=False)
(down_proj): Linear(in_features=4096, out_features=1536, bias=False)
(activation_fn): SiLU()
)
(input_layernorm): Phi3RMSNorm()
(resid_attn_dropout): Dropout(p=0.0, inplace=False)
(resid_mlp_dropout): Dropout(p=0.0, inplace=False)
(post_attention_layernorm): Phi3RMSNorm()
)
)
(norm): Phi3RMSNorm()
)
(lm_head): Linear(in_features=1536, out_features=32064, bias=False)
)
Phi3Config {
"_name_or_path": "microsoft/Phi-3-mini-4k-instruct",
"architectures": [
"Phi3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "microsoft/Phi-3-mini-4k-instruct--configuration_phi3.Phi3Config",
"AutoModelForCausalLM": "microsoft/Phi-3-mini-4k-instruct--modeling_phi3.Phi3ForCausalLM"
},
"bos_token_id": 1,
"embd_pdrop": 0.0,
"eos_token_id": 32000,
"hidden_act": "silu",
"hidden_size": 1536,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 4096,
"model_type": "phi3",
"num_attention_heads": 16,
"num_hidden_layers": 32,
"num_key_value_heads": 16,
"original_max_position_embeddings": 4096,
"pad_token_id": 32000,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000.0,
"sliding_window": 2047,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.36.2",
"use_cache": true,
"vocab_size": 32064
}
num_params 1004570112
evaluating on wikitext2
nsamples 83
sample 0
sample 50
wikitext perplexity 92795.3984375
```

</details>

### Qwen/Qwen2-7B

```bash
python prune_llm.py --model Qwen/Qwen2-7B --pruning_ratio 0.5
```


<details>
<summary>Output:</summary>

```
----------------- Before Pruning -----------------
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 3584)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=3584, out_features=3584, bias=True)
(k_proj): Linear(in_features=3584, out_features=512, bias=True)
(v_proj): Linear(in_features=3584, out_features=512, bias=True)
(o_proj): Linear(in_features=3584, out_features=3584, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
(up_proj): Linear(in_features=3584, out_features=18944, bias=False)
(down_proj): Linear(in_features=18944, out_features=3584, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((3584,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=3584, out_features=152064, bias=False)
)
----------------- After Pruning -----------------
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(152064, 1792)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): Linear(in_features=1792, out_features=2048, bias=True)
(k_proj): Linear(in_features=1792, out_features=512, bias=True)
(v_proj): Linear(in_features=1792, out_features=512, bias=True)
(o_proj): Linear(in_features=2048, out_features=1792, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=1792, out_features=9472, bias=False)
(up_proj): Linear(in_features=1792, out_features=9472, bias=False)
(down_proj): Linear(in_features=9472, out_features=1792, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((1792,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((1792,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((1792,), eps=1e-06)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=1792, out_features=152064, bias=False)
)
Qwen2Config {
"_attn_implementation_autoset": true,
"_name_or_path": "Qwen/Qwen2-7B",
"architectures": [
"Qwen2ForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
"hidden_act": "silu",
"hidden_size": 1792,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 131072,
"max_window_layers": 28,
"model_type": "qwen2",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.46.2",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 152064
}
num_params 2227887872
```

</details>

2 changes: 0 additions & 2 deletions examples/torchvision_models/torchvision_global_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
)
from torchvision.models.detection.fcos import fcos_resnet50_fpn
from torchvision.models.detection.keypoint_rcnn import keypointrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn_v2
from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2
from torchvision.models.alexnet import alexnet

from torchvision.models.vision_transformer import (
Expand Down
2 changes: 0 additions & 2 deletions examples/torchvision_models/torchvision_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
)
from torchvision.models.detection.fcos import fcos_resnet50_fpn
from torchvision.models.detection.keypoint_rcnn import keypointrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn_v2
from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2
from torchvision.models.alexnet import alexnet

from torchvision.models.vision_transformer import (
Expand Down
Loading

0 comments on commit 6dc7e57

Please sign in to comment.