Skip to content

Commit

Permalink
Bugfix and enhancement for ray-BERT SUT (mlcommons#1531)
Browse files Browse the repository at this point in the history
* fix: bugfix in ray_SUT.py. Now it yields correct result in accuracy test

* feat: allow to change batch_size in ray_SUT

* feat: use FP16 in ray_SUT to get an additional 3x speedup

* feat: add a reminder when no existing RAY cluster is detected
  • Loading branch information
interestingLSY authored Dec 12, 2023
1 parent 4c231dc commit 9382f57
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions language/bert/ray_SUT.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from ray.util.actor_pool import ActorPool


# Adjustable Parameters
BATCH_SIZE = 16 # Note. num_samples (called "test_query_count" in CM) must be a multiple of batch_size

@ray.remote(num_cpus=1,num_gpus=1)
class TorchPredictor:
def __init__(self, config_json, model_file, batch_size):
Expand Down Expand Up @@ -68,7 +71,7 @@ def __init__(self, config_json, model_file, batch_size):
torch_tensorrt.Input(shape=[batch_size, 384], dtype=torch.int32),
torch_tensorrt.Input(shape=[batch_size, 384], dtype=torch.int32),
],
enabled_precisions= {torch.float32},
enabled_precisions= {torch.float32, torch.float16},
workspace_size=2000000000,
truncate_long_and_double=True)

Expand Down Expand Up @@ -111,10 +114,23 @@ def __init__(self, args):
print("Finished constructing SUT.")
self.qsl = get_squad_QSL(args.max_examples)

ray.init()
self.batch_size = 10
try:
ray.init(address="auto")
except:
print("WARN: Cannot connect to existing Ray cluster.")
print("We are going to start a new RAY cluster, but pay attention that")
print("the cluster contains only one node.")
print("If you want to use multiple nodes, please start the cluster manually via:")
print("\tOn the head node, run `ray start --head`")
print("\tOn other nodes, run `ray start --address=<head node IP>:6379`")
ray.init()

self.batch_size = BATCH_SIZE
resources = ray.cluster_resources()
num_gpus = int(resources.get('GPU', 0))

print(f"The cluster has {num_gpus} GPUs.")

self.actor_list = [TorchPredictor.remote(config_json, model_file, self.batch_size) for _ in range(num_gpus)]
self.pool = ActorPool(self.actor_list)

Expand All @@ -134,6 +150,10 @@ def __init__(self, args):
print("BERT_Ray_SUT construct complete")

def issue_queries(self, query_samples):
if len(query_samples) % self.batch_size != 0:
print("ERROR: batch size must be a multiple of the number of samples")
sys.exit(1)

batch_samples = []
i = 0
while i < len(query_samples):
Expand All @@ -154,23 +174,15 @@ def issue_queries(self, query_samples):
# print("samples len", len(batch_samples))
batch_inference_results = list(self.pool.map_unordered(lambda a, v: a.forward.remote(v), batch_samples))

results = []
cur_query_index = 0
for batch_inference_result in batch_inference_results:
batch_inference_result = batch_inference_result["output"]
for inference_result in batch_inference_result:
response_array = array.array("B", inference_result.tobytes())
bi = response_array.buffer_info()
results.append(bi)

# print("results len", len(results))

responses = []
for i in range(len(query_samples)):
# print(query_samples[i].index)
bi = results[i]
response = lg.QuerySampleResponse(query_samples[i].id, bi[0], bi[1])
responses.append(response)
lg.QuerySamplesComplete(responses)
response = lg.QuerySampleResponse(query_samples[cur_query_index].id, bi[0], bi[1])
lg.QuerySamplesComplete([response])
cur_query_index += 1

def flush_queries(self):
pass
Expand Down

0 comments on commit 9382f57

Please sign in to comment.