From 9760adc1578473507cdd91e095ed5e6f76b11b51 Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Mon, 30 Sep 2024 09:56:10 -0700 Subject: [PATCH] Add initial IR for alloc enqueue --- .../mlir-tensorrt/Conversion/Passes.td | 1 + .../lib/Compiler/StableHloToExecutable.cpp | 10 +- .../include/mlir-executor-c/Runtime/Runtime.h | 16 ++ .../Runtime/Backend/Lua/LuaRuntime.h | 7 + .../executor/lib/CAPI/Runtime/Runtime.cpp | 59 +++++- .../lib/Runtime/Backend/Lua/LuaRuntime.cpp | 168 +++++++++++++++++- .../Lua/Modules/TensorRT/TensorRTModule.cpp | 82 ++++++++- .../python/bindings/Runtime/RuntimePyBind.cpp | 54 +++++- .../ClusteringDynamicShape/alloc_enqueue.mlir | 13 ++ .../test_stablehlo_alloc_enqueue.py | 124 +++++++++++++ 10 files changed, 521 insertions(+), 13 deletions(-) create mode 100644 mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/alloc_enqueue.mlir create mode 100644 mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_alloc_enqueue.py diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td index 1eb2f538a..f56fb2f61 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td @@ -192,6 +192,7 @@ def ConvertCUDAToExecutorPass : Pass<"convert-cuda-to-executor", //===----------------------------------------------------------------------===// // ConvertTensorRTRuntimeToExecutorPass //===----------------------------------------------------------------------===// +// TODO: Modify this pass to generate non-DPS stype enqueue functions. def ConvertTensorRTRuntimeToExecutorPass : Pass<"convert-tensorrt-runtime-to-executor", "::mlir::ModuleOp"> { let summary = "Converts TensorRTRuntime dialect ops to executor dialect operations"; diff --git a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp index 1ed9d8bdd..ce68e4a3d 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp @@ -487,11 +487,13 @@ StableHloToExecutableTask::compileStableHLOToExecutable( runner = pm.get(); } + runner->printAsTextualPipeline(llvm::dbgs()); + // Setup pass manager - if (failed(runner->run(module))) - return getInternalErrorStatus( - "failed to run compilation on module with symbol name: {0}", - module.getName() ? *module.getName() : "no-symbol-name"); + // if (failed(runner->run(module))) + // return getInternalErrorStatus( + // "failed to run compilation on module with symbol name: {0}", + // module.getName() ? *module.getName() : "no-symbol-name"); // Translate to Runtime Executable FailureOr> exeStorage = diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index 345412aee..a72f1c3b0 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -289,6 +289,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) { return !value.ptr; } +// Returns whether the RuntimeValue is MemRef. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value); + +// Returns whether the RuntimeValue is Scalar. +MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value); + /// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue. MLIR_CAPI_EXPORTED MTRT_RuntimeValue mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref); @@ -383,6 +389,16 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction( const MTRT_RuntimeValue *inArgs, size_t numInArgs, const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream); +/// Variant of above function which return results. +MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunctionWithResult( + MTRT_RuntimeSession session, MTRT_RuntimeClient client, + MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, + MTRT_RuntimeValue *resultArgs, size_t numResultArgs, MTRT_Stream stream); + +MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNbResults(MTRT_RuntimeSession session, + MTRT_StringView name, + int64_t *numResults); + //===----------------------------------------------------------------------===// // DLPack //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h index c953172b9..86dc3112e 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Runtime/Backend/Lua/LuaRuntime.h @@ -100,6 +100,13 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name, llvm::ArrayRef outputArgs, std::optional stream = {}); +/// Execute a named function in the session with the specified input args and return results. +StatusOr>> +executeFunctionWithResultWithLuaBackend( + LuaRuntimeSession &session, RuntimeClient &client, std::string_view name, + llvm::ArrayRef inputArgs, + std::optional stream = {}); + } // namespace mlirtrt::runtime #endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H diff --git a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp index 8b4e208e8..27bbbcb15 100644 --- a/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp +++ b/mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp @@ -641,6 +641,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) { return wrap(static_cast(x)); } +bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::MemRef; +} + +bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) { + RuntimeValue *x = unwrap(value); + return x->getKind() == RuntimeValue::Kind::Scalar; +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeSessionOptions //===----------------------------------------------------------------------===// @@ -697,7 +707,6 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction( llvm::SmallVector outArgValues = llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs), [](MTRT_RuntimeValue arg) { return unwrap(arg); }); - StatusOr>> result = executeFunctionWithLuaBackend( *cppSession, std::string_view(name.data, name.length), inArgValues, @@ -705,11 +714,57 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction( !mtrtStreamIsNull(stream) ? std::optional(unwrap(stream)->getRawStream()) : std::nullopt); - if (!result.isOk()) + if (!result.isOk()) { return wrap(result.getStatus()); + } + return mtrtStatusGetOk(); +} + +MTRT_Status mtrtRuntimeSessionExecuteFunctionWithResult( + MTRT_RuntimeSession session, MTRT_RuntimeClient client, + MTRT_StringView name, const MTRT_RuntimeValue *inArgs, size_t numInArgs, + MTRT_RuntimeValue *resultArgs, size_t numResultArgs, + MTRT_Stream stream) { + LuaRuntimeSession *cppSession = + static_cast(unwrap(session)); + + RuntimeClient *cppClient = unwrap(client); + + llvm::SmallVector inArgValues = + llvm::map_to_vector(llvm::ArrayRef(inArgs, numInArgs), + [](MTRT_RuntimeValue arg) { return unwrap(arg); }); + StatusOr>> results = + executeFunctionWithResultWithLuaBackend( + *cppSession, *cppClient, std::string_view(name.data, name.length), + inArgValues, + !mtrtStreamIsNull(stream) + ? std::optional(unwrap(stream)->getRawStream()) + : std::nullopt); + if (!results.isOk()) { + return wrap(results.getStatus()); + } + + assert(results->size() == numResultArgs); + + for (size_t i = 0; i < numResultArgs; ++i) { + resultArgs[i] = wrap((*results)[i].release()); + } return mtrtStatusGetOk(); } + +MTRT_Status mtrtRuntimeSessionGetNbResults(MTRT_RuntimeSession session, + MTRT_StringView name, + int64_t *numResults) { + LuaRuntimeSession *cppSession = + static_cast(unwrap(session)); + *numResults = cppSession->getExecutable() + .getFunction(std::string_view(name.data, name.length)) + .getSignature() + .getNumResults(); + return mtrtStatusGetOk(); +} + //===----------------------------------------------------------------------===// // MTRT_RuntimeClient //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp index 33c15b35c..6bed36c42 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/LuaRuntime.cpp @@ -423,6 +423,83 @@ static Status pushScalarArgument(sol::state_view &lua, return getOkStatus(); } +// Function to extract shape and stride from sol::object table +std::tuple, std::vector> +extractShapeAndStride(const sol::table &table) { + size_t tableSize = table.size(); + assert(tableSize >= 3 && + "Table does not contain shape and stride information"); + size_t shapeStrideSize = (tableSize - 3) / 2; + std::vector shape; + std::vector stride; + + shape.reserve(shapeStrideSize); + stride.reserve(shapeStrideSize); + + // Extract shape + for (size_t i = 4; i <= 3 + shapeStrideSize; ++i) { + shape.push_back(table[i].get()); + } + + // Extract stride + for (size_t i = 4 + shapeStrideSize; i <= tableSize; ++i) { + stride.push_back(table[i].get()); + } + + return std::make_tuple(shape, stride); +} + +// Convert sol::object to MemRefValue +StatusOr> +solObjectToMemRefValue(RuntimeClient *client, const sol::object &obj) { + assert(obj.is() && "Expected a table for MemRefValue"); + + sol::table memrefTable = obj.as(); + uintptr_t ptr = memrefTable[1].get(); + int64_t offset = memrefTable[3].get(); + + auto [shape, strides] = extractShapeAndStride(memrefTable); + + // TODO: How to extract this information. Should we use function signature to fill in this information later? + mlirtrt::runtime::PointerType addressSpace = + mlirtrt::runtime::PointerType::device; + int64_t bitsPerElement = 32; + std::optional device = + std::nullopt; + std::optional scalarType = ScalarTypeCode::f32; + + return MemRefValue::create(client, addressSpace, bitsPerElement, ptr, offset, + llvm::ArrayRef(shape), + llvm::ArrayRef(strides), device, + scalarType); +} + +// Convert sol::object to ScalarValue +std::unique_ptr solObjectToScalarValue(const sol::object &obj) { + + // TODO: ScalarType is not known. Should we use function signature to fill in + // this information later? Since ScalarValue data type is int64_t. Let's cast + // the object value to int64_t for now. + return std::make_unique(obj.as(), ScalarTypeCode::unknown); +} + +// Convert sol::object to RuntimeValue's +llvm::SmallVector> +solObjectToRuntimeValues(RuntimeClient *client, + std::vector const &results) { + llvm::SmallVector> values; + for (sol::object r : results) { + // if (r.is()) { + // Assume it's a MemRefValue if it's a table + values.emplace_back(std::move(*solObjectToMemRefValue(client, r))); + // } else { + // // Assume it's a ScalarValue for all other cases + // values.emplace_back(solObjectToScalarValue(r)); + // } + } + return values; +} + static Status validateArgsTypesAgainstFuncArgs(const RuntimeValue *runArg, const TypeUnionView &sigArg) { if (sigArg.isa()) { @@ -520,11 +597,11 @@ runtime::executeFunctionWithLuaBackend( return getStatusWithMsg(StatusCode::InternalError, "no function named \"", std::string(name), "\" found"); - if (sig.getNumResults() > 0) - return getInvalidArgStatus("functions with {0} results are not supported", - sig.getNumResults()); - // Validate the number of arguments against the signature. + if (sig.getNumResults() != 0) + return getInvalidArgStatus( + "function expects 0 result args but received {0}", + sig.getNumResults()); if (sig.getNumOutputArgs() != outputArgs.size()) return getInvalidArgStatus( "function expects {0} output args (destination args) but received {1}", @@ -600,3 +677,86 @@ runtime::executeFunctionWithLuaBackend( return llvm::SmallVector>{}; } + +StatusOr>> +runtime::executeFunctionWithResultWithLuaBackend( + LuaRuntimeSession &session, + RuntimeClient &client, + std::string_view name, + llvm::ArrayRef inputArgs, + std::optional stream) { + + FunctionView meta = session.getExecutable().getFunction(name); + FunctionSignatureView sig = meta.getSignature(); + + // Call the main function, if present. + sol::state &lua = session.getLuaState(); + AllocTracker &tracker = session.getAllocTracker(); + sol::protected_function funcObj = lua[name]; + if (funcObj.get_type() != sol::type::function) + return getStatusWithMsg(StatusCode::InternalError, "no function named \"", + std::string(name), "\" found"); + + // Validate the number of arguments against the signature. + if (sig.getNumOutputArgs() != 0) + return getInvalidArgStatus( + "function expects 0 output args (destination args) but received {0}", + sig.getNumOutputArgs()); + if (sig.getNumInputArgs() != inputArgs.size()) + return getInvalidArgStatus("function expects {0} input args " + "(non-destination args) but received {1}", + sig.getNumInputArgs(), inputArgs.size()); + + // Validate the inferred Lua function type here against the signature. + for (unsigned i = 0; i < inputArgs.size(); ++i) { + auto status = validateArgsTypesAgainstFuncArgs(inputArgs[i], sig.getArg(i)); + if (!status.isOk()) + return getInvalidArgStatus( + "Input argument {0} validation failed against " + "corresponding function signature arg {0}. Reason: {1}", + i, status.getString()); + } + + // Create the arguments. + llvm::SmallVector args; + args.reserve(inputArgs.size()); + for (auto [idx, rv] : llvm::enumerate(inputArgs)) { + if (MemRefValue *memref = llvm::dyn_cast(rv)) { + MTRT_RETURN_IF_ERROR(pushMemRefTableArg(lua, tracker, args, *memref)); + continue; + } + if (ScalarValue *scalar = llvm::dyn_cast(rv)) { + MTRT_RETURN_IF_ERROR(pushScalarArgument(lua, args, *scalar)); + continue; + } + return getInvalidArgStatus( + "input argument #{0} to function {1} has an unsupported type; " + "arguments must be either MemRefs or scalars", + idx + 1, name); + } + if (stream) + RETURN_STATUS_IF_ERROR(session.setCudaStream(*stream)); + + // If the number of arguments exceed a particular threshold, then + // we pass arguments packed into a table, otherwise we pass as arguments. + sol::protected_function_result result = + sig.getCConv() == CallingConvention::unpacked + ? funcObj(sol::as_args(args)) + : funcObj(args); + + if (!result.valid()) { + sol::error err(result); + return getStatusWithMsg(StatusCode::InternalError, + "failed to run function \"", std::string(name), + "\": ", err.what()); + } + + int returnCount = result.return_count(); + std::vector results; + // Lua index start from 1 + for (int i = 1; i <= returnCount; ++i) { + results.push_back(result[i]); + } + + return solObjectToRuntimeValues(&client, results); +} diff --git a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp index db196043b..5e6eb1827 100644 --- a/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp +++ b/mlir-tensorrt/executor/lib/Runtime/Backend/Lua/Modules/TensorRT/TensorRTModule.cpp @@ -257,7 +257,8 @@ prepareBuffers(const AllocTracker &allocTracker, unsigned argumentBuffersIdx = 1; // The number of arguments should be equal to the number of results plus the // number of arguments of the TensorRT engine's functional signature. - const unsigned numOperands = sig.numResults + sig.numArguments; + const unsigned numOperands = sig.numArguments; + result.reserve(va.size() / 3); std::vector &hostBuffers = context.getHostIOBuffers(); unsigned hostBufferIdx = 0; @@ -279,6 +280,10 @@ prepareBuffers(const AllocTracker &allocTracker, for (int64_t dimIdx = 0; dimIdx < rank; dimIdx++) dims.d[dimIdx] = va.get(argumentBuffersIdx++); + // Increment rank times to account for strides: This is a hack + for (int64_t dimIdx = 0; dimIdx < rank; dimIdx++) + argumentBuffersIdx++; + uintptr_t pointer = buffer.ptr + offset; MTRT_DBGF("enqueue arg %u ptr=0x%lx offset=%ld", i, buffer.ptr, offset); @@ -339,10 +344,48 @@ prepareBuffers(const AllocTracker &allocTracker, return result; } +class OutputAllocator : public nvinfer1::IOutputAllocator { +public: + OutputAllocator() = default; + ~OutputAllocator() = default; + + void *reallocateOutput(char const *tensorName, void *currentMemory, + uint64_t size, uint64_t alignment) noexcept override { + // Some memory allocators return nullptr when allocating zero bytes, but + // TensorRT requires a non-null ptr even for empty tensors, so allocate a + // dummy byte. + // Fall-back to local memory management. + void* buffer{nullptr}; + cudaMalloc(&buffer, size); + return buffer; + } + + //! IMirroredBuffer does not implement Async allocation, hence this is just a + //! wrap around + void *reallocateOutputAsync(char const *tensorName, void *currentMemory, + uint64_t size, uint64_t alignment, + cudaStream_t /*stream*/) noexcept override { + return reallocateOutput(tensorName, currentMemory, size, alignment); + } + + void notifyShape(char const *tensorName, + nvinfer1::Dims const &dims) noexcept override { + mFinalDims = dims; + } + + nvinfer1::Dims getFinalDims() { return mFinalDims; } + + void* getMemory() { return memory.get(); } + +private: + std::unique_ptr memory; + nvinfer1::Dims mFinalDims; +}; + static Status enqueueV3Wrapper(AllocTracker &tracker, ResourceTracker &resourceTracker, NvInferExecContextWrapper &context, - CudaStreamPtr stream, sol::table &va) { + CudaStreamPtr stream, sol::table &va, std::optional outputDescriptors = 0) { StatusOr>> buffers = prepareBuffers(tracker, context, stream, va); if (!buffers.isOk()) @@ -360,6 +403,13 @@ static Status enqueueV3Wrapper(AllocTracker &tracker, return getStatusWithMsg(StatusCode::InternalError, "failed to set input-consumed event"); + std::unique_ptr allocator(new OutputAllocator()); + if (outputDescriptors) { + // Register an output allocator. `enqueueV3` callback should set output + // pointer, and notify shapes. + context->setOutputAllocator("result0", allocator.get()); + } + if (!context->enqueueV3(stream)) return getStatusWithMsg(StatusCode::InternalError, "failed to enqueue engine execution on stream"); @@ -369,6 +419,18 @@ static Status enqueueV3Wrapper(AllocTracker &tracker, MTRT_DBGF("%s", "enqueueV3 successful and inputs are consumed"); + int64_t* desc = reinterpret_cast(*outputDescriptors); + + if (outputDescriptors) { + // Store following: number of results, [rank, ptr, [shape, ...], [stride, ...]]... + // For now assume only one result and just copy input pointer to output + desc[0] = 1; // Number of result is 1 + desc[1] = std::get<2>((*buffers)[0]).nbDims; // Copy input rank + desc[2] = std::get<1>((*buffers)[0]); // Copy input pointer + desc[3] = std::get<2>((*buffers)[0]).d[0]; // Copy input shape + desc[4] = 1; // Use stride 1 + } + return getOkStatus(); } @@ -429,4 +491,20 @@ void mlirtrt::runtime::registerExecutorTensorRTModuleLuaRuntimeMethods( *context, stream, va); SET_LUA_ERROR_IF_ERROR(result, state); }; + + lua["_trtrt_alloc_enqueue"] = + [allocTracker, + resourceTracker](sol::this_state state, + std::shared_ptr context, + CudaStreamPtr stream, uintptr_t outputDescriptors, sol::table va) { + ADD_TENSORRT_MODULE_RANGE("trtrt_alloc_enqueue"); + sol::state_view luaState(state); + assert(context != nullptr); + assert(outputDescriptors != 0); + assert(stream != nullptr && "expected valid stream"); + Status result = enqueueV3Wrapper(*allocTracker, *resourceTracker, + *context, stream, va, outputDescriptors); + SET_LUA_ERROR_IF_ERROR(result, state); + }; + } diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index 200a7ebda..82655c059 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -567,6 +567,15 @@ static MTRT_RuntimeValue convertArgType(py::object obj) { throw std::runtime_error("argument must be MemRef or scalar"); } +/// Convert Runtime value to PyMemRefValue or PyScalarValue object. +static py::object convertGenericArgToPyObject(MTRT_RuntimeValue value) { + if (mtrtRuntimeValueIsMemRef(value)) + return py::cast(mtrtRuntimeValueDynCastToMemRef(value)); + if (mtrtRuntimeValueIsScalar(value)) + return py::cast(mtrtRuntimeValueDynCastToScalar(value)); + return py::none(); +} + //===----------------------------------------------------------------------===// // Declare the bindings. //===----------------------------------------------------------------------===// @@ -927,5 +936,48 @@ PYBIND11_MODULE(_api, m) { THROW_IF_MTRT_ERROR(s); }, py::arg("name"), py::arg("in_args"), py::arg("out_args"), - py::arg("stream") = py::none()); + py::arg("stream") = py::none()) + .def( + "execute_function_with_result", + [](PyRuntimeSession &self, PyRuntimeClient &client, std::string name, + std::vector inArgs, + std::optional stream) -> py::object { + MTRT_StringView nameRef{name.data(), name.size()}; + + auto inArgsGeneric = llvm::map_to_vector(inArgs, convertArgType); + + // Query the function metadata to get the number of output arguments + int64_t numResults; + MTRT_Status s = mtrtRuntimeSessionGetNbResults(self, nameRef, &numResults); + THROW_IF_MTRT_ERROR(s); + + // Prepare a vector to hold output arguments + std::vector resultArgs(numResults); + + // Execute the function, letting it populate outArgsGeneric + s = mtrtRuntimeSessionExecuteFunctionWithResult( + self, client, nameRef, inArgsGeneric.data(), + inArgsGeneric.size(), resultArgs.data(), resultArgs.size(), + stream ? *stream : mtrtStreamGetNull()); + THROW_IF_MTRT_ERROR(s); + + // Convert the output arguments to Python objects + std::vector outArgs; + for (const auto& arg : resultArgs) { + outArgs.push_back(convertGenericArgToPyObject(arg)); + } + + // Process the results + if (outArgs.empty()) { + return py::none(); + } else if (outArgs.size() == 1) { + return outArgs[0]; + } else { + return outArgs[0]; + } + }, + py::arg("client"), py::arg("name"), py::arg("in_args"), py::arg("stream") = py::none(), + "Execute a function and return the result as a Python object"); + + } diff --git a/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/alloc_enqueue.mlir b/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/alloc_enqueue.mlir new file mode 100644 index 000000000..b8f19afa6 --- /dev/null +++ b/mlir-tensorrt/test/Target/Lua/IntegrationTests/ClusteringDynamicShape/alloc_enqueue.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-tensorrt-opt %s \ +// RUN: -pass-pipeline="builtin.module(inline{default-pipeline=canonicalize inlining-threshold=4294967295 max-iterations=4 },stablehlo-ext-lower-special-custom-calls,stablehlo-ext-expand-tuples{entry-function-name=main},stablehlo-ext-canonicalize-shapes{max-iterations=4},stablehlo-raise-qdq,stablehlo-ext-constant-folding,stablehlo-gather-to-slice,stablehlo-ext-canonicalize-shapes{max-iterations=4},stablehlo-canonicalize-dot-general,stablehlo-ext-constant-folding,stablehlo-ext-canonicalize-shapes{max-iterations=4},func.func(stablehlo-ext-canonicalize-scatter),func.func(stablehlo-ext-canonicalize-gather),stablehlo-ext-constant-folding,stablehlo-ext-canonicalize-shapes{max-iterations=4},cse,func.func(tensorrt-stablehlo-input-preprocessing),cse,stablehlo-ext-constant-folding,stablehlo-ext-canonicalize-shapes{max-iterations=4},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},convert-stablehlo-to-scf,func.func(tensorrt-infer-plugin-shapes),func.func(plan-materialize-shape-calculations),func.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}),plan-refine-types,plan-create-shape-funcs,func.func(plan-populate-func-bounds-attrs),stablehlo-clustering{disable-create-shape-func-pass=false disallow-host-tensors-in-tensorrt-clusters=false entrypoint=main},plan-create-closed-regions{test-pre-walk-order=false},plan-outline-clusters,func-ext-duplicate-function-elimination,plan-eliminate-shape-ops,func.func(stablehlo-to-arith-pipeline),func.func(stablehlo-to-std),tensorrt.module(convert-stablehlo-to-tensorrt{allow-i64-to-i32-conversion=false convert-conditionals=true convert-loops=false}),convert-tensorrt-to-runtime,func.func(post-clustering-validation),canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},inline{default-pipeline=canonicalize inlining-threshold=4294967295 max-iterations=4 },func.func(cse),func.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}),func.func(scf-detensorize-loops),func.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}),tensorrt.module(tensorrt-broadcast-elimination,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse,tensorrt-transpose-elimination,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse,tensorrt-reshape-elimination,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse,tensorrt-raise-normalizations,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse,tensorrt-apply-bug-wars{tensorrt-strongly-typed=false tensorrt-version=8.5},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse,tensorrt-expand-ops,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse,tensorrt-expand-ops,func.func(tensorrt-legalize-int8),translate-tensorrt-to-engine),memref-cast-elimination,plan-alloc-tensors,plan-bufferize,memref-cast-elimination,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},drop-equivalent-buffer-results,func.func(buffer-loop-hoisting),func.func(buffer-hoisting),expand-realloc{emit-deallocs=false},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},ownership-based-buffer-deallocation{private-function-dynamic-ownership=false},canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},buffer-deallocation-simplification,bufferization-lower-deallocations,cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},func.func(canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}),convert-memref-to-cuda,convert-plan-to-executor,executor-allocs-to-globals,func.func(executor-populate-func-metadata),convert-tensorrt-runtime-to-executor{index-bitwidth=64 use-packed-memref-cconv=true},convert-cuda-to-executor{index-bitwidth=64 use-packed-memref-cconv=true},drop-nested-modules,convert-scf-to-cf,fold-memref-alias-ops,memref-expand,expand-strided-metadata,cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},convert-memref-to-executor{allow-unchecked-memref-cast-conversion=true index-bitwidth=64 use-packed-memref-cconv=true},cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},convert-std-to-executor{index-bitwidth=64 use-packed-memref-cconv=true},cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},executor-lower-globals,convert-executor-to-executor{index-bitwidth=64 use-packed-memref-cconv=true},executor-decompose-aggregate-loads-and-stores{target=lua},executor-expand-ops{lower-alloca=true lower-getoffset=true},cse,canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},executor-lower-to-runtime-builtins{index-bitwidth=64 target=lua use-packed-memref-cconv=true},executor-pack-arguments{max-arguments=100})" \ +// RUN: | mlir-tensorrt-translate -mlir-to-runtime-executable -allow-unregistered-dialect | mlir-tensorrt-runner -input-type=rtexe + +func.func public @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // Add operation + %result = stablehlo.add %arg0, %arg1 : tensor<1xf32> + return %result: tensor<1xf32> +} + +// CHECK: result[0] = 2.000 + +// CHECK-NOT: result \ No newline at end of file diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_alloc_enqueue.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_alloc_enqueue.py new file mode 100644 index 000000000..c8f680738 --- /dev/null +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_alloc_enqueue.py @@ -0,0 +1,124 @@ +# RUN: %PYTHON %s +import time + +import mlir_tensorrt.compiler.api as compiler +import mlir_tensorrt.compiler.ir as ir +import mlir_tensorrt.runtime.api as runtime +import numpy as np + +ASM = """ +func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %1 = stablehlo.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %1 : tensor<1xf32> +} +""" + +EXECUTOR = """ +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry, #dlti.dl_entry, 64 : i64>, #dlti.dl_entry, 64 : i64>>, executor.global_init_func = @executor_init_globals, executor.process_grid_shape = array} { + executor.func private @_dealloc(...) + executor.func private @_inttoptr_i64_i64(i64) -> !executor.ptr + executor.func private @_load_i64(...) -> i64 + executor.func private @_store_i64(...) + executor.func private @executor_alloc(i64, i64) -> !executor.ptr + executor.func private @__cuda_stream_create() -> !executor.ptr + executor.global @stream0 constant : !executor.ptr + executor.func private @_trtrt_alloc_enqueue(!executor.opaque<"trtrt_context">, !executor.ptr, !executor.ptr, ...) + executor.func private @_trtrt_create_runtime() -> !executor.opaque<"trtrt_runtime"> + executor.func private @_trtrt_create_context(!executor.opaque<"trtrt_engine">) -> !executor.opaque<"trtrt_context"> + executor.func private @_trtrt_load(!executor.opaque<"trtrt_runtime">, !executor.ptr, i64) -> !executor.opaque<"trtrt_engine"> + executor.global @tensorrt_runtime : !executor.opaque<"trtrt_runtime"> + executor.constant_resource @tensorrt_cluster_engine_data dense<"tensor<16852xi8> + executor.global @tensorrt_cluster_exec_ctx constant : !executor.opaque<"trtrt_context"> + func.func @main(%arg0: !executor.table, !executor.ptr, i64, i64, i64>, %arg1: !executor.table, !executor.ptr, i64, i64, i64>) -> !executor.table, !executor.ptr, i64, i64, i64> attributes {executor.function_metadata = #executor.func_meta<[memref<1xf32, #executor.memory_type>, memref<1xf32, #executor.memory_type>], [memref<1xf32, #executor.memory_type>], num_output_args = 0>} { + %c0_i64 = executor.constant 0 : i64 + %c1_i64 = executor.constant 1 : i64 + %c2_i64 = executor.constant 2 : i64 + %c3_i64 = executor.constant 3 : i64 + %c4_i64 = executor.constant 4 : i64 + %c40_i64 = executor.constant 40 : i64 + %0 = executor.get_global @tensorrt_cluster_exec_ctx : !executor.opaque<"trtrt_context"> + %1 = executor.get_global @stream0 : !executor.ptr + %2 = executor.table.get %arg0[1] : , !executor.ptr, i64, i64, i64> + %3 = executor.table.get %arg1[1] : , !executor.ptr, i64, i64, i64> + %4 = executor.call @executor_alloc(%c40_i64, %c4_i64) : (i64, i64) -> !executor.ptr + executor.call @_store_i64(%4, %c1_i64, %c1_i64) : (!executor.ptr, i64, i64) -> () + %5 = executor.table.create(%2, %c0_i64, %c1_i64, %c1_i64, %c1_i64, %3, %c0_i64, %c1_i64, %c1_i64, %c1_i64 : !executor.ptr, i64, i64, i64, i64, !executor.ptr, i64, i64, i64, i64) : !executor.table, i64, i64, i64, i64, !executor.ptr, i64, i64, i64, i64> + executor.call @_trtrt_alloc_enqueue(%0, %1, %4, %5) : (!executor.opaque<"trtrt_context">, !executor.ptr, !executor.ptr, !executor.table, i64, i64, i64, i64, !executor.ptr, i64, i64, i64, i64>) -> () + %7 = executor.call @_load_i64(%4, %c2_i64) : (!executor.ptr, i64) -> i64 + %8 = executor.call @_inttoptr_i64_i64(%7) : (i64) -> !executor.ptr + %9 = executor.call @_load_i64(%4, %c3_i64) : (!executor.ptr, i64) -> i64 + %10 = executor.call @_load_i64(%4, %c4_i64) : (!executor.ptr, i64) -> i64 + %11 = executor.table.create(%8, %8, %c0_i64, %9, %10 : !executor.ptr, !executor.ptr, i64, i64, i64) : , !executor.ptr, i64, i64, i64> + executor.call @_dealloc(%4) : (!executor.ptr) -> () + return %11 : !executor.table, !executor.ptr, i64, i64, i64> + } + func.func private @executor_init_globals() { + %c16852_i64 = executor.constant 16852 : i64 + %0 = executor.call @__cuda_stream_create() : () -> !executor.ptr + executor.set_global %0, @stream0 : !executor.ptr + %1 = executor.call @_trtrt_create_runtime() : () -> !executor.opaque<"trtrt_runtime"> + executor.set_global %1, @tensorrt_runtime : !executor.opaque<"trtrt_runtime"> + %2 = executor.load_constant_resource @tensorrt_cluster_engine_data : !executor.ptr + %3 = executor.get_global @tensorrt_runtime : !executor.opaque<"trtrt_runtime"> + %4 = executor.call @_trtrt_load(%3, %2, %c16852_i64) : (!executor.opaque<"trtrt_runtime">, !executor.ptr, i64) -> !executor.opaque<"trtrt_engine"> + %5 = executor.call @_trtrt_create_context(%4) : (!executor.opaque<"trtrt_engine">) -> !executor.opaque<"trtrt_context"> + executor.set_global %5, @tensorrt_cluster_exec_ctx : !executor.opaque<"trtrt_context"> + return + } +} +""" + + +def stablehlo_add(): + # Build/parse the main function. + with ir.Context() as context: + m = ir.Module.parse(EXECUTOR) + + # Use the compiler API to compile to executable. + client = compiler.CompilerClient(context) + opts = compiler.StableHLOToExecutableOptions( + client, + ["--tensorrt-builder-opt-level=3", "--tensorrt-strongly-typed=false"], + ) + opts.set_debug_options(False, [], "alloc_enqueue") + exe = compiler.compiler_stablehlo_to_executable(client, m.operation, opts) + + # The RuntimeClient can and should persist across multiple Executables, RuntimeSessions, etc. + # It is primarily an interface for creating and manipulating buffers. + client = runtime.RuntimeClient() + stream = client.create_stream() + devices = client.get_devices() + + import pdb + + pdb.set_trace() + + if len(devices) == 0: + return + + session_options = runtime.RuntimeSessionOptions(num_devices=1, device_id=0) + session = runtime.RuntimeSession(session_options, exe) + + arg0 = client.create_memref( + np.array([1.0], dtype=np.float32).data, + device=devices[0], + stream=stream, + ) + arg1 = client.create_memref( + np.array([2.0], dtype=np.float32).data, + device=devices[0], + stream=stream, + ) + + result = session.execute_function_with_result( + client, "main", in_args=[arg0, arg1], stream=stream + ) + + data = np.asarray(client.copy_to_host(result, stream=stream)) + stream.sync() + + print(data) + + +if __name__ == "__main__": + stablehlo_add()