-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Adding TRT options/task #435
base: main
Are you sure you want to change the base?
Conversation
// TODO (pranavm): Figure out a better way to reuse TRT translation options - | ||
// maybe move to options providers? | ||
struct TensorRTOptions | ||
: public mlirtrt::compiler::OptionsProvider<TensorRTOptions> { | ||
mlir::tensorrt::TensorRTTranslationOptions options; | ||
|
||
void addToOptions(mlir::OptionsContext &context) { | ||
options.addToOptions(context); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can move TensorRTTranslationOptions
to make them an options provider if that makes sense to do.
mlir-tensorrt/compiler/include/mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h
Outdated
Show resolved
Hide resolved
TensorRTToExecutableOptions::TensorRTToExecutableOptions( | ||
TaskExtensionRegistry extensions) { | ||
// TODO (pranavm): Do we need to support extensions? | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we want to require extensions for all options types or if we need to handle both cases in the options registry. If it's the former, then I can just assert that the extensions are empty here (or maybe even just add support?). If it's the latter, we could have a setExtensions
method so it becomes optional instead of having it part of the constructor.
mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/PassManagerUtils.h
Outdated
Show resolved
Hide resolved
137caf5
to
08f90f3
Compare
08f90f3
to
88d48b9
Compare
mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.td
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp
Outdated
Show resolved
Hide resolved
0c2e89c
to
250f6f4
Compare
...tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h
Outdated
Show resolved
Hide resolved
Fix TensorRTOptions registration
mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsRegistry.h
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.h
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/Passes.td
Outdated
Show resolved
Hide resolved
...tensorrt/compiler/include/mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/include/mlir-tensorrt/Registration/RegisterMlirTensorRtPasses.h
Outdated
Show resolved
Hide resolved
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRTensorRTRegistration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This catch-all library has been removed in a commit which hasn't been sync'd up yet. Just remove this and add in dependent libraries explicitly.
mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp
Outdated
Show resolved
Hide resolved
|
||
buildPostClusteringPipeline(pm, options); | ||
|
||
mlir::executor::ConvertStdToExecutorPassOptions stdToExecOpts; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will want to add a host-target
flag which can take options executor
, llvm
or emitc
. This executor lowering pass is just for the executor
option. The pipeliens for the other branches will be sync'd up Friday
mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/TensorRTToExecutable.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp
Outdated
Show resolved
Hide resolved
mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/TensorRTToExecutable.cpp
Outdated
Show resolved
Hide resolved
SmallVector<Value> inputs = | ||
makeRegionIsolatedFromAbove(rewriter, inlineGroupOp.getRegion()); | ||
|
||
tensorrt::TensorRTModuleOp trtModule = getOrCreateTensorRTModuleOp(inlineGroupOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@christopherbate I had to move getOrCreateTensorRTModuleOp
here to make this pass work, otherwise in createOutlinedFunc
, at line SymbolTable(module)
will hit seg fault. Could you check if this is correct?
An example run, input:
func.func @trt_gather_default1(%arg0: tensor<10x20x30xf32>, %arg1: tensor<2x5xi32>,
%arg2: tensor<10x2x5x30xf32>) -> tensor<10x2x5x30xf32> {
%0 = tensorrt.gather {
axis = 1 : i64
} ins(%arg0, %arg1 : tensor<10x20x30xf32>, tensor<2x5xi32>) -> tensor<10x2x5x30xf32>
%1 = tensorrt.element_wise <kSUM>(%0, %arg2 : tensor<10x2x5x30xf32>, tensor<10x2x5x30xf32>) -> tensor<10x2x5x30xf32>
return %1 : tensor<10x2x5x30xf32>
}
Output:
module {
func.func @trt_gather_default1(%arg0: tensor<10x20x30xf32>, %arg1: tensor<2x5xi32>, %arg2: tensor<10x2x5x30xf32>) -> tensor<10x2x5x30xf32> {
%0 = tensorrt.call_alloc @trt_engines::@tensorrt_cluster(%arg0, %arg1, %arg2 : tensor<10x20x30xf32>, tensor<2x5xi32>, tensor<10x2x5x30xf32>) -> tensor<10x2x5x30xf32>
return %0 : tensor<10x2x5x30xf32>
}
tensorrt.module @trt_engines {
func.func @tensorrt_cluster(%arg0: tensor<10x20x30xf32>, %arg1: tensor<2x5xi32>, %arg2: tensor<10x2x5x30xf32>) -> tensor<10x2x5x30xf32> {
%0 = tensorrt.gather {axis = 1 : i64} ins(%arg0, %arg1 : tensor<10x20x30xf32>, tensor<2x5xi32>) -> tensor<10x2x5x30xf32>
%1 = tensorrt.element_wise <kSUM>(%0, %arg2 : tensor<10x2x5x30xf32>, tensor<10x2x5x30xf32>) -> tensor<10x2x5x30xf32>
return %1 : tensor<10x2x5x30xf32>
}
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented above on how to fix. You should move it out of outlineOp
back to its original position in order to avoid performing multiple linear scans unnecessarily.
No description provided.