-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
122 lines (103 loc) · 3.77 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import time
import torch
import argparse
import os
from flask import Flask, request
from batching import BatchingManager
from threading import Thread
from stats import ServerStats
from model import load_base_model_config, ServerModel, DynamicBatchingServerModel
import tiktoken
from generate.generate import static_batch_generate, generate, dynamic_batch_generate
from generate.generate_mock import (
mock_generate,
mock_dynamic_batch_generate,
mock_static_batch_generate,
)
# Expect batching type to be stored in BATCHING env var
# else defaults to ``nobatch``
BATCHING = os.getenv("BATCHING")
if BATCHING not in ["static", "dynamic", "nobatch"]:
BATCHING = "nobatch"
# set mock=True to test without usng an actual gpt model
# if mock=True, stub functions are used in place of gpt
mock = False
app = Flask(__name__)
# server_stats is a python class that keeps track of experiment stats
server_stats = ServerStats()
@app.route("/stats", methods=["GET"])
def stats():
print("SERVER LOGS: Handling stats request")
latency_per_token = server_stats.latency_per_token()
throughput = server_stats.throughput()
# break down stats by num tokens requested
with open("data/token_data_{}.csv".format(BATCHING), "w") as f:
for data_len, lst in server_stats.token_breakdown().items():
f.write("{},".format(data_len))
f.write("{}".format(sum(lst) / len(lst)))
f.write("\n")
return {
"latency-per-token": latency_per_token,
"throughput": throughput,
"total-tokens-handled": server_stats.total_tokens,
"total-elapsed-time": server_stats.total_elapsed,
}
@app.route("/", methods=["POST"])
def home():
return (
"hello use /inference endpoint for inferences and /stats to get infrence stats"
)
@app.route("/inference", methods=["POST"])
def inference():
data = json.loads(request.get_data())
prompt = data["prompt"]
num_tokens = int(data["num_tokens"])
assert isinstance(prompt, str)
print(
"SERVER LOGS: NEW, prompt len ~{} | requesting {} tokens".format(
len(prompt.split(" ")), num_tokens
)
)
# make inference
rid = server_stats.start_request(num_tokens)
inference = manager.enqueue(prompt, num_tokens)
# wait for inference to be completed
completion = inference.wait_for_completion()
server_stats.finish_request(rid)
print(
"SERVER LOGS: FINISHED, prompt len ~{} | requesting {} tokens".format(
len(prompt.split(" ")), num_tokens
)
)
return {"completion": completion}
print("SERVER LOGS: Launching with batching strategy of ({})".format(BATCHING))
gpt_model, enc, device = None, None, None
if not mock:
gpt_model, enc, device = load_base_model_config()
gpt_model.eval()
gpt_model.to(device)
else:
gpt_model = None
enc = tiktoken.get_encoding("gpt2")
device = "cpu"
if BATCHING == "nobatch":
model = ServerModel(gpt_model, enc, device)
generate = generate if not mock else mock_generate
manager = BatchingManager(model, generate)
run_inferences = Thread(target=manager.no_batching_loop)
elif BATCHING == "static":
model = ServerModel(gpt_model, enc, device)
static_batch_generate = (
static_batch_generate if not mock else mock_static_batch_generate
)
manager = BatchingManager(model, static_batch_generate)
run_inferences = Thread(target=manager.static_batching_loop)
elif BATCHING == "dynamic":
model = DynamicBatchingServerModel(gpt_model, enc, device)
dynamic_batch_generate = (
dynamic_batch_generate if not mock else mock_dynamic_batch_generate
)
manager = BatchingManager(model, dynamic_batch_generate)
run_inferences = Thread(target=manager.dynamic_batching_loop)
run_inferences.start()