From 9382f573e1f042ad7800ed430e26749d9853cc3e Mon Sep 17 00:00:00 2001 From: interestingLSY <2659723130@qq.com> Date: Wed, 13 Dec 2023 02:03:53 +0800 Subject: [PATCH] Bugfix and enhancement for ray-BERT SUT (#1531) * fix: bugfix in ray_SUT.py. Now it yields correct result in accuracy test * feat: allow to change batch_size in ray_SUT * feat: use FP16 in ray_SUT to get an additional 3x speedup * feat: add a reminder when no existing RAY cluster is detected --- language/bert/ray_SUT.py | 42 ++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/language/bert/ray_SUT.py b/language/bert/ray_SUT.py index fd53dc07c..d03e52f50 100644 --- a/language/bert/ray_SUT.py +++ b/language/bert/ray_SUT.py @@ -34,6 +34,9 @@ from ray.util.actor_pool import ActorPool +# Adjustable Parameters +BATCH_SIZE = 16 # Note. num_samples (called "test_query_count" in CM) must be a multiple of batch_size + @ray.remote(num_cpus=1,num_gpus=1) class TorchPredictor: def __init__(self, config_json, model_file, batch_size): @@ -68,7 +71,7 @@ def __init__(self, config_json, model_file, batch_size): torch_tensorrt.Input(shape=[batch_size, 384], dtype=torch.int32), torch_tensorrt.Input(shape=[batch_size, 384], dtype=torch.int32), ], - enabled_precisions= {torch.float32}, + enabled_precisions= {torch.float32, torch.float16}, workspace_size=2000000000, truncate_long_and_double=True) @@ -111,10 +114,23 @@ def __init__(self, args): print("Finished constructing SUT.") self.qsl = get_squad_QSL(args.max_examples) - ray.init() - self.batch_size = 10 + try: + ray.init(address="auto") + except: + print("WARN: Cannot connect to existing Ray cluster.") + print("We are going to start a new RAY cluster, but pay attention that") + print("the cluster contains only one node.") + print("If you want to use multiple nodes, please start the cluster manually via:") + print("\tOn the head node, run `ray start --head`") + print("\tOn other nodes, run `ray start --address=:6379`") + ray.init() + + self.batch_size = BATCH_SIZE resources = ray.cluster_resources() num_gpus = int(resources.get('GPU', 0)) + + print(f"The cluster has {num_gpus} GPUs.") + self.actor_list = [TorchPredictor.remote(config_json, model_file, self.batch_size) for _ in range(num_gpus)] self.pool = ActorPool(self.actor_list) @@ -134,6 +150,10 @@ def __init__(self, args): print("BERT_Ray_SUT construct complete") def issue_queries(self, query_samples): + if len(query_samples) % self.batch_size != 0: + print("ERROR: batch size must be a multiple of the number of samples") + sys.exit(1) + batch_samples = [] i = 0 while i < len(query_samples): @@ -154,23 +174,15 @@ def issue_queries(self, query_samples): # print("samples len", len(batch_samples)) batch_inference_results = list(self.pool.map_unordered(lambda a, v: a.forward.remote(v), batch_samples)) - results = [] + cur_query_index = 0 for batch_inference_result in batch_inference_results: batch_inference_result = batch_inference_result["output"] for inference_result in batch_inference_result: response_array = array.array("B", inference_result.tobytes()) bi = response_array.buffer_info() - results.append(bi) - - # print("results len", len(results)) - - responses = [] - for i in range(len(query_samples)): - # print(query_samples[i].index) - bi = results[i] - response = lg.QuerySampleResponse(query_samples[i].id, bi[0], bi[1]) - responses.append(response) - lg.QuerySamplesComplete(responses) + response = lg.QuerySampleResponse(query_samples[cur_query_index].id, bi[0], bi[1]) + lg.QuerySamplesComplete([response]) + cur_query_index += 1 def flush_queries(self): pass