From ba40f3ba4d1f2ce4cc9bc3745db41e4e4af171e8 Mon Sep 17 00:00:00 2001 From: Uranus <109661872+UranusSeven@users.noreply.github.com> Date: Tue, 5 Sep 2023 14:53:24 +0800 Subject: [PATCH] Uniform code style for utils for better readability (#108) --- lightllm/utils/infer_utils.py | 13 ++++++++++--- lightllm/utils/net_utils.py | 7 +++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/lightllm/utils/infer_utils.py b/lightllm/utils/infer_utils.py index 4178eb7b..6bf1a0b1 100644 --- a/lightllm/utils/infer_utils.py +++ b/lightllm/utils/infer_utils.py @@ -21,24 +21,29 @@ def time_func(*args, **kwargs): ans = func(*args, **kwargs) torch.cuda.synchronize() return ans + return time_func + return inner_func + time_mark = {} + def mark_start(key): torch.cuda.synchronize() global time_mark time_mark[key] = time.time() return + def mark_end(key, print_min_cost=0.0): torch.cuda.synchronize() global time_mark cost_time = (time.time() - time_mark[key]) * 1000 if cost_time > print_min_cost: print(f"cost {key}:", cost_time) - + def calculate_time(show=False, min_cost_ms=0.0): def wrapper(func): @@ -53,16 +58,18 @@ def inner_func(*args, **kwargs): if cost_time > min_cost_ms: print(f"Function {func.__name__} took {cost_time} ms to run.") return result + return inner_func + return wrapper def set_random_seed(seed: int) -> None: import random + random.seed(seed) import numpy as np + torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - - \ No newline at end of file diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index ef8e0358..acac8fda 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -1,15 +1,14 @@ import socket - def alloc_can_use_network_port(num=3, used_nccl_port=None): port_list = [] for port in range(10000, 65536): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - result = s.connect_ex(('localhost', port)) + result = s.connect_ex(("localhost", port)) if result != 0 and port != used_nccl_port: port_list.append(port) - + if len(port_list) == num: return port_list - return None \ No newline at end of file + return None