From b02c0a5be0b5974aeaa9919e434a417afa9e9d65 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Wed, 18 Dec 2024 23:26:42 +0000 Subject: [PATCH] new approach to mlx async operations and make tokenizer operations async too --- exo/inference/mlx/sharded_inference_engine.py | 43 ++++++---- exo/inference/mlx/sharded_utils.py | 4 +- exo/inference/mlx/test_non_blocking.py | 81 +++++++++++++++++++ exo/main.py | 12 +++ exo/orchestration/node.py | 7 +- 5 files changed, 128 insertions(+), 19 deletions(-) create mode 100644 exo/inference/mlx/test_non_blocking.py diff --git a/exo/inference/mlx/sharded_inference_engine.py b/exo/inference/mlx/sharded_inference_engine.py index bbe4d435e..89dcd602d 100644 --- a/exo/inference/mlx/sharded_inference_engine.py +++ b/exo/inference/mlx/sharded_inference_engine.py @@ -12,6 +12,7 @@ import asyncio from collections import OrderedDict from mlx_lm.models.cache import make_prompt_cache +from concurrent.futures import ThreadPoolExecutor class MLXDynamicShardInferenceEngine(InferenceEngine): def __init__(self, shard_downloader: ShardDownloader): @@ -20,6 +21,12 @@ def __init__(self, shard_downloader: ShardDownloader): self.caches = OrderedDict() self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1) self.sampler = make_sampler(*self.sampler_params) + self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx") + self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer") + + async def _eval_mlx(self, *args): + loop = asyncio.get_running_loop() + await loop.run_in_executor(self._mlx_thread, mx.eval, *args) async def poll_state(self, request_id: str, max_caches=2): if request_id in self.caches: @@ -38,16 +45,19 @@ async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> logits = mx.array(x) logits = logits[:, -1, :] logprobs = logits - mx.logsumexp(logits, keepdims=True) - return np.asarray(self.sampler(logprobs), dtype=int) + result = self.sampler(logprobs) + await self._eval_mlx(result) + return np.asarray(result, dtype=int) async def encode(self, shard: Shard, prompt: str) -> np.ndarray: await self.ensure_shard(shard) - tokens = self.tokenizer.encode(prompt) - return np.asarray(tokens) + loop = asyncio.get_running_loop() + return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt)) async def decode(self, shard: Shard, tokens) -> str: await self.ensure_shard(shard) - return self.tokenizer.decode(tokens) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens) async def save_checkpoint(self, shard: Shard, path: str): await self.ensure_shard(shard) @@ -61,8 +71,9 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr await self.ensure_shard(shard) state = await self.poll_state(request_id) x = mx.array(input_data) - output_data = np.array(self.model(x, **state), copy=False) - return output_data + output = self.model(x, **state) + await self._eval_mlx(output) + return np.array(output, copy=False) async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"): await self.ensure_shard(shard) @@ -87,26 +98,25 @@ async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, tr return True async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5): - loop = asyncio.get_running_loop() - nothin = await self.ensure_train(shard, loss, opt, lr) + await self.ensure_train(shard, loss, opt, lr) + def train_step(inp, tar, lng): lval, grad = self.session['LVaG'](self.model, inp, tar, lng) gradlayers = grad['model']['layers'] self.session['opt'].update(self.model, grad) - mx.eval(self.model.parameters(), self.session['opt'].state, lval) - return lval, gradlayers + return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval) x = mx.array(inputs) y = mx.array(targets) l = mx.array(lengths) - score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l) - #print(f"{score=}") + score, gradients, eval_args = train_step(x, y, l) + await self._eval_mlx(*eval_args) layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l] - #print(layers[0]) - - return score, np.array(layers[0]['input_layernorm'], copy=False) + first_layer = np.array(layers[0]['input_layernorm'], copy=False) + await self._eval_mlx(first_layer) + return score, first_layer async def ensure_shard(self, shard: Shard): if self.shard == shard: @@ -121,3 +131,6 @@ async def ensure_shard(self, shard: Shard): self.caches = OrderedDict() self.session = {} + async def cleanup(self): + self._mlx_thread.shutdown(wait=True) + diff --git a/exo/inference/mlx/sharded_utils.py b/exo/inference/mlx/sharded_utils.py index fca15a1f6..2c4b45388 100644 --- a/exo/inference/mlx/sharded_utils.py +++ b/exo/inference/mlx/sharded_utils.py @@ -164,8 +164,8 @@ def class_predicate(p, m): model.load_weights(list(weights.items()), strict=True) - if not lazy: - mx.eval(model.parameters()) + # if not lazy: + # mx.eval(model.parameters()) model.eval() return model diff --git a/exo/inference/mlx/test_non_blocking.py b/exo/inference/mlx/test_non_blocking.py new file mode 100644 index 000000000..64eedfdde --- /dev/null +++ b/exo/inference/mlx/test_non_blocking.py @@ -0,0 +1,81 @@ +import asyncio +import time +import numpy as np +from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.inference.shard import Shard +from exo.models import build_base_shard +from collections import deque +from statistics import mean, median + +async def test_non_blocking(): + # Setup + shard_downloader = HFShardDownloader() + engine = MLXDynamicShardInferenceEngine(shard_downloader) + _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine") + shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers) + await engine.ensure_shard(shard) + + queue = asyncio.Queue() + measurements = deque(maxlen=1000000) + running = True + + async def mlx_worker(): + try: + start_time = time.time() + count = 0 + while running and (time.time() - start_time) < 5: # Hard time limit + start = time.perf_counter_ns() + await engine.infer_prompt("req1", shard, "test prompt") + duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms + count += 1 + print(f"MLX operation {count} took: {duration:.3f}ms") + except asyncio.CancelledError: + pass + finally: + print(f"\nTotal MLX operations completed: {count}") + print(f"Average rate: {count/5:.1f} ops/second") + + async def latency_producer(): + try: + start_time = time.perf_counter_ns() + count = 0 + while running: + await queue.put(time.perf_counter_ns()) + count += 1 + await asyncio.sleep(0) # Yield to event loop without delay + duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds + print(f"\nProducer iterations: {count}") + print(f"Producer rate: {count/duration:.1f} iterations/second") + except asyncio.CancelledError: + pass + + async def latency_consumer(): + try: + while running: + timestamp = await queue.get() + latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms + measurements.append(latency) + queue.task_done() + except asyncio.CancelledError: + pass + + tasks = [ + asyncio.create_task(mlx_worker()), + asyncio.create_task(latency_producer()), + asyncio.create_task(latency_consumer()) + ] + + try: + await asyncio.wait_for(asyncio.gather(*tasks), timeout=6) + except asyncio.TimeoutError: + print("\nTest timed out") + finally: + running = False + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + print(f"\nFinal measurement count: {len(measurements)}") + +if __name__ == "__main__": + asyncio.run(test_non_blocking()) diff --git a/exo/main.py b/exo/main.py index 9d4110b6a..e30cc98ed 100644 --- a/exo/main.py +++ b/exo/main.py @@ -235,12 +235,24 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam print(f"Processing prompt: {prompt}") await node.process_prompt(shard, prompt, request_id=request_id) + first_token_time = time.time() tokens = [] + i = 0 def on_token(_request_id, _token, _is_finished): + nonlocal i + i += 1 + if i % 20 == 0: + print(f"TPS: {i / (time.time() - first_token_time)}") + tokens.append(_token) return _request_id == request_id and _is_finished await callback.wait(on_token, timeout=300) + print("=== Stats ===") + print(f"Total time: {time.time() - first_token_time}") + print(f"Total tokens: {len(tokens)}") + print(f"Total tokens per second: {len(tokens) / (time.time() - first_token_time)}") + print("\nGenerated response:") print(tokenizer.decode(tokens)) except Exception as e: diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 2a8c831aa..cbd5ceedf 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -123,6 +123,7 @@ async def process_inference_result( context = TraceContext(request_id=request_id or str(uuid.uuid4()), sequence_number=0) tracer.set_context(request_id, context) + is_finished = False try: with tracer.start_span( f"process_inference_result.{self.get_partition_index()}", @@ -136,9 +137,10 @@ async def process_inference_result( ): if request_id not in self.buffered_token_output: self.buffered_token_output[request_id] = ([], False) - is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens - if shard.is_last_layer() and not is_finished: + if shard.is_last_layer(): + is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens + # Add span for sampling with tracer.start_span( "sample_token", @@ -203,6 +205,7 @@ async def process_inference_result( self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) self.outstanding_requests.pop(request_id) + return np.array(self.buffered_token_output[request_id][0]) except Exception as e: if request_id in self.outstanding_requests: