forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild.py
519 lines (472 loc) · 20.7 KB
/
build.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
from pathlib import Path
import tensorrt as trt
import torch
import torch.multiprocessing as mp
from transformers import BloomConfig, BloomForCausalLM
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import smooth_quantize, weight_only_quantize
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
from weight import load_from_hf_bloom, load_from_bin, parse_config, check_embedding_share # isort:skip
MODEL_NAME = "bloom"
import onnx
import tensorrt as trt
from onnx import TensorProto, helper
def trt_dtype_to_onnx(dtype):
if dtype == trt.float16:
return TensorProto.DataType.FLOAT16
elif dtype == trt.float32:
return TensorProto.DataType.FLOAT
elif dtype == trt.int32:
return TensorProto.DataType.INT32
else:
raise TypeError("%s is not supported" % dtype)
def to_onnx(network, path):
inputs = []
for i in range(network.num_inputs):
network_input = network.get_input(i)
inputs.append(
helper.make_tensor_value_info(
network_input.name, trt_dtype_to_onnx(network_input.dtype),
list(network_input.shape)))
outputs = []
for i in range(network.num_outputs):
network_output = network.get_output(i)
outputs.append(
helper.make_tensor_value_info(
network_output.name, trt_dtype_to_onnx(network_output.dtype),
list(network_output.shape)))
nodes = []
for i in range(network.num_layers):
layer = network.get_layer(i)
layer_inputs = []
for j in range(layer.num_inputs):
ipt = layer.get_input(j)
if ipt is not None:
layer_inputs.append(layer.get_input(j).name)
layer_outputs = [
layer.get_output(j).name for j in range(layer.num_outputs)
]
nodes.append(
helper.make_node(str(layer.type),
name=layer.name,
inputs=layer_inputs,
outputs=layer_outputs,
domain="com.nvidia"))
onnx_model = helper.make_model(helper.make_graph(nodes,
'attention',
inputs,
outputs,
initializer=None),
producer_name='NVIDIA')
onnx.save(onnx_model, path)
def get_engine_name(model, dtype, tp_size, rank):
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
with open(path, 'wb') as f:
f.write(bytearray(engine))
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--world_size',
type=int,
default=1,
help='world size, only support tensor parallelism now')
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--bin_model_dir', type=str, default=None)
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float32', 'float16'])
parser.add_argument(
'--timing_cache',
type=str,
default='model.cache',
help=
'The path of to read timing cache from, will be ignored if the file does not exist'
)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--vocab_size', type=int, default=250680)
parser.add_argument('--n_layer', type=int, default=32)
parser.add_argument('--n_positions', type=int, default=2048)
parser.add_argument('--n_embd', type=int, default=4096)
parser.add_argument('--n_head', type=int, default=32)
parser.add_argument('--mlp_hidden_size', type=int, default=None)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--max_input_len', type=int, default=1024)
parser.add_argument('--max_output_len', type=int, default=1024)
parser.add_argument('--max_beam_width', type=int, default=1)
parser.add_argument('--use_gpt_attention_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16', 'float32'])
parser.add_argument('--use_gemm_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16', 'float32'])
parser.add_argument('--enable_context_fmha',
default=False,
action='store_true')
parser.add_argument('--enable_context_fmha_fp32_acc',
default=False,
action='store_true')
parser.add_argument(
'--use_layernorm_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16', 'float32'],
help=
"Activates layernorm plugin. You can specify the plugin dtype or leave blank to use the model dtype."
)
parser.add_argument('--parallel_build', default=False, action='store_true')
parser.add_argument('--visualize', default=False, action='store_true')
parser.add_argument('--enable_debug_output',
default=False,
action='store_true')
parser.add_argument('--gpus_per_node', type=int, default=8)
parser.add_argument(
'--output_dir',
type=str,
default='bloom_outputs',
help=
'The path to save the serialized engine files, timing cache file and model configs'
)
# Arguments related to the quantization of the model.
parser.add_argument(
'--use_smooth_quant',
default=False,
action="store_true",
help=
'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
'See --per_channel and --per_token for finer-grained quantization options.'
)
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument(
'--per_channel',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--per_token',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.')
parser.add_argument(
'--int8_kv_cache',
default=False,
action="store_true",
help=
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
)
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help=
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help=
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument(
'--use_embedding_sharing',
action="store_true",
default=False,
help=
'Try to reduce the engine size by sharing the embedding lookup table between two layers.'
'Note: the flag might not take effect when the criteria are not met.')
parser.add_argument(
'--use_lookup_plugin',
nargs='?',
const=None,
default=False,
choices=['float16', 'float32', 'bfloat16'],
help="Activates the lookup plugin which enables embedding sharing.")
args = parser.parse_args()
logger.set_level(args.log_level)
if args.model_dir is not None:
hf_config = BloomConfig.from_pretrained(args.model_dir)
args.n_embd = hf_config.hidden_size
args.n_head = hf_config.num_attention_heads
args.n_layer = hf_config.num_hidden_layers
args.vocab_size = hf_config.vocab_size
elif args.bin_model_dir is not None:
logger.info(f"Setting model configuration from {args.bin_model_dir}.")
n_embd, n_head, n_layer, vocab_size, _, rotary_pct, bias, inter_size, multi_query_mode, dtype, prompt_num_tasks, prompt_max_vocab_size = parse_config(
Path(args.bin_model_dir) / "config.ini")
args.n_embd = n_embd
args.n_head = n_head
args.n_layer = n_layer
args.vocab_size = vocab_size
assert not (
args.use_smooth_quant and args.use_weight_only
), "You cannot enable both SmoothQuant and INT8 weight-only together."
if args.use_smooth_quant:
args.quant_mode = QuantMode.use_smooth_quant(args.per_token,
args.per_channel)
elif args.use_weight_only:
args.quant_mode = QuantMode.use_weight_only(
args.weight_only_precision == 'int4')
else:
args.quant_mode = QuantMode(0)
if args.int8_kv_cache:
args.quant_mode = args.quant_mode.set_int8_kv_cache()
return args
def build_rank_engine(builder: Builder,
builder_config: tensorrt_llm.builder.BuilderConfig,
engine_name, rank, args):
'''
@brief: Build the engine on the given rank.
@param rank: The rank to build the engine.
@param args: The cmd line arguments.
@return: The built engine.
'''
kv_dtype = str_dtype_to_trt(args.dtype)
# Share_embedding_table can be set True only when:
# 1) the weight for lm_head() does not exist while other weights exist
# 2) For multiple-processes, use_parallel_embedding=True and embedding_sharding_dim == 0.
# Besides, for TensorRT 9.0, we can observe the engine size reduction when the lookup and gemm plugin are enabled.
share_embedding_table = False
if args.use_embedding_sharing:
if args.world_size > 1:
if args.model_dir is not None and args.embedding_sharding_dim == 0 and args.use_parallel_embedding:
share_embedding_table = check_embedding_share(args.model_dir)
else:
if args.model_dir is not None:
share_embedding_table = check_embedding_share(args.model_dir)
if not share_embedding_table:
logger.warning(f'Cannot share the embedding lookup table.')
if share_embedding_table:
logger.info(
'Engine will share embedding and language modeling weights.')
# Initialize Module
tensorrt_llm_bloom = tensorrt_llm.models.BloomForCausalLM(
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.n_embd,
vocab_size=args.vocab_size,
max_position_embeddings=args.n_positions,
dtype=kv_dtype,
mapping=Mapping(world_size=args.world_size,
rank=rank,
tp_size=args.world_size), # TP only
use_parallel_embedding=args.use_parallel_embedding,
embedding_sharding_dim=args.embedding_sharding_dim,
share_embedding_table=share_embedding_table,
quant_mode=args.quant_mode)
if args.use_smooth_quant:
tensorrt_llm_bloom = smooth_quantize(tensorrt_llm_bloom,
args.quant_mode)
elif args.use_weight_only:
tensorrt_llm_bloom = weight_only_quantize(tensorrt_llm_bloom,
args.quant_mode)
if args.model_dir is not None:
logger.info(f'Loading HF BLOOM ... from {args.model_dir}')
tik = time.time()
hf_bloom = BloomForCausalLM.from_pretrained(args.model_dir,
torch_dtype="auto")
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'HF BLOOM loaded. Total time: {t}')
print(hf_bloom)
load_from_hf_bloom(tensorrt_llm_bloom,
hf_bloom,
rank,
args.world_size,
fp16=(args.dtype == 'float16'),
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
share_embedding_table=share_embedding_table)
elif args.bin_model_dir is not None:
load_from_bin(tensorrt_llm_bloom,
args.bin_model_dir,
rank,
args.world_size,
args.dtype,
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
share_embedding_table=share_embedding_table)
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)
if args.use_lookup_plugin:
# Use the plugin for the embedding parallelism
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.enable_context_fmha_fp32_acc:
network.plugin_config.set_context_fmha(
ContextFMHAType.enabled_with_fp32_acc)
# Quantization plugins.
if args.use_smooth_quant:
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_layernorm_quantization_plugin(
dtype=args.dtype)
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
elif args.use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(args.dtype)
with net_guard(network):
# Prepare
network.set_named_parameters(tensorrt_llm_bloom.named_parameters())
# Forward
inputs = tensorrt_llm_bloom.prepare_inputs(args.max_batch_size,
args.max_input_len,
args.max_output_len, True,
args.max_beam_width)
tensorrt_llm_bloom(*inputs)
if args.enable_debug_output:
# mark intermediate nodes' outputs
for k, v in tensorrt_llm_bloom.named_network_outputs():
v = v.trt_tensor
v.name = k
network.trt_network.mark_output(v)
v.dtype = kv_dtype
if args.visualize:
model_path = os.path.join(args.output_dir, 'test.onnx')
to_onnx(network.trt_network, model_path)
tensorrt_llm.graph_rewriting.optimize(network)
engine = None
# Network -> Engine
engine = builder.build_engine(network, builder_config)
if rank == 0:
config_path = os.path.join(args.output_dir, 'config.json')
builder.save_config(builder_config, config_path)
return engine
def build(rank, args):
torch.cuda.set_device(rank % args.gpus_per_node)
tensorrt_llm.logger.set_level(args.log_level)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# when doing serializing build, all ranks share one engine
builder = Builder()
cache = None
for cur_rank in range(args.world_size):
# skip other ranks if parallel_build is enabled
if args.parallel_build and cur_rank != rank:
continue
# NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
int8_trt_flag = args.quant_mode.has_act_and_weight_quant(
) or args.quant_mode.has_int8_kv_cache()
builder_config = builder.create_builder_config(
name=MODEL_NAME,
precision=args.dtype,
timing_cache=args.timing_cache if cache is None else cache,
tensor_parallel=args.world_size, # TP only
parallel_build=args.parallel_build,
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.n_embd,
vocab_size=args.vocab_size,
max_position_embeddings=args.n_positions,
max_batch_size=args.max_batch_size,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
int8=(args.quant_mode.has_act_and_weight_quant()
or args.quant_mode.has_int8_kv_cache()),
quant_mode=args.quant_mode)
builder_config.trt_builder_config.builder_optimization_level = 1
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,
cur_rank)
engine = build_rank_engine(builder, builder_config, engine_name,
cur_rank, args)
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
if cur_rank == 0:
# Use in-memory timing cache for multiple builder passes.
if not args.parallel_build:
cache = builder_config.trt_builder_config.get_timing_cache()
serialize_engine(engine, os.path.join(args.output_dir, engine_name))
if rank == 0:
ok = builder.save_timing_cache(
builder_config, os.path.join(args.output_dir, "model.cache"))
assert ok, "Failed to save timing cache."
if __name__ == '__main__':
args = parse_arguments()
logger.set_level(args.log_level)
tik = time.time()
if args.parallel_build and args.world_size > 1 and \
torch.cuda.device_count() >= args.world_size:
logger.warning(
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
)
mp.spawn(build, nprocs=args.world_size, args=(args, ))
else:
args.parallel_build = False
logger.info('Serially build TensorRT engines.')
build(0, args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building all {args.world_size} engines: {t}')