-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable inference serving capabilities on sagemaker endpoint using tor…
…nado
- Loading branch information
Showing
13 changed files
with
500 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from __future__ import absolute_import | ||
|
||
import utils.logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#!/bin/bash | ||
python /etc/sagemaker-inference-server/serve.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from __future__ import absolute_import | ||
|
||
""" | ||
TODO: when adding support for more serving frameworks, move the below logic into a condition statement. | ||
We also need to define the right environment variable for signify what serving framework to use. | ||
Ex. | ||
inference_server = None | ||
serving_framework = os.getenv("SAGEMAKER_INFERENCE_FRAMEWORK", None) | ||
if serving_framework == "FastAPI": | ||
inference_server = FastApiServer() | ||
elif serving_framework == "Flask": | ||
inference_server = FlaskServer() | ||
else: | ||
inference_server = TornadoServer() | ||
inference_server.serve() | ||
""" | ||
from tornado_server.server import TornadoServer | ||
|
||
inference_server = TornadoServer() | ||
inference_server.serve() |
12 changes: 12 additions & 0 deletions
12
template/v3/dirs/etc/sagemaker-inference-server/tornado_server/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from __future__ import absolute_import | ||
|
||
import pathlib | ||
import sys | ||
|
||
# make the utils modules accessible to modules from within the tornado_server folder | ||
utils_path = pathlib.Path(__file__).parent.parent / "utils" | ||
sys.path.insert(0, str(utils_path.resolve())) | ||
|
||
# make the tornado_server modules accessible to each other | ||
tornado_module_path = pathlib.Path(__file__).parent | ||
sys.path.insert(0, str(tornado_module_path.resolve())) |
76 changes: 76 additions & 0 deletions
76
template/v3/dirs/etc/sagemaker-inference-server/tornado_server/async_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import logging | ||
from typing import AsyncIterator, Iterator | ||
|
||
import tornado.web | ||
from stream_handler import StreamHandler | ||
|
||
from utils.environment import Environment | ||
from utils.exception import AsyncInvocationsException | ||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class InvocationsHandler(tornado.web.RequestHandler, StreamHandler): | ||
"""Handler mapped to the /invocations POST route. | ||
This handler wraps the async handler retrieved from the inference script | ||
and encapsulates it behind the post() method. The post() method is done | ||
asynchronously. | ||
""" | ||
|
||
def initialize(self, handler: callable, environment: Environment): | ||
"""Initializes the handler function and the serving environment.""" | ||
|
||
self._handler = handler | ||
self._environment = environment | ||
|
||
async def post(self): | ||
"""POST method used to encapsulate and invoke the async handle method asynchronously""" | ||
|
||
try: | ||
response = await self._handler(self.request) | ||
|
||
if isinstance(response, Iterator): | ||
await self.stream(response) | ||
elif isinstance(response, AsyncIterator): | ||
await self.astream(response) | ||
else: | ||
self.write(response) | ||
except Exception as e: | ||
raise AsyncInvocationsException(e) | ||
|
||
|
||
class PingHandler(tornado.web.RequestHandler): | ||
"""Handler mapped to the /ping GET route. | ||
Ping handler to monitor the health of the Tornados server. | ||
""" | ||
|
||
def get(self): | ||
"""Simple GET method to assess the health of the server.""" | ||
|
||
self.write("") | ||
|
||
|
||
async def handle(handler: callable, environment: Environment): | ||
"""Serves the async handler function using Tornado. | ||
Opens the /invocations and /ping routes used by a SageMaker Endpoint | ||
for inference serving capabilities. | ||
""" | ||
|
||
logger.info("Starting inference server in asynchronous mode...") | ||
|
||
app = tornado.web.Application( | ||
[ | ||
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), | ||
(r"/ping", PingHandler), | ||
] | ||
) | ||
app.listen(environment.port) | ||
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`") | ||
await asyncio.Event().wait() |
121 changes: 121 additions & 0 deletions
121
template/v3/dirs/etc/sagemaker-inference-server/tornado_server/server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import importlib | ||
import logging | ||
import subprocess | ||
import sys | ||
from pathlib import Path | ||
|
||
from utils.environment import Environment | ||
from utils.exception import ( | ||
InferenceCodeLoadException, | ||
RequirementsInstallException, | ||
ServerStartException, | ||
) | ||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class TornadoServer: | ||
"""Holds serving logic using the Tornado framework. | ||
The serve.py script will invoke TornadoServer.serve() to start the serving process. | ||
The TornadoServer will install the runtime requirements specified through a requirements file. | ||
It will then load an handler function within an inference script and then front it will an /invocations | ||
route using the Tornado framework. | ||
""" | ||
|
||
def __init__(self): | ||
"""Initialize the serving behaviors. | ||
Defines the serving behavior through Environment() and locate where | ||
the inference code is contained. | ||
""" | ||
|
||
self._environment = Environment() | ||
logger.setLevel(self._environment.logging_level) | ||
logger.debug(f"Environment: {str(self._environment)}") | ||
|
||
self._path_to_inference_code = ( | ||
Path(self._environment.base_directory).joinpath(self._environment.code_directory) | ||
if self._environment.code_directory | ||
else Path(self._environment.base_directory) | ||
) | ||
logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`") | ||
|
||
def initialize(self): | ||
"""Initialize the serving artifacts and dependencies. | ||
Install the runtime requirements and then locate the handler function from | ||
the inference script. | ||
""" | ||
|
||
logger.info("Initializing inference server...") | ||
self._install_runtime_requirements() | ||
self._handler = self._load_inference_handler() | ||
|
||
def serve(self): | ||
"""Orchestrate the initialization and server startup behavior. | ||
Call the initalize() method, determine the right Tornado serving behavior (async or sync), | ||
and then start the Tornado server through asyncio | ||
""" | ||
|
||
logger.info("Serving inference requests using Tornado...") | ||
self.initialize() | ||
|
||
if asyncio.iscoroutinefunction(self._handler): | ||
import async_handler as inference_handler | ||
else: | ||
import sync_handler as inference_handler | ||
|
||
try: | ||
asyncio.run(inference_handler.handle(self._handler, self._environment)) | ||
except Exception as e: | ||
raise ServerStartException(e) | ||
|
||
def _install_runtime_requirements(self): | ||
"""Install the runtime requirements.""" | ||
|
||
logger.info("Installing runtime requirements...") | ||
requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements) | ||
if requirements_txt.is_file(): | ||
try: | ||
subprocess.check_call(["micromamba", "install", "--yes", "--file", str(requirements_txt)]) | ||
except Exception as e: | ||
raise RequirementsInstallException(e) | ||
else: | ||
logger.debug(f"No requirements file was found at `{str(requirements_txt)}`") | ||
|
||
def _load_inference_handler(self) -> callable: | ||
"""Load the handler function from the inference script.""" | ||
|
||
logger.info("Loading inference handler...") | ||
inference_module_name, handle_name = self._environment.code.split(".") | ||
if inference_module_name and handle_name: | ||
inference_module_file = f"{inference_module_name}.py" | ||
module_spec = importlib.util.spec_from_file_location( | ||
inference_module_file, str(self._path_to_inference_code.joinpath(inference_module_file)) | ||
) | ||
if module_spec: | ||
sys.path.insert(0, str(self._path_to_inference_code.resolve())) | ||
module = importlib.util.module_from_spec(module_spec) | ||
module_spec.loader.exec_module(module) | ||
|
||
if hasattr(module, handle_name): | ||
handler = getattr(module, handle_name) | ||
else: | ||
raise InferenceCodeLoadException( | ||
f"Handler `{handle_name}` could not be found in module `{inference_module_file}`" | ||
) | ||
logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`") | ||
return handler | ||
else: | ||
raise InferenceCodeLoadException( | ||
f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`" | ||
) | ||
raise InferenceCodeLoadException( | ||
f"Inference code expected in the format of `<module>.<handler>` but was provided as {self._environment.code}" | ||
) |
59 changes: 59 additions & 0 deletions
59
template/v3/dirs/etc/sagemaker-inference-server/tornado_server/stream_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import absolute_import | ||
|
||
import logging | ||
from typing import AsyncIterator, Iterator | ||
|
||
from tornado.ioloop import IOLoop | ||
|
||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class StreamHandler: | ||
"""Mixin that enables async and sync streaming capabilities to the async and sync handlers | ||
stream() runs a provided iterator/generator fn in an async manner. | ||
astream() runs a provided async iterator/generator fn in an async manner. | ||
""" | ||
|
||
async def stream(self, iterator: Iterator): | ||
"""Streams the response from a sync response iterator | ||
A sync iterator must be manually iterated through asynchronously. | ||
In a loop, iterate through each next(iterator) call in an async execution. | ||
""" | ||
|
||
self._set_stream_headers() | ||
|
||
while True: | ||
try: | ||
chunk = await IOLoop.current().run_in_executor(None, next, iterator) | ||
# Some iterators do not throw a StopIteration upon exhaustion. | ||
# Instead, they return an empty response. Account for this case. | ||
if not chunk: | ||
raise StopIteration() | ||
|
||
self.write(chunk) | ||
await self.flush() | ||
except StopIteration: | ||
break | ||
except Exception as e: | ||
logger.error("Unexpected exception occurred when streaming response...") | ||
break | ||
|
||
async def astream(self, aiterator: AsyncIterator): | ||
"""Streams the response from an async response iterator""" | ||
|
||
self._set_stream_headers() | ||
|
||
async for chunk in aiterator: | ||
self.write(chunk) | ||
await self.flush() | ||
|
||
def _set_stream_headers(self): | ||
"""Set the headers in preparation for the streamed response""" | ||
|
||
self.set_header("Content-Type", "text/event-stream") | ||
self.set_header("Cache-Control", "no-cache") | ||
self.set_header("Connection", "keep-alive") |
77 changes: 77 additions & 0 deletions
77
template/v3/dirs/etc/sagemaker-inference-server/tornado_server/sync_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import logging | ||
from typing import AsyncIterator, Iterator | ||
|
||
import tornado.web | ||
from stream_handler import StreamHandler | ||
from tornado.ioloop import IOLoop | ||
|
||
from utils.environment import Environment | ||
from utils.exception import SyncInvocationsException | ||
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER | ||
|
||
logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER) | ||
|
||
|
||
class InvocationsHandler(tornado.web.RequestHandler, StreamHandler): | ||
"""Handler mapped to the /invocations POST route. | ||
This handler wraps the sync handler retrieved from the inference script | ||
and encapsulates it behind the post() method. The post() method is done | ||
asynchronously. | ||
""" | ||
|
||
def initialize(self, handler: callable, environment: Environment): | ||
"""Initializes the handler function and the serving environment.""" | ||
|
||
self._handler = handler | ||
self._environment = environment | ||
|
||
async def post(self): | ||
"""POST method used to encapsulate and invoke the sync handle method asynchronously""" | ||
|
||
try: | ||
response = await IOLoop.current().run_in_executor(None, self._handler, self.request) | ||
|
||
if isinstance(response, Iterator): | ||
await self.stream(response) | ||
elif isinstance(response, AsyncIterator): | ||
await self.astream(response) | ||
else: | ||
self.write(response) | ||
except Exception as e: | ||
raise SyncInvocationsException(e) | ||
|
||
|
||
class PingHandler(tornado.web.RequestHandler): | ||
"""Handler mapped to the /ping GET route. | ||
Ping handler to monitor the health of the Tornados server. | ||
""" | ||
|
||
def get(self): | ||
"""Simple GET method to assess the health of the server.""" | ||
|
||
self.write("") | ||
|
||
|
||
async def handle(handler: callable, environment: Environment): | ||
"""Serves the sync handler function using Tornado. | ||
Opens the /invocations and /ping routes used by a SageMaker Endpoint | ||
for inference serving capabilities. | ||
""" | ||
|
||
logger.info("Starting inference server in synchronous mode...") | ||
|
||
app = tornado.web.Application( | ||
[ | ||
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)), | ||
(r"/ping", PingHandler), | ||
] | ||
) | ||
app.listen(environment.port) | ||
logger.debug(f"Synchronous inference server listening on port: `{environment.port}`") | ||
await asyncio.Event().wait() |
1 change: 1 addition & 0 deletions
1
template/v3/dirs/etc/sagemaker-inference-server/utils/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from __future__ import absolute_import |
Oops, something went wrong.