Skip to content

Commit

Permalink
new approach to mlx async operations and make tokenizer operations as…
Browse files Browse the repository at this point in the history
…ync too
  • Loading branch information
AlexCheema committed Dec 18, 2024
1 parent 165a9e1 commit b02c0a5
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 19 deletions.
43 changes: 28 additions & 15 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import asyncio
from collections import OrderedDict
from mlx_lm.models.cache import make_prompt_cache
from concurrent.futures import ThreadPoolExecutor

class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
Expand All @@ -20,6 +21,12 @@ def __init__(self, shard_downloader: ShardDownloader):
self.caches = OrderedDict()
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
self.sampler = make_sampler(*self.sampler_params)
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")

async def _eval_mlx(self, *args):
loop = asyncio.get_running_loop()
await loop.run_in_executor(self._mlx_thread, mx.eval, *args)

async def poll_state(self, request_id: str, max_caches=2):
if request_id in self.caches:
Expand All @@ -38,16 +45,19 @@ async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) ->
logits = mx.array(x)
logits = logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, keepdims=True)
return np.asarray(self.sampler(logprobs), dtype=int)
result = self.sampler(logprobs)
await self._eval_mlx(result)
return np.asarray(result, dtype=int)

async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
await self.ensure_shard(shard)
tokens = self.tokenizer.encode(prompt)
return np.asarray(tokens)
loop = asyncio.get_running_loop()
return np.asarray(await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.encode, prompt))

async def decode(self, shard: Shard, tokens) -> str:
await self.ensure_shard(shard)
return self.tokenizer.decode(tokens)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(self._tokenizer_thread, self.tokenizer.decode, tokens)

async def save_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
Expand All @@ -61,8 +71,9 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr
await self.ensure_shard(shard)
state = await self.poll_state(request_id)
x = mx.array(input_data)
output_data = np.array(self.model(x, **state), copy=False)
return output_data
output = self.model(x, **state)
await self._eval_mlx(output)
return np.array(output, copy=False)

async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
await self.ensure_shard(shard)
Expand All @@ -87,26 +98,25 @@ async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, tr
return True

async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
loop = asyncio.get_running_loop()
nothin = await self.ensure_train(shard, loss, opt, lr)
await self.ensure_train(shard, loss, opt, lr)

def train_step(inp, tar, lng):
lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
gradlayers = grad['model']['layers']
self.session['opt'].update(self.model, grad)
mx.eval(self.model.parameters(), self.session['opt'].state, lval)
return lval, gradlayers
return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)

x = mx.array(inputs)
y = mx.array(targets)
l = mx.array(lengths)

score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
#print(f"{score=}")
score, gradients, eval_args = train_step(x, y, l)
await self._eval_mlx(*eval_args)

layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
#print(layers[0])

return score, np.array(layers[0]['input_layernorm'], copy=False)
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
await self._eval_mlx(first_layer)
return score, first_layer

async def ensure_shard(self, shard: Shard):
if self.shard == shard:
Expand All @@ -121,3 +131,6 @@ async def ensure_shard(self, shard: Shard):
self.caches = OrderedDict()
self.session = {}

async def cleanup(self):
self._mlx_thread.shutdown(wait=True)

4 changes: 2 additions & 2 deletions exo/inference/mlx/sharded_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def class_predicate(p, m):

model.load_weights(list(weights.items()), strict=True)

if not lazy:
mx.eval(model.parameters())
# if not lazy:
# mx.eval(model.parameters())

model.eval()
return model
Expand Down
81 changes: 81 additions & 0 deletions exo/inference/mlx/test_non_blocking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import asyncio
import time
import numpy as np
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.inference.shard import Shard
from exo.models import build_base_shard
from collections import deque
from statistics import mean, median

async def test_non_blocking():
# Setup
shard_downloader = HFShardDownloader()
engine = MLXDynamicShardInferenceEngine(shard_downloader)
_shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
await engine.ensure_shard(shard)

queue = asyncio.Queue()
measurements = deque(maxlen=1000000)
running = True

async def mlx_worker():
try:
start_time = time.time()
count = 0
while running and (time.time() - start_time) < 5: # Hard time limit
start = time.perf_counter_ns()
await engine.infer_prompt("req1", shard, "test prompt")
duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms
count += 1
print(f"MLX operation {count} took: {duration:.3f}ms")
except asyncio.CancelledError:
pass
finally:
print(f"\nTotal MLX operations completed: {count}")
print(f"Average rate: {count/5:.1f} ops/second")

async def latency_producer():
try:
start_time = time.perf_counter_ns()
count = 0
while running:
await queue.put(time.perf_counter_ns())
count += 1
await asyncio.sleep(0) # Yield to event loop without delay
duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds
print(f"\nProducer iterations: {count}")
print(f"Producer rate: {count/duration:.1f} iterations/second")
except asyncio.CancelledError:
pass

async def latency_consumer():
try:
while running:
timestamp = await queue.get()
latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms
measurements.append(latency)
queue.task_done()
except asyncio.CancelledError:
pass

tasks = [
asyncio.create_task(mlx_worker()),
asyncio.create_task(latency_producer()),
asyncio.create_task(latency_consumer())
]

try:
await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
except asyncio.TimeoutError:
print("\nTest timed out")
finally:
running = False
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
print(f"\nFinal measurement count: {len(measurements)}")

if __name__ == "__main__":
asyncio.run(test_non_blocking())
12 changes: 12 additions & 0 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,24 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
print(f"Processing prompt: {prompt}")
await node.process_prompt(shard, prompt, request_id=request_id)

first_token_time = time.time()
tokens = []
i = 0
def on_token(_request_id, _token, _is_finished):
nonlocal i
i += 1
if i % 20 == 0:
print(f"TPS: {i / (time.time() - first_token_time)}")

tokens.append(_token)
return _request_id == request_id and _is_finished
await callback.wait(on_token, timeout=300)

print("=== Stats ===")
print(f"Total time: {time.time() - first_token_time}")
print(f"Total tokens: {len(tokens)}")
print(f"Total tokens per second: {len(tokens) / (time.time() - first_token_time)}")

print("\nGenerated response:")
print(tokenizer.decode(tokens))
except Exception as e:
Expand Down
7 changes: 5 additions & 2 deletions exo/orchestration/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async def process_inference_result(
context = TraceContext(request_id=request_id or str(uuid.uuid4()), sequence_number=0)
tracer.set_context(request_id, context)

is_finished = False
try:
with tracer.start_span(
f"process_inference_result.{self.get_partition_index()}",
Expand All @@ -136,9 +137,10 @@ async def process_inference_result(
):
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens

if shard.is_last_layer() and not is_finished:
if shard.is_last_layer():
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens

# Add span for sampling
with tracer.start_span(
"sample_token",
Expand Down Expand Up @@ -203,6 +205,7 @@ async def process_inference_result(
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
self.outstanding_requests.pop(request_id)


return np.array(self.buffered_token_output[request_id][0])
except Exception as e:
if request_id in self.outstanding_requests:
Expand Down

0 comments on commit b02c0a5

Please sign in to comment.