Skip to content

Commit

Permalink
Overhaul backend function execution for improved performance and flex…
Browse files Browse the repository at this point in the history
…ibility

This PR replaces the DPS-style calling convention with a non-DPS approach, eliminating the requirement for call sites to preallocate output buffers. This change enables us to bypass the computation of output shapes and advance allocation of output buffers, laying the groundwork for supporting data-dependent shapes where network outputs can have dynamic dimensions.

The underlying compiler stack has been enhanced to avoid allocating oversized buffers and eliminate an extra device-to-device copy operation from TensorRT-allocated memory to MLIR-TRT managed memory.

Additionally, we've improved the copy operation to support copying to host memory. This enhancement removes the need to track output device allocations for device-to-host copies. Previously, copy outputs were restricted to device allocations; now they can be allocated on both device and host.

Tests have been updated to align with the new calling convention, ensuring compatibility and correctness.
  • Loading branch information
jhalakpatel committed Nov 4, 2024
1 parent 3ac751b commit ceeeaf6
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 144 deletions.
2 changes: 1 addition & 1 deletion tripy/tests/backend/api/test_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_signature(self, single_return_executable):
assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
assert param.annotation == tp.Tensor

assert signature.return_annotation == tp.Tensor
assert signature.return_annotation == Sequence[tp.Tensor]

def test_signature_multiple_return_values(self, multiple_return_executable):
signature = inspect.signature(multiple_return_executable)
Expand Down
3 changes: 1 addition & 2 deletions tripy/tests/frontend/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def test_no_explicit_cast(self):
"devices",
[
("cpu", "gpu"),
# TODO(#155)
# ("gpu", "cpu"),
("gpu", "cpu"),
],
)
def test_explicit_copy(self, devices):
Expand Down
2 changes: 1 addition & 1 deletion tripy/tests/integration/test_iota.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_negative_no_casting(self, dtype):
a = tp.ones((2, 2))
out = Iota.build([frontend_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype)

exception_str = "error: 'tensorrt.linspace' op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values"
exception_str = "InternalError: failed to run compilation on module with symbol name."
if dtype == tp.bool:
exception_str = "InternalError: failed to run compilation"
with helper.raises(
Expand Down
3 changes: 2 additions & 1 deletion tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,6 @@ def test_non_constant_scale(self):
input = tp.ones((4, 4))
scale = tp.ones((4,))
quantized = tp.quantize(input, scale, tp.int8, dim=0)
quantized_int32 = tp.cast(quantized, tp.int32)

assert bool(tp.all(quantized == tp.ones((4, 4), dtype=tp.int8)))
assert bool(tp.all(quantized_int32 == tp.ones((4, 4), dtype=tp.int32)))
1 change: 0 additions & 1 deletion tripy/tripy/backend/api/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,4 @@ def process_arg(name, arg):
return Executable(
executable,
compiled_arg_names,
output_devices=[out.device for out in trace.outputs],
)
50 changes: 29 additions & 21 deletions tripy/tripy/backend/api/executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import base64
import inspect
from typing import Sequence, Union
from typing import Sequence, Union, Tuple, Callable

import mlir_tensorrt.runtime.api as runtime

Expand All @@ -37,13 +37,11 @@ class Executable:
"""

# The constructor is intentionally undocumented because it is not meant to be called by users.
# TODO(#155): output_devices is not needed after they can be queried from executable
def __init__(self, executable, arg_names, output_devices):
def __init__(self, executable, arg_names):
self._executable = executable
self._executor = Executor(self._executable)
self._arg_names = arg_names
self._num_expected_args = len(arg_names)
self._output_devices = output_devices
self._executable_signature = self._executable.get_signature("main")

# Build a signature so the executable works with `inspect.signature`
Expand Down Expand Up @@ -128,7 +126,7 @@ def add(a, b):
tensor.eval()

try:
executor_outputs = self._executor.execute(self._output_devices, input_tensors)
executor_outputs = self._executor.execute(input_tensors)
except runtime.MTRTException as err:
# TODO: Evaluate whether this should be moved into the executor
if "function expects a memref type with element type" in str(err):
Expand Down Expand Up @@ -170,15 +168,22 @@ def add(a, b):
output_tensors = output_tensors[0]
return output_tensors

def _get_arg_info(self, idx):
arg = self._executable_signature.get_arg(idx)
arg = runtime.MemRefType(arg)
arg_bound = self._executable_signature.get_arg_bound(idx)
shape_bounds = tuple(zip(arg_bound.min(), arg_bound.max()))
if len(shape_bounds) == 0:
# For static shape arguments, get_arg_bound returns an empty list and we fallback to arg.shape
shape_bounds = tuple((x, x) for x in arg.shape)
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))
def _get_info(self, idx: int, get_item: Callable, get_bound: Callable) -> ArgInfo:
item = runtime.MemRefType(get_item(idx))
bound = get_bound(idx)
shape_bounds = tuple(zip(bound.min(), bound.max()))

if not shape_bounds:
# For static shape, fallback to item.shape
shape_bounds = tuple((x, x) for x in item.shape)

return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(item.dtype))

def _get_arg_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_arg, self._executable_signature.get_arg_bound)

def _get_result_info(self, idx: int) -> ArgInfo:
return self._get_info(idx, self._executable_signature.get_result, self._executable_signature.get_res_bound)

def get_input_info(self) -> Sequence[ArgInfo]:
"""
Expand Down Expand Up @@ -221,11 +226,16 @@ def add(a, b):
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
print(compiled_add.get_output_info())
"""
output_info = []
offset = self._executable_signature.get_num_input_args()
for idx in range(self._executable_signature.get_num_output_args()):
output_info.append(self._get_arg_info(idx + offset))
return output_info
num_input_args = self._executable_signature.get_num_input_args()
num_output_args = self._executable_signature.get_num_output_args()
num_results = self._executable_signature.get_num_results()

assert not (num_output_args and num_results), "Cannot have both output arguments and results"

if num_output_args:
return [self._get_arg_info(idx + num_input_args) for idx in range(num_output_args)]
else:
return [self._get_result_info(idx) for idx in range(num_results)]

def save(self, path: str) -> None:
"""
Expand Down Expand Up @@ -289,7 +299,6 @@ def add(a, b):
def encode_executable(executable):
return {
"arg_names": executable._arg_names,
"output_devices": executable._output_devices,
"executable": base64.b64encode(executable._executable.serialize()).decode(),
}

Expand All @@ -300,5 +309,4 @@ def decode_executable(executable_dict):
return Executable(
runtime.Executable(executable_bytes),
executable_dict["arg_names"],
executable_dict["output_devices"],
)
1 change: 1 addition & 0 deletions tripy/tripy/backend/mlir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _make_mlir_opts(self, trt_builder_opt_level):
f"--tensorrt-timing-cache-path={G_TIMING_CACHE_FILE}",
f"--tensorrt-builder-opt-level={trt_builder_opt_level}",
"--tensorrt-strongly-typed=True",
"--enable-non-dps-returns",
]
if config.enable_mlir_debug or config.enable_tensorrt_debug:
opts.append("--debug=true")
Expand Down
117 changes: 3 additions & 114 deletions tripy/tripy/backend/mlir/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,89 +31,14 @@

class Executor:
def __init__(self, executable: runtime.Executable) -> None:

self.runtime_client = MLIRRuntimeClient()
session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0)
self.session = runtime.RuntimeSession(session_options, executable)
self.device = self.runtime_client.get_devices()[0] # Assume a single device is available.
self.signature = executable.get_signature("main")
self.stream = default_stream()
self.num_input_args = self.signature.get_num_input_args()
self.num_output_args = self.signature.get_num_output_args()
self.output_args = [
self.signature.get_arg(index + self.num_input_args) for index in range(self.num_output_args)
]
self.output_memrefs = [runtime.MemRefType(out) for out in self.output_args]

def _create_shape_memref(self, shape):
shape = make_tuple(shape)
if len(shape) == 0:
return create_memref(
shape=(0,),
dtype=datatype.int64,
device=device("cpu"),
)
return create_memref(
array=convert_list_to_array(shape, datatype.int64),
shape=(len(shape),),
dtype=datatype.int64,
device=device("cpu"),
)

def _get_outputs_shape(self):
outputs_shape = []
all_outputs_known = True
for memref in self.output_memrefs:
outputs_shape.append(memref.shape)
all_outputs_known &= all(dim >= 0 for dim in memref.shape)
return outputs_shape, all_outputs_known

def _get_inputs_runtime_shape(self, inputs):
inputs_shape = []
for input in inputs:
inputs_shape.append(input.trace_tensor.producer.data.shape)
return inputs_shape

def _execute_shape_inference(self, inputs_shape, outputs_shape):
inputs_shape_memref = [self._create_shape_memref(inp_shape) for inp_shape in inputs_shape]
outputs_shape_memref = [self._create_shape_memref(out_shape) for out_shape in outputs_shape]
self.session.execute_function(
name=self.signature.get_shape_func_name(), in_args=inputs_shape_memref, out_args=outputs_shape_memref
)

outputs_runtime_shape = [memoryview(s).tolist() for s in outputs_shape_memref]
return outputs_runtime_shape

def _get_output_tensor_info(self, outputs_runtime_shape, output_devices):
outputs_tensor_info = []
for index in range(self.num_output_args):
memref = self.output_memrefs[index]
dtype = convert_runtime_dtype_to_tripy_dtype(memref.dtype)

output_device = output_devices[index]
if not output_device:
output_device = device(("gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0))

runtime_shape = [rs if dim < 0 else dim for dim, rs in zip(memref.shape, outputs_runtime_shape[index])]
outputs_tensor_info.append(
TensorInfo(
len(runtime_shape),
tuple(runtime_shape),
dtype,
output_device,
)
)
return outputs_tensor_info

def get_output_tensor_runtime_info(self, inputs, output_devices=List[device]):
outputs_shape, all_outputs_known = self._get_outputs_shape()
if not all_outputs_known:
inputs_shape = self._get_inputs_runtime_shape(inputs)
outputs_shape = self._execute_shape_inference(inputs_shape, outputs_shape)
output_tensor_info = self._get_output_tensor_info(outputs_shape, output_devices)
return output_tensor_info

def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
def execute(self, inputs: List["Tensor"] = []) -> List[runtime.MemRefValue]:
in_args = []
for inp in inputs:
memref = inp.trace_tensor.producer.data
Expand All @@ -131,45 +56,9 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) ->
)
in_args.append(memref)

# HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
out_tensor_info = self.get_output_tensor_runtime_info(inputs, output_devices)

# Allocate output memory and store buffer pointers.
outputs = [
create_memref(
shape=info.shape, dtype=info.dtype, device=info.device, stream=self.stream._active_cuda_stream
)
for info in out_tensor_info
]

out_args = []
for out in outputs:
memref = out
# HACK (#155): MLIR-TensorRT requires inputs to be on device.
# Remove explicit copy to device once #155 is addressed.
if memref.address_space != runtime.PointerType.device:
memref = self.runtime_client.copy_to_device(
host_memref=memref,
device=self.runtime_client.get_devices()[0],
stream=self.stream._active_cuda_stream,
)
if not memref:
raise_error("Could not allocate output memref", details=memref.error_details)
out_args.append(memref)

# Execute and populate device pointers.
self.session.execute_function(
"main", in_args=in_args, out_args=out_args, stream=self.stream._active_cuda_stream
outputs = self.session.execute_function(
"main", in_args=in_args, stream=self.stream._active_cuda_stream, client=self.runtime_client
)

# For outputs that were on the host, do the copy back
# TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
for idx, out_info in enumerate(out_tensor_info):
if out_info.device.kind != "gpu":
self.runtime_client.copy_to_host(
device_memref=out_args[idx],
existing_host_memref=outputs[idx],
stream=self.stream._active_cuda_stream,
)

return outputs
9 changes: 9 additions & 0 deletions tripy/tripy/flat_ir/ops/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class CopyOp(BaseFlatIROp):

target: tripy.common.device

def set_memory_space_attr(self, tensor, mem_space_attr):
current_type = tensor.type
# Set the encoding attribute on the operation's result
new_type = ir.RankedTensorType.get(current_type.shape, current_type.element_type, encoding=mem_space_attr)
tensor.set_type(new_type)

def to_mlir(self, operands):
from mlir_tensorrt.compiler.dialects import bufferization, tensor, arith

Expand All @@ -46,7 +52,10 @@ def to_mlir(self, operands):
sliced_dims.append(dim)

alloc_tensor = bufferization.alloc_tensor(inp_type, sliced_dims, memory_space=mem_space_attr)
self.set_memory_space_attr(alloc_tensor, mem_space_attr)
result_tensor = bufferization.materialize_in_destination(inp_type, operands[0], alloc_tensor)
self.set_memory_space_attr(result_tensor, mem_space_attr)
cast_tensor = tensor.cast(self.outputs[0].to_mlir(), result_tensor)
self.set_memory_space_attr(cast_tensor, mem_space_attr)

return [cast_tensor]
6 changes: 3 additions & 3 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def eval(self) -> runtime.MemRefValue:

compiler = Compiler(trt_builder_opt_level=0)
executable = compiler.compile(mlir, flat_ir=flat_ir)
executor = Executor(executable)
self.executor = Executor(executable)
# Upon computing the value of this tensor, we switch it to have a `Storage`
# parameter so that it does not need to be computed again.
data = executor.execute([out.device for out in flat_ir.outputs])
executor.stream.synchronize()
data = self.executor.execute()
self.executor.stream.synchronize()
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
data = data[0]

Expand Down

0 comments on commit ceeeaf6

Please sign in to comment.