Skip to content

Commit

Permalink
Enable inference serving capabilities on sagemaker endpoint using tor…
Browse files Browse the repository at this point in the history
…nado
  • Loading branch information
gwang111 committed Jan 14, 2025
1 parent 33b6986 commit 141031a
Show file tree
Hide file tree
Showing 13 changed files with 500 additions and 1 deletion.
2 changes: 1 addition & 1 deletion template/v3/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ RUN mkdir -p $SAGEMAKER_LOGGING_DIR && \
&& ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} python \
&& rm -rf ${HOME_DIR}/oss_compliance*

ENV PATH="/opt/conda/bin:/opt/conda/condabin:$PATH"
ENV PATH="/etc/sagemaker-inference-server:/opt/conda/bin:/opt/conda/condabin:$PATH"
WORKDIR "/home/${NB_USER}"
ENV SHELL=/bin/bash
ENV OPENSSL_MODULES=/opt/conda/lib64/ossl-modules/
Expand Down
3 changes: 3 additions & 0 deletions template/v3/dirs/etc/sagemaker-inference-server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import absolute_import

import utils.logger
2 changes: 2 additions & 0 deletions template/v3/dirs/etc/sagemaker-inference-server/serve
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/bash
python /etc/sagemaker-inference-server/serve.py
25 changes: 25 additions & 0 deletions template/v3/dirs/etc/sagemaker-inference-server/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import absolute_import

"""
TODO: when adding support for more serving frameworks, move the below logic into a condition statement.
We also need to define the right environment variable for signify what serving framework to use.
Ex.
inference_server = None
serving_framework = os.getenv("SAGEMAKER_INFERENCE_FRAMEWORK", None)
if serving_framework == "FastAPI":
inference_server = FastApiServer()
elif serving_framework == "Flask":
inference_server = FlaskServer()
else:
inference_server = TornadoServer()
inference_server.serve()
"""
from tornado_server.server import TornadoServer

inference_server = TornadoServer()
inference_server.serve()
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import absolute_import

import pathlib
import sys

# make the utils modules accessible to modules from within the tornado_server folder
utils_path = pathlib.Path(__file__).parent.parent / "utils"
sys.path.insert(0, str(utils_path.resolve()))

# make the tornado_server modules accessible to each other
tornado_module_path = pathlib.Path(__file__).parent
sys.path.insert(0, str(tornado_module_path.resolve()))
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import absolute_import

import asyncio
import logging
from typing import AsyncIterator, Iterator

import tornado.web
from stream_handler import StreamHandler

from utils.environment import Environment
from utils.exception import AsyncInvocationsException
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class InvocationsHandler(tornado.web.RequestHandler, StreamHandler):
"""Handler mapped to the /invocations POST route.
This handler wraps the async handler retrieved from the inference script
and encapsulates it behind the post() method. The post() method is done
asynchronously.
"""

def initialize(self, handler: callable, environment: Environment):
"""Initializes the handler function and the serving environment."""

self._handler = handler
self._environment = environment

async def post(self):
"""POST method used to encapsulate and invoke the async handle method asynchronously"""

try:
response = await self._handler(self.request)

if isinstance(response, Iterator):
await self.stream(response)
elif isinstance(response, AsyncIterator):
await self.astream(response)
else:
self.write(response)
except Exception as e:
raise AsyncInvocationsException(e)


class PingHandler(tornado.web.RequestHandler):
"""Handler mapped to the /ping GET route.
Ping handler to monitor the health of the Tornados server.
"""

def get(self):
"""Simple GET method to assess the health of the server."""

self.write("")


async def handle(handler: callable, environment: Environment):
"""Serves the async handler function using Tornado.
Opens the /invocations and /ping routes used by a SageMaker Endpoint
for inference serving capabilities.
"""

logger.info("Starting inference server in asynchronous mode...")

app = tornado.web.Application(
[
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
(r"/ping", PingHandler),
]
)
app.listen(environment.port)
logger.debug(f"Asynchronous inference server listening on port: `{environment.port}`")
await asyncio.Event().wait()
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import absolute_import

import asyncio
import importlib
import logging
import subprocess
import sys
from pathlib import Path

from utils.environment import Environment
from utils.exception import (
InferenceCodeLoadException,
RequirementsInstallException,
ServerStartException,
)
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class TornadoServer:
"""Holds serving logic using the Tornado framework.
The serve.py script will invoke TornadoServer.serve() to start the serving process.
The TornadoServer will install the runtime requirements specified through a requirements file.
It will then load an handler function within an inference script and then front it will an /invocations
route using the Tornado framework.
"""

def __init__(self):
"""Initialize the serving behaviors.
Defines the serving behavior through Environment() and locate where
the inference code is contained.
"""

self._environment = Environment()
logger.setLevel(self._environment.logging_level)
logger.debug(f"Environment: {str(self._environment)}")

self._path_to_inference_code = (
Path(self._environment.base_directory).joinpath(self._environment.code_directory)
if self._environment.code_directory
else Path(self._environment.base_directory)
)
logger.debug(f"Path to inference code: `{str(self._path_to_inference_code)}`")

def initialize(self):
"""Initialize the serving artifacts and dependencies.
Install the runtime requirements and then locate the handler function from
the inference script.
"""

logger.info("Initializing inference server...")
self._install_runtime_requirements()
self._handler = self._load_inference_handler()

def serve(self):
"""Orchestrate the initialization and server startup behavior.
Call the initalize() method, determine the right Tornado serving behavior (async or sync),
and then start the Tornado server through asyncio
"""

logger.info("Serving inference requests using Tornado...")
self.initialize()

if asyncio.iscoroutinefunction(self._handler):
import async_handler as inference_handler
else:
import sync_handler as inference_handler

try:
asyncio.run(inference_handler.handle(self._handler, self._environment))
except Exception as e:
raise ServerStartException(e)

def _install_runtime_requirements(self):
"""Install the runtime requirements."""

logger.info("Installing runtime requirements...")
requirements_txt = self._path_to_inference_code.joinpath(self._environment.requirements)
if requirements_txt.is_file():
try:
subprocess.check_call(["micromamba", "install", "--yes", "--file", str(requirements_txt)])
except Exception as e:
raise RequirementsInstallException(e)
else:
logger.debug(f"No requirements file was found at `{str(requirements_txt)}`")

def _load_inference_handler(self) -> callable:
"""Load the handler function from the inference script."""

logger.info("Loading inference handler...")
inference_module_name, handle_name = self._environment.code.split(".")
if inference_module_name and handle_name:
inference_module_file = f"{inference_module_name}.py"
module_spec = importlib.util.spec_from_file_location(
inference_module_file, str(self._path_to_inference_code.joinpath(inference_module_file))
)
if module_spec:
sys.path.insert(0, str(self._path_to_inference_code.resolve()))
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)

if hasattr(module, handle_name):
handler = getattr(module, handle_name)
else:
raise InferenceCodeLoadException(
f"Handler `{handle_name}` could not be found in module `{inference_module_file}`"
)
logger.debug(f"Loaded handler `{handle_name}` from module `{inference_module_name}`")
return handler
else:
raise InferenceCodeLoadException(
f"Inference code could not be found at `{str(self._path_to_inference_code.joinpath(inference_module_file))}`"
)
raise InferenceCodeLoadException(
f"Inference code expected in the format of `<module>.<handler>` but was provided as {self._environment.code}"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import absolute_import

import logging
from typing import AsyncIterator, Iterator

from tornado.ioloop import IOLoop

from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class StreamHandler:
"""Mixin that enables async and sync streaming capabilities to the async and sync handlers
stream() runs a provided iterator/generator fn in an async manner.
astream() runs a provided async iterator/generator fn in an async manner.
"""

async def stream(self, iterator: Iterator):
"""Streams the response from a sync response iterator
A sync iterator must be manually iterated through asynchronously.
In a loop, iterate through each next(iterator) call in an async execution.
"""

self._set_stream_headers()

while True:
try:
chunk = await IOLoop.current().run_in_executor(None, next, iterator)
# Some iterators do not throw a StopIteration upon exhaustion.
# Instead, they return an empty response. Account for this case.
if not chunk:
raise StopIteration()

self.write(chunk)
await self.flush()
except StopIteration:
break
except Exception as e:
logger.error("Unexpected exception occurred when streaming response...")
break

async def astream(self, aiterator: AsyncIterator):
"""Streams the response from an async response iterator"""

self._set_stream_headers()

async for chunk in aiterator:
self.write(chunk)
await self.flush()

def _set_stream_headers(self):
"""Set the headers in preparation for the streamed response"""

self.set_header("Content-Type", "text/event-stream")
self.set_header("Cache-Control", "no-cache")
self.set_header("Connection", "keep-alive")
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import absolute_import

import asyncio
import logging
from typing import AsyncIterator, Iterator

import tornado.web
from stream_handler import StreamHandler
from tornado.ioloop import IOLoop

from utils.environment import Environment
from utils.exception import SyncInvocationsException
from utils.logger import SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER

logger = logging.getLogger(SAGEMAKER_DISTRIBUTION_INFERENCE_LOGGER)


class InvocationsHandler(tornado.web.RequestHandler, StreamHandler):
"""Handler mapped to the /invocations POST route.
This handler wraps the sync handler retrieved from the inference script
and encapsulates it behind the post() method. The post() method is done
asynchronously.
"""

def initialize(self, handler: callable, environment: Environment):
"""Initializes the handler function and the serving environment."""

self._handler = handler
self._environment = environment

async def post(self):
"""POST method used to encapsulate and invoke the sync handle method asynchronously"""

try:
response = await IOLoop.current().run_in_executor(None, self._handler, self.request)

if isinstance(response, Iterator):
await self.stream(response)
elif isinstance(response, AsyncIterator):
await self.astream(response)
else:
self.write(response)
except Exception as e:
raise SyncInvocationsException(e)


class PingHandler(tornado.web.RequestHandler):
"""Handler mapped to the /ping GET route.
Ping handler to monitor the health of the Tornados server.
"""

def get(self):
"""Simple GET method to assess the health of the server."""

self.write("")


async def handle(handler: callable, environment: Environment):
"""Serves the sync handler function using Tornado.
Opens the /invocations and /ping routes used by a SageMaker Endpoint
for inference serving capabilities.
"""

logger.info("Starting inference server in synchronous mode...")

app = tornado.web.Application(
[
(r"/invocations", InvocationsHandler, dict(handler=handler, environment=environment)),
(r"/ping", PingHandler),
]
)
app.listen(environment.port)
logger.debug(f"Synchronous inference server listening on port: `{environment.port}`")
await asyncio.Event().wait()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __future__ import absolute_import
Loading

0 comments on commit 141031a

Please sign in to comment.