From e3b31f2376ff4343b244f9b2b4007a6f42360e3a Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Fri, 8 Dec 2023 02:46:34 -0800 Subject: [PATCH] [Sync] distributed runtime support and many updates (#83) * [Sync] distributed runtime support and many updates * update torch-mlir * fix multi process --------- Co-authored-by: Zhekun Zhang --- NOTICE | 77 + compiler/doc/byteir_mhlo_custom_call.md | 61 +- compiler/include/byteir/Dialect/Ace/AceOps.td | 16 +- .../Dialect/mhlo/Transforms/CanonicalizeExt.h | 7 + .../byteir/Dialect/mhlo/Util/CustomCallUtil.h | 4 + .../include/byteir/Dialect/mhlo/Util/Util.h | 4 +- .../include/byteir/Target/Cpp/CppEmitter.h | 2 + .../byteir/Transforms/GraphClusteringAlgo.h | 1 + compiler/include/byteir/Transforms/Passes.td | 3 +- .../HloToByreTensor/HloToByreTensor.cpp | 17 +- compiler/lib/Dialect/Ace/IR/AceDialect.cpp | 44 + compiler/lib/Dialect/Byre/IR/ByreDialect.cpp | 8 - compiler/lib/Dialect/Cat/IR/CatDialect.cpp | 4 +- .../mhlo/Transforms/CanonicalizeExt.cpp | 48 +- .../mhlo/Transforms/HloAggressiveFusion.cpp | 10 + .../lib/Dialect/mhlo/Transforms/HloFolder.cpp | 2 +- .../mhlo/Transforms/LayoutTransformation.cpp | 7 +- compiler/lib/Dialect/mhlo/Util/Util.cpp | 8 +- .../Transforms/GraphClusteringByDevice.cpp | 132 +- compiler/python/byteir/compile.py | 10 +- .../python/byteir/dialects/cat/ait_cache.py | 2 +- .../byteir/dialects/cat/ir_processor.py | 24 +- compiler/python/version.txt | 2 +- compiler/scripts/gen_testcases.py | 2 +- .../HloToByreTensor/compute_ops.mlir | 10 + compiler/test/Dialect/Byre/invalid.mlir | 9 - .../Pipelines/Host/E2E/Case0/00_Input.mlir | 2 +- .../Pipelines/Host/E2E/Case0/01_HostOpt.mlir | 8 +- .../Host/E2E/Case0/02a_ByreHost.mlir | 4 +- .../Pipelines/Host/E2E/Case0/02b_ToLLVM.mlir | 4 +- .../Host/E2E/Case0/03b_ToLLVMIR.mlir | 6 +- .../test/Pipelines/Host/E2E/Case0/Output.ll | 6 +- .../test/Pipelines/Host/E2E/Case0/Output.mlir | 2 +- .../Host/E2E/Case0/TotalPipeline.mlir | 2 +- .../test/Pipelines/Host/E2E/Case0/template.py | 2 +- .../Pipelines/Host/E2E/Case1/00_Input.mlir | 2 +- .../Pipelines/Host/E2E/Case1/01_HostOpt.mlir | 8 +- .../Host/E2E/Case1/02a_ByreHost.mlir | 4 +- .../Pipelines/Host/E2E/Case1/02b_ToLLVM.mlir | 4 +- .../Host/E2E/Case1/03b_ToLLVMIR.mlir | 227 +- .../test/Pipelines/Host/E2E/Case1/Output.ll | 18 +- .../test/Pipelines/Host/E2E/Case1/Output.mlir | 2 +- .../Host/E2E/Case1/TotalPipeline.mlir | 2 +- .../test/Pipelines/Host/E2E/Case1/template.py | 2 +- .../Host/E2E/RngUniform/00_Input.mlir | 11 + .../Host/E2E/RngUniform/01_HostOpt.mlir | 42 + .../Host/E2E/RngUniform/02a_ByreHost.mlir | 44 + .../Host/E2E/RngUniform/02b_ToLLVM.mlir | 44 + .../Host/E2E/RngUniform/03b_ToLLVMIR.mlir | 62 + .../Pipelines/Host/E2E/RngUniform/Output.ll | 64 + .../Pipelines/Host/E2E/RngUniform/Output.mlir | 15 + .../Host/E2E/RngUniform/TotalPipeline.mlir | 11 + .../Pipelines/Host/E2E/RngUniform/template.py | 34 + .../Host/E2E/Transpose/00_Input.mlir | 2 +- .../Host/E2E/Transpose/01_HostOpt.mlir | 2 +- .../Host/E2E/Transpose/TotalPipeline.mlir | 2 +- .../Pipelines/Host/E2E/TypeCvt/00_Input.mlir | 2 +- .../Host/E2E/TypeCvt/01_HostOpt.mlir | 2 +- .../Host/E2E/TypeCvt/TotalPipeline.mlir | 2 +- .../graphClusteringByDeviceGreedy.mlir | 122 + .../graphClusteringByDeviceTopDown.mlir | 25 +- .../src/Compiler/OFCompilerUtils.cpp | 7 +- .../src/Conversion/OFCheckNonLowered.cpp | 3 +- .../src/Conversion/OFRewriteCustomOnnxOps.cpp | 10 +- .../src/Conversion/OFRewriteToCustomCall.cpp | 124 +- .../src/Conversion/OFRewriteToCustomCall.td | 16 +- .../onnx-frontend/src/onnx-frontend.cpp | 2 +- .../test/of_check_non_lowered.mlir | 12 +- .../test/of_rewrite_custom_onnx_op.mlir | 30 +- .../test/of_rewrite_to_custom_call.mlir | 39 +- frontends/onnx-frontend/test/base.py | 13 +- frontends/onnx-frontend/test/ops/test_math.py | 27 + frontends/onnx-frontend/test/ops/test_rnn.py | 34 + .../onnx-frontend/test/ops/test_tensor.py | 20 + .../third_party/patches/OnnxMlirNewOps.patch | 2100 ++++++++++++++++- .../third_party/patches/OnnxMlirPRelu.patch | 73 - .../patches/OnnxMlirRegisterLibrary.patch | 3 +- .../tf_mlir_ext/numerical/numerical_test.py | 3 + .../numerical/rewrite_to_custom_call.mlir | 57 + .../tests/fallback_to_custom_call.mlir | 6 + .../tests/rewrite_to_custom_call.mlir | 57 + .../transforms/rewrite_to_custom_call.cc | 35 +- .../transforms/rewrite_to_custom_call.td | 85 + .../transforms/tf_fallback_to_custom_call.cc | 20 + .../torch-frontend/examples/demo/backend.py | 87 +- .../torch-frontend/examples/demo/context.py | 350 +++ .../torch-frontend/examples/demo/main.py | 22 +- .../patches/backend_contract.patch | 79 + .../third_party/patches/einsum.patch | 12 - .../third_party/patches/tuple.patch | 16 + .../torch-frontend/third_party/torch-mlir | 2 +- .../torch-frontend/python/test/test_ops.py | 25 + .../torch-frontend/python/version.txt | 2 +- runtime/VERSION_NUMBER | 2 +- runtime/cmake/CMakeLists.txt | 36 +- runtime/cmake/Modules/FindNCCL.cmake | 66 + runtime/cmake/brt_device_nccl.cmake | 24 + runtime/cmake/brt_framework.cmake | 3 + runtime/cmake/brt_provider_nccl.cmake | 27 + runtime/cmake/brt_unittests.cmake | 43 + .../brt/backends/cuda/device/common/util.h | 44 - runtime/include/brt/backends/device_api.h | 7 +- .../brt/backends/nccl/device/d_context_nccl.h | 31 + .../nccl/device/distributed_backend_nccl.h | 72 + .../include/brt/backends/nccl/device/utils.h | 32 + .../backends/nccl/providers/nccl_provider.h | 46 + .../backends/nccl/providers/op_registration.h | 27 + .../include/brt/backends/rng_state_context.h | 86 + runtime/include/brt/core/common/enums.h | 31 + .../brt/core/context/execution_context.h | 7 +- .../include/brt/core/distributed/d_context.h | 21 + .../core/distributed/distributed_backend.h | 123 + .../core/distributed/distributed_session.h | 80 + .../brt/core/distributed/rendezvous_socket.h | 46 + .../brt/core/framework/execution_plan.h | 5 +- .../brt/core/session/request_context.h | 2 + .../cpu/providers/default/cpu_provider.cc | 15 +- ...ringToNumber.cc => tf_string_to_number.cc} | 2 +- ...stringToNumber.h => tf_string_to_number.h} | 0 .../providers/default/math/elementwise_ops.cc | 2 +- .../providers/default/math/elementwise_ops.h | 0 .../default/tensor_generate/rng_state.cc | 98 + .../default/tensor_generate/rng_state.h | 42 + .../default/tensor_generate/rng_state.cc | 7 +- .../default/tensor_generate/rng_state.h | 4 +- .../nccl/device/distributed_backend_nccl.cc | 232 ++ runtime/lib/backends/nccl/device/utils.cc | 51 + .../backends/nccl/providers/nccl_provider.cc | 67 + .../nccl/providers/op_registration.cc | 43 + runtime/lib/backends/nccl/providers/recv.cc | 71 + .../backends/nccl/providers/recv.h} | 24 +- runtime/lib/backends/nccl/providers/send.cc | 71 + runtime/lib/backends/nccl/providers/send.h | 34 + .../core/distributed/distributed_backend.cc | 35 + .../core/distributed/distributed_session.cc | 92 + .../lib/core/distributed/rendezvous_socket.cc | 337 +++ runtime/lib/core/framework/execution_plan.cc | 5 +- .../default/kernel/rng_state_test.cc | 58 + .../default/kernel/rng_state_test.cc | 5 +- .../nccl/device/test_distributed_backend.cc | 78 + .../test/backends/nccl/device/test_utils.cc | 54 + .../providers/test_distributed_session.cc | 194 ++ .../distributed/test_rendezvous_socket.cc | 200 ++ .../include/brt/test/common/nccl/test_base.h | 65 + .../include/brt/test/common/nccl/test_utils.h | 57 + .../test/test_files/Distributed/add_send.mlir | 9 + runtime/test/test_files/Distributed/recv.mlir | 6 + .../test/test_files/Distributed/recv_add.mlir | 9 + runtime/test/test_files/Distributed/send.mlir | 6 + runtime/test/test_files/rng_state_cpu.mlir | 12 + .../{rng_state.mlir => rng_state_cuda.mlir} | 0 scripts/runtime/build_and_test.sh | 5 + 152 files changed, 6955 insertions(+), 572 deletions(-) create mode 100644 compiler/test/Pipelines/Host/E2E/RngUniform/00_Input.mlir create mode 100644 compiler/test/Pipelines/Host/E2E/RngUniform/01_HostOpt.mlir create mode 100644 compiler/test/Pipelines/Host/E2E/RngUniform/02a_ByreHost.mlir create mode 100644 compiler/test/Pipelines/Host/E2E/RngUniform/02b_ToLLVM.mlir create mode 100644 compiler/test/Pipelines/Host/E2E/RngUniform/03b_ToLLVMIR.mlir create mode 100644 compiler/test/Pipelines/Host/E2E/RngUniform/Output.ll create mode 100644 compiler/test/Pipelines/Host/E2E/RngUniform/Output.mlir create mode 100644 compiler/test/Pipelines/Host/E2E/RngUniform/TotalPipeline.mlir create mode 100644 compiler/test/Pipelines/Host/E2E/RngUniform/template.py create mode 100644 compiler/test/Transforms/graphClusteringByDeviceGreedy.mlir create mode 100644 frontends/onnx-frontend/test/ops/test_rnn.py delete mode 100644 frontends/onnx-frontend/third_party/patches/OnnxMlirPRelu.patch create mode 100644 frontends/torch-frontend/examples/demo/context.py create mode 100644 frontends/torch-frontend/third_party/patches/backend_contract.patch create mode 100644 frontends/torch-frontend/third_party/patches/tuple.patch create mode 100644 runtime/cmake/Modules/FindNCCL.cmake create mode 100644 runtime/cmake/brt_device_nccl.cmake create mode 100644 runtime/cmake/brt_provider_nccl.cmake create mode 100644 runtime/include/brt/backends/nccl/device/d_context_nccl.h create mode 100644 runtime/include/brt/backends/nccl/device/distributed_backend_nccl.h create mode 100644 runtime/include/brt/backends/nccl/device/utils.h create mode 100644 runtime/include/brt/backends/nccl/providers/nccl_provider.h create mode 100644 runtime/include/brt/backends/nccl/providers/op_registration.h create mode 100644 runtime/include/brt/backends/rng_state_context.h create mode 100644 runtime/include/brt/core/common/enums.h create mode 100644 runtime/include/brt/core/distributed/d_context.h create mode 100644 runtime/include/brt/core/distributed/distributed_backend.h create mode 100644 runtime/include/brt/core/distributed/distributed_session.h create mode 100644 runtime/include/brt/core/distributed/rendezvous_socket.h rename runtime/lib/backends/cpu/providers/default/custom_call/{tf_stringToNumber.cc => tf_string_to_number.cc} (99%) rename runtime/lib/backends/cpu/providers/default/custom_call/{tf_stringToNumber.h => tf_string_to_number.h} (100%) rename runtime/{include/brt => lib}/backends/cpu/providers/default/math/elementwise_ops.h (100%) create mode 100644 runtime/lib/backends/cpu/providers/default/tensor_generate/rng_state.cc create mode 100644 runtime/lib/backends/cpu/providers/default/tensor_generate/rng_state.h create mode 100644 runtime/lib/backends/nccl/device/distributed_backend_nccl.cc create mode 100644 runtime/lib/backends/nccl/device/utils.cc create mode 100644 runtime/lib/backends/nccl/providers/nccl_provider.cc create mode 100644 runtime/lib/backends/nccl/providers/op_registration.cc create mode 100644 runtime/lib/backends/nccl/providers/recv.cc rename runtime/{include/brt/backends/cuda/providers/default/tensor_generate/rng_state_context.h => lib/backends/nccl/providers/recv.h} (66%) create mode 100644 runtime/lib/backends/nccl/providers/send.cc create mode 100644 runtime/lib/backends/nccl/providers/send.h create mode 100644 runtime/lib/core/distributed/distributed_backend.cc create mode 100644 runtime/lib/core/distributed/distributed_session.cc create mode 100644 runtime/lib/core/distributed/rendezvous_socket.cc create mode 100644 runtime/test/backends/cpu/providers/default/kernel/rng_state_test.cc create mode 100644 runtime/test/backends/nccl/device/test_distributed_backend.cc create mode 100644 runtime/test/backends/nccl/device/test_utils.cc create mode 100644 runtime/test/backends/nccl/providers/test_distributed_session.cc create mode 100644 runtime/test/distributed/test_rendezvous_socket.cc create mode 100644 runtime/test/include/brt/test/common/nccl/test_base.h create mode 100644 runtime/test/include/brt/test/common/nccl/test_utils.h create mode 100644 runtime/test/test_files/Distributed/add_send.mlir create mode 100644 runtime/test/test_files/Distributed/recv.mlir create mode 100644 runtime/test/test_files/Distributed/recv_add.mlir create mode 100644 runtime/test/test_files/Distributed/send.mlir create mode 100644 runtime/test/test_files/rng_state_cpu.mlir rename runtime/test/test_files/{rng_state.mlir => rng_state_cuda.mlir} (100%) diff --git a/NOTICE b/NOTICE index 8f370de33..6fc598fd6 100644 --- a/NOTICE +++ b/NOTICE @@ -1355,3 +1355,80 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +_____ + +MegRay is Licensed under the Apache License, Version 2.0 (the "License") + +Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. + +Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. + +Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. + +You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of this License; and +You must cause any modified files to carry prominent notices stating that You changed the files; and +You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and +If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. + +This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. + +Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. + +While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS diff --git a/compiler/doc/byteir_mhlo_custom_call.md b/compiler/doc/byteir_mhlo_custom_call.md index a70927446..e0e80b608 100644 --- a/compiler/doc/byteir_mhlo_custom_call.md +++ b/compiler/doc/byteir_mhlo_custom_call.md @@ -42,13 +42,15 @@ If an op is frontend-specific, it uses a frontend-specific prefix, such as `tf` Further needed infomation for a given coarse-grained op are encoded in a dictionary attribute, called `byteir_attrs`, which includes all named attributes. -```Op Attribute: byteir_attrs = {approximate = "none"} or byteir_attrs = {} of if none``` +**Op Attribute**: + * ```byteir_attrs = {approximate = "none"}``` or ```byteir_attrs = {}``` if no attribute + * ```axis``` attribute must be positive ### byteir.layer_norm - Operands: - input: Tensor - - weight: Tensor - - bias: Tensor + - weight: Tensor (shape should be same as axis of input tensor) + - bias: Tensor (shape should be same as axis of input tensor) - Attrs - epsilon: F64Attr - axis: I64ArrayAttr @@ -138,6 +140,12 @@ Further needed infomation for a given coarse-grained op are encoded in a diction %0 = "mhlo.custom_call"(%arg0) {call_target_name = "byteir.erf", has_side_effect = false} : (tensor) -> tensor ``` +### byteir.addn +- Operands: + - inputs: Variadic\ +- Results: + - outputs: Tensor + ### byteir.one_hot - Operands: - indices: IntTensor @@ -149,13 +157,20 @@ Further needed infomation for a given coarse-grained op are encoded in a diction - Results: - output: Tensor (ElementType same as on_value and off_value) +### byteir.repeat +- Operands: + - input: Tensor + - repeats: Int32/Int64 Tensor +- Results: + - output: Tensor (ElementType same as input) + ### byteir.quantize - Operands: - input: FloatTensor - scale: FloatTensor (rank=0 for per-tensor quantization, or rank=1 for per-channel quantization) - zero_point: Int8/Int16/Uint8/Uint16 Tensor (shape same as scale) - Attrs - - axis: I64Attr (Optional, required only for per-channel quantization) + - axis: Optional\ (required only for per-channel quantization) - Results: - output: Int8/Int16/Uint8/Uint16 Tensor (type same as zero_point) @@ -165,7 +180,7 @@ Further needed infomation for a given coarse-grained op are encoded in a diction - scale: FloatTensor (rank=0 for per-tensor dequantization, or rank=1 for per-channel dequantization) - zero_point: Int8/Int16/Uint8/Uint16 Tensor (shape same as scale, type same as input) - Attrs - - axis: I64Attr (Optional, channel axis index, required only for per-channel dequantization) + - axis: Optional\ (channel axis index, required only for per-channel dequantization) - Results: - output: FloatTensor @@ -220,3 +235,39 @@ Further needed infomation for a given coarse-grained op are encoded in a diction %shape = shape.shape_of %arg0 : tensor<3xindex> %0 = "mhlo.custom_call"(%low, %high, %seed, %offset, %shape) {call_target_name = "byteir.rng_uniform", has_side_effect = false} : (tensor, tensor, tensor, tensor, tensor<3xindex>) -> tensor ``` + +### byteir.flash_attn_fwd +- Operands: + - q: Tensor + - k: Tensor + - v: Tensor +- Attrs: + - dropout_p: FloatAttr + - softmax_scale: FloatAttr + - causal: BoolAttr + - return_softmax: BoolAttr +- Results: + - output: Tensor + - softmax_lse: Tensor + - softmax_return: Tensor + - rng: Tensor + +### byteir.flash_attn_bwd +- Operands: + - dout: Tensor + - q: Tensor + - k: Tensor + - v: Tensor + - out: Tensor + - softmax_lse: Tensor + - rng_state: Tensor +- Attrs: + - dropout_p: FloatAttr + - softmax_scale: FloatAttr + - causal: BoolAttr +- Results: + - dq: Tensor + - dk: Tensor + - dv: Tensor + - d_softmax: Tensor + - dq_accum: Tensor diff --git a/compiler/include/byteir/Dialect/Ace/AceOps.td b/compiler/include/byteir/Dialect/Ace/AceOps.td index 0d1a181fa..69c604036 100644 --- a/compiler/include/byteir/Dialect/Ace/AceOps.td +++ b/compiler/include/byteir/Dialect/Ace/AceOps.td @@ -48,13 +48,27 @@ def Ace_ConstOp : Ace_Op<"constant", [ let builders = [ OpBuilder<(ins "Attribute":$value)>]; - // FIXME: let this op only has generic format for golden ref calculate. + // FIXME: let this op only has generic format for loading without register ace dialect. // let assemblyFormat = "attr-dict $value"; let hasFolder = 1; let hasVerifier = 0; } +def Ace_ReshapeOp : Ace_Op<"reshape", [Pure, SameOperandsAndResultElementType]> { + let summary = "Reshape operation"; + let description = [{ + Performs reshape of `operand` tensor to a `result` tensor, like `mhlo.reshape`. + }]; + + let arguments = (ins AnyTensor:$operand); + + let results = (outs Ace_StaticShapeTensor); + let hasVerifier = 1; + // note: let this op only has generic format for loading without register ace dialect. +} + + def Ace_ActivateOp : Ace_Op<"activate", [Pure]> { let summary = "Activate operation"; diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/CanonicalizeExt.h b/compiler/include/byteir/Dialect/mhlo/Transforms/CanonicalizeExt.h index 18c1e13a7..87be51b4d 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/CanonicalizeExt.h +++ b/compiler/include/byteir/Dialect/mhlo/Transforms/CanonicalizeExt.h @@ -32,6 +32,13 @@ namespace mhlo { // fold multi op with zero void populateFoldMultiplyZeroPattern(RewritePatternSet &patterns); +// fold large binary Op +void populateFoldLargeBinaryOpPatterns(RewritePatternSet &patterns); + +// fold benefical convert with constant +void populateFoldBeneficialConstantConvertOpPattern( + RewritePatternSet &patterns); + // populate canonicalizeExt patterns void populateCanonicalizeExtPatterns(RewritePatternSet &patterns, MLIRContext *context, diff --git a/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h b/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h index 83a53e329..8879042a7 100644 --- a/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h +++ b/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h @@ -81,6 +81,10 @@ constexpr llvm::StringRef getDequantizeName() { return CUSTOM_CALL_NAME_PREFIX "dequantize"; } +constexpr llvm::StringRef getResizeName() { + return CUSTOM_CALL_NAME_PREFIX "resize"; +} + constexpr llvm::StringRef getRngUniformName() { return CUSTOM_CALL_NAME_PREFIX "rng_uniform"; } diff --git a/compiler/include/byteir/Dialect/mhlo/Util/Util.h b/compiler/include/byteir/Dialect/mhlo/Util/Util.h index e64f12ffc..bad9354e3 100644 --- a/compiler/include/byteir/Dialect/mhlo/Util/Util.h +++ b/compiler/include/byteir/Dialect/mhlo/Util/Util.h @@ -41,8 +41,8 @@ enum class NamedLayout : uint32_t { inline std::string stringifyEnum(NamedLayout layout) { switch (layout) { - case NamedLayout::UNKNOWN: - return "UNKNOWN"; + // case NamedLayout::UNKNOWN: + // return "UNKNOWN"; case NamedLayout::NHWC: return "NHWC"; case NamedLayout::NDHWC: diff --git a/compiler/include/byteir/Target/Cpp/CppEmitter.h b/compiler/include/byteir/Target/Cpp/CppEmitter.h index 75f1e8cf4..24592dd24 100644 --- a/compiler/include/byteir/Target/Cpp/CppEmitter.h +++ b/compiler/include/byteir/Target/Cpp/CppEmitter.h @@ -29,6 +29,8 @@ class CppEmitter { public: explicit CppEmitter(llvm::raw_ostream &os, bool declareVariablesAtTop); + virtual ~CppEmitter() = default; + /// Emits attribute or returns failure. virtual mlir::LogicalResult emitAttribute(mlir::Location loc, mlir::Attribute attr); diff --git a/compiler/include/byteir/Transforms/GraphClusteringAlgo.h b/compiler/include/byteir/Transforms/GraphClusteringAlgo.h index 90c22df95..7468e6ba5 100644 --- a/compiler/include/byteir/Transforms/GraphClusteringAlgo.h +++ b/compiler/include/byteir/Transforms/GraphClusteringAlgo.h @@ -26,6 +26,7 @@ enum class GraphClusteringAlgo : uint32_t { kFallback = 0, kTopDown = 1, kBottomUp = 2, + kGreedy = 3, }; } // namespace mlir diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index 8ac2f3e7c..97d69c022 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -207,7 +207,8 @@ def GraphClusteringByDevice : Pass<"graph-clustering-by-device", "ModuleOp"> { [{llvm::cl::values( clEnumValN(mlir::GraphClusteringAlgo::kFallback, "Fallback", "Fallback clustering algorithm"), clEnumValN(mlir::GraphClusteringAlgo::kTopDown, "TopDown", "Create and merge device cluster progressively from top to bottom"), - clEnumValN(mlir::GraphClusteringAlgo::kBottomUp, "BottomUp", "Create and merge device cluster progressively from bottom to top") + clEnumValN(mlir::GraphClusteringAlgo::kBottomUp, "BottomUp", "Create and merge device cluster progressively from bottom to top"), + clEnumValN(mlir::GraphClusteringAlgo::kGreedy, "Greedy", "Choose largest subgraph from BottomUp or TopDown strategy") )}]> ]; } diff --git a/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp b/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp index f5dfbd6a9..fe9ccda70 100644 --- a/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp +++ b/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp @@ -129,16 +129,16 @@ class ConvertConstLikeOp : public OpConversionPattern { } }; -class ConvertReshapeOp : public OpConversionPattern { +template class ConvertReshapeOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(mhlo::ReshapeOp op, mhlo::ReshapeOp::Adaptor adaptor, + matchAndRewrite(OP op, typename OP::Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto operand = adaptor.getOperand(); - auto operandType = operand.getType().cast(); - auto resultType = op.getType().cast(); + auto operandType = operand.getType().template cast(); + auto resultType = op.getType().template cast(); if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); @@ -749,9 +749,10 @@ void mlir::populateHloToByreTensorPattern( ConvertSelectAndScatterOpToByrePattern>(patterns.getContext(), appendArgTypes); - patterns.add, - ConvertConstLikeOp, ConvertReshapeOp, - ConvertSliceOp, ConvertConcatenateOp>(patterns.getContext()); + patterns.add< + ConvertConstLikeOp, ConvertConstLikeOp, + ConvertReshapeOp, ConvertReshapeOp, + ConvertSliceOp, ConvertConcatenateOp>(patterns.getContext()); } std::unique_ptr> diff --git a/compiler/lib/Dialect/Ace/IR/AceDialect.cpp b/compiler/lib/Dialect/Ace/IR/AceDialect.cpp index cc54b7f97..239cef09e 100644 --- a/compiler/lib/Dialect/Ace/IR/AceDialect.cpp +++ b/compiler/lib/Dialect/Ace/IR/AceDialect.cpp @@ -14,6 +14,24 @@ // limitations under the License. // //===----------------------------------------------------------------------===// +// +// Some code comes from openxla/stablehlo project, the original license: +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// Copyright 2022 The StableHLO Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// #include "byteir/Dialect/Ace/AceDialect.h" #include "mlir/IR/Builders.h" @@ -75,3 +93,29 @@ void AceDialect::initialize() { //===----------------------------------------------------------------------===// OpFoldResult mlir::ace::ConstOp::fold(FoldAdaptor) { return getValue(); } + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +LogicalResult mlir::ace::ReshapeOp::verify() { + auto operandTy = getOperand().getType().dyn_cast(); + // If the operand type is dynamically shaped there is nothing to verify. + if (!operandTy || !operandTy.hasStaticShape()) + return success(); + + // If the operand type is statically shaped (not required) the number of + // elements must match that of the result type. + auto resultTy = getResult().getType().cast(); + assert(resultTy && resultTy.hasStaticShape() && + "result type must be statically shaped"); + int64_t numResultElements = resultTy.getNumElements(); + int64_t numOperandElements = operandTy.getNumElements(); + if (numResultElements != numOperandElements) + return emitOptionalError(getLoc(), "number of output elements (", + numResultElements, + ") doesn't match expected number of elements (", + numOperandElements, ")"); + + return success(); +} diff --git a/compiler/lib/Dialect/Byre/IR/ByreDialect.cpp b/compiler/lib/Dialect/Byre/IR/ByreDialect.cpp index 7883dc82f..aa6d81bd8 100644 --- a/compiler/lib/Dialect/Byre/IR/ByreDialect.cpp +++ b/compiler/lib/Dialect/Byre/IR/ByreDialect.cpp @@ -251,14 +251,6 @@ LogicalResult ByreDialect::verifyOperationAttribute(Operation *op, } } - if (!numOutputs) { - return op->emitError( - "expected at least 1 argument which was attached with '") - << ByreDialect::getEntryPointFuncArgTypeAttrName() - << "' attribute contained '" - << stringifyEnum(EntryFuncArgType::Output) << '\''; - } - // FuncOp has no return if (funcOp.getNumResults() != 0) { return op->emitError("expected no return in ") diff --git a/compiler/lib/Dialect/Cat/IR/CatDialect.cpp b/compiler/lib/Dialect/Cat/IR/CatDialect.cpp index 966a22e22..ea3ba83cc 100644 --- a/compiler/lib/Dialect/Cat/IR/CatDialect.cpp +++ b/compiler/lib/Dialect/Cat/IR/CatDialect.cpp @@ -36,7 +36,7 @@ LogicalResult VerifyBMMLayout(Value lhs, Value rhs, Value out, llvm::StringRef layoutStr) { auto lhsType = lhs.getType().cast(); auto rhsType = rhs.getType().cast(); - auto outType = out.getType().cast(); + // auto outType = out.getType().cast(); if (lhsType.getRank() != 3 || rhsType.getRank() != 3) return failure(); @@ -58,7 +58,7 @@ LogicalResult VerifyGemmLayout(Value lhs, Value rhs, Value out, llvm::StringRef layoutStr) { auto lhsShape = lhs.getType().cast().getShape(); auto rhsShape = rhs.getType().cast().getShape(); - auto outShape = out.getType().cast().getShape(); + // auto outShape = out.getType().cast().getShape(); if (layoutStr == "rrr" && lhsShape[1] == rhsShape[0]) return success(); if (layoutStr == "rcr" && lhsShape[1] == rhsShape[1]) diff --git a/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp b/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp index 2738ebba6..89fc8a094 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp @@ -652,9 +652,9 @@ struct EliminateRedundantConvertFromI1 }; /// tensor -/// / | \ \ +/// / | \ | /// slice_0 slice_1 ... slice_n -/// | | | \ +/// | | | | /// ... reshape_0 reshape_1 ... reshape_n ... /// \ \ | / / /li /// concatenate @@ -1078,29 +1078,27 @@ template struct Xor { T operator()(const T &a, const T &b) const { return a ^ b; } }; -template struct Pow; +template struct Pow; // note: the power op in XLA will return 0 in case of power(-1,-n), where n>0. -template -struct Pow>> { - T operator()(const T &a, const T &b) const { +template <> struct Pow { + APSInt operator()(const APSInt &a, const APSInt &b) const { int64_t aPromoted = a.getSExtValue(); int64_t bPromoted = b.getSExtValue(); auto bitWidth = a.getBitWidth(); APInt res_(bitWidth, std::pow(aPromoted, bPromoted), true); - T res(res_); + APSInt res(res_); return res; } }; -template -struct Pow>> { - T operator()(const T &a, const T &b) const { +template <> struct Pow { + APFloat operator()(const APFloat &a, const APFloat &b) const { double aPromoted = a.convertToDouble(); double bPromoted = b.convertToDouble(); auto &semantics = a.getSemantics(); bool loses_info; - T res(std::pow(aPromoted, bPromoted)); + APFloat res(std::pow(aPromoted, bPromoted)); res.convert(semantics, APFloat::rmNearestTiesToEven, &loses_info); return res; } @@ -1911,13 +1909,9 @@ void mlir::mhlo::populateFoldMultiplyZeroPattern(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } -void mlir::mhlo::populateCanonicalizeExtPatterns(RewritePatternSet &patterns, - MLIRContext *ctx, - bool blindFold) { - patterns.add(ctx); - patterns.add(ctx); - patterns.add(ctx); - patterns.add(ctx); +void mlir::mhlo::populateFoldLargeBinaryOpPatterns( + RewritePatternSet &patterns) { + auto ctx = patterns.getContext(); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); @@ -1927,9 +1921,25 @@ void mlir::mhlo::populateCanonicalizeExtPatterns(RewritePatternSet &patterns, patterns.add>(ctx); patterns.add>(ctx); patterns.add(ctx); +} + +void mlir::mhlo::populateFoldBeneficialConstantConvertOpPattern( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +// TODO: split more patterns to populate function +void mlir::mhlo::populateCanonicalizeExtPatterns(RewritePatternSet &patterns, + MLIRContext *ctx, + bool blindFold) { + patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + populateFoldLargeBinaryOpPatterns(patterns); patterns.add(ctx); patterns.add(ctx); - patterns.add(ctx); + populateFoldBeneficialConstantConvertOpPattern(patterns); patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); diff --git a/compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp b/compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp index 6db573e5f..23db43748 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp @@ -18,6 +18,7 @@ #include "byteir/Dialect/mhlo/Transforms/HloFuser.h" #include "byteir/Dialect/mhlo/Transforms/GenericFusionCommon.h" +#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h" #include "byteir/Dialect/mhlo/Util/FusionUtil.h" #include "byteir/Dialect/mhlo/Util/Util.h" #include "byteir/Utils/IRRewrite.h" @@ -35,7 +36,16 @@ using namespace mlir::mhlo; namespace { namespace aggressive_fusion { +bool isCustomMhloRngUniformOp(Operation *op) { + if (auto customOp = llvm::dyn_cast_or_null(op)) { + return customOp.getCallTargetName() == getRngUniformName(); + } + return false; +} + bool isFusibleCandidate(Operation *op) { + if (isCustomMhloRngUniformOp(op)) + return true; return isMhlo(op) && !llvm::isa(op); } diff --git a/compiler/lib/Dialect/mhlo/Transforms/HloFolder.cpp b/compiler/lib/Dialect/mhlo/Transforms/HloFolder.cpp index 8364a25da..ebb6f319f 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/HloFolder.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/HloFolder.cpp @@ -398,7 +398,7 @@ struct ConvOrDotWithBiasFollowedByBroadcastPattern< auto weightType = weight.getType().template cast(); BroadcastInDimOp newBroadInDimOp = builder.create( constOp->getLoc(), weightType, constOp.getOutput(), - rewriter.getI64TensorAttr(weightFeatureDim)); + rewriter.getI64TensorAttr({weightFeatureDim})); MulOp newMulOp = builder.create(constOp->getLoc(), weight, newBroadInDimOp->getResult(0)); convOrDotOp->setOperand(1, newMulOp->getResult(0)); diff --git a/compiler/lib/Dialect/mhlo/Transforms/LayoutTransformation.cpp b/compiler/lib/Dialect/mhlo/Transforms/LayoutTransformation.cpp index 7fc513780..0509d749c 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/LayoutTransformation.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/LayoutTransformation.cpp @@ -567,7 +567,6 @@ struct BatchNormTrainingLayoutTransformationPattern if (op->hasAttr(TransformationDisableKey)) { return failure(); } - auto inputType = op.getOperand().getType().cast(); if (targetLayout == "NHWC") { Value inputTranspose = createNCHW2NHWCValue(rewriter, op->getLoc(), op.getOperand()); @@ -608,7 +607,7 @@ struct BatchNormInferenceLayoutTransformationPattern if (op->hasAttr(TransformationDisableKey)) { return failure(); } - auto inputType = op.getOperand().getType().cast(); + // auto inputType = op.getOperand().getType().cast(); if (targetLayout == "NHWC") { Value inputTranspose = createNCHW2NHWCValue(rewriter, op->getLoc(), op.getOperand()); @@ -645,7 +644,7 @@ struct BatchNormGradLayoutTransformationPattern if (op->hasAttr(TransformationDisableKey)) { return failure(); } - auto inputType = op.getOperand().getType().cast(); + // auto inputType = op.getOperand().getType().cast(); if (targetLayout == "NHWC") { Value operandTranspose = createNCHW2NHWCValue(rewriter, op->getLoc(), op.getOperand()); @@ -677,7 +676,7 @@ struct BatchNormGradLayoutTransformationPattern // return the input layout of conv op when there are conv op which have // the same input layout in funcOp byteir::NamedLayout findGlobalLayout(func::FuncOp func) { - Region &body = func.getBody(); + // Region &body = func.getBody(); byteir::NamedLayout inputLayout = byteir::NamedLayout::UNKNOWN; func.walk([&inputLayout](mhlo::ConvolutionOp conv) { diff --git a/compiler/lib/Dialect/mhlo/Util/Util.cpp b/compiler/lib/Dialect/mhlo/Util/Util.cpp index a5e8531ab..977f43cac 100644 --- a/compiler/lib/Dialect/mhlo/Util/Util.cpp +++ b/compiler/lib/Dialect/mhlo/Util/Util.cpp @@ -160,7 +160,7 @@ std::optional mlir::getCumsumIndex(mhlo::ReduceWindowOp op) { return std::nullopt; } int64_t index = K_INITIAL; - for (size_t i = 0; i < inputShape.getRank(); i++) { + for (int64_t i = 0; i < inputShape.getRank(); i++) { if (window_dimensions[i] == 1 && padding[i * 2] == 0 && padding[i * 2 + 1] == 0) { // not cumsum index @@ -597,7 +597,7 @@ mlir::computeReshapeInputOutputRankMapIndex(ShapedType inputType, j++; } } - if (result.size() != inputType.getRank()) { + if (result.size() != static_cast(inputType.getRank())) { return std::nullopt; } return result; @@ -626,10 +626,10 @@ mlir::computeReshapeExpandDim(mhlo::ReshapeOp reshapeOp) { return std::nullopt; } auto index = *maybeIndex; - for (int64_t i = 0; i < reshapeResultType.getRank(); i++) { + for (int64_t i = 0; i < reshapeOperandType.getRank(); i++) { if (index[i] != i) { return i; } } - return std::nullopt; + return reshapeOperandType.getRank(); } diff --git a/compiler/lib/Transforms/GraphClusteringByDevice.cpp b/compiler/lib/Transforms/GraphClusteringByDevice.cpp index ae38b7b19..e224a07ac 100644 --- a/compiler/lib/Transforms/GraphClusteringByDevice.cpp +++ b/compiler/lib/Transforms/GraphClusteringByDevice.cpp @@ -135,6 +135,7 @@ getFunctionMetadatasFallback(func::FuncOp funcOp, StringRef attrName, struct ActiveDeviceCluster { using OpList = llvm::SetVector; + using OpClusterMap = llvm::DenseMap; OpList operations; ActiveDeviceCluster *mergedInto; @@ -160,7 +161,8 @@ struct ActiveDeviceCluster { // return merged ActiveDeviceCluster or nullptr for merge failure // arg order sensitive, prefer merge lhs into rhs static ActiveDeviceCluster *tryMerge(ActiveDeviceCluster *lhs, - ActiveDeviceCluster *rhs); + ActiveDeviceCluster *rhs, + OpClusterMap &op2cluster); struct CompareByNumOps { bool operator()(ActiveDeviceCluster *lhs, ActiveDeviceCluster *rhs) { @@ -169,7 +171,8 @@ struct ActiveDeviceCluster { }; private: - static bool tryMergeInto(ActiveDeviceCluster *from, ActiveDeviceCluster *to); + static bool tryMergeInto(ActiveDeviceCluster *from, ActiveDeviceCluster *to, + OpClusterMap &op2cluster); static bool anyDefIn(Operation *op, const OpList &operations) { for (auto &&operand : op->getOperands()) @@ -189,12 +192,27 @@ struct ActiveDeviceCluster { // \p moveUp in pre-order, and the remaining operations will be kept in \p src // in pre-order static auto computeMoveUpSet(const OpList &target, OpList &src, - OpList &moveUp) { + OpList &moveUp, OpClusterMap &op2cluster) { std::vector vec = src.takeVector(); OpList &remain = src; for (auto &&op : vec) { + if (remain.contains(op)) + continue; if (anyDefIn(op, target) || anyDefIn(op, remain)) { - remain.insert(op); + auto &&iter = op2cluster.find(op); + OpList ops; + if (iter == op2cluster.end()) { + remain.insert(op); + continue; + } + ActiveDeviceCluster *cluster = iter->second.getRoot(); + for (Operation *clusterOp : cluster->operations) { + assert(std::find(vec.begin(), vec.end(), clusterOp) != vec.end()); + assert(remain.insert(clusterOp)); + if (moveUp.contains(clusterOp)) { + moveUp.remove(clusterOp); + } + } } else { moveUp.insert(op); } @@ -205,24 +223,39 @@ struct ActiveDeviceCluster { // \p moveDown in post-order, and the remaining operations will be kept in \p // src in pre-order static auto computeMoveDownSet(const OpList &target, OpList &src, - OpList &moveDown) { + OpList &moveDown, OpClusterMap &op2cluster) { std::vector vec = src.takeVector(); OpList &remain = src; for (auto &&op : llvm::reverse(vec)) { + if (remain.contains(op)) + continue; if (anyUseIn(op, target) || anyUseIn(op, remain)) { - remain.insert(op); + auto &&iter = op2cluster.find(op); + OpList ops; + if (iter == op2cluster.end()) { + remain.insert(op); + continue; + } + ActiveDeviceCluster *cluster = iter->second.getRoot(); + for (Operation *clusterOp : llvm::reverse(cluster->operations)) { + assert(std::find(vec.begin(), vec.end(), clusterOp) != vec.end()); + assert(remain.insert(clusterOp)); + if (moveDown.contains(clusterOp)) { + moveDown.remove(clusterOp); + } + } } else { moveDown.insert(op); } } - vec = remain.takeVector(); remain.insert(vec.rbegin(), vec.rend()); } }; bool ActiveDeviceCluster::tryMergeInto(ActiveDeviceCluster *from, - ActiveDeviceCluster *to) { + ActiveDeviceCluster *to, + OpClusterMap &op2cluster) { static auto takePointer = [](Operation &op) { return &op; }; if (from->isBeforeInBlock(to)) { OpList toMove( @@ -231,8 +264,8 @@ bool ActiveDeviceCluster::tryMergeInto(ActiveDeviceCluster *from, llvm::map_iterator(to->operations.front()->getIterator(), takePointer)); OpList moveUp, moveDown; - computeMoveUpSet(from->operations, toMove, moveUp); - computeMoveDownSet(to->operations, toMove, moveDown); + computeMoveUpSet(from->operations, toMove, moveUp, op2cluster); + computeMoveDownSet(to->operations, toMove, moveDown, op2cluster); if (!toMove.empty()) return false; @@ -257,8 +290,8 @@ bool ActiveDeviceCluster::tryMergeInto(ActiveDeviceCluster *from, takePointer)); OpList moveUp, moveDown; - computeMoveDownSet(from->operations, toMove, moveDown); - computeMoveUpSet(to->operations, toMove, moveUp); + computeMoveDownSet(from->operations, toMove, moveDown, op2cluster); + computeMoveUpSet(to->operations, toMove, moveUp, op2cluster); if (!toMove.empty()) return false; @@ -280,18 +313,19 @@ bool ActiveDeviceCluster::tryMergeInto(ActiveDeviceCluster *from, } ActiveDeviceCluster *ActiveDeviceCluster::tryMerge(ActiveDeviceCluster *lhs, - ActiveDeviceCluster *rhs) { + ActiveDeviceCluster *rhs, + OpClusterMap &op2cluster) { if (!lhs || !rhs || lhs == rhs) return nullptr; if (lhs->mergedInto || rhs->mergedInto) return nullptr; - if (tryMergeInto(lhs, rhs)) { + if (tryMergeInto(lhs, rhs, op2cluster)) { return rhs; } - if (tryMergeInto(rhs, lhs)) { + if (tryMergeInto(rhs, lhs, op2cluster)) { return lhs; } @@ -337,7 +371,16 @@ DeviceClusteringAlgoBaseHelper::DeviceClusteringAlgoBaseHelper( continue; } } - + // if a constant is only used by host op, mark it as host + if (isMhloConstantLike(&op) && op.getResult(0).hasOneUse()) { + Operation *user = *op.getResult(0).getUsers().begin(); + if (user->hasAttr(attrName)) { + StringAttr attr = user->getAttrOfType(attrName); + if (attr.getValue().str() == DEVICE_ATTR_HOST) { + continue; + } + } + } op2cluster.try_emplace(&op, ActiveDeviceCluster(&op)); } } @@ -392,7 +435,8 @@ void DeviceClusteringAlgoBaseHelper::populateCandidates() { ActiveDeviceCluster *cluster = workList.front(); workList.pop_front(); for (auto &&iter = workList.begin(); iter != workList.end();) { - if (auto merged = ActiveDeviceCluster::tryMerge(*iter, cluster)) { + if (auto merged = + ActiveDeviceCluster::tryMerge(*iter, cluster, op2cluster)) { cluster = merged; iter = workList.erase(iter); } else { @@ -434,7 +478,8 @@ void TopDownDeviceClustering::mergeDeviceClustersProgressively() { auto curCluster = getCluster(op); for (auto &&operand : op->getOperands()) { auto preCluster = getCluster(operand); - if (auto merged = ActiveDeviceCluster::tryMerge(curCluster, preCluster)) { + if (auto merged = ActiveDeviceCluster::tryMerge(preCluster, curCluster, + op2cluster)) { curCluster = merged; } } @@ -457,7 +502,8 @@ void BottomUpDeviceClustering::mergeDeviceClustersProgressively() { auto curCluster = getCluster(op); for (auto &&use : op->getUses()) { auto preCluster = getCluster(use.getOwner()); - if (auto merged = ActiveDeviceCluster::tryMerge(preCluster, curCluster)) { + if (auto merged = ActiveDeviceCluster::tryMerge(preCluster, curCluster, + op2cluster)) { curCluster = merged; } } @@ -624,21 +670,63 @@ void GraphClusteringByDevicePass::runOnOperation() { for (auto funcOp : originalFuncs) { std::optional> metadatas; switch (this->clusterAlgo) { - case GraphClusteringAlgo::kTopDown: + case GraphClusteringAlgo::kTopDown: { metadatas = TopDownDeviceClustering(funcOp, attrName) .getFunctionMetadatas(attrName, device, deviceAnchorName, dupOutputs); break; - case GraphClusteringAlgo::kBottomUp: + } + case GraphClusteringAlgo::kBottomUp: { metadatas = BottomUpDeviceClustering(funcOp, attrName) .getFunctionMetadatas(attrName, device, deviceAnchorName, dupOutputs); break; + } + case GraphClusteringAlgo::kGreedy: { + std::optional> topDownMetadatas; + std::optional> bottomUpMetadatas; + auto topDownFunc = funcOp.clone(); + auto bottomUpFunc = funcOp.clone(); + + topDownMetadatas = + TopDownDeviceClustering(topDownFunc, attrName) + .getFunctionMetadatas(attrName, device, deviceAnchorName, + dupOutputs); + bottomUpMetadatas = + BottomUpDeviceClustering(bottomUpFunc, attrName) + .getFunctionMetadatas(attrName, device, deviceAnchorName, + dupOutputs); + if (topDownMetadatas && bottomUpMetadatas) { + auto topDownSize = (*topDownMetadatas)[0].ops.size(); + auto bottomUpSize = (*bottomUpMetadatas)[0].ops.size(); + if (topDownSize > bottomUpSize) { + metadatas = TopDownDeviceClustering(funcOp, attrName) + .getFunctionMetadatas(attrName, device, + deviceAnchorName, dupOutputs); + } else { + metadatas = BottomUpDeviceClustering(funcOp, attrName) + .getFunctionMetadatas(attrName, device, + deviceAnchorName, dupOutputs); + } + } else if (topDownMetadatas) { + metadatas = TopDownDeviceClustering(funcOp, attrName) + .getFunctionMetadatas(attrName, device, + deviceAnchorName, dupOutputs); + } else if (bottomUpMetadatas) { + metadatas = BottomUpDeviceClustering(funcOp, attrName) + .getFunctionMetadatas(attrName, device, + deviceAnchorName, dupOutputs); + } + topDownFunc.erase(); + bottomUpFunc.erase(); + break; + } case GraphClusteringAlgo::kFallback: - default: + default: { metadatas = getFunctionMetadatasFallback(funcOp, attrName, device, deviceAnchorName, dupOutputs); } + } if (!metadatas) { signalPassFailure(); diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index 5c5d618e1..7b93e29a9 100644 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -107,7 +107,7 @@ def compile_cuda_with_ait( name: str = "model", aggressive_mode: bool = False, parallelism: int = 1, - disable_byteir_cache: bool = False, + disable_byteir_ait_cache: bool = False, **kwargs, ): target = "cuda" @@ -123,7 +123,7 @@ def compile_cuda_with_ait( processor = IRProcessor(name, "./workspace", compile_parallelism=parallelism, - disable_byteir_cache=disable_byteir_cache, + disable_byteir_ait_cache=disable_byteir_ait_cache, verbose=verbose) processor.module = module @@ -224,7 +224,7 @@ def compile( target: str = "cuda", verbose: bool = False, parallelism: int = 1, - disable_byteir_cache: bool = False, + disable_byteir_ait_cache: bool = False, **kwargs, ): context = ir.Context() @@ -245,7 +245,7 @@ def compile( entry_func, verbose, parallelism=parallelism, - disable_byteir_cache=disable_byteir_cache) + disable_byteir_ait_cache=disable_byteir_ait_cache) elif target == "cuda_with_ait_aggressive": compile_cuda_with_ait(module, output, @@ -253,6 +253,6 @@ def compile( verbose, aggressive_mode=True, parallelism=parallelism, - disable_byteir_cache=disable_byteir_cache) + disable_byteir_ait_cache=disable_byteir_ait_cache) else: raise NotImplemented("not implemented target: {}".format(target)) diff --git a/compiler/python/byteir/dialects/cat/ait_cache.py b/compiler/python/byteir/dialects/cat/ait_cache.py index 9f1c43713..37f524dcc 100644 --- a/compiler/python/byteir/dialects/cat/ait_cache.py +++ b/compiler/python/byteir/dialects/cat/ait_cache.py @@ -9,7 +9,7 @@ HOME_DIR = "/tmp/" DEFAULT_CACHE_DIR = os.path.join(HOME_DIR, ".byteir_cache/ait_cache/") CACHE_FILE_NAME = "ait_global_cache.json" -IDX_KEY = "byteir_ait_chache_auto_increment_idx" +IDX_KEY = "byteir_ait_cache_auto_increment_idx" class AITCache: def __init__(self, cache_dir = DEFAULT_CACHE_DIR) -> None: diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index c1388385b..c4c7e14f9 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -28,18 +28,21 @@ def __init__(self, job_name, workdir, compile_parallelism = MAX_COMPILATION_PARALLELISM, - disable_byteir_cache = False, + disable_byteir_ait_cache = False, verbose = False): self.job_name = job_name self.workdir = workdir self.module = None self.ait_reuse_recorder = {} # key: hash str, value: Tuple(dll_name, ait_module_path) self.compile_parallelism = min(compile_parallelism, MAX_COMPILATION_PARALLELISM) - self.pool = multiprocessing.Pool(compile_parallelism) + if self.compile_parallelism > 1: + self.pool = multiprocessing.Pool(compile_parallelism) + else: + self.pool = None self.byteir_cache = AITCache() self.verbose = verbose - self.disable_byteir_cache = disable_byteir_cache - if not disable_byteir_cache: + self.disable_byteir_ait_cache = disable_byteir_ait_cache + if not disable_byteir_ait_cache: self.byteir_cache.load_or_create_cache() def _get_builder(self, module, subgraph_name, backend="ait"): @@ -159,16 +162,19 @@ def ait_opt_pass(self, anchor_only=False, dump_ir=False): print("compile ait module using {} processes".format(min(len(work_items_not_in_cache), self.compile_parallelism))) t_st = time.time() for func_ir_str in work_items_not_in_cache: - self.pool.apply_async(_parallel_ait_compile, (self.workdir, func_ir_str)) - # _parallel_ait_compile(self.workdir, func_ir_str) + if self.pool: + self.pool.apply_async(_parallel_ait_compile, (self.workdir, func_ir_str)) + else: + _parallel_ait_compile(self.workdir, func_ir_str) - self.pool.close() - self.pool.join() + if self.pool: + self.pool.close() + self.pool.join() t_ed = time.time() print("compilation finished in {}s".format(t_ed-t_st)) # update byteir cache - if not self.disable_byteir_cache: + if not self.disable_byteir_ait_cache: for key, lib_path in libs_to_add_to_cache.items(): self.byteir_cache.add(gpu_type, key, lib_path, override=False) self.byteir_cache._save() diff --git a/compiler/python/version.txt b/compiler/python/version.txt index 3e1ad720b..9dbb0c005 100644 --- a/compiler/python/version.txt +++ b/compiler/python/version.txt @@ -1 +1 @@ -1.5.0 \ No newline at end of file +1.7.0 \ No newline at end of file diff --git a/compiler/scripts/gen_testcases.py b/compiler/scripts/gen_testcases.py index ae99aa0d9..18971330a 100644 --- a/compiler/scripts/gen_testcases.py +++ b/compiler/scripts/gen_testcases.py @@ -45,7 +45,7 @@ class HostPipelineCollections: # pipelines InputPipeline = functools.partial(OptPipeline, Input, [HostOpt], [ - "--hlo-opt=\"target=CPU\"", "--linalg-tensor-opt=\"target=CPU\"", "--byteir-bufferize-opt", "--scf-opt=\"target=CPU\"", + "--hlo-opt=\"target=CPU\"", "--linalg-tensor-opt=\"target=CPU\"", "--byre-tensor-opt=\"entry-func=main append-arg-types\"", "--byteir-bufferize-opt", "--scf-opt=\"target=CPU\"", ]) HostOptPipeline = functools.partial(OptPipeline, HostOpt, [ByreHost, ToLLVM], [ "--host-opt", "--byre-opt", diff --git a/compiler/test/Conversion/HloToByreTensor/compute_ops.mlir b/compiler/test/Conversion/HloToByreTensor/compute_ops.mlir index b317a6242..b025684b3 100644 --- a/compiler/test/Conversion/HloToByreTensor/compute_ops.mlir +++ b/compiler/test/Conversion/HloToByreTensor/compute_ops.mlir @@ -78,6 +78,16 @@ func.func @test_mhlo_reshape_2(%arg0 : tensor<2x3x4x5xf32>) -> tensor<8x15xf32> // ----- +func.func @test_ace_reshape_0(%arg0 : tensor<2x3x4x5x!ace.string>) -> tensor<6x20x!ace.string> attributes {__placeholder__byre.entry_point} { + %0 = "ace.reshape"(%arg0) : (tensor<2x3x4x5x!ace.string>) -> tensor<6x20x!ace.string> + return %0 : tensor<6x20x!ace.string> +} +// CHECK-LABEL: func.func @test_ace_reshape_0 +// CHECK-NEXT: tensor.collapse_shape +// CHECK-SAME: {{\[}}[0, 1], [2, 3]] + +// ----- + func.func @test_mhlo_dot(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>) -> tensor<128x32xf32> attributes {__placeholder__byre.entry_point} { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<128x64xf32>, tensor<64x32xf32>) -> tensor<128x32xf32> return %0 : tensor<128x32xf32> diff --git a/compiler/test/Dialect/Byre/invalid.mlir b/compiler/test/Dialect/Byre/invalid.mlir index b95c46caf..c059e3487 100644 --- a/compiler/test/Dialect/Byre/invalid.mlir +++ b/compiler/test/Dialect/Byre/invalid.mlir @@ -46,15 +46,6 @@ module attributes {byre.container_module} { // ----- -module attributes {byre.container_module} { - // expected-error @+1 {{expected at least 1 argument which was attached with 'byre.argtype' attribute contained 'Output'}} - func.func @invalid_entry_func(%arg0 : memref<100x?xf32> {byre.argtype = 1: i32, byre.argname = "output"}) attributes {byre.entry_point} { - return - } -} - -// ----- - module attributes {byre.container_module} { // expected-error @+1 {{invalid argtype 'Input|Output'}} func.func @invalid_entry_func(%arg0 : memref<100x?xf32> {byre.argtype = 3: i32}) attributes {byre.entry_point} { diff --git a/compiler/test/Pipelines/Host/E2E/Case0/00_Input.mlir b/compiler/test/Pipelines/Host/E2E/Case0/00_Input.mlir index a34147ba2..c71767d10 100644 --- a/compiler/test/Pipelines/Host/E2E/Case0/00_Input.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case0/00_Input.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s // CHECK-LABEL: func.func @main diff --git a/compiler/test/Pipelines/Host/E2E/Case0/01_HostOpt.mlir b/compiler/test/Pipelines/Host/E2E/Case0/01_HostOpt.mlir index ea6b0b88f..f4d18de58 100644 --- a/compiler/test/Pipelines/Host/E2E/Case0/01_HostOpt.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case0/01_HostOpt.mlir @@ -1,13 +1,13 @@ // RUN: byteir-opt %s --host-opt --byre-opt | FileCheck %s // CHECK-LABEL: memref.global "private" -// CHECK-LABEL: func.func @Unknown1( +// CHECK-LABEL: func.func @Unknown0( // CHECK-SAME: %[[ARG0:.*]]: memref<1xi64>, %[[ARG1:.*]]: memref<1xi64>, %[[ARG2:.*]]: memref<1xi64>, %[[ARG3:.*]]: memref<1x128xi32>, %[[RES0:.*]]: memref<1x128xi32>, %[[RES1:.*]]: memref<1x128xi32>) // CHECK-LABEL: func.func @main module { memref.global "private" constant @__constant_1x128xi32 : memref<1x128xi32> = dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000"> - func.func private @Unknown1(%arg0: memref<1xi64>, %arg1: memref<1xi64>, %arg2: memref<1xi64>, %arg3: memref<1x128xi32>) -> (memref<1x128xi32>, memref<1x128xi32>) attributes {__byteir_hlo_aggressive_fusion__} { + func.func private @Unknown0(%arg0: memref<1xi64>, %arg1: memref<1xi64>, %arg2: memref<1xi64>, %arg3: memref<1x128xi32>) -> (memref<1x128xi32>, memref<1x128xi32>) attributes {__byteir_hlo_aggressive_fusion__} { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index @@ -37,8 +37,8 @@ module { } return %alloc_0, %alloc : memref<1x128xi32>, memref<1x128xi32> } - func.func @main(%arg0: memref<1xi64>, %arg1: memref<1xi64>, %arg2: memref<1xi64>, %arg3: memref<1x128xi32>) -> (memref<1x128xi32>, memref<1x128xi32>) { - %0:2 = call @Unknown1(%arg0, %arg1, %arg2, %arg3) : (memref<1xi64>, memref<1xi64>, memref<1xi64>, memref<1x128xi32>) -> (memref<1x128xi32>, memref<1x128xi32>) + func.func @main(%arg0: memref<1xi64>, %arg1: memref<1xi64>, %arg2: memref<1xi64>, %arg3: memref<1x128xi32>) -> (memref<1x128xi32>, memref<1x128xi32>) attributes {__placeholder__byre.entry_point} { + %0:2 = call @Unknown0(%arg0, %arg1, %arg2, %arg3) : (memref<1xi64>, memref<1xi64>, memref<1xi64>, memref<1x128xi32>) -> (memref<1x128xi32>, memref<1x128xi32>) return %0#0, %0#1 : memref<1x128xi32>, memref<1x128xi32> } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case0/02a_ByreHost.mlir b/compiler/test/Pipelines/Host/E2E/Case0/02a_ByreHost.mlir index 9e05b4a70..1033c725f 100644 --- a/compiler/test/Pipelines/Host/E2E/Case0/02a_ByreHost.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case0/02a_ByreHost.mlir @@ -5,7 +5,7 @@ module attributes {byre.container_module} { module attributes {byteir.llvm_module} { memref.global "private" constant @__constant_1x128xi32 : memref<1x128xi32> = dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000"> - func.func @Unknown1(%arg0: memref<1xi64>, %arg1: memref<1xi64>, %arg2: memref<1xi64>, %arg3: memref<1x128xi32>, %arg4: memref<1x128xi32>, %arg5: memref<1x128xi32>) attributes {__byre__kernel_name = "Unknown1", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32, 3 : i32, 4 : i32, 5 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + func.func @Unknown0(%arg0: memref<1xi64>, %arg1: memref<1xi64>, %arg2: memref<1xi64>, %arg3: memref<1x128xi32>, %arg4: memref<1x128xi32>, %arg5: memref<1x128xi32>) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32, 3 : i32, 4 : i32, 5 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index @@ -35,7 +35,7 @@ module attributes {byre.container_module} { } } func.func @main(%arg0: memref<1xi64> {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg1: memref<1xi64> {byre.argname = "Input1", byre.argtype = 1 : i32}, %arg2: memref<1xi64> {byre.argname = "Input2", byre.argtype = 1 : i32}, %arg3: memref<1x128xi32> {byre.argname = "Input3", byre.argtype = 1 : i32}, %arg4: memref<1x128xi32> {byre.argname = "Output0", byre.argtype = 2 : i32}, %arg5: memref<1x128xi32> {byre.argname = "Output1", byre.argtype = 2 : i32}) attributes {byre.entry_point} { - byre.compute @LLVMJITOp(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {kernel_name = "Unknown1", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 1 : i32, 1 : i32, 2 : i32, 2 : i32]} : memref<1xi64>, memref<1xi64>, memref<1xi64>, memref<1x128xi32>, memref<1x128xi32>, memref<1x128xi32> + byre.compute @LLVMJITOp(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 1 : i32, 1 : i32, 2 : i32, 2 : i32]} : memref<1xi64>, memref<1xi64>, memref<1xi64>, memref<1x128xi32>, memref<1x128xi32>, memref<1x128xi32> return } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case0/02b_ToLLVM.mlir b/compiler/test/Pipelines/Host/E2E/Case0/02b_ToLLVM.mlir index 9c9c452f4..0a40e4071 100644 --- a/compiler/test/Pipelines/Host/E2E/Case0/02b_ToLLVM.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case0/02b_ToLLVM.mlir @@ -6,7 +6,7 @@ module attributes {byre.container_module} { module attributes {byteir.llvm_module} { memref.global "private" constant @__constant_1x128xi32 : memref<1x128xi32> = dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000"> - func.func @Unknown1(%arg0: memref<1xi64>, %arg1: memref<1xi64>, %arg2: memref<1xi64>, %arg3: memref<1x128xi32>, %arg4: memref<1x128xi32>, %arg5: memref<1x128xi32>) attributes {__byre__kernel_name = "Unknown1", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32, 3 : i32, 4 : i32, 5 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + func.func @Unknown0(%arg0: memref<1xi64>, %arg1: memref<1xi64>, %arg2: memref<1xi64>, %arg3: memref<1x128xi32>, %arg4: memref<1x128xi32>, %arg5: memref<1x128xi32>) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32, 3 : i32, 4 : i32, 5 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index @@ -36,7 +36,7 @@ module attributes {byre.container_module} { } } func.func @main(%arg0: memref<1xi64> {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg1: memref<1xi64> {byre.argname = "Input1", byre.argtype = 1 : i32}, %arg2: memref<1xi64> {byre.argname = "Input2", byre.argtype = 1 : i32}, %arg3: memref<1x128xi32> {byre.argname = "Input3", byre.argtype = 1 : i32}, %arg4: memref<1x128xi32> {byre.argname = "Output0", byre.argtype = 2 : i32}, %arg5: memref<1x128xi32> {byre.argname = "Output1", byre.argtype = 2 : i32}) attributes {byre.entry_point} { - byre.compute @LLVMJITOp(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {kernel_name = "Unknown1", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 1 : i32, 1 : i32, 2 : i32, 2 : i32]} : memref<1xi64>, memref<1xi64>, memref<1xi64>, memref<1x128xi32>, memref<1x128xi32>, memref<1x128xi32> + byre.compute @LLVMJITOp(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 1 : i32, 1 : i32, 2 : i32, 2 : i32]} : memref<1xi64>, memref<1xi64>, memref<1xi64>, memref<1x128xi32>, memref<1x128xi32>, memref<1x128xi32> return } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case0/03b_ToLLVMIR.mlir b/compiler/test/Pipelines/Host/E2E/Case0/03b_ToLLVMIR.mlir index 8beba6563..dfd56398e 100644 --- a/compiler/test/Pipelines/Host/E2E/Case0/03b_ToLLVMIR.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case0/03b_ToLLVMIR.mlir @@ -5,7 +5,7 @@ module attributes {byre.container_module, llvm.data_layout = ""} { llvm.mlir.global private constant @__constant_1x128xi32(dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000"> : tensor<1x128xi32>) {addr_space = 0 : i32} : !llvm.array<1 x array<128 x i32>> - llvm.func @Unknown1(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: !llvm.ptr, %arg11: !llvm.ptr, %arg12: i64, %arg13: i64, %arg14: i64, %arg15: !llvm.ptr, %arg16: !llvm.ptr, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64, %arg21: i64, %arg22: !llvm.ptr, %arg23: !llvm.ptr, %arg24: i64, %arg25: i64, %arg26: i64, %arg27: i64, %arg28: i64, %arg29: !llvm.ptr, %arg30: !llvm.ptr, %arg31: i64, %arg32: i64, %arg33: i64, %arg34: i64, %arg35: i64) attributes {__byre__kernel_name = "Unknown1", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32, 3 : i32, 4 : i32, 5 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + llvm.func @Unknown0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: !llvm.ptr, %arg6: !llvm.ptr, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: !llvm.ptr, %arg11: !llvm.ptr, %arg12: i64, %arg13: i64, %arg14: i64, %arg15: !llvm.ptr, %arg16: !llvm.ptr, %arg17: i64, %arg18: i64, %arg19: i64, %arg20: i64, %arg21: i64, %arg22: !llvm.ptr, %arg23: !llvm.ptr, %arg24: i64, %arg25: i64, %arg26: i64, %arg27: i64, %arg28: i64, %arg29: !llvm.ptr, %arg30: !llvm.ptr, %arg31: i64, %arg32: i64, %arg33: i64, %arg34: i64, %arg35: i64) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32, 3 : i32, 4 : i32, 5 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { %0 = llvm.mlir.constant(128 : index) : i64 %1 = llvm.mlir.constant(0 : index) : i64 %2 = llvm.mlir.constant(1 : index) : i64 @@ -52,7 +52,7 @@ module attributes {byre.container_module, llvm.data_layout = ""} { ^bb6: // pred: ^bb4 llvm.return } - llvm.func @_mlir_ciface_Unknown1(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr, %arg4: !llvm.ptr, %arg5: !llvm.ptr) attributes {__byre__kernel_name = "Unknown1", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32, 3 : i32, 4 : i32, 5 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + llvm.func @_mlir_ciface_Unknown0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr, %arg4: !llvm.ptr, %arg5: !llvm.ptr) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32, 3 : i32, 4 : i32, 5 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -95,7 +95,7 @@ module attributes {byre.container_module, llvm.data_layout = ""} { %39 = llvm.extractvalue %34[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %40 = llvm.extractvalue %34[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %41 = llvm.extractvalue %34[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - llvm.call @Unknown1(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11, %13, %14, %15, %16, %17, %19, %20, %21, %22, %23, %24, %25, %27, %28, %29, %30, %31, %32, %33, %35, %36, %37, %38, %39, %40, %41) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> () + llvm.call @Unknown0(%1, %2, %3, %4, %5, %7, %8, %9, %10, %11, %13, %14, %15, %16, %17, %19, %20, %21, %22, %23, %24, %25, %27, %28, %29, %30, %31, %32, %33, %35, %36, %37, %38, %39, %40, %41) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> () llvm.return } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case0/Output.ll b/compiler/test/Pipelines/Host/E2E/Case0/Output.ll index a45d49585..bd069d692 100644 --- a/compiler/test/Pipelines/Host/E2E/Case0/Output.ll +++ b/compiler/test/Pipelines/Host/E2E/Case0/Output.ll @@ -7,7 +7,7 @@ declare ptr @malloc(i64) declare void @free(ptr) -define void @Unknown1(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, ptr %5, ptr %6, i64 %7, i64 %8, i64 %9, ptr %10, ptr %11, i64 %12, i64 %13, i64 %14, ptr %15, ptr %16, i64 %17, i64 %18, i64 %19, i64 %20, i64 %21, ptr %22, ptr %23, i64 %24, i64 %25, i64 %26, i64 %27, i64 %28, ptr %29, ptr %30, i64 %31, i64 %32, i64 %33, i64 %34, i64 %35) { +define void @Unknown0(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, ptr %5, ptr %6, i64 %7, i64 %8, i64 %9, ptr %10, ptr %11, i64 %12, i64 %13, i64 %14, ptr %15, ptr %16, i64 %17, i64 %18, i64 %19, i64 %20, i64 %21, ptr %22, ptr %23, i64 %24, i64 %25, i64 %26, i64 %27, i64 %28, ptr %29, ptr %30, i64 %31, i64 %32, i64 %33, i64 %34, i64 %35) { br label %37 37: ; preds = %40, %36 @@ -56,7 +56,7 @@ define void @Unknown1(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, ptr %5, ptr %6, i6 ret void } -define void @_mlir_ciface_Unknown1(ptr %0, ptr %1, ptr %2, ptr %3, ptr %4, ptr %5) { +define void @_mlir_ciface_Unknown0(ptr %0, ptr %1, ptr %2, ptr %3, ptr %4, ptr %5) { %7 = load { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %0, align 8 %8 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %7, 0 %9 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %7, 1 @@ -99,7 +99,7 @@ define void @_mlir_ciface_Unknown1(ptr %0, ptr %1, ptr %2, ptr %3, ptr %4, ptr % %46 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %41, 3, 1 %47 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %41, 4, 0 %48 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %41, 4, 1 - call void @Unknown1(ptr %8, ptr %9, i64 %10, i64 %11, i64 %12, ptr %14, ptr %15, i64 %16, i64 %17, i64 %18, ptr %20, ptr %21, i64 %22, i64 %23, i64 %24, ptr %26, ptr %27, i64 %28, i64 %29, i64 %30, i64 %31, i64 %32, ptr %34, ptr %35, i64 %36, i64 %37, i64 %38, i64 %39, i64 %40, ptr %42, ptr %43, i64 %44, i64 %45, i64 %46, i64 %47, i64 %48) + call void @Unknown0(ptr %8, ptr %9, i64 %10, i64 %11, i64 %12, ptr %14, ptr %15, i64 %16, i64 %17, i64 %18, ptr %20, ptr %21, i64 %22, i64 %23, i64 %24, ptr %26, ptr %27, i64 %28, i64 %29, i64 %30, i64 %31, i64 %32, ptr %34, ptr %35, i64 %36, i64 %37, i64 %38, i64 %39, i64 %40, ptr %42, ptr %43, i64 %44, i64 %45, i64 %46, i64 %47, i64 %48) ret void } diff --git a/compiler/test/Pipelines/Host/E2E/Case0/Output.mlir b/compiler/test/Pipelines/Host/E2E/Case0/Output.mlir index b9012bd6d..ea67c5cac 100644 --- a/compiler/test/Pipelines/Host/E2E/Case0/Output.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case0/Output.mlir @@ -4,7 +4,7 @@ module attributes {byre.container_module} { func.func @main(%arg0: memref<1xi64, "cpu"> {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg1: memref<1xi64, "cpu"> {byre.argname = "Input1", byre.argtype = 1 : i32}, %arg2: memref<1xi64, "cpu"> {byre.argname = "Input2", byre.argtype = 1 : i32}, %arg3: memref<1x128xi32, "cpu"> {byre.argname = "Input3", byre.argtype = 1 : i32}, %arg4: memref<1x128xi32, "cpu"> {byre.argname = "Output0", byre.argtype = 2 : i32}, %arg5: memref<1x128xi32, "cpu"> {byre.argname = "Output1", byre.argtype = 2 : i32}) attributes {byre.entry_point, device_file_name = "your_file"} { - byre.compute @LLVMJITOp(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {device = "cpu", kernel_name = "Unknown1", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 1 : i32, 1 : i32, 2 : i32, 2 : i32]} : memref<1xi64, "cpu">, memref<1xi64, "cpu">, memref<1xi64, "cpu">, memref<1x128xi32, "cpu">, memref<1x128xi32, "cpu">, memref<1x128xi32, "cpu"> + byre.compute @LLVMJITOp(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {device = "cpu", kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 1 : i32, 1 : i32, 2 : i32, 2 : i32]} : memref<1xi64, "cpu">, memref<1xi64, "cpu">, memref<1xi64, "cpu">, memref<1x128xi32, "cpu">, memref<1x128xi32, "cpu">, memref<1x128xi32, "cpu"> return } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case0/TotalPipeline.mlir b/compiler/test/Pipelines/Host/E2E/Case0/TotalPipeline.mlir index 5a4d601c7..a01a117e1 100644 --- a/compiler/test/Pipelines/Host/E2E/Case0/TotalPipeline.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case0/TotalPipeline.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byteir-bufferize-opt --scf-opt="target=CPU" --host-opt --byre-opt --to-llvm | byteir-translate --mlir-to-llvmir | FileCheck %s +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" --host-opt --byre-opt --to-llvm | byteir-translate --mlir-to-llvmir | FileCheck %s // CHECK-LABEL: constant // CHECK-LABEL: define void @_mlir_ciface_Unknown diff --git a/compiler/test/Pipelines/Host/E2E/Case0/template.py b/compiler/test/Pipelines/Host/E2E/Case0/template.py index 37ce99ea6..c80e6c08f 100644 --- a/compiler/test/Pipelines/Host/E2E/Case0/template.py +++ b/compiler/test/Pipelines/Host/E2E/Case0/template.py @@ -18,7 +18,7 @@ """), HostOptPipeline(r""" // CHECK-LABEL: memref.global "private" -// CHECK-LABEL: func.func @Unknown1( +// CHECK-LABEL: func.func @Unknown0( // CHECK-SAME: %[[ARG0:.*]]: memref<1xi64>, %[[ARG1:.*]]: memref<1xi64>, %[[ARG2:.*]]: memref<1xi64>, %[[ARG3:.*]]: memref<1x128xi32>, %[[RES0:.*]]: memref<1x128xi32>, %[[RES1:.*]]: memref<1x128xi32>) // CHECK-LABEL: func.func @main """), diff --git a/compiler/test/Pipelines/Host/E2E/Case1/00_Input.mlir b/compiler/test/Pipelines/Host/E2E/Case1/00_Input.mlir index 5f5d7f843..5529af945 100644 --- a/compiler/test/Pipelines/Host/E2E/Case1/00_Input.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case1/00_Input.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s // CHECK-LABEL: func.func @main diff --git a/compiler/test/Pipelines/Host/E2E/Case1/01_HostOpt.mlir b/compiler/test/Pipelines/Host/E2E/Case1/01_HostOpt.mlir index 949999818..27efe0e37 100644 --- a/compiler/test/Pipelines/Host/E2E/Case1/01_HostOpt.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case1/01_HostOpt.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: memref.global "private" // CHECK-LABEL: memref.global "private" -// CHECK-LABEL: func.func @Unknown6( +// CHECK-LABEL: func.func @Unknown0( // CHECK-SAME: %[[ARG0:.*]]: memref<1x100x27x48x3xf32>, %[[ARG1:.*]]: memref<51200xi32>) // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x100x27x48x3xi32> // CHECK: %[[ALLOC0:.*]] = memref.alloc() : memref<100x1296x1xi32> @@ -12,7 +12,7 @@ module { memref.global "private" constant @__constant_100xi32 : memref<100xi32> = dense<[0, 512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192, 8704, 9216, 9728, 10240, 10752, 11264, 11776, 12288, 12800, 13312, 13824, 14336, 14848, 15360, 15872, 16384, 16896, 17408, 17920, 18432, 18944, 19456, 19968, 20480, 20992, 21504, 22016, 22528, 23040, 23552, 24064, 24576, 25088, 25600, 26112, 26624, 27136, 27648, 28160, 28672, 29184, 29696, 30208, 30720, 31232, 31744, 32256, 32768, 33280, 33792, 34304, 34816, 35328, 35840, 36352, 36864, 37376, 37888, 38400, 38912, 39424, 39936, 40448, 40960, 41472, 41984, 42496, 43008, 43520, 44032, 44544, 45056, 45568, 46080, 46592, 47104, 47616, 48128, 48640, 49152, 49664, 50176, 50688]> memref.global "private" constant @__constant_100x1296xi32 : memref<100x1296xi32> = dense<1> - func.func private @Unknown6(%arg0: memref<1x100x27x48x3xf32>) -> memref<51200xi32> attributes {__byteir_hlo_aggressive_fusion__} { + func.func private @Unknown0(%arg0: memref<1x100x27x48x3xf32>) -> memref<51200xi32> attributes {__byteir_hlo_aggressive_fusion__} { %c0_i32 = arith.constant 0 : i32 %c5_i32 = arith.constant 5 : i32 %c3_i32 = arith.constant 3 : i32 @@ -71,8 +71,8 @@ module { } return %alloc_5 : memref<51200xi32> } - func.func @main(%arg0: memref<1x100x27x48x3xf32>) -> memref<51200xi32> { - %0 = call @Unknown6(%arg0) : (memref<1x100x27x48x3xf32>) -> memref<51200xi32> + func.func @main(%arg0: memref<1x100x27x48x3xf32>) -> memref<51200xi32> attributes {__placeholder__byre.entry_point} { + %0 = call @Unknown0(%arg0) : (memref<1x100x27x48x3xf32>) -> memref<51200xi32> return %0 : memref<51200xi32> } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case1/02a_ByreHost.mlir b/compiler/test/Pipelines/Host/E2E/Case1/02a_ByreHost.mlir index 650c67e0c..8c83bdc8a 100644 --- a/compiler/test/Pipelines/Host/E2E/Case1/02a_ByreHost.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case1/02a_ByreHost.mlir @@ -6,7 +6,7 @@ module attributes {byre.container_module} { module attributes {byteir.llvm_module} { memref.global "private" constant @__constant_100xi32 : memref<100xi32> = dense<[0, 512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192, 8704, 9216, 9728, 10240, 10752, 11264, 11776, 12288, 12800, 13312, 13824, 14336, 14848, 15360, 15872, 16384, 16896, 17408, 17920, 18432, 18944, 19456, 19968, 20480, 20992, 21504, 22016, 22528, 23040, 23552, 24064, 24576, 25088, 25600, 26112, 26624, 27136, 27648, 28160, 28672, 29184, 29696, 30208, 30720, 31232, 31744, 32256, 32768, 33280, 33792, 34304, 34816, 35328, 35840, 36352, 36864, 37376, 37888, 38400, 38912, 39424, 39936, 40448, 40960, 41472, 41984, 42496, 43008, 43520, 44032, 44544, 45056, 45568, 46080, 46592, 47104, 47616, 48128, 48640, 49152, 49664, 50176, 50688]> memref.global "private" constant @__constant_100x1296xi32 : memref<100x1296xi32> = dense<1> - func.func @Unknown6(%arg0: memref<1x100x27x48x3xf32>, %arg1: memref<51200xi32>) attributes {__byre__kernel_name = "Unknown6", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + func.func @Unknown0(%arg0: memref<1x100x27x48x3xf32>, %arg1: memref<51200xi32>) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { %c0_i32 = arith.constant 0 : i32 %c5_i32 = arith.constant 5 : i32 %c3_i32 = arith.constant 3 : i32 @@ -68,7 +68,7 @@ module attributes {byre.container_module} { } } func.func @main(%arg0: memref<1x100x27x48x3xf32> {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg1: memref<51200xi32> {byre.argname = "Output0", byre.argtype = 2 : i32}) attributes {byre.entry_point} { - byre.compute @LLVMJITOp(%arg0, %arg1) {kernel_name = "Unknown6", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 2 : i32]} : memref<1x100x27x48x3xf32>, memref<51200xi32> + byre.compute @LLVMJITOp(%arg0, %arg1) {kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 2 : i32]} : memref<1x100x27x48x3xf32>, memref<51200xi32> return } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case1/02b_ToLLVM.mlir b/compiler/test/Pipelines/Host/E2E/Case1/02b_ToLLVM.mlir index 54b0fa382..fb9aae3f1 100644 --- a/compiler/test/Pipelines/Host/E2E/Case1/02b_ToLLVM.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case1/02b_ToLLVM.mlir @@ -7,7 +7,7 @@ module attributes {byre.container_module} { module attributes {byteir.llvm_module} { memref.global "private" constant @__constant_100xi32 : memref<100xi32> = dense<[0, 512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192, 8704, 9216, 9728, 10240, 10752, 11264, 11776, 12288, 12800, 13312, 13824, 14336, 14848, 15360, 15872, 16384, 16896, 17408, 17920, 18432, 18944, 19456, 19968, 20480, 20992, 21504, 22016, 22528, 23040, 23552, 24064, 24576, 25088, 25600, 26112, 26624, 27136, 27648, 28160, 28672, 29184, 29696, 30208, 30720, 31232, 31744, 32256, 32768, 33280, 33792, 34304, 34816, 35328, 35840, 36352, 36864, 37376, 37888, 38400, 38912, 39424, 39936, 40448, 40960, 41472, 41984, 42496, 43008, 43520, 44032, 44544, 45056, 45568, 46080, 46592, 47104, 47616, 48128, 48640, 49152, 49664, 50176, 50688]> memref.global "private" constant @__constant_100x1296xi32 : memref<100x1296xi32> = dense<1> - func.func @Unknown6(%arg0: memref<1x100x27x48x3xf32>, %arg1: memref<51200xi32>) attributes {__byre__kernel_name = "Unknown6", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + func.func @Unknown0(%arg0: memref<1x100x27x48x3xf32>, %arg1: memref<51200xi32>) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { %c0_i32 = arith.constant 0 : i32 %c5_i32 = arith.constant 5 : i32 %c3_i32 = arith.constant 3 : i32 @@ -69,7 +69,7 @@ module attributes {byre.container_module} { } } func.func @main(%arg0: memref<1x100x27x48x3xf32> {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg1: memref<51200xi32> {byre.argname = "Output0", byre.argtype = 2 : i32}) attributes {byre.entry_point} { - byre.compute @LLVMJITOp(%arg0, %arg1) {kernel_name = "Unknown6", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 2 : i32]} : memref<1x100x27x48x3xf32>, memref<51200xi32> + byre.compute @LLVMJITOp(%arg0, %arg1) {kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 2 : i32]} : memref<1x100x27x48x3xf32>, memref<51200xi32> return } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case1/03b_ToLLVMIR.mlir b/compiler/test/Pipelines/Host/E2E/Case1/03b_ToLLVMIR.mlir index a416f6dd7..81904b677 100644 --- a/compiler/test/Pipelines/Host/E2E/Case1/03b_ToLLVMIR.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case1/03b_ToLLVMIR.mlir @@ -8,128 +8,127 @@ module attributes {byre.container_module, llvm.data_layout = ""} { llvm.func @malloc(i64) -> !llvm.ptr llvm.mlir.global private constant @__constant_100xi32(dense<[0, 512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192, 8704, 9216, 9728, 10240, 10752, 11264, 11776, 12288, 12800, 13312, 13824, 14336, 14848, 15360, 15872, 16384, 16896, 17408, 17920, 18432, 18944, 19456, 19968, 20480, 20992, 21504, 22016, 22528, 23040, 23552, 24064, 24576, 25088, 25600, 26112, 26624, 27136, 27648, 28160, 28672, 29184, 29696, 30208, 30720, 31232, 31744, 32256, 32768, 33280, 33792, 34304, 34816, 35328, 35840, 36352, 36864, 37376, 37888, 38400, 38912, 39424, 39936, 40448, 40960, 41472, 41984, 42496, 43008, 43520, 44032, 44544, 45056, 45568, 46080, 46592, 47104, 47616, 48128, 48640, 49152, 49664, 50176, 50688]> : tensor<100xi32>) {addr_space = 0 : i32} : !llvm.array<100 x i32> llvm.mlir.global private constant @__constant_100x1296xi32(dense<1> : tensor<100x1296xi32>) {addr_space = 0 : i32} : !llvm.array<100 x array<1296 x i32>> - llvm.func @Unknown6(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: !llvm.ptr, %arg14: !llvm.ptr, %arg15: i64, %arg16: i64, %arg17: i64) attributes {__byre__kernel_name = "Unknown6", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { - %0 = llvm.mlir.constant(2 : index) : i64 - %1 = llvm.mlir.constant(3888 : index) : i64 - %2 = llvm.mlir.constant(3 : index) : i64 - %3 = llvm.mlir.constant(129600 : index) : i64 - %4 = llvm.mlir.constant(0 : index) : i64 - %5 = llvm.mlir.constant(1 : index) : i64 - %6 = llvm.mlir.constant(388800 : index) : i64 - %7 = llvm.mlir.constant(1296 : index) : i64 - %8 = llvm.mlir.constant(51200 : index) : i64 - %9 = llvm.mlir.constant(6 : i32) : i32 - %10 = llvm.mlir.constant(3 : i32) : i32 - %11 = llvm.mlir.constant(5 : i32) : i32 - %12 = llvm.mlir.constant(0 : i32) : i32 - %13 = llvm.mlir.addressof @__constant_100x1296xi32 : !llvm.ptr - %14 = llvm.getelementptr %13[0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<100 x array<1296 x i32>> - %15 = llvm.mlir.addressof @__constant_100xi32 : !llvm.ptr - %16 = llvm.getelementptr %15[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<100 x i32> - %17 = llvm.mlir.null : !llvm.ptr - %18 = llvm.getelementptr %17[388800] : (!llvm.ptr) -> !llvm.ptr, i32 - %19 = llvm.ptrtoint %18 : !llvm.ptr to i64 - %20 = llvm.call @malloc(%19) : (i64) -> !llvm.ptr - llvm.br ^bb1(%4 : i64) - ^bb1(%21: i64): // 2 preds: ^bb0, ^bb2 - %22 = llvm.icmp "slt" %21, %6 : i64 - llvm.cond_br %22, ^bb2, ^bb3 + llvm.func @Unknown0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64, %arg13: !llvm.ptr, %arg14: !llvm.ptr, %arg15: i64, %arg16: i64, %arg17: i64) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + %0 = llvm.mlir.constant(3888 : index) : i64 + %1 = llvm.mlir.constant(3 : index) : i64 + %2 = llvm.mlir.constant(129600 : index) : i64 + %3 = llvm.mlir.constant(0 : index) : i64 + %4 = llvm.mlir.constant(1 : index) : i64 + %5 = llvm.mlir.constant(388800 : index) : i64 + %6 = llvm.mlir.constant(1296 : index) : i64 + %7 = llvm.mlir.constant(51200 : index) : i64 + %8 = llvm.mlir.constant(6 : i32) : i32 + %9 = llvm.mlir.constant(3 : i32) : i32 + %10 = llvm.mlir.constant(5 : i32) : i32 + %11 = llvm.mlir.constant(0 : i32) : i32 + %12 = llvm.mlir.addressof @__constant_100x1296xi32 : !llvm.ptr + %13 = llvm.getelementptr %12[0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<100 x array<1296 x i32>> + %14 = llvm.mlir.addressof @__constant_100xi32 : !llvm.ptr + %15 = llvm.getelementptr %14[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<100 x i32> + %16 = llvm.mlir.null : !llvm.ptr + %17 = llvm.getelementptr %16[388800] : (!llvm.ptr) -> !llvm.ptr, i32 + %18 = llvm.ptrtoint %17 : !llvm.ptr to i64 + %19 = llvm.call @malloc(%18) : (i64) -> !llvm.ptr + llvm.br ^bb1(%3 : i64) + ^bb1(%20: i64): // 2 preds: ^bb0, ^bb2 + %21 = llvm.icmp "slt" %20, %5 : i64 + llvm.cond_br %21, ^bb2, ^bb3 ^bb2: // pred: ^bb1 - %23 = llvm.getelementptr %arg1[%21] : (!llvm.ptr, i64) -> !llvm.ptr, f32 - %24 = llvm.load %23 : !llvm.ptr -> f32 - %25 = llvm.fptosi %24 : f32 to i32 - %26 = llvm.getelementptr %20[%21] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - llvm.store %25, %26 : i32, !llvm.ptr - %27 = llvm.add %21, %5 : i64 - llvm.br ^bb1(%27 : i64) + %22 = llvm.getelementptr %arg1[%20] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + %23 = llvm.load %22 : !llvm.ptr -> f32 + %24 = llvm.fptosi %23 : f32 to i32 + %25 = llvm.getelementptr %19[%20] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + llvm.store %24, %25 : i32, !llvm.ptr + %26 = llvm.add %20, %4 : i64 + llvm.br ^bb1(%26 : i64) ^bb3: // pred: ^bb1 - %28 = llvm.mlir.null : !llvm.ptr - %29 = llvm.getelementptr %28[129600] : (!llvm.ptr) -> !llvm.ptr, i32 - %30 = llvm.ptrtoint %29 : !llvm.ptr to i64 - %31 = llvm.call @malloc(%30) : (i64) -> !llvm.ptr - llvm.br ^bb4(%4 : i64) - ^bb4(%32: i64): // 2 preds: ^bb3, ^bb5 - %33 = llvm.icmp "slt" %32, %3 : i64 - llvm.cond_br %33, ^bb5, ^bb6 + %27 = llvm.mlir.null : !llvm.ptr + %28 = llvm.getelementptr %27[129600] : (!llvm.ptr) -> !llvm.ptr, i32 + %29 = llvm.ptrtoint %28 : !llvm.ptr to i64 + %30 = llvm.call @malloc(%29) : (i64) -> !llvm.ptr + llvm.br ^bb4(%3 : i64) + ^bb4(%31: i64): // 2 preds: ^bb3, ^bb5 + %32 = llvm.icmp "slt" %31, %2 : i64 + llvm.cond_br %32, ^bb5, ^bb6 ^bb5: // pred: ^bb4 - %34 = llvm.srem %32, %7 : i64 - %35 = llvm.sdiv %32, %7 : i64 - %36 = llvm.mul %35, %1 : i64 - %37 = llvm.add %36, %0 : i64 - %38 = llvm.mul %34, %2 : i64 - %39 = llvm.add %37, %38 : i64 - %40 = llvm.add %39, %4 : i64 - %41 = llvm.getelementptr %20[%40] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - %42 = llvm.load %41 : !llvm.ptr -> i32 - %43 = llvm.mul %35, %1 : i64 - %44 = llvm.mul %34, %2 : i64 - %45 = llvm.add %43, %44 : i64 - %46 = llvm.add %45, %4 : i64 - %47 = llvm.getelementptr %20[%46] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - %48 = llvm.load %47 : !llvm.ptr -> i32 - %49 = llvm.mul %35, %1 : i64 - %50 = llvm.add %49, %5 : i64 - %51 = llvm.mul %34, %2 : i64 - %52 = llvm.add %50, %51 : i64 - %53 = llvm.add %52, %4 : i64 - %54 = llvm.getelementptr %20[%53] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - %55 = llvm.load %54 : !llvm.ptr -> i32 - %56 = llvm.getelementptr %16[%35] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - %57 = llvm.load %56 : !llvm.ptr -> i32 - %58 = llvm.ashr %55, %11 : i32 - %59 = llvm.shl %58, %10 : i32 - %60 = llvm.ashr %48, %11 : i32 - %61 = llvm.shl %60, %9 : i32 - %62 = llvm.add %61, %59 : i32 - %63 = llvm.ashr %42, %11 : i32 - %64 = llvm.add %63, %62 : i32 - %65 = llvm.add %64, %57 : i32 - %66 = llvm.mul %35, %7 : i64 - %67 = llvm.add %66, %34 : i64 - %68 = llvm.add %67, %4 : i64 - %69 = llvm.getelementptr %31[%68] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - llvm.store %65, %69 : i32, !llvm.ptr - %70 = llvm.add %32, %5 : i64 - llvm.br ^bb4(%70 : i64) + %33 = llvm.srem %31, %6 : i64 + %34 = llvm.sdiv %31, %6 : i64 + %35 = llvm.getelementptr %19[2] : (!llvm.ptr) -> !llvm.ptr, i32 + %36 = llvm.mul %34, %0 : i64 + %37 = llvm.mul %33, %1 : i64 + %38 = llvm.add %36, %37 : i64 + %39 = llvm.add %38, %3 : i64 + %40 = llvm.getelementptr %35[%39] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + %41 = llvm.load %40 : !llvm.ptr -> i32 + %42 = llvm.mul %34, %0 : i64 + %43 = llvm.mul %33, %1 : i64 + %44 = llvm.add %42, %43 : i64 + %45 = llvm.add %44, %3 : i64 + %46 = llvm.getelementptr %19[%45] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + %47 = llvm.load %46 : !llvm.ptr -> i32 + %48 = llvm.getelementptr %19[1] : (!llvm.ptr) -> !llvm.ptr, i32 + %49 = llvm.mul %34, %0 : i64 + %50 = llvm.mul %33, %1 : i64 + %51 = llvm.add %49, %50 : i64 + %52 = llvm.add %51, %3 : i64 + %53 = llvm.getelementptr %48[%52] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + %54 = llvm.load %53 : !llvm.ptr -> i32 + %55 = llvm.getelementptr %15[%34] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + %56 = llvm.load %55 : !llvm.ptr -> i32 + %57 = llvm.ashr %54, %10 : i32 + %58 = llvm.shl %57, %9 : i32 + %59 = llvm.ashr %47, %10 : i32 + %60 = llvm.shl %59, %8 : i32 + %61 = llvm.add %60, %58 : i32 + %62 = llvm.ashr %41, %10 : i32 + %63 = llvm.add %62, %61 : i32 + %64 = llvm.add %63, %56 : i32 + %65 = llvm.mul %34, %6 : i64 + %66 = llvm.add %65, %33 : i64 + %67 = llvm.add %66, %3 : i64 + %68 = llvm.getelementptr %30[%67] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + llvm.store %64, %68 : i32, !llvm.ptr + %69 = llvm.add %31, %4 : i64 + llvm.br ^bb4(%69 : i64) ^bb6: // pred: ^bb4 - llvm.call @free(%20) : (!llvm.ptr) -> () - llvm.br ^bb7(%4 : i64) - ^bb7(%71: i64): // 2 preds: ^bb6, ^bb8 - %72 = llvm.icmp "slt" %71, %8 : i64 - llvm.cond_br %72, ^bb8, ^bb9(%4 : i64) + llvm.call @free(%19) : (!llvm.ptr) -> () + llvm.br ^bb7(%3 : i64) + ^bb7(%70: i64): // 2 preds: ^bb6, ^bb8 + %71 = llvm.icmp "slt" %70, %7 : i64 + llvm.cond_br %71, ^bb8, ^bb9(%3 : i64) ^bb8: // pred: ^bb7 - %73 = llvm.getelementptr %arg14[%71] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - llvm.store %12, %73 : i32, !llvm.ptr - %74 = llvm.add %71, %5 : i64 - llvm.br ^bb7(%74 : i64) - ^bb9(%75: i64): // 2 preds: ^bb7, ^bb10 - %76 = llvm.icmp "slt" %75, %3 : i64 - llvm.cond_br %76, ^bb10, ^bb11 + %72 = llvm.getelementptr %arg14[%70] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + llvm.store %11, %72 : i32, !llvm.ptr + %73 = llvm.add %70, %4 : i64 + llvm.br ^bb7(%73 : i64) + ^bb9(%74: i64): // 2 preds: ^bb7, ^bb10 + %75 = llvm.icmp "slt" %74, %2 : i64 + llvm.cond_br %75, ^bb10, ^bb11 ^bb10: // pred: ^bb9 - %77 = llvm.srem %75, %7 : i64 - %78 = llvm.sdiv %75, %7 : i64 - %79 = llvm.mul %78, %7 : i64 - %80 = llvm.add %79, %77 : i64 - %81 = llvm.add %80, %4 : i64 - %82 = llvm.getelementptr %31[%81] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - %83 = llvm.load %82 : !llvm.ptr -> i32 - %84 = llvm.sext %83 : i32 to i64 - %85 = llvm.getelementptr %arg14[%84] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - %86 = llvm.load %85 : !llvm.ptr -> i32 - %87 = llvm.mul %78, %7 : i64 - %88 = llvm.add %87, %77 : i64 - %89 = llvm.getelementptr %14[%88] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - %90 = llvm.load %89 : !llvm.ptr -> i32 - %91 = llvm.add %86, %90 : i32 - %92 = llvm.getelementptr %arg14[%84] : (!llvm.ptr, i64) -> !llvm.ptr, i32 - llvm.store %91, %92 : i32, !llvm.ptr - %93 = llvm.add %75, %5 : i64 - llvm.br ^bb9(%93 : i64) + %76 = llvm.srem %74, %6 : i64 + %77 = llvm.sdiv %74, %6 : i64 + %78 = llvm.mul %77, %6 : i64 + %79 = llvm.add %78, %76 : i64 + %80 = llvm.add %79, %3 : i64 + %81 = llvm.getelementptr %30[%80] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + %82 = llvm.load %81 : !llvm.ptr -> i32 + %83 = llvm.sext %82 : i32 to i64 + %84 = llvm.getelementptr %arg14[%83] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + %85 = llvm.load %84 : !llvm.ptr -> i32 + %86 = llvm.mul %77, %6 : i64 + %87 = llvm.add %86, %76 : i64 + %88 = llvm.getelementptr %13[%87] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + %89 = llvm.load %88 : !llvm.ptr -> i32 + %90 = llvm.add %85, %89 : i32 + %91 = llvm.getelementptr %arg14[%83] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + llvm.store %90, %91 : i32, !llvm.ptr + %92 = llvm.add %74, %4 : i64 + llvm.br ^bb9(%92 : i64) ^bb11: // pred: ^bb9 - llvm.call @free(%31) : (!llvm.ptr) -> () + llvm.call @free(%30) : (!llvm.ptr) -> () llvm.return } - llvm.func @_mlir_ciface_Unknown6(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {__byre__kernel_name = "Unknown6", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + llvm.func @_mlir_ciface_Unknown0(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> @@ -150,7 +149,7 @@ module attributes {byre.container_module, llvm.data_layout = ""} { %17 = llvm.extractvalue %14[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %18 = llvm.extractvalue %14[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %19 = llvm.extractvalue %14[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - llvm.call @Unknown6(%1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %15, %16, %17, %18, %19) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> () + llvm.call @Unknown0(%1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %15, %16, %17, %18, %19) : (!llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64) -> () llvm.return } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case1/Output.ll b/compiler/test/Pipelines/Host/E2E/Case1/Output.ll index 0589196e2..ce008f887 100644 --- a/compiler/test/Pipelines/Host/E2E/Case1/Output.ll +++ b/compiler/test/Pipelines/Host/E2E/Case1/Output.ll @@ -8,7 +8,7 @@ declare ptr @malloc(i64) declare void @free(ptr) -define void @Unknown6(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10, i64 %11, i64 %12, ptr %13, ptr %14, i64 %15, i64 %16, i64 %17) { +define void @Unknown0(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10, i64 %11, i64 %12, ptr %13, ptr %14, i64 %15, i64 %16, i64 %17) { %19 = call ptr @malloc(i64 ptrtoint (ptr getelementptr (i32, ptr null, i32 388800) to i64)) br label %20 @@ -38,12 +38,12 @@ define void @Unknown6(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i6 34: ; preds = %31 %35 = srem i64 %32, 1296 %36 = sdiv i64 %32, 1296 - %37 = mul i64 %36, 3888 - %38 = add i64 %37, 2 + %37 = getelementptr i32, ptr %19, i32 2 + %38 = mul i64 %36, 3888 %39 = mul i64 %35, 3 %40 = add i64 %38, %39 %41 = add i64 %40, 0 - %42 = getelementptr i32, ptr %19, i64 %41 + %42 = getelementptr i32, ptr %37, i64 %41 %43 = load i32, ptr %42, align 4 %44 = mul i64 %36, 3888 %45 = mul i64 %35, 3 @@ -51,12 +51,12 @@ define void @Unknown6(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i6 %47 = add i64 %46, 0 %48 = getelementptr i32, ptr %19, i64 %47 %49 = load i32, ptr %48, align 4 - %50 = mul i64 %36, 3888 - %51 = add i64 %50, 1 + %50 = getelementptr i32, ptr %19, i32 1 + %51 = mul i64 %36, 3888 %52 = mul i64 %35, 3 %53 = add i64 %51, %52 %54 = add i64 %53, 0 - %55 = getelementptr i32, ptr %19, i64 %54 + %55 = getelementptr i32, ptr %50, i64 %54 %56 = load i32, ptr %55, align 4 %57 = getelementptr i32, ptr @__constant_100xi32, i64 %36 %58 = load i32, ptr %57, align 4 @@ -122,7 +122,7 @@ define void @Unknown6(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i6 ret void } -define void @_mlir_ciface_Unknown6(ptr %0, ptr %1) { +define void @_mlir_ciface_Unknown0(ptr %0, ptr %1) { %3 = load { ptr, ptr, i64, [5 x i64], [5 x i64] }, ptr %0, align 8 %4 = extractvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %3, 0 %5 = extractvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %3, 1 @@ -143,7 +143,7 @@ define void @_mlir_ciface_Unknown6(ptr %0, ptr %1) { %20 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %17, 2 %21 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %17, 3, 0 %22 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %17, 4, 0 - call void @Unknown6(ptr %4, ptr %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10, i64 %11, i64 %12, i64 %13, i64 %14, i64 %15, i64 %16, ptr %18, ptr %19, i64 %20, i64 %21, i64 %22) + call void @Unknown0(ptr %4, ptr %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10, i64 %11, i64 %12, i64 %13, i64 %14, i64 %15, i64 %16, ptr %18, ptr %19, i64 %20, i64 %21, i64 %22) ret void } diff --git a/compiler/test/Pipelines/Host/E2E/Case1/Output.mlir b/compiler/test/Pipelines/Host/E2E/Case1/Output.mlir index d91bfca63..38dbeedd7 100644 --- a/compiler/test/Pipelines/Host/E2E/Case1/Output.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case1/Output.mlir @@ -4,7 +4,7 @@ module attributes {byre.container_module} { func.func @main(%arg0: memref<1x100x27x48x3xf32, "cpu"> {byre.argname = "Input0", byre.argtype = 1 : i32}, %arg1: memref<51200xi32, "cpu"> {byre.argname = "Output0", byre.argtype = 2 : i32}) attributes {byre.entry_point, device_file_name = "your_file"} { - byre.compute @LLVMJITOp(%arg0, %arg1) {device = "cpu", kernel_name = "Unknown6", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 2 : i32]} : memref<1x100x27x48x3xf32, "cpu">, memref<51200xi32, "cpu"> + byre.compute @LLVMJITOp(%arg0, %arg1) {device = "cpu", kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 2 : i32]} : memref<1x100x27x48x3xf32, "cpu">, memref<51200xi32, "cpu"> return } } \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Case1/TotalPipeline.mlir b/compiler/test/Pipelines/Host/E2E/Case1/TotalPipeline.mlir index 78615df68..b8ebf7bce 100644 --- a/compiler/test/Pipelines/Host/E2E/Case1/TotalPipeline.mlir +++ b/compiler/test/Pipelines/Host/E2E/Case1/TotalPipeline.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byteir-bufferize-opt --scf-opt="target=CPU" --host-opt --byre-opt --to-llvm | byteir-translate --mlir-to-llvmir | FileCheck %s +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" --host-opt --byre-opt --to-llvm | byteir-translate --mlir-to-llvmir | FileCheck %s // CHECK-LABEL: constant // CHECK-LABEL: define void @_mlir_ciface_Unknown diff --git a/compiler/test/Pipelines/Host/E2E/Case1/template.py b/compiler/test/Pipelines/Host/E2E/Case1/template.py index 3218aa315..2680bf973 100644 --- a/compiler/test/Pipelines/Host/E2E/Case1/template.py +++ b/compiler/test/Pipelines/Host/E2E/Case1/template.py @@ -42,7 +42,7 @@ HostOptPipeline(r""" // CHECK-LABEL: memref.global "private" // CHECK-LABEL: memref.global "private" -// CHECK-LABEL: func.func @Unknown6( +// CHECK-LABEL: func.func @Unknown0( // CHECK-SAME: %[[ARG0:.*]]: memref<1x100x27x48x3xf32>, %[[ARG1:.*]]: memref<51200xi32>) // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x100x27x48x3xi32> // CHECK: %[[ALLOC0:.*]] = memref.alloc() : memref<100x1296x1xi32> diff --git a/compiler/test/Pipelines/Host/E2E/RngUniform/00_Input.mlir b/compiler/test/Pipelines/Host/E2E/RngUniform/00_Input.mlir new file mode 100644 index 000000000..1ffdc49e0 --- /dev/null +++ b/compiler/test/Pipelines/Host/E2E/RngUniform/00_Input.mlir @@ -0,0 +1,11 @@ +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s + +// CHECK-LABEL: func.func @main + +func.func @main() -> tensor<1x97xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.constant dense<1.000000e+00> : tensor + %2 = mhlo.constant dense<[1, 97]> : tensor<2xi64> + %3 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor<1x97xf32> + return %3 : tensor<1x97xf32> +} \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/RngUniform/01_HostOpt.mlir b/compiler/test/Pipelines/Host/E2E/RngUniform/01_HostOpt.mlir new file mode 100644 index 000000000..03557f538 --- /dev/null +++ b/compiler/test/Pipelines/Host/E2E/RngUniform/01_HostOpt.mlir @@ -0,0 +1,42 @@ +// RUN: byteir-opt %s --host-opt --byre-opt | FileCheck %s + +// CHECK-LABEL: func.func @Unknown + +module { + func.func private @Unknown0(%arg0: memref, %arg1: memref) -> memref<1x97xf32> attributes {__byteir_hlo_aggressive_fusion__} { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 2.32830644E-10 : f32 + %c12345_i32 = arith.constant 12345 : i32 + %c1103515245_i32 = arith.constant 1103515245 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c97 = arith.constant 97 : index + %alloc = memref.alloc() : memref<1x97xf32> + scf.for %arg2 = %c0 to %c97 step %c1 { + %0 = memref.load %arg0[] : memref + %1 = memref.load %arg1[] : memref + %2 = arith.trunci %0 : i64 to i32 + %3 = arith.trunci %1 : i64 to i32 + %4 = arith.addi %2, %3 : i32 + %5 = arith.muli %4, %c1103515245_i32 : i32 + %6 = arith.addi %5, %c12345_i32 : i32 + %7 = arith.index_cast %arg2 : index to i32 + %8 = arith.addi %7, %6 : i32 + %9 = arith.muli %8, %c1103515245_i32 : i32 + %10 = arith.addi %9, %c12345_i32 : i32 + %11 = arith.uitofp %10 : i32 to f32 + %12 = arith.mulf %11, %cst_0 : f32 + %13 = arith.addf %12, %cst : f32 + memref.store %13, %alloc[%c0, %arg2] : memref<1x97xf32> + } + return %alloc : memref<1x97xf32> + } + func.func @main() -> memref<1x97xf32> attributes {__placeholder__byre.entry_point} { + %alloc = memref.alloc() : memref + byre.compute @GetSeed(%alloc) {memory_effects = [2 : i32]} : memref + %alloc_0 = memref.alloc() : memref + byre.compute @NextOffset(%alloc_0) {memory_effects = [2 : i32]} : memref + %0 = call @Unknown0(%alloc, %alloc_0) : (memref, memref) -> memref<1x97xf32> + return %0 : memref<1x97xf32> + } +} \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/RngUniform/02a_ByreHost.mlir b/compiler/test/Pipelines/Host/E2E/RngUniform/02a_ByreHost.mlir new file mode 100644 index 000000000..02126fa7e --- /dev/null +++ b/compiler/test/Pipelines/Host/E2E/RngUniform/02a_ByreHost.mlir @@ -0,0 +1,44 @@ +// RUN: byteir-opt %s -byre-host="device-file-name=your_file target=cpu" | FileCheck %s + +// CHECK-LABEL: func.func @main + +module attributes {byre.container_module} { + module attributes {byteir.llvm_module} { + func.func @Unknown0(%arg0: memref, %arg1: memref, %arg2: memref<1x97xf32>) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 2.32830644E-10 : f32 + %c12345_i32 = arith.constant 12345 : i32 + %c1103515245_i32 = arith.constant 1103515245 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c97 = arith.constant 97 : index + scf.for %arg3 = %c0 to %c97 step %c1 { + %0 = memref.load %arg0[] : memref + %1 = memref.load %arg1[] : memref + %2 = arith.trunci %0 : i64 to i32 + %3 = arith.trunci %1 : i64 to i32 + %4 = arith.addi %2, %3 : i32 + %5 = arith.muli %4, %c1103515245_i32 : i32 + %6 = arith.addi %5, %c12345_i32 : i32 + %7 = arith.index_cast %arg3 : index to i32 + %8 = arith.addi %7, %6 : i32 + %9 = arith.muli %8, %c1103515245_i32 : i32 + %10 = arith.addi %9, %c12345_i32 : i32 + %11 = arith.uitofp %10 : i32 to f32 + %12 = arith.mulf %11, %cst_0 : f32 + %13 = arith.addf %12, %cst : f32 + memref.store %13, %arg2[%c0, %arg3] : memref<1x97xf32> + } + return + } + } + func.func @main(%arg0: memref<1x97xf32> {byre.argname = "Output0", byre.argtype = 2 : i32}) attributes {byre.entry_point} { + %alloc = memref.alloc() : memref<256xi8> + %0 = "byre.alias"(%alloc) {offset = 0 : i64} : (memref<256xi8>) -> memref + byre.compute @GetSeed(%0) {memory_effects = [2 : i32]} : memref + %1 = "byre.alias"(%alloc) {offset = 128 : i64} : (memref<256xi8>) -> memref + byre.compute @NextOffset(%1) {memory_effects = [2 : i32]} : memref + byre.compute @LLVMJITOp(%0, %1, %arg0) {kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 2 : i32]} : memref, memref, memref<1x97xf32> + return + } +} \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/RngUniform/02b_ToLLVM.mlir b/compiler/test/Pipelines/Host/E2E/RngUniform/02b_ToLLVM.mlir new file mode 100644 index 000000000..4fb8711a7 --- /dev/null +++ b/compiler/test/Pipelines/Host/E2E/RngUniform/02b_ToLLVM.mlir @@ -0,0 +1,44 @@ +// RUN: byteir-opt %s --to-llvm | FileCheck %s + +// CHECK: llvm.func + +module attributes {byre.container_module} { + module attributes {byteir.llvm_module} { + func.func @Unknown0(%arg0: memref, %arg1: memref, %arg2: memref<1x97xf32>) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 2.32830644E-10 : f32 + %c12345_i32 = arith.constant 12345 : i32 + %c1103515245_i32 = arith.constant 1103515245 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c97 = arith.constant 97 : index + scf.for %arg3 = %c0 to %c97 step %c1 { + %0 = memref.load %arg0[] : memref + %1 = memref.load %arg1[] : memref + %2 = arith.trunci %0 : i64 to i32 + %3 = arith.trunci %1 : i64 to i32 + %4 = arith.addi %2, %3 : i32 + %5 = arith.muli %4, %c1103515245_i32 : i32 + %6 = arith.addi %5, %c12345_i32 : i32 + %7 = arith.index_cast %arg3 : index to i32 + %8 = arith.addi %7, %6 : i32 + %9 = arith.muli %8, %c1103515245_i32 : i32 + %10 = arith.addi %9, %c12345_i32 : i32 + %11 = arith.uitofp %10 : i32 to f32 + %12 = arith.mulf %11, %cst_0 : f32 + %13 = arith.addf %12, %cst : f32 + memref.store %13, %arg2[%c0, %arg3] : memref<1x97xf32> + } + return + } + } + func.func @main(%arg0: memref<1x97xf32> {byre.argname = "Output0", byre.argtype = 2 : i32}) attributes {byre.entry_point} { + %alloc = memref.alloc() : memref<256xi8> + %0 = "byre.alias"(%alloc) {offset = 0 : i64} : (memref<256xi8>) -> memref + byre.compute @GetSeed(%0) {memory_effects = [2 : i32]} : memref + %1 = "byre.alias"(%alloc) {offset = 128 : i64} : (memref<256xi8>) -> memref + byre.compute @NextOffset(%1) {memory_effects = [2 : i32]} : memref + byre.compute @LLVMJITOp(%0, %1, %arg0) {kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 2 : i32]} : memref, memref, memref<1x97xf32> + return + } +} \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/RngUniform/03b_ToLLVMIR.mlir b/compiler/test/Pipelines/Host/E2E/RngUniform/03b_ToLLVMIR.mlir new file mode 100644 index 000000000..a46de32bb --- /dev/null +++ b/compiler/test/Pipelines/Host/E2E/RngUniform/03b_ToLLVMIR.mlir @@ -0,0 +1,62 @@ +// RUN: byteir-translate %s --mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: define void @_mlir_ciface_Unknown + +module attributes {byre.container_module, llvm.data_layout = ""} { + llvm.func @Unknown0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: !llvm.ptr, %arg4: !llvm.ptr, %arg5: i64, %arg6: !llvm.ptr, %arg7: !llvm.ptr, %arg8: i64, %arg9: i64, %arg10: i64, %arg11: i64, %arg12: i64) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + %0 = llvm.mlir.constant(97 : index) : i64 + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.mlir.constant(1 : index) : i64 + %3 = llvm.mlir.constant(1103515245 : i32) : i32 + %4 = llvm.mlir.constant(12345 : i32) : i32 + %5 = llvm.mlir.constant(2.32830644E-10 : f32) : f32 + %6 = llvm.mlir.constant(0.000000e+00 : f32) : f32 + llvm.br ^bb1(%1 : i64) + ^bb1(%7: i64): // 2 preds: ^bb0, ^bb2 + %8 = llvm.icmp "slt" %7, %0 : i64 + llvm.cond_br %8, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %9 = llvm.load %arg1 : !llvm.ptr -> i64 + %10 = llvm.load %arg4 : !llvm.ptr -> i64 + %11 = llvm.trunc %9 : i64 to i32 + %12 = llvm.trunc %10 : i64 to i32 + %13 = llvm.add %11, %12 : i32 + %14 = llvm.mul %13, %3 : i32 + %15 = llvm.add %14, %4 : i32 + %16 = llvm.trunc %7 : i64 to i32 + %17 = llvm.add %16, %15 : i32 + %18 = llvm.mul %17, %3 : i32 + %19 = llvm.add %18, %4 : i32 + %20 = llvm.uitofp %19 : i32 to f32 + %21 = llvm.fmul %20, %5 : f32 + %22 = llvm.fadd %21, %6 : f32 + %23 = llvm.mul %1, %0 : i64 + %24 = llvm.add %23, %7 : i64 + %25 = llvm.getelementptr %arg7[%24] : (!llvm.ptr, i64) -> !llvm.ptr, f32 + llvm.store %22, %25 : f32, !llvm.ptr + %26 = llvm.add %7, %2 : i64 + llvm.br ^bb1(%26 : i64) + ^bb3: // pred: ^bb1 + llvm.return + } + llvm.func @_mlir_ciface_Unknown0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) attributes {__byre__kernel_name = "Unknown0", __byre__llvm_file_name = "host_kernels.ll", __byteir_hlo_aggressive_fusion__, arg_offsets = [0 : i32, 1 : i32, 2 : i32], byre_compute_name = "LLVMJITOp", byre_force_compute_name, llvm.emit_c_interface} { + %0 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)> + %1 = llvm.extractvalue %0[0] : !llvm.struct<(ptr, ptr, i64)> + %2 = llvm.extractvalue %0[1] : !llvm.struct<(ptr, ptr, i64)> + %3 = llvm.extractvalue %0[2] : !llvm.struct<(ptr, ptr, i64)> + %4 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)> + %5 = llvm.extractvalue %4[0] : !llvm.struct<(ptr, ptr, i64)> + %6 = llvm.extractvalue %4[1] : !llvm.struct<(ptr, ptr, i64)> + %7 = llvm.extractvalue %4[2] : !llvm.struct<(ptr, ptr, i64)> + %8 = llvm.load %arg2 : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %9 = llvm.extractvalue %8[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %10 = llvm.extractvalue %8[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %11 = llvm.extractvalue %8[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %12 = llvm.extractvalue %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %13 = llvm.extractvalue %8[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %14 = llvm.extractvalue %8[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %15 = llvm.extractvalue %8[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + llvm.call @Unknown0(%1, %2, %3, %5, %6, %7, %9, %10, %11, %12, %13, %14, %15) : (!llvm.ptr, !llvm.ptr, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64) -> () + llvm.return + } +} \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/RngUniform/Output.ll b/compiler/test/Pipelines/Host/E2E/RngUniform/Output.ll new file mode 100644 index 000000000..1a1f3f054 --- /dev/null +++ b/compiler/test/Pipelines/Host/E2E/RngUniform/Output.ll @@ -0,0 +1,64 @@ +; ModuleID = 'LLVMDialectModule' +source_filename = "LLVMDialectModule" + +declare ptr @malloc(i64) + +declare void @free(ptr) + +define void @Unknown0(ptr %0, ptr %1, i64 %2, ptr %3, ptr %4, i64 %5, ptr %6, ptr %7, i64 %8, i64 %9, i64 %10, i64 %11, i64 %12) { + br label %14 + +14: ; preds = %17, %13 + %15 = phi i64 [ %34, %17 ], [ 0, %13 ] + %16 = icmp slt i64 %15, 97 + br i1 %16, label %17, label %35 + +17: ; preds = %14 + %18 = load i64, ptr %1, align 4 + %19 = load i64, ptr %4, align 4 + %20 = trunc i64 %18 to i32 + %21 = trunc i64 %19 to i32 + %22 = add i32 %20, %21 + %23 = mul i32 %22, 1103515245 + %24 = add i32 %23, 12345 + %25 = trunc i64 %15 to i32 + %26 = add i32 %25, %24 + %27 = mul i32 %26, 1103515245 + %28 = add i32 %27, 12345 + %29 = uitofp i32 %28 to float + %30 = fmul float %29, 0x3DF0000000000000 + %31 = fadd float %30, 0.000000e+00 + %32 = add i64 0, %15 + %33 = getelementptr float, ptr %7, i64 %32 + store float %31, ptr %33, align 4 + %34 = add i64 %15, 1 + br label %14 + +35: ; preds = %14 + ret void +} + +define void @_mlir_ciface_Unknown0(ptr %0, ptr %1, ptr %2) { + %4 = load { ptr, ptr, i64 }, ptr %0, align 8 + %5 = extractvalue { ptr, ptr, i64 } %4, 0 + %6 = extractvalue { ptr, ptr, i64 } %4, 1 + %7 = extractvalue { ptr, ptr, i64 } %4, 2 + %8 = load { ptr, ptr, i64 }, ptr %1, align 8 + %9 = extractvalue { ptr, ptr, i64 } %8, 0 + %10 = extractvalue { ptr, ptr, i64 } %8, 1 + %11 = extractvalue { ptr, ptr, i64 } %8, 2 + %12 = load { ptr, ptr, i64, [2 x i64], [2 x i64] }, ptr %2, align 8 + %13 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, 0 + %14 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, 1 + %15 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, 2 + %16 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, 3, 0 + %17 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, 3, 1 + %18 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, 4, 0 + %19 = extractvalue { ptr, ptr, i64, [2 x i64], [2 x i64] } %12, 4, 1 + call void @Unknown0(ptr %5, ptr %6, i64 %7, ptr %9, ptr %10, i64 %11, ptr %13, ptr %14, i64 %15, i64 %16, i64 %17, i64 %18, i64 %19) + ret void +} + +!llvm.module.flags = !{!0} + +!0 = !{i32 2, !"Debug Info Version", i32 3} diff --git a/compiler/test/Pipelines/Host/E2E/RngUniform/Output.mlir b/compiler/test/Pipelines/Host/E2E/RngUniform/Output.mlir new file mode 100644 index 000000000..dd0b8e8b3 --- /dev/null +++ b/compiler/test/Pipelines/Host/E2E/RngUniform/Output.mlir @@ -0,0 +1,15 @@ +// RUN: byteir-opt %s | FileCheck %s + +// CHECK-LABEL: func.func @main + +module attributes {byre.container_module} { + func.func @main(%arg0: memref<1x97xf32, "cpu"> {byre.argname = "Output0", byre.argtype = 2 : i32}) attributes {byre.entry_point, device_file_name = "your_file"} { + %alloc = memref.alloc() : memref<256xi8, "cpu"> + %0 = "byre.alias"(%alloc) {device = "cpu", offset = 0 : i64} : (memref<256xi8, "cpu">) -> memref + byre.compute @GetSeed(%0) {device = "cpu", memory_effects = [2 : i32]} : memref + %1 = "byre.alias"(%alloc) {device = "cpu", offset = 128 : i64} : (memref<256xi8, "cpu">) -> memref + byre.compute @NextOffset(%1) {device = "cpu", memory_effects = [2 : i32]} : memref + byre.compute @LLVMJITOp(%0, %1, %arg0) {device = "cpu", kernel_name = "Unknown0", llvm_file_name = "host_kernels.ll", memory_effects = [1 : i32, 1 : i32, 2 : i32]} : memref, memref, memref<1x97xf32, "cpu"> + return + } +} \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/RngUniform/TotalPipeline.mlir b/compiler/test/Pipelines/Host/E2E/RngUniform/TotalPipeline.mlir new file mode 100644 index 000000000..2020e0f4e --- /dev/null +++ b/compiler/test/Pipelines/Host/E2E/RngUniform/TotalPipeline.mlir @@ -0,0 +1,11 @@ +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" --host-opt --byre-opt --to-llvm | byteir-translate --mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: define void @_mlir_ciface_Unknown + +func.func @main() -> tensor<1x97xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.constant dense<1.000000e+00> : tensor + %2 = mhlo.constant dense<[1, 97]> : tensor<2xi64> + %3 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor<1x97xf32> + return %3 : tensor<1x97xf32> +} \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/RngUniform/template.py b/compiler/test/Pipelines/Host/E2E/RngUniform/template.py new file mode 100644 index 000000000..46285dc87 --- /dev/null +++ b/compiler/test/Pipelines/Host/E2E/RngUniform/template.py @@ -0,0 +1,34 @@ +Testcase( + contents=[Content(stages=(Input, E2E), content=r""" +func.func @main() -> tensor<1x97xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = mhlo.constant dense<1.000000e+00> : tensor + %2 = mhlo.constant dense<[1, 97]> : tensor<2xi64> + %3 = "mhlo.rng"(%0, %1, %2) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor<1x97xf32> + return %3 : tensor<1x97xf32> +} + """)], + pipelines=[ + InputPipeline(r""" +// CHECK-LABEL: func.func @main +"""), + HostOptPipeline(r""" +// CHECK-LABEL: func.func @Unknown +"""), + ToLLVMPipeline(r""" +// CHECK: llvm.func +"""), + ToLLVMIRPipeline(r""" +// CHECK-LABEL: define void @_mlir_ciface_Unknown +"""), + ByreHostPipeline(r""" +// CHECK-LABEL: func.func @main +"""), + TotalPipeline(r""" +// CHECK-LABEL: define void @_mlir_ciface_Unknown +"""), + ByreOutPipeline(r""" +// CHECK-LABEL: func.func @main +"""), + ] +) \ No newline at end of file diff --git a/compiler/test/Pipelines/Host/E2E/Transpose/00_Input.mlir b/compiler/test/Pipelines/Host/E2E/Transpose/00_Input.mlir index b2b79419a..ee507240f 100644 --- a/compiler/test/Pipelines/Host/E2E/Transpose/00_Input.mlir +++ b/compiler/test/Pipelines/Host/E2E/Transpose/00_Input.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s // CHECK-LABEL: func.func @main diff --git a/compiler/test/Pipelines/Host/E2E/Transpose/01_HostOpt.mlir b/compiler/test/Pipelines/Host/E2E/Transpose/01_HostOpt.mlir index 846e40b4e..c3aaa04c1 100644 --- a/compiler/test/Pipelines/Host/E2E/Transpose/01_HostOpt.mlir +++ b/compiler/test/Pipelines/Host/E2E/Transpose/01_HostOpt.mlir @@ -24,7 +24,7 @@ module { } return %alloc : memref<1x64x64x32xf32> } - func.func @main(%arg0: memref<1x32x64x64xf32>) -> memref<1x64x64x32xf32> { + func.func @main(%arg0: memref<1x32x64x64xf32>) -> memref<1x64x64x32xf32> attributes {__placeholder__byre.entry_point} { %0 = call @Unknown0(%arg0) : (memref<1x32x64x64xf32>) -> memref<1x64x64x32xf32> return %0 : memref<1x64x64x32xf32> } diff --git a/compiler/test/Pipelines/Host/E2E/Transpose/TotalPipeline.mlir b/compiler/test/Pipelines/Host/E2E/Transpose/TotalPipeline.mlir index fc4b540a4..1eb3a7f4a 100644 --- a/compiler/test/Pipelines/Host/E2E/Transpose/TotalPipeline.mlir +++ b/compiler/test/Pipelines/Host/E2E/Transpose/TotalPipeline.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byteir-bufferize-opt --scf-opt="target=CPU" --host-opt --byre-opt --to-llvm | byteir-translate --mlir-to-llvmir | FileCheck %s +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" --host-opt --byre-opt --to-llvm | byteir-translate --mlir-to-llvmir | FileCheck %s // CHECK-LABEL: define void @_mlir_ciface_Unknown diff --git a/compiler/test/Pipelines/Host/E2E/TypeCvt/00_Input.mlir b/compiler/test/Pipelines/Host/E2E/TypeCvt/00_Input.mlir index ae9ece9a1..fcfcb19bb 100644 --- a/compiler/test/Pipelines/Host/E2E/TypeCvt/00_Input.mlir +++ b/compiler/test/Pipelines/Host/E2E/TypeCvt/00_Input.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s // CHECK-LABEL: func.func @main diff --git a/compiler/test/Pipelines/Host/E2E/TypeCvt/01_HostOpt.mlir b/compiler/test/Pipelines/Host/E2E/TypeCvt/01_HostOpt.mlir index 80ca0b87b..7ae77fb03 100644 --- a/compiler/test/Pipelines/Host/E2E/TypeCvt/01_HostOpt.mlir +++ b/compiler/test/Pipelines/Host/E2E/TypeCvt/01_HostOpt.mlir @@ -24,7 +24,7 @@ module { } return %alloc : memref<1x224x224x3xf16> } - func.func @main(%arg0: memref<1x224x224x3xf32>) -> memref<1x224x224x3xf16> { + func.func @main(%arg0: memref<1x224x224x3xf32>) -> memref<1x224x224x3xf16> attributes {__placeholder__byre.entry_point} { %0 = call @Unknown0(%arg0) : (memref<1x224x224x3xf32>) -> memref<1x224x224x3xf16> return %0 : memref<1x224x224x3xf16> } diff --git a/compiler/test/Pipelines/Host/E2E/TypeCvt/TotalPipeline.mlir b/compiler/test/Pipelines/Host/E2E/TypeCvt/TotalPipeline.mlir index fc85e3aca..4adace8da 100644 --- a/compiler/test/Pipelines/Host/E2E/TypeCvt/TotalPipeline.mlir +++ b/compiler/test/Pipelines/Host/E2E/TypeCvt/TotalPipeline.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byteir-bufferize-opt --scf-opt="target=CPU" --host-opt --byre-opt --to-llvm | byteir-translate --mlir-to-llvmir | FileCheck %s +// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" --host-opt --byre-opt --to-llvm | byteir-translate --mlir-to-llvmir | FileCheck %s // CHECK-LABEL: define void @_mlir_ciface_Unknown diff --git a/compiler/test/Transforms/graphClusteringByDeviceGreedy.mlir b/compiler/test/Transforms/graphClusteringByDeviceGreedy.mlir new file mode 100644 index 000000000..5b334b142 --- /dev/null +++ b/compiler/test/Transforms/graphClusteringByDeviceGreedy.mlir @@ -0,0 +1,122 @@ +// RUN: byteir-opt %s -allow-unregistered-dialect -graph-clustering-by-device="cluster-algo=Greedy" --split-input-file --canonicalize | FileCheck %s +// RUN: byteir-opt %s -allow-unregistered-dialect -graph-clustering-by-device="cluster-algo=TopDown" --split-input-file | FileCheck %s --check-prefix TOPDOWN +// RUN: byteir-opt %s -allow-unregistered-dialect -graph-clustering-by-device="cluster-algo=BottomUp" --split-input-file | FileCheck %s --check-prefix BOTTOMUP + +func.func @use_bottom_up(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = "foo.bar"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = "foo.bar"(%0) {device = "host"} : (tensor<4xf32>) -> tensor<4xf32> + %2 = "foo.bar"(%0) : (tensor<4xf32>) -> tensor<4xf32> + %3 = "foo.bar"(%1) : (tensor<4xf32>) -> tensor<4xf32> + %4 = "foo.bar"(%2, %3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %5 = "foo.bar"(%4, %0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %5 : tensor<4xf32> +} + +// CHECK-LABEL: func.func @use_bottom_up +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: call @use_bottom_up_test +// CHECK-NEXT: return +// CHECK-LABEL: func.func @use_bottom_up_test +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: return + +// TOPDOWN-LABEL: func.func @use_bottom_up +// TOPDOWN-NEXT: "foo.bar" +// TOPDOWN-NEXT: "foo.bar" +// TOPDOWN-NEXT: "foo.bar" +// TOPDOWN-NEXT: call @use_bottom_up_test +// TOPDOWN-NEXT: return +// TOPDOWN-LABEL: func.func @use_bottom_up_test +// TOPDOWN-NEXT: "foo.bar" +// TOPDOWN-NEXT: "foo.bar" +// TOPDOWN-NEXT: "foo.bar" +// TOPDOWN-NEXT: return + +// ----- + +func.func @use_top_down(%arg0 : tensor<4xf32>) -> tensor<4xf32> { + %0 = "foo.bar"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = "foo.bar"(%0) : (tensor<4xf32>) -> tensor<4xf32> + %2 = "foo.bar"(%0) : (tensor<4xf32>) -> tensor<4xf32> + %3 = "foo.bar"(%2) : (tensor<4xf32>) -> tensor<4xf32> + %4 = "foo.bar"(%3) {device = "host"} : (tensor<4xf32>) -> tensor<4xf32> + %5 = "foo.bar"(%1, %4) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %4 : tensor<4xf32> +} + +// CHECK-LABEL: func.func @use_top_down +// CHECK-NEXT: call @use_top_down_test +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: return +// CHECK-LABEL: func.func @use_top_down_test +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: return + +// BOTTOMUP-LABEL: func.func @use_top_down +// BOTTOMUP-NEXT: call @use_top_down_test +// BOTTOMUP-NEXT: "foo.bar" +// BOTTOMUP-NEXT: "foo.bar" +// BOTTOMUP-NEXT: "foo.bar" +// BOTTOMUP-NEXT: return +// BOTTOMUP-LABEL: func.func @use_top_down_test +// BOTTOMUP-NEXT: "foo.bar" +// BOTTOMUP-NEXT: "foo.bar" +// BOTTOMUP-NEXT: "foo.bar" +// BOTTOMUP-NEXT: return + +// ----- +func.func @constant_used_by_host_op(%arg0: tensor) -> (tensor) { + %0 = mhlo.add %arg0, %arg0 : tensor + %1 = mhlo.add %0, %arg0 : tensor + %2 = "mhlo.constant"() {value = dense<1.0000> : tensor } : () -> tensor + %3 = "mhlo.add"(%2, %1) {device = "host"} : (tensor, tensor) -> tensor + %4 = mhlo.subtract %3, %1 : tensor + return %4 : tensor +} + +// CHECK-LABEL: func.func @constant_used_by_host_op +// CHECK-NEXT: mhlo.constant +// CHECK-NEXT: call @constant_used_by_host_op_test +// CHECK-NEXT: mhlo.add +// CHECK-NEXT: mhlo.subtract +// CHECK-NEXT: return +// CHECK-LABEL: func.func @constant_used_by_host_op_test +// CHECK-NEXT: mhlo.add +// CHECK-NEXT: mhlo.add +// CHECK-NEXT: return + +// ----- + +func.func @should_move_down(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0 = "foo.bar"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = "foo.bar"(%0) {device = "host"} : (tensor<4xf32>) -> tensor<4xf32> + %2 = "foo.bar"(%arg1) : (tensor<4xf32>) -> tensor<4xf32> + %3 = "foo.bar"(%1, %2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %4 = "foo.bar"(%2) {device = "host"} : (tensor<4xf32>) -> tensor<4xf32> + %5 = "foo.bar"(%4, %2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %6 = "foo.bar"(%0) : (tensor<4xf32>) -> tensor<4xf32> + %7 = "foo.bar"(%6) : (tensor<4xf32>) -> tensor<4xf32> + return %7, %5 : tensor<4xf32>, tensor<4xf32> +} + +// CHECK-LABEL: func.func @should_move_down +// CHECK-NEXT: call @should_move_down_test +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: return +// CHECK-LABEL: func.func @should_move_down_test +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: return diff --git a/compiler/test/Transforms/graphClusteringByDeviceTopDown.mlir b/compiler/test/Transforms/graphClusteringByDeviceTopDown.mlir index f01d4714d..cab10cfdf 100644 --- a/compiler/test/Transforms/graphClusteringByDeviceTopDown.mlir +++ b/compiler/test/Transforms/graphClusteringByDeviceTopDown.mlir @@ -66,4 +66,27 @@ func.func @cannot_merge(%arg0 : tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func.func @cannot_merge_test // CHECK-NEXT: "foo.bar" // CHECK-NEXT: "foo.bar" -// CHECK-NEXT: return \ No newline at end of file +// CHECK-NEXT: return + +// ----- + +func.func @no_host(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>, %arg3 : tensor<4xf32>, %arg4 : tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0 = "foo.bar"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + %1 = "foo.bar"(%arg1) : (tensor<4xf32>) -> tensor<4xf32> + %2 = "foo.bar"(%arg2) : (tensor<4xf32>) -> tensor<4xf32> + %3 = "foo.bar"(%arg3) : (tensor<4xf32>) -> tensor<4xf32> + %4 = "foo.bar"(%arg4) : (tensor<4xf32>) -> tensor<4xf32> + %5 = "foo.bar"(%1, %2, %3, %4) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %6 = "foo.bar"(%0, %1, %2, %3) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %6, %5 : tensor<4xf32>, tensor<4xf32> +} + +// CHECK-LABEL: func.func @no_host_test +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: "foo.bar" +// CHECK-NEXT: return diff --git a/frontends/onnx-frontend/onnx-frontend/src/Compiler/OFCompilerUtils.cpp b/frontends/onnx-frontend/onnx-frontend/src/Compiler/OFCompilerUtils.cpp index 060ba25e9..bc10e2d3f 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Compiler/OFCompilerUtils.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Compiler/OFCompilerUtils.cpp @@ -223,9 +223,12 @@ int compileModule(mlir::OwningOpRef &module, mlir::PassManager &pm, std::string outputFilename, onnx_frontend::EmissionTargetType emissionTarget, bool emitElide) { - if (mlir::failed(pm.run(*module))) + bool runFailure = mlir::failed(pm.run(*module)); + int outputStatus = + emitOutput(module, outputFilename, emissionTarget, emitElide); + if (runFailure) return onnx_mlir::CompilerFailure; - return emitOutput(module, outputFilename, emissionTarget, emitElide); + return outputStatus; } } // namespace onnx_frontend diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCheckNonLowered.cpp b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCheckNonLowered.cpp index 603eb1c73..da05cec66 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCheckNonLowered.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFCheckNonLowered.cpp @@ -38,7 +38,7 @@ struct OFCheckNonLoweredPass if (isa(op->getDialect())) { llvm::Twine msg(op->getName().getStringRef() + ": ONNX op is not lowered"); - emitRemark(op->getLoc(), msg); + emitWarning(op->getLoc(), msg); onnxFound = true; } }); @@ -48,6 +48,7 @@ struct OFCheckNonLoweredPass } else { llvm::Twine msg("Please lower all ONNX ops"); emitError(func.getLoc(), msg); + return signalPassFailure(); } } }; diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteCustomOnnxOps.cpp b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteCustomOnnxOps.cpp index 2d558d8cb..67b99ff4d 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteCustomOnnxOps.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteCustomOnnxOps.cpp @@ -56,10 +56,16 @@ Value createQuantizeDequantize(PatternRewriter &rewriter, Location loc, "Quantize/Dequantize's scale type must be ranked"); assert(scaleType.getRank() <= 1 && "Quantize/Dequantize's scale rank should be 0 or 1"); - // rewrite output type from float32 to uint16 for quantize + Value zeropoint = inputs[2]; + RankedTensorType zeropointType = + zeropoint.getType().dyn_cast_or_null(); + assert(zeropointType != nullptr && + "Quantize/Dequantize's zeropoint type must be ranked"); + Type zpElementType = zeropointType.getElementType(); + // rewrite output type to zpElementType for quantize if (func_name == "quantize") outputType = RankedTensorType::get( - outputType.getShape(), rewriter.getIntegerType(16, /*isSigned*/ false)); + outputType.getShape(), zpElementType); std::string call_target_name = std::string(CALL_TARGET_NAME_PREFIX) + func_name.str(); diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp index 931c18c72..daaa0aaf7 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.cpp @@ -30,6 +30,7 @@ namespace { #define CALL_TARGET_NAME_PREFIX "byteir." // clang-format off +// func(byteir_op_name, onnx_op_name) #define VALID_CUSTOM_CALL_OP(func) \ func(arg_max, ArgMax) \ func(arg_min, ArgMin) \ @@ -39,17 +40,21 @@ namespace { func(instance_norm, InstanceNorm) \ func(l2_norm, L2Norm) \ func(layer_norm, LayerNorm) \ + func(log_softmax, LogSoftmax) \ + func(one_hot, OneHot) \ func(quantize, Quantize) \ func(resize, Resize) \ - func(softmax, Softmax) \ - func(log_softmax, LogSoftmax) + func(softmax, Softmax) +// generate get name function name, function name contains onnx op name, +// return byteir custom call target op name #define GEN_FUNCNAME(call_target_name, func_name) \ constexpr const char *get##func_name##NameWithPrefix() { \ return CALL_TARGET_NAME_PREFIX #call_target_name; \ } \ constexpr const char *get##func_name##Name() { return #call_target_name; } +// Wrapname class which outputs target name for ops that can be simply replaced. #define WRAP(onnx_class, func_name) \ template <> struct WrapName { \ static constexpr const char *call_target_name = get##func_name##NameWithPrefix(); \ @@ -58,9 +63,9 @@ namespace { #define WRAP_LIST(func) \ func(ONNXArgMaxOp, ArgMax) \ func(ONNXArgMinOp, ArgMin) \ + func(ONNXDequantizeLinearOp, Dequantize) \ func(ONNXErfOp, Erf) \ - func(ONNXQuantizeLinearOp, Quantize) \ - func(ONNXDequantizeLinearOp, Dequantize) + func(ONNXQuantizeLinearOp, Quantize) // clang-format on VALID_CUSTOM_CALL_OP(GEN_FUNCNAME) @@ -334,31 +339,57 @@ Value createResize(PatternRewriter &rewriter, Location loc, Value input, //===----------------------------------------------------------------------===// // LayerNorm //===----------------------------------------------------------------------===// + +Value createSqueezedValue(PatternRewriter &rewriter, Location loc, Value input, + SmallVector &axis_vec) { + RankedTensorType inputType = + input.getType().dyn_cast_or_null(); + int64_t inputRank = inputType.getRank(); + int64_t axisSize = axis_vec.size(); + if (inputRank == axisSize) + return input; + Type elemType = inputType.getElementType(); + auto inputShape = inputType.getShape(); + SmallVector outputShape; + for (int64_t axis : axis_vec) { + outputShape.emplace_back(inputShape[axis]); + } + RankedTensorType outputType = RankedTensorType::get(outputShape, elemType); + Value output = rewriter.create(loc, outputType, input); + return output; +} + Value createLayerNorm(PatternRewriter &rewriter, Location loc, Value input, Value scale, Value B, ArrayAttr axis_attr, Attribute epsilon_attr) { RankedTensorType inputType = input.getType().dyn_cast_or_null(); assert(inputType != nullptr && "Input type must be ranked"); - int64_t axis = axis_attr[0].cast().getInt(); - // canonicalize axis to be positive - if (axis < 0) { - axis = inputType.getRank() + axis; + int64_t num_axis = axis_attr.size(); + SmallVector axis_vec; + for (int64_t i = 0; i < num_axis; i++) { + int64_t axis = axis_attr[i].cast().getInt(); + // canonicalize axis to be positive + if (axis < 0) + axis = inputType.getRank() + axis; + axis_vec.emplace_back(axis); } + Value squeezedScale = createSqueezedValue(rewriter, loc, scale, axis_vec); + Value squeezedB = createSqueezedValue(rewriter, loc, B, axis_vec); double eps = (*epsilon_attr.cast().getValues().begin()) .convertToDouble(); std::string call_target_name = getLayerNormNameWithPrefix(); mhlo::CustomCallOp customCallOp = rewriter.create( loc, llvm::ArrayRef{inputType}, - llvm::ArrayRef{input, scale, B}, call_target_name, false, - rewriter.getStringAttr(""), + llvm::ArrayRef{input, squeezedScale, squeezedB}, call_target_name, + false, rewriter.getStringAttr(""), mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL, rewriter.getArrayAttr(llvm::ArrayRef{}), mhlo::CustomCallSchedule::NONE, nullptr, nullptr, rewriter.getArrayAttr(llvm::ArrayRef{})); DictionaryAttrWrapper attrs(rewriter.getContext()); attrs.setAttr("epsilon", rewriter.getF64FloatAttr(eps)); - attrs.setAttr("axis", rewriter.getI64ArrayAttr({axis})); + attrs.setAttr("axis", rewriter.getI64ArrayAttr(ArrayRef(axis_vec))); customCallOp->setAttr(BYTEIR_ATTRS, getCleanAttr(attrs)); return customCallOp.getResults()[0]; } @@ -467,6 +498,75 @@ Value createLayerNormGeLU(PatternRewriter &rewriter, Location loc, Value input, return createGeLU(rewriter, loc, result); } +//===----------------------------------------------------------------------===// +// OneHot Pattern +//===----------------------------------------------------------------------===// + +Value createOneHot(PatternRewriter &rewriter, Location loc, Value indices, + Value depthValue, Value values, IntegerAttr axisAttr, + Value output) { + // indices + RankedTensorType indicesType = indices.getType().dyn_cast(); + assert(indicesType && indicesType.hasStaticShape() && + "indices must be static"); + int64_t indicesRank = indicesType.getRank(); + Type indicesElementType = indicesType.getElementType(); + // depth + ONNXConstantOp depthOp = depthValue.getDefiningOp(); + assert(depthOp && "onnx.OneHot's depth should be constant"); + ElementsAttr depthAttr = depthOp.getValueAttr().dyn_cast(); + int64_t depth = depthAttr.getValues()[0].getSExtValue(); + // axis + int64_t axis = axisAttr.getSInt(); + if (axis < 0) + axis += indicesRank + 1; + assert(axis >= 0 && axis <= indicesRank && "axis not in range"); + // normalized indices + Value zero = rewriter.create( + loc, + DenseIntElementsAttr::get(RankedTensorType::get({}, indicesElementType), + ArrayRef{0})); + Value broadcastZero = rewriter.create( + loc, indicesType, zero, rewriter.getI64TensorAttr({})); + Value broadcastDepth; + int64_t depthRank = depthValue.getType().cast().getRank(); + if (depthRank == 1) + broadcastDepth = rewriter.create( + loc, indicesType, depthValue, rewriter.getI64TensorAttr({0})); + else + broadcastDepth = rewriter.create( + loc, indicesType, depthValue, rewriter.getI64TensorAttr({})); + Value compareGeZero = rewriter.create( + loc, indices, broadcastZero, mhlo::ComparisonDirection::GE); + Value positiveIndices = + rewriter.create(loc, indices, broadcastDepth); + Value normalizedIndices = rewriter.create( + loc, indicesType, compareGeZero, indices, positiveIndices); + // values + ONNXConstantOp ValuesOp = values.getDefiningOp(); + assert(ValuesOp && "onnx.OneHot's values should be constant"); + ElementsAttr valuesAttr = ValuesOp.getValueAttr().dyn_cast(); + assert(valuesAttr && valuesAttr.size() == 2 && + "value should keep ElementsAttr with size = 2"); + Attribute off_value = valuesAttr.getValues()[0]; + Attribute on_value = valuesAttr.getValues()[1]; + mhlo::CustomCallOp customCallOp = rewriter.create( + loc, llvm::ArrayRef{output.getType()}, + llvm::ArrayRef{normalizedIndices}, getOneHotNameWithPrefix(), + false, rewriter.getStringAttr(""), + mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL, + rewriter.getArrayAttr(llvm::ArrayRef{}), + mhlo::CustomCallSchedule::NONE, nullptr, nullptr, + rewriter.getArrayAttr(llvm::ArrayRef{})); + DictionaryAttrWrapper attrs(rewriter.getContext()); + attrs.setAttr("depth", rewriter.getI64IntegerAttr(depth)); + attrs.setAttr("axis", rewriter.getI64IntegerAttr(axis)); + attrs.setAttr("on_value", on_value); + attrs.setAttr("off_value", off_value); + customCallOp->setAttr(BYTEIR_ATTRS, getCleanAttr(attrs)); + return customCallOp.getResults()[0]; +} + #include "onnx-frontend/src/Conversion/OFRewriteToCustomCall.inc" //===----------------------------------------------------------------------===// @@ -577,6 +677,8 @@ struct OFRewriteToCustomCallPass std::make_unique(context)); validOpSet[getLayerNormName()].emplace_back( std::make_unique(context)); + validOpSet[getOneHotName()].emplace_back( + std::make_unique(context)); validOpSet[getArgMaxName()].emplace_back( std::make_unique>( context, 1)); diff --git a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td index a3ade69a8..f20ff1b08 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td +++ b/frontends/onnx-frontend/onnx-frontend/src/Conversion/OFRewriteToCustomCall.td @@ -19,7 +19,7 @@ #define ONNX_FRONTEND_SRC_CONVERSION_OF_CUSTOM_CALL_TD include "src/Dialect/ONNX/ONNX.td" // third_party/onnx-mlir/src/Dialect/ONNX/ONNX.td - + def IsOneSize : Constraint, "is of size one">; def IsOneSizeElements : Constraint, @@ -166,12 +166,12 @@ def ValueSqrtTwo : Constraint, "valu def ValueOne : Constraint, "value attr is not splat value of 1.0">; def ValueHalf : Constraint, "value attr is not splat value of 0.5">; def IsTwoTimes : Constraint()" >, + And<[ + CPred< "$0.isa()" >, CPred< "$1.isa()" >, CPred< "isFPAttrTimesCloseTo($0.cast(), $1.cast(), 2.0)">]>, And<[ - CPred< "$0.isa()" >, + CPred< "$0.isa()" >, CPred< "$1.isa()" >, CPred< "isFPAttrTimesCloseTo($0.cast(), $1.cast(), 2.0)">]>]>, "value attr $0 is not 2 times of $1">; @@ -260,4 +260,12 @@ def RewriteLayerNormGeLUWithMulConstPropagation : Pat< (IsTwoTimes $const_1, $const_2) ]>; +//===----------------------------------------------------------------------===// +// OneHot Pattern +//===----------------------------------------------------------------------===// + +def RewriteOneHot : Pat< + (ONNXOneHotOp:$output $indices, $depth, $values, $axis), + (NativeCodeCall<"createOneHot($_builder, $_loc, $0, $1, $2, $3, $4)"> $indices, $depth, $values, $axis, $output)>; + #endif // ONNX_FRONTEND_SRC_CONVERSION_OF_CUSTOM_CALL_TD diff --git a/frontends/onnx-frontend/onnx-frontend/src/onnx-frontend.cpp b/frontends/onnx-frontend/onnx-frontend/src/onnx-frontend.cpp index 67943dde6..782fed720 100644 --- a/frontends/onnx-frontend/onnx-frontend/src/onnx-frontend.cpp +++ b/frontends/onnx-frontend/onnx-frontend/src/onnx-frontend.cpp @@ -87,9 +87,9 @@ int main(int argc, char *argv[]) { if (emissionTarget == onnx_frontend::EmitMhloIR) { onnx_frontend::addCustomizedONNXToMhloPasses(pm, onnx_frontend::customCallOps); + onnx_frontend::addVerifyONNXToMhloPasses(pm); } auto status = onnx_frontend::compileModule( module, pm, onnx_mlir::outputBaseName, emissionTarget, emitElide); - onnx_frontend::addVerifyONNXToMhloPasses(pm); return status; } diff --git a/frontends/onnx-frontend/onnx-frontend/test/of_check_non_lowered.mlir b/frontends/onnx-frontend/onnx-frontend/test/of_check_non_lowered.mlir index 2b24175ec..5a5d25948 100644 --- a/frontends/onnx-frontend/onnx-frontend/test/of_check_non_lowered.mlir +++ b/frontends/onnx-frontend/onnx-frontend/test/of_check_non_lowered.mlir @@ -1,14 +1,8 @@ -// RUN: onnx-frontend-opt -check-non-lowered %s -split-input-file | FileCheck %s +// RUN: onnx-frontend-opt -check-non-lowered %s -split-input-file -verify-diagnostics func.func @test_onnx_non_lowered(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { - // expected-error @+1 {{onnx.NoValue: ONNX op is not lowered}} + // expected-warning @+2 {{onnx.NoValue: ONNX op is not lowered}} + // expected-error @-2 {{Please lower all ONNX ops}} %0 = "onnx.NoValue"() : () -> none return %arg0 : tensor<1x2xf32> } - -func.func @test_onnx_lowered(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { - %0 = mhlo.constant dense<[[1.000000e+00, 2.000000e+00]]> : tensor<1x2xf32> - %1 = mhlo.add %arg0, %0 : tensor<1x2xf32> - return %1 : tensor<1x2xf32> - // CHECK-LABEL: func.func @test_onnx_lowered -} \ No newline at end of file diff --git a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_custom_onnx_op.mlir b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_custom_onnx_op.mlir index 53ba3a647..a6b6744d3 100644 --- a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_custom_onnx_op.mlir +++ b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_custom_onnx_op.mlir @@ -26,4 +26,32 @@ func.func @test_dequantize16(%arg0: tensor<1x128x1x1xui16>) -> tensor<1x128x1x1x // CHECK-DAG: [[CONST1:%.*]] = onnx.Constant dense<0> : tensor // CHECK-NEXT: [[CUSTOM:%.*]] = mhlo.custom_call @byteir.dequantize([[ARG0]], [[CONST0]], [[CONST1]]) {backend_config = "", byteir_attrs = {}} : (tensor<1x128x1x1xui16>, tensor, tensor) -> tensor<1x128x1x1xf32> // CHECK-NEXT: return [[CUSTOM]] : tensor<1x128x1x1xf32> -} \ No newline at end of file +} + +func.func @test_quantize16_typed(%arg0: tensor<1x128x1x1xf32>) -> tensor<1x128x1x1xi16> { + %0 = "onnx.Constant"() {onnx_node_name = "Constant_0", value = dense<"0x1F1F1F1F"> : tensor} : () -> tensor + %1 = "onnx.Constant"() {onnx_node_name = "Constant_1", value = dense<"0x0000"> : tensor} : () -> tensor + %2 = "onnx.Custom"(%arg0, %0, %1) {domain_name = "", function_name = "quantize", inputs_for_infer = [0], onnx_node_name = "QuantLinear", shape_infer_pattern = "SameAs"} : (tensor<1x128x1x1xf32>, tensor, tensor) -> tensor<1x128x1x1xi16> + return %2 : tensor<1x128x1x1xi16> + // CHECK-LABEL: @test_quantize16_typed( + // CHECK-SAME: [[ARG0:%.*]]: tensor<1x128x1x1xf32>) -> tensor<1x128x1x1xi16> { + // CHECK-DAG: [[CONST0:%.*]] = onnx.Constant dense<3.36953024E-20> : tensor + // CHECK-DAG: [[CONST1:%.*]] = onnx.Constant dense<0> : tensor + // CHECK-NEXT: [[CUSTOM:%.*]] = mhlo.custom_call @byteir.quantize([[ARG0]], [[CONST0]], [[CONST1]]) {backend_config = "", byteir_attrs = {}} : (tensor<1x128x1x1xf32>, tensor, tensor) -> tensor<1x128x1x1xi16> + // CHECK-NEXT: return [[CUSTOM]] : tensor<1x128x1x1xi16> +} + +// ----- + +func.func @test_dequantize16_typed(%arg0: tensor<1x128x1x1xi16>) -> tensor<1x128x1x1xf32> { + %0 = "onnx.Constant"() {onnx_node_name = "Constant_0", value = dense<"0x1F1F1F1F"> : tensor} : () -> tensor + %1 = "onnx.Constant"() {onnx_node_name = "Constant_1", value = dense<"0x0000"> : tensor} : () -> tensor + %2 = "onnx.Custom"(%arg0, %0, %1) {domain_name = "", function_name = "dequantize", inputs_for_infer = [0], onnx_node_name = "DeQuantLinear", shape_infer_pattern = "SameAs"} : (tensor<1x128x1x1xi16>, tensor, tensor) -> tensor<1x128x1x1xf32> + return %2 : tensor<1x128x1x1xf32> + // CHECK-LABEL: @test_dequantize16_typed( + // CHECK-SAME: [[ARG0:%.*]]: tensor<1x128x1x1xi16>) -> tensor<1x128x1x1xf32> { + // CHECK-DAG: [[CONST0:%.*]] = onnx.Constant dense<3.36953024E-20> : tensor + // CHECK-DAG: [[CONST1:%.*]] = onnx.Constant dense<0> : tensor + // CHECK-NEXT: [[CUSTOM:%.*]] = mhlo.custom_call @byteir.dequantize([[ARG0]], [[CONST0]], [[CONST1]]) {backend_config = "", byteir_attrs = {}} : (tensor<1x128x1x1xi16>, tensor, tensor) -> tensor<1x128x1x1xf32> + // CHECK-NEXT: return [[CUSTOM]] : tensor<1x128x1x1xf32> +} diff --git a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir index 6d7e11fd5..0de4d41e7 100644 --- a/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir +++ b/frontends/onnx-frontend/onnx-frontend/test/of_rewrite_to_custom_call.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-frontend-opt -rewrite-to-custom-call="ops=arg_max,arg_min,layer_norm,erf,gelu,l2_norm,quantize,dequantize,softmax,resize" -of-canonicalize -constprop-onnx -of-canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-frontend-opt -rewrite-to-custom-call="ops=arg_max,arg_min,layer_norm,erf,gelu,l2_norm,quantize,dequantize,softmax,resize,one_hot" -of-canonicalize -constprop-onnx -of-canonicalize %s -split-input-file | FileCheck %s func.func @test_arg_max(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5xi64> { %0 = "onnx.ArgMax"(%arg0) {axis = 3 : si64, keepdims = 0 : si64, onnx_node_name = "ArgMax_0"} : (tensor<1x5x5x3xf32>) -> tensor<1x5x5xi64> @@ -99,6 +99,29 @@ func.func @test_layer_norm_without_last_add(%arg0: tensor<1x3xf32>) -> tensor<1x // ----- +func.func @test_layer_norm_squeeze(%arg0: tensor<2x4x3xf32>) -> tensor<2x4x3xf32> { + %22 = "onnx.ReduceMeanV13"(%arg0) {axes = [-1], onnx_node_name = "ReduceMean_25"} : (tensor<2x4x3xf32>) -> tensor<2x4x1xf32> + %23 = "onnx.Sub"(%arg0, %22) {onnx_node_name = "Sub_26"} : (tensor<2x4x3xf32>, tensor<2x4x1xf32>) -> tensor<2x4x3xf32> + %25 = "onnx.Mul"(%23, %23) : (tensor<2x4x3xf32>, tensor<2x4x3xf32>) -> tensor<2x4x3xf32> + %26 = "onnx.ReduceMeanV13"(%25) {axes = [-1], onnx_node_name = "ReduceMean_29"} : (tensor<2x4x3xf32>) -> tensor<2x4x1xf32> + %27 = "onnx.Constant"() {value = dense<9.99999974E-6> : tensor} : () -> tensor + %28 = "onnx.Add"(%26, %27) {onnx_node_name = "Add_31"} : (tensor<2x4x1xf32>, tensor) -> tensor<2x4x1xf32> + %29 = "onnx.Sqrt"(%28) {onnx_node_name = "Sqrt_32"} : (tensor<2x4x1xf32>) -> tensor<2x4x1xf32> + %30 = "onnx.Div"(%23, %29) {onnx_node_name = "Div_33"} : (tensor<2x4x3xf32>, tensor<2x4x1xf32>) -> tensor<2x4x3xf32> + %31 = "onnx.Constant"() {value = dense<[[[0.15, 0.2, 0.25]]]> : tensor<1x1x3xf32>} : () -> tensor<1x1x3xf32> + %32 = "onnx.Mul"(%30, %31) {onnx_node_name = "Mul_34"} : (tensor<2x4x3xf32>, tensor<1x1x3xf32>) -> tensor<2x4x3xf32> + %33 = "onnx.Constant"() {value = dense<[[[1.0, 2.0, 3.0]]]> : tensor<1x1x3xf32>} : () -> tensor<1x1x3xf32> + %34 = "onnx.Add"(%32, %33) {onnx_node_name = "Add_35"} : (tensor<2x4x3xf32>, tensor<1x1x3xf32>) -> tensor<2x4x3xf32> + return %34 : tensor<2x4x3xf32> +// CHECK-LABEL: @test_layer_norm_squeeze(%arg0: tensor<2x4x3xf32>) -> tensor<2x4x3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = mhlo.constant dense<[1.500000e-01, 2.000000e-01, 2.500000e-01]> : tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32> +// CHECK-NEXT: %2 = mhlo.custom_call @byteir.layer_norm(%arg0, [[VAR_0_]], [[VAR_1_]]) {backend_config = "", byteir_attrs = {axis = [2], epsilon = 9.9999997473787516E-6 : f64}} : (tensor<2x4x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<2x4x3xf32> +// CHECK-NEXT: return %2 : tensor<2x4x3xf32> +} + +// ----- + func.func @test_erf(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { %0 = "onnx.Erf"(%arg0) : (tensor<3x2xf32>) -> tensor<3x2xf32> return %0 : tensor<3x2xf32> @@ -342,4 +365,18 @@ func.func @test_not_l2_norm_gelu_dense(%1092: tensor<1x4xf32>) -> tensor<1x4xf32 return %1105 : tensor<1x4xf32> // CHECK-LABEL: func.func @test_not_l2_norm_gelu_dense // CHECK-NOT: mhlo.custom_call @byteir.gelu +} + +func.func @test_onehot(%arg0 : tensor<2x3x4xi64>) -> tensor<2x3x4x64xi64> { + %0 = onnx.Constant dense<64> : tensor<1xi64> + %1 = onnx.Constant dense<[0, 1]> : tensor<2xi64> + %2 = "onnx.OneHot"(%arg0, %0, %1) {axis = -1 : si64} : (tensor<2x3x4xi64>, tensor<1xi64>, tensor<2xi64>) -> tensor<2x3x4x64xi64> + "func.return"(%2) : (tensor<2x3x4x64xi64>) -> () +// CHECK-LABEL: func.func @test_onehot +// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x3x4xi64>) -> tensor<2x3x4x64xi64> { +// CHECK: %[[GE_ZERO:.+]] = mhlo.compare GE, %[[ARG0]], %[[ZERO:.+]], NOTYPE : (tensor<2x3x4xi64>, tensor<2x3x4xi64>) -> tensor<2x3x4xi1> +// CHECK: %[[POS_ARG0:.+]] = mhlo.add %[[ARG0]], %[[DEPTH:.+]] : tensor<2x3x4xi64> +// CHECK: %[[NORM_ARG:.+]] = mhlo.select %[[GE_ZERO]], %[[ARG0]], %[[POS_ARG0]] : tensor<2x3x4xi1>, tensor<2x3x4xi64> +// CHECK: %[[RESULT:.+]] = mhlo.custom_call @byteir.one_hot(%[[NORM_ARG]]) {backend_config = "", byteir_attrs = {axis = 3 : i64, depth = 64 : i64, off_value = 0 : i64, on_value = 1 : i64}} : (tensor<2x3x4xi64>) -> tensor<2x3x4x64xi64> +// CHECK: return %[[RESULT]] : tensor<2x3x4x64xi64> } \ No newline at end of file diff --git a/frontends/onnx-frontend/test/base.py b/frontends/onnx-frontend/test/base.py index 010c350a3..d05d88455 100644 --- a/frontends/onnx-frontend/test/base.py +++ b/frontends/onnx-frontend/test/base.py @@ -50,7 +50,7 @@ def convert_onnx_to_mhlo_ir(self, onnx_path, mhlo_ir_path, batch_size): with open(mhlo_ir_path, "w") as f: f.write(out) - def onnx_mhlo_test_helper(self, onnx_path, mhlo_ir_path, input_data: Dict[str, npt.NDArray]): + def onnx_mhlo_test_helper(self, onnx_path, mhlo_ir_path, input_data: Dict[str, npt.NDArray], decimal): # TODO: handle a model with multiple final outputs onnx_model = onnx.load(onnx_path) onnx_init_names = set([init.name for init in onnx_model.graph.initializer]) @@ -66,7 +66,7 @@ def onnx_mhlo_test_helper(self, onnx_path, mhlo_ir_path, input_data: Dict[str, n mhlo_data: List[npt.NDArray] = [input_data[name] for name in input_names] mhlo_outputs: List[npt.NDArray] = interp.call_function("main", mhlo_data) - np.testing.assert_almost_equal(onnx_outputs, mhlo_outputs, decimal=4) + np.testing.assert_almost_equal(onnx_outputs, mhlo_outputs, decimal=decimal) def run(self, model_filename: Optional[str] = None, @@ -74,15 +74,16 @@ def run(self, input_data: Optional[Dict[str, npt.NDArray]] = None, input_filename: Optional[str] = None, input_shape_dtype: Optional[List[List]] = None, - batch_size: int = 1): + batch_size: int = 1, + decimal: int = 4): assert self.data_dir is not None, "self.data_dir not initialized in derived class" - assert osp.isdir(self.data_dir), "self.data_dir (" + \ - self.data_dir + ") is not a directory" assert self.tmp_dir is not None, "self.tmp_dir not initialized in derived class" assert osp.isdir(self.tmp_dir), "self.tmp_dir (" + \ self.tmp_dir + ") is not a directory" if input_filename is not None: + assert osp.isdir(self.data_dir), "self.data_dir (" + \ + self.data_dir + ") is not a directory" # set inputs content = pickle.load( open(osp.join(self.data_dir, input_filename), 'rb')) @@ -109,7 +110,7 @@ def run(self, onnx.save(model_onnx_pb, onnx_path) self.convert_onnx_to_mhlo_ir(onnx_path, mhlo_ir_path, batch_size) - self.onnx_mhlo_test_helper(onnx_path, mhlo_ir_path, input_data) + self.onnx_mhlo_test_helper(onnx_path, mhlo_ir_path, input_data, decimal) else: raise ValueError( "Model file {} has an unkown extension name".format(model_filename)) diff --git a/frontends/onnx-frontend/test/ops/test_math.py b/frontends/onnx-frontend/test/ops/test_math.py index b86a90c08..48a67ab91 100644 --- a/frontends/onnx-frontend/test/ops/test_math.py +++ b/frontends/onnx-frontend/test/ops/test_math.py @@ -247,8 +247,35 @@ def test_max(self): model_onnx_pb=proto, input_shape_dtype=input_shape_dtype) + def test_min(self): + input_shape_dtype = [ + ["X", (3, 2), "float32"], + ["Y", (3, 2), "float32"], + ] + output_shape_dtype = [ + ["Z", (3, 2), "float32"], + ] + proto = build_onnx("Min", input_shape_dtype, output_shape_dtype) + self.run(model_filename="min.onnx", + model_onnx_pb=proto, + input_shape_dtype=input_shape_dtype) + def test_gelu(self): input_shape_dtype = [ ["input_0", (1, 5, 5, 3), "float32"], ] self.run(model_filename="gelu.onnx", input_shape_dtype=input_shape_dtype) + + def test_where(self): + input_shape_dtype = [ + ["pred", (3, 2), "bool"], + ["X", (3, 2), "float32"], + ["Y", (3, 2), "float32"], + ] + output_shape_dtype = [ + ["Z", (3, 2), "float32"], + ] + proto = build_onnx("Where", input_shape_dtype, output_shape_dtype) + self.run(model_filename="where.onnx", + model_onnx_pb=proto, + input_shape_dtype=input_shape_dtype) diff --git a/frontends/onnx-frontend/test/ops/test_rnn.py b/frontends/onnx-frontend/test/ops/test_rnn.py new file mode 100644 index 000000000..db159e80b --- /dev/null +++ b/frontends/onnx-frontend/test/ops/test_rnn.py @@ -0,0 +1,34 @@ +import numpy as np +import pytest +import onnx +from test.base import TestBase +from test.ops.utils import build_onnx + + +class TestOpsRNN(TestBase): + + @pytest.fixture(autouse=True) + def setup(self, tmpdir_factory): + self.setup_base(tmpdir_factory, "test/ops/data/rnn") + + def test_lstm(self): + input_shape_dtype = [ + ["X", (10, 16, 512), "float32"], + ["W", (2, 1024, 512), "float32"], + ["R", (2, 1024, 256), "float32"], + ["B", (2, 2048), "float32"], + ["", None, None], + ["initial_h", (2, 16, 256), "float32"], + ["initial_c", (2, 16, 256), "float32"], + ["", None, None], + ] + output_shape_dtype = [ + ["Y", (10, 2, 16, 256), "float32"], + ] + proto = build_onnx("LSTM", input_shape_dtype, output_shape_dtype, + direction="bidirectional", hidden_size=256) + input_shape_dtype = [input_shape_dtype[0], input_shape_dtype[1], + input_shape_dtype[2], input_shape_dtype[3], + input_shape_dtype[5], input_shape_dtype[6],] + self.run(model_filename="lstm.onnx", model_onnx_pb=proto, + input_shape_dtype=input_shape_dtype, decimal=3) \ No newline at end of file diff --git a/frontends/onnx-frontend/test/ops/test_tensor.py b/frontends/onnx-frontend/test/ops/test_tensor.py index 67c18a974..3ec6902f2 100644 --- a/frontends/onnx-frontend/test/ops/test_tensor.py +++ b/frontends/onnx-frontend/test/ops/test_tensor.py @@ -141,6 +141,26 @@ def test_arg_min(self): ] self.run(model_filename="arg_min.onnx", input_shape_dtype=input_shape_dtype) + def test_onehot(self): + input_shape_dtype = [ + ["X", (2, 3, 4), "int64"], + ["depth", (1,), "int64"], + ["values", (2,), "float32"], + ] + output_shape_dtype = [ + ["Y", (2, 3, 4, 5), "float32"], + ] + depth_tensor = onnx.helper.make_tensor( + "depth", onnx.TensorProto.INT64, [1], np.array([5])) + values_tensor = onnx.helper.make_tensor( + "values", onnx.TensorProto.FLOAT, [2], np.array([0.0, 1.0])) + proto = build_onnx( + "OneHot", input_shape_dtype, output_shape_dtype, + initializer=[depth_tensor, values_tensor], axis=-1 + ) + input_shape_dtype = [input_shape_dtype[0]] + self.run(model_filename="onehot.onnx", model_onnx_pb=proto, input_shape_dtype=input_shape_dtype) + def test_pad(self): input_shape_dtype = [ ["X", (1, 3, 5, 5), "float32"], diff --git a/frontends/onnx-frontend/third_party/patches/OnnxMlirNewOps.patch b/frontends/onnx-frontend/third_party/patches/OnnxMlirNewOps.patch index e38b71cc5..a55e2c9e9 100644 --- a/frontends/onnx-frontend/third_party/patches/OnnxMlirNewOps.patch +++ b/frontends/onnx-frontend/third_party/patches/OnnxMlirNewOps.patch @@ -1,54 +1,364 @@ diff --git a/src/Conversion/ONNXToMhlo/CMakeLists.txt b/src/Conversion/ONNXToMhlo/CMakeLists.txt -index bd7283a9..dcce81be 100644 +index bd7283a9..20756e93 100644 --- a/src/Conversion/ONNXToMhlo/CMakeLists.txt +++ b/src/Conversion/ONNXToMhlo/CMakeLists.txt -@@ -44,8 +44,11 @@ add_onnx_mlir_library(OMONNXToMhlo +@@ -38,14 +38,20 @@ add_onnx_mlir_library(OMONNXToMhlo + NN/ConvTranspose.cpp + NN/Normalization.cpp + NN/Pooling.cpp ++ RNN/LSTM.cpp ++ RNN/RNNBase.cpp + Tensor/ArgMax.cpp + Tensor/Concat.cpp + Tensor/Constant.cpp Tensor/Expand.cpp Tensor/Flatten.cpp Tensor/Gather.cpp + Tensor/GatherElements.cpp Tensor/Identity.cpp ++ Tensor/OneHot.cpp + Tensor/Pad.cpp Tensor/Reshape.cpp + Tensor/ScatterND.cpp Tensor/Shape.cpp Tensor/Slice.cpp Tensor/Split.cpp +@@ -75,7 +81,3 @@ target_link_libraries(MhloDialect PUBLIC + target_link_libraries(StablehloTypeInference PUBLIC + StablehloBase + ) +- +-target_link_libraries(StablehloTypeInference PUBLIC +- StablehloBase +- ) diff --git a/src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp b/src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp -index 73330c37..b347f97c 100644 +index 73330c37..7471748a 100644 --- a/src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp +++ b/src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp -@@ -41,8 +41,11 @@ void populateONNXToMhloConversionPattern( +@@ -34,6 +34,8 @@ void populateONNXToMhloConversionPattern( + populateLoweringONNXConvTransposeOpToMhloPattern(patterns, ctx); + populateLoweringONNXNormalizationOpToMhloPattern(patterns, ctx); + populateLoweringONNXPoolingOpToMhloPattern(patterns, ctx); ++ // Recurrent neural network ++ populateLoweringONNXLSTMOpToMhloPattern(patterns, ctx); + // Tensor + populateLoweringONNXArgMaxOpToMhloPattern(patterns, ctx); + populateLoweringONNXConcatOpToMhloPattern(patterns, ctx); +@@ -41,8 +43,12 @@ void populateONNXToMhloConversionPattern( populateLoweringONNXExpandOpToMhloPattern(patterns, ctx); populateLoweringONNXFlattenOpToMhloPattern(patterns, ctx); populateLoweringONNXGatherOpToMhloPattern(patterns, ctx); + populateLoweringONNXGatherElementsOpToMhloPattern(patterns, ctx); populateLoweringONNXIdentityOpToMhloPattern(patterns, ctx); ++ populateLoweringONNXOneHotOpToMhloPattern(patterns, ctx); + populateLoweringONNXPadOpToMhloPattern(patterns, ctx); populateLoweringONNXReshapeOpToMhloPattern(patterns, ctx); + populateLoweringONNXScatterNDOpToMhloPattern(patterns, ctx); populateLoweringONNXShapeOpToMhloPattern(patterns, ctx); populateLoweringONNXSliceOpToMhloPattern(patterns, ctx); populateLoweringONNXSplitOpToMhloPattern(patterns, ctx); -@@ -89,7 +92,7 @@ void FrontendToMhloLoweringPass::runOnOperation() { +@@ -89,10 +95,8 @@ void FrontendToMhloLoweringPass::runOnOperation() { // Added affine as some affine maps are generated by IndexExpression. It could // be disabled and/or replaced by shape max/min. target.addLegalDialect(); -+ arith::ArithDialect, shape::ShapeDialect, affine::AffineDialect, tensor::TensorDialect>(); - // Needed to support unsigned int computations. To be removed if we use a - // scheme that does not rely on the UnrealizedConversionCastOp. - target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); +- // Needed to support unsigned int computations. To be removed if we use a +- // scheme that does not rely on the UnrealizedConversionCastOp. +- target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); ++ arith::ArithDialect, shape::ShapeDialect, affine::AffineDialect, ++ tensor::TensorDialect>(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the frontend operations. +diff --git a/src/Conversion/ONNXToMhlo/DialectBuilder.cpp b/src/Conversion/ONNXToMhlo/DialectBuilder.cpp +index eae15683..a0547044 100644 +--- a/src/Conversion/ONNXToMhlo/DialectBuilder.cpp ++++ b/src/Conversion/ONNXToMhlo/DialectBuilder.cpp +@@ -13,6 +13,7 @@ + //===----------------------------------------------------------------------===// + + #include "mlir/Dialect/Arith/IR/Arith.h" ++#include "llvm/ADT/TypeSwitch.h" + + #include "src/Conversion/ONNXToMhlo/DialectBuilder.hpp" + #include "src/Dialect/ONNX/ONNXOps.hpp" +@@ -22,6 +23,62 @@ using namespace mlir; + + namespace onnx_mlir { + ++Value MhloBuilder::constant(mlir::Type type, double val) const { ++ Value constant = nullptr; ++ // Could be a vector type; look at the element type. ++ Type elementType = type; ++ VectorType vectorType = type.dyn_cast(); ++ if (vectorType) ++ elementType = vectorType.getElementType(); ++ TypeSwitch(elementType) ++ .Case([&](Type) { ++ constant = ++ b().create(loc(), b().getF16FloatAttr(val)); ++ }) ++ .Case([&](Type) { ++ constant = ++ b().create(loc(), b().getF32FloatAttr(val)); ++ }) ++ .Case([&](Type) { ++ constant = ++ b().create(loc(), b().getF64FloatAttr(val)); ++ }) ++ .Case([&](IntegerType elementType) { ++ assert(val == (int64_t)val && "value is ambiguous"); ++ unsigned width = elementType.getWidth(); ++ ++ if (width == 1) ++ constant = ++ b().create(loc(), b().getBoolAttr(val != 0)); ++ else { ++ if (elementType.isUnsignedInteger()) { ++ constant = b().create(loc(), ++ b().getIntegerAttr(elementType, APInt(width, (uint64_t)val, false))); ++ } else { ++ constant = b().create(loc(), ++ b().getIntegerAttr(elementType, APInt(width, (int64_t)val, true))); ++ } ++ } ++ }) ++ .Case([&](Type elementType) { ++ constant = b().create( ++ loc(), b().getIntegerAttr(elementType, val)); ++ }) ++ .Default([](Type) { llvm_unreachable("unsupported element type"); }); ++ ++ assert(constant != nullptr && "Expecting valid constant value"); ++ return constant; ++} ++ ++Value MhloBuilder::constantIndex(int64_t val) const { ++ IntegerAttr constantAttr = b().getIntegerAttr(b().getIndexType(), val); ++ return b().create(loc(), constantAttr); ++} ++ ++Value MhloBuilder::shaped_zero(mlir::Type type) const { ++ return b().create(loc(), b().getZeroAttr(type)); ++} ++ + // ============================================================================= + // IndexExpr Builder for Lowering using Shape/MHLO Dialect. + // ============================================================================= +diff --git a/src/Conversion/ONNXToMhlo/DialectBuilder.hpp b/src/Conversion/ONNXToMhlo/DialectBuilder.hpp +index 7d7ee86d..143e2ad2 100644 +--- a/src/Conversion/ONNXToMhlo/DialectBuilder.hpp ++++ b/src/Conversion/ONNXToMhlo/DialectBuilder.hpp +@@ -26,6 +26,34 @@ + + namespace onnx_mlir { + ++// ============================================================================= ++// mhlo Builder ++// ============================================================================= ++ ++struct MhloBuilder : DialectBuilder { ++ MhloBuilder(mlir::Location loc) : DialectBuilder(loc) {} ++ MhloBuilder(mlir::OpBuilder &b, mlir::Location loc) ++ : DialectBuilder(b, loc), patternRewriter(&b) {} ++ MhloBuilder(const DialectBuilder &db) : DialectBuilder(db) {} ++ virtual ~MhloBuilder() {} ++ ++ // ConstantOp ++ mlir::Value constant(mlir::Type type, double val) const; ++ mlir::Value constantIndex(int64_t val) const; ++ mlir::Value shaped_zero(mlir::Type type) const; ++ ++protected: ++ ++ // Private getters of builder (concise version). ++ mlir::OpBuilder &rewriter() const { ++ assert(patternRewriter && "rewriter is null"); ++ return *patternRewriter; ++ } ++ ++private: ++ mlir::OpBuilder *patternRewriter; ++}; ++ + // ============================================================================= + // IndexExpr Builder for Shape lowering + // ============================================================================= +@@ -43,6 +71,22 @@ protected: + mlir::Value getShapeVal(mlir::Value tensorOrMemrefValue, uint64_t i) final; + }; + ++// ============================================================================= ++// MultiDialectBuilder for Mhlo ++// ============================================================================= ++ ++// Recursive class specialized for MhloBuilder referred to as ++// mhlo. ++template ++struct MultiDialectBuilder ++ : MultiDialectBuilder { ++ MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc) ++ : MultiDialectBuilder(b, loc), mhlo(b, loc) {} ++ MultiDialectBuilder(const DialectBuilder &db) ++ : MultiDialectBuilder(db), mhlo(db) {} ++ MhloBuilder mhlo; ++}; ++ + // Recursive class specialized for AffineBuilder refereed to as affine. + template + struct MultiDialectBuilder +diff --git a/src/Conversion/ONNXToMhlo/Math/Elementwise.cpp b/src/Conversion/ONNXToMhlo/Math/Elementwise.cpp +index 26c392b8..9d958ef5 100644 +--- a/src/Conversion/ONNXToMhlo/Math/Elementwise.cpp ++++ b/src/Conversion/ONNXToMhlo/Math/Elementwise.cpp +@@ -66,6 +66,11 @@ struct MhloDialectOp { + using Op = mhlo::MaxOp; + }; + ++template <> ++struct MhloDialectOp { ++ using Op = mhlo::MinOp; ++}; ++ + template <> + struct MhloDialectOp { + using Op = mhlo::MulOp; +@@ -106,6 +111,11 @@ struct MhloDialectOp { + using Op = mhlo::TanhOp; + }; + ++template <> ++struct MhloDialectOp { ++ using Op = mhlo::SelectOp; ++}; ++ + namespace { + + template +@@ -293,6 +303,40 @@ struct ONNXElementwiseBinaryOpLoweringToMhlo : public ConversionPattern { + } + }; + ++// ONNXPReluOp(x) = alpha * x if x < 0 else x. ++template <> ++struct ONNXElementwiseBinaryOpLoweringToMhlo ++ : public ConversionPattern { ++ ONNXElementwiseBinaryOpLoweringToMhlo(MLIRContext *ctx) ++ : ConversionPattern(ONNXPReluOp::getOperationName(), 1, ctx) {} ++ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ++ ConversionPatternRewriter &rewriter) const final { ++ Location loc = op->getLoc(); ++ // Prior code here used the "analysis" version that did not generate code. ++ // Since code is actually not needed here at this time, one could use ++ // IndexExprBuilderForAnalysis createIE(loc) instead. ++ IndexExprBuilderForMhlo createShapeIE(rewriter, loc); ++ ONNXBroadcastOpShapeHelper shapeHelper(op, operands, &createShapeIE); ++ shapeHelper.computeShapeAndAssertOnFailure(); ++ ++ int64_t outputRank = shapeHelper.outputRank; ++ llvm::SmallVector broadcastedOperands = ++ getBroadcastedOperands(op, rewriter, loc, outputRank); ++ Value inp = broadcastedOperands[0]; ++ Value broadcastedSlope = broadcastedOperands[1]; ++ Type resultType = *op->result_type_begin(); ++ Value PReluActivationVal = ++ rewriter.create(loc, inp, broadcastedSlope); ++ Value broadcastedZero = getShapedZero(loc, rewriter, inp); ++ Value compareGtZero = rewriter.create( ++ loc, inp, broadcastedZero, mhlo::ComparisonDirection::GT); ++ Value resultOp = rewriter.create( ++ loc, resultType, compareGtZero, inp, PReluActivationVal); ++ rewriter.replaceOp(op, resultOp); ++ return success(); ++ } ++}; ++ + // Element-wise variadic ops lowering to Mhlo dialect. + //===----------------------------------------------------------------------===// + template +@@ -343,12 +387,15 @@ void populateLoweringONNXElementwiseOpToMhloPattern( + ONNXElementwiseCompareBinaryOpLoweringToMhlo, + ONNXElementwiseCompareBinaryOpLoweringToMhlo, + ONNXElementwiseBinaryOpLoweringToMhlo, ++ ONNXElementwiseBinaryOpLoweringToMhlo, + ONNXElementwiseVariadicOpLoweringToMhlo, + ONNXElementwiseVariadicOpLoweringToMhlo, + ONNXElementwiseVariadicOpLoweringToMhlo, + ONNXElementwiseVariadicOpLoweringToMhlo, ++ ONNXElementwiseVariadicOpLoweringToMhlo, + ONNXElementwiseVariadicOpLoweringToMhlo, +- ONNXElementwiseVariadicOpLoweringToMhlo>(ctx); ++ ONNXElementwiseVariadicOpLoweringToMhlo, ++ ONNXElementwiseVariadicOpLoweringToMhlo>(ctx); + } + + } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.cpp b/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.cpp -index 4e4adfc1..5b39e407 100644 +index 4e4adfc1..1f10571f 100644 --- a/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.cpp +++ b/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.cpp -@@ -93,4 +93,17 @@ llvm::SmallVector getBroadcastedOperands( +@@ -14,6 +14,9 @@ + //===----------------------------------------------------------------------===// + + #include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" ++#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" ++#include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp" ++ + #include "stablehlo/dialect/BroadcastUtils.h" + + using namespace mlir; +@@ -44,11 +47,6 @@ llvm::SmallVector getBroadcastedOperands(Operation *op, + Type outputType = *op->result_type_begin(); + assert(outputType.isa() && "output type is not shaped"); + ShapedType outputShapedType = outputType.cast(); +- Type elementType = +- op->getOperands()[0].getType().dyn_cast().getElementType(); +- RankedTensorType broadcastedOutputType = +- RankedTensorType::get(outputShapedType.getShape(), elementType); +- + Value resultExtents = + mlir::hlo::computeNaryElementwiseBroadcastingResultExtents( + loc, op->getOperands(), rewriter); +@@ -58,6 +56,10 @@ llvm::SmallVector getBroadcastedOperands(Operation *op, + assert(operandType != nullptr && "operand type is not ranked"); + SmallVector broadcastDimensions = llvm::to_vector<4>( + llvm::seq(outputRank - operandType.getRank(), outputRank)); ++ Type elementType = ++ operand.getType().dyn_cast().getElementType(); ++ RankedTensorType broadcastedOutputType = ++ RankedTensorType::get(outputShapedType.getShape(), elementType); + Value broadcast = rewriter.create(loc, + broadcastedOutputType, operand, resultExtents, + rewriter.getI64TensorAttr(broadcastDimensions)); +@@ -72,11 +74,6 @@ llvm::SmallVector getBroadcastedOperands( + llvm::SmallVector broadcastedOperands; + assert(outputType.isa() && "output type is not shaped"); + ShapedType outputShapedType = outputType.cast(); +- Type elementType = +- operands[0].getType().dyn_cast().getElementType(); +- RankedTensorType broadcastedOutputType = +- RankedTensorType::get(outputShapedType.getShape(), elementType); +- + Value resultExtents = + mlir::hlo::computeNaryElementwiseBroadcastingResultExtents( + loc, operands, rewriter); +@@ -86,6 +83,10 @@ llvm::SmallVector getBroadcastedOperands( + assert(operandType != nullptr && "operand type is not ranked"); + SmallVector broadcastDimensions = llvm::to_vector<4>( + llvm::seq(outputRank - operandType.getRank(), outputRank)); ++ Type elementType = ++ operand.getType().dyn_cast().getElementType(); ++ RankedTensorType broadcastedOutputType = ++ RankedTensorType::get(outputShapedType.getShape(), elementType); + Value broadcast = rewriter.create(loc, + broadcastedOutputType, operand, resultExtents, + rewriter.getI64TensorAttr(broadcastDimensions)); +@@ -93,4 +94,142 @@ llvm::SmallVector getBroadcastedOperands( } return broadcastedOperands; } + -+ElementsAttr getElementAttributeFromMhloValue(Value value) { ++ElementsAttr getElementAttributeFromConstValue(Value value) { + auto definingOp = value.getDefiningOp(); + if (auto constantOp = dyn_cast_or_null(definingOp)) { + return constantOp.getValue().dyn_cast(); @@ -59,10 +369,135 @@ index 4e4adfc1..5b39e407 100644 + } + return nullptr; +} ++ ++DenseIntElementsAttr GetI64ElementsAttr( ++ ArrayRef values, Builder *builder) { ++ RankedTensorType ty = RankedTensorType::get( ++ {static_cast(values.size())}, builder->getIntegerType(64)); ++ return DenseIntElementsAttr::get(ty, values); ++} ++ ++namespace { ++// Returns the DenseElementsAttr of input if it's a mhlo constant or ++// onnx.Constant. Otherwise returns a nullptr attribute. ++DenseElementsAttr getDenseElementAttrFromConstValue(mlir::Value value) { ++ Operation *definingOp = value.getDefiningOp(); ++ if (auto globalOp = dyn_cast_or_null(definingOp)) { ++ return globalOp.getValueAttr().dyn_cast(); ++ } else if (auto constOp = dyn_cast_or_null(definingOp)) { ++ if (constOp.getValue().has_value()) ++ return constOp.getValueAttr().dyn_cast(); ++ } ++ return nullptr; ++} ++} // namespace ++ ++// Emit an ONNXSqueezeOp. If the input is constant, do const propagation, ++/// and return a constant. ++Value foldOrEmitONNXSqueezeOpMhlo(ConversionPatternRewriter &rewriter, ++ Location loc, Type resultType, Value input, int64_t axis) { ++ MultiDialectBuilder create(rewriter, loc); ++ TensorType tensorType = create.onnx.toTensor(resultType); ++ if (DenseElementsAttr inputElements = ++ getDenseElementAttrFromConstValue(input)) { ++ DenseElementsAttr squeezedElements = inputElements.reshape(tensorType); ++ Value constVal = create.onnx.constant(squeezedElements); ++ return constVal; ++ } else { ++ return rewriter.create(loc, tensorType, ++ create.onnx.toTensor(input), create.onnx.constantInt64({axis})) ++ .getResult(); ++ } ++} ++ ++/// Emit an ONNXUnsqueezeOp. If the input is constant, do const ++/// propagation, and return a constant. ++Value foldOrEmitONNXUnsqueezeOpMhlo(ConversionPatternRewriter &rewriter, ++ Location loc, Type resultType, Value input, int64_t axis) { ++ MultiDialectBuilder create(rewriter, loc); ++ TensorType tensorType = create.onnx.toTensor(resultType); ++ if (DenseElementsAttr inputElements = ++ getDenseElementAttrFromConstValue(input)) { ++ DenseElementsAttr unsqueezedElements = inputElements.reshape(tensorType); ++ Value constVal = create.onnx.constant(unsqueezedElements); ++ return constVal; ++ } else { ++ return rewriter.create(loc, tensorType, ++ create.onnx.toTensor(input), create.onnx.constantInt64({axis})) ++ .getResult(); ++ } ++} ++ ++/// Emit an ONNXSplitOp. If the input is constant, do const propagation, and ++/// return constants. ++/// Only support evenly splitting. ++std::vector foldOrEmitONNXSplitOpMhlo(ConversionPatternRewriter &rewriter, ++ Location loc, ArrayRef resultTypes, Value input, int64_t axis) { ++ MultiDialectBuilder create(rewriter, loc); ++ std::vector resVals; ++ int outputNum = resultTypes.size(); ++ if (DenseElementsAttr inputElements = ++ getDenseElementAttrFromConstValue(input)) { ++ auto inputShape = inputElements.getType().getShape(); ++ assert(outputNum == 0 || inputShape[axis] % outputNum == 0); ++ int64_t sizeOfEachSplit = outputNum != 0 ? inputShape[axis] / outputNum : 0; ++ SmallVector sizes(outputNum, sizeOfEachSplit); ++ ++ OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); ++ std::vector splits = ++ elementsBuilder.split(inputElements, axis, sizes); ++ for (ElementsAttr splitElements : splits) { ++ // Avoid DisposableElementsAttr during conversion. ++ DenseElementsAttr denseSplitElements = ++ elementsBuilder.toDenseElementsAttr(splitElements); ++ Value constVal = create.onnx.constant(denseSplitElements); ++ resVals.emplace_back(constVal); ++ } ++ } else { ++ SmallVector convertedTypes; ++ SmallVector splitSizesI64; ++ for (auto t : resultTypes) { ++ convertedTypes.emplace_back(create.onnx.toTensor(t)); ++ splitSizesI64.emplace_back(t.cast().getShape()[axis]); ++ } ++ Value splitSizes = create.onnx.constantInt64(splitSizesI64); ++ ONNXSplitOp split = rewriter.create(loc, convertedTypes, ++ create.onnx.toTensor(input), splitSizes, ++ /*axis=*/axis, nullptr); ++ for (int i = 0; i < outputNum; ++i) ++ resVals.emplace_back(split.getOutputs()[i]); ++ } ++ return resVals; ++} ++ ++/// Emit an ONNXTransposeOp. If the input is constant, do const propagation, ++/// and return a constant. ++Value foldOrEmitONNXTransposeOpMhlo(ConversionPatternRewriter &rewriter, ++ Location loc, Type resultType, Value input, ArrayAttr permAttr) { ++ MultiDialectBuilder create(rewriter, loc); ++ if (DenseElementsAttr inputElements = ++ getDenseElementAttrFromConstValue(input)) { ++ SmallVector perm; ++ for (auto permVal : permAttr.getValue()) ++ perm.emplace_back(permVal.cast().getInt()); ++ ++ OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); ++ ElementsAttr transposedElements = ++ elementsBuilder.transpose(inputElements, perm); ++ // Avoid DisposableElementsAttr during conversion. ++ DenseElementsAttr denseTransposedElements = ++ elementsBuilder.toDenseElementsAttr(transposedElements); ++ Value constVal = create.onnx.constant(denseTransposedElements); ++ return constVal; ++ } else { ++ return rewriter.create(loc, create.onnx.toTensor(resultType), ++ create.onnx.toTensor(input), permAttr).getResult(); ++ } ++} + } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp b/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp -index ec5a9f2b..2e62686a 100644 +index ec5a9f2b..9b6fd411 100644 --- a/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp +++ b/src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp @@ -19,6 +19,7 @@ @@ -73,16 +508,67 @@ index ec5a9f2b..2e62686a 100644 #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -@@ -113,6 +114,8 @@ llvm::SmallVector getBroadcastedOperands( +@@ -28,6 +29,8 @@ + + #include "mhlo/IR/hlo_ops.h" + ++#include "src/Conversion/ONNXToMhlo/DialectBuilder.hpp" ++#include "src/Dialect/Mlir/DialectBuilder.hpp" + #include "src/Dialect/Mlir/IndexExpr.hpp" + #include "src/Dialect/ONNX/DialectBuilder.hpp" + #include "src/Dialect/ONNX/ONNXOps.hpp" +@@ -113,6 +116,40 @@ llvm::SmallVector getBroadcastedOperands( llvm::SmallVector &operands, Type outputType, ConversionPatternRewriter &rewriter, Location loc, int64_t outputRank); - -+mlir::ElementsAttr getElementAttributeFromMhloValue(mlir::Value value); + ++mlir::ElementsAttr getElementAttributeFromConstValue(mlir::Value value); ++ ++DenseIntElementsAttr GetI64ElementsAttr( ++ ArrayRef values, Builder *builder); ++ ++//===----------------------------------------------------------------------===// ++// Fold and emit support. ++//===----------------------------------------------------------------------===// ++ ++/// Emit an ONNXSqueezeOp. If the input is constant, do const propagation, and ++/// return a constant. ++mlir::Value foldOrEmitONNXSqueezeOpMhlo( ++ mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, ++ mlir::Type resultType, mlir::Value input, int64_t axis); ++ ++/// Emit an ONNXUnsqueezeOp. If the input is constant, do const propagation, and ++/// return a constant. ++mlir::Value foldOrEmitONNXUnsqueezeOpMhlo( ++ mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, ++ mlir::Type resultType, mlir::Value input, int64_t axis); ++ ++/// Emit an ONNXSplitOp. If the input is constant, do const propagation, and ++/// return constants. ++/// Only support evenly splitting. ++std::vector foldOrEmitONNXSplitOpMhlo( ++ mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, ++ llvm::ArrayRef resultTypes, mlir::Value input, int64_t axis); ++ ++/// Emit an ONNXTransposeOp. If the input is constant, do const propagation, and ++/// return a constant. ++mlir::Value foldOrEmitONNXTransposeOpMhlo(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, mlir::Type resultType, mlir::Value input, ++ mlir::ArrayAttr permAttr); + // `Math` directory methods: void populateLoweringONNXClipOpToMhloPattern( RewritePatternSet &, MLIRContext *); -@@ -148,10 +151,16 @@ void populateLoweringONNXFlattenOpToMhloPattern( +@@ -133,6 +170,9 @@ void populateLoweringONNXNormalizationOpToMhloPattern( + RewritePatternSet &, MLIRContext *); + void populateLoweringONNXPoolingOpToMhloPattern( + RewritePatternSet &, MLIRContext *); ++// `RNN` directory methods: ++void populateLoweringONNXLSTMOpToMhloPattern( ++ RewritePatternSet &, MLIRContext *); + // `Tensor` directory methods: + void populateLoweringONNXArgMaxOpToMhloPattern( + RewritePatternSet &, MLIRContext *); +@@ -148,10 +188,17 @@ void populateLoweringONNXFlattenOpToMhloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXGatherOpToMhloPattern( RewritePatternSet &, MLIRContext *); @@ -90,8 +576,9 @@ index ec5a9f2b..2e62686a 100644 + RewritePatternSet &, MLIRContext *); void populateLoweringONNXIdentityOpToMhloPattern( RewritePatternSet &, MLIRContext *); -+void populateLoweringONNXPadOpToMhloPattern( ++void populateLoweringONNXOneHotOpToMhloPattern( + RewritePatternSet &, MLIRContext *); ++void populateLoweringONNXPadOpToMhloPattern(RewritePatternSet &, MLIRContext *); void populateLoweringONNXReshapeOpToMhloPattern( RewritePatternSet &, MLIRContext *); +void populateLoweringONNXScatterNDOpToMhloPattern( @@ -99,12 +586,1304 @@ index ec5a9f2b..2e62686a 100644 void populateLoweringONNXShapeOpToMhloPattern( RewritePatternSet &, MLIRContext *); void populateLoweringONNXSliceOpToMhloPattern( +diff --git a/src/Conversion/ONNXToMhlo/RNN/LSTM.cpp b/src/Conversion/ONNXToMhlo/RNN/LSTM.cpp +new file mode 100644 +index 00000000..9bc9f05a +--- /dev/null ++++ b/src/Conversion/ONNXToMhlo/RNN/LSTM.cpp +@@ -0,0 +1,625 @@ ++/* ++ * SPDX-License-Identifier: Apache-2.0 ++ */ ++ ++//===--------------- LSTM.cpp - Lowering LSTM Op --------------------------===// ++// ++// Copyright 2023 ++// ++// ============================================================================= ++// ++// This file lowers the ONNX LSTM Operators to Mhlo dialect. ++// ++//===----------------------------------------------------------------------===// ++ ++#include "src/Conversion/ONNXToMhlo/DialectBuilder.hpp" ++#include "src/Conversion/ONNXToMhlo/RNN/RNNBase.hpp" ++#include "src/Dialect/Mlir/DialectBuilder.hpp" ++ ++#include "llvm/Support/Debug.h" ++ ++#define DEBUG_TYPE "lstm" ++ ++using namespace mlir; ++ ++namespace onnx_mlir { ++ ++namespace mhlo { ++ ++struct LstmState { ++ // returned states. ++ SmallVector forwardAllH; ++ SmallVector reverseAllH; ++ Value ht; ++ Value ct; ++ // intermediate states. ++ Value forwardHt; ++ Value reverseHt; ++ Value forwardCt; ++ Value reverseCt; ++}; ++ ++struct LstmActivationPack { ++ RNNActivation f; ++ RNNActivation g; ++ RNNActivation h; ++}; ++ ++struct LstmWeightPack { ++ Value WT; ++ Value RT; ++}; ++ ++struct LstmBiasPack { ++ bool hasBias = false; ++ Value Wbi; ++ Value Wbo; ++ Value Wbf; ++ Value Wbc; ++ Value Rbi; ++ Value Rbo; ++ Value Rbf; ++ Value Rbc; ++ // Put peephole here. ++ bool hasPeephole = false; ++ Value Pi; ++ Value Po; ++ Value Pf; ++}; ++ ++template <> ++bool hasAllNoneOutput(ONNXLSTMOp *op) { ++ return (isNoneValue(op->getY()) && isNoneValue(op->getYH()) && ++ isNoneValue(op->getYC())); ++} ++ ++template <> ++std::tuple ++getActivationPack(ONNXLSTMOp *op) { ++ auto direction = op->getDirection(); ++ auto activations = op->getActivations(); ++ auto activationAlpha = op->getActivationAlpha(); ++ auto activationBeta = op->getActivationBeta(); ++ ++ LstmActivationPack activationForward, activationReverse; ++ ++ // Get activation function name. ++ // Default forward functions ++ activationForward.f.name = "sigmoid"; ++ activationForward.g.name = "tanh"; ++ activationForward.h.name = "tanh"; ++ // Default backward functions ++ activationReverse.f.name = "sigmoid"; ++ activationReverse.g.name = "tanh"; ++ activationReverse.h.name = "tanh"; ++ if (activations) { ++ ArrayAttr activationArrAttr = activations.value(); ++ if (direction == FORWARD || direction == BIDIRECTIONAL) { ++ // Forward activations. ++ if (activationArrAttr.size() > 0) { ++ activationForward.f.name = ++ activationArrAttr[0].cast().getValue(); ++ } ++ if (activationArrAttr.size() > 1) { ++ activationForward.g.name = ++ activationArrAttr[1].cast().getValue(); ++ } ++ if (activationArrAttr.size() > 2) { ++ activationForward.h.name = ++ activationArrAttr[2].cast().getValue(); ++ } ++ } ++ ++ // Reverse activations. ++ if (direction == REVERSE || direction == BIDIRECTIONAL) { ++ unsigned int startIndex = (direction == REVERSE) ? 0 : 3; ++ if (activationArrAttr.size() > startIndex) { ++ activationReverse.f.name = ++ activationArrAttr[startIndex].cast().getValue(); ++ } ++ if (activationArrAttr.size() > startIndex + 1) { ++ activationReverse.g.name = ++ activationArrAttr[startIndex + 1].cast().getValue(); ++ } ++ if (activationArrAttr.size() > startIndex + 2) { ++ activationReverse.h.name = ++ activationArrAttr[startIndex + 2].cast().getValue(); ++ } ++ } ++ } ++ ++ // Get alpha attributes. ++ if (activationAlpha) { ++ ArrayAttr activationArrAttr = activationAlpha.value(); ++ if (direction == FORWARD || direction == BIDIRECTIONAL) { ++ // Forward activations. ++ if (activationArrAttr.size() > 0) { ++ activationForward.f.alpha = activationArrAttr[0].cast(); ++ } ++ if (activationArrAttr.size() > 1) { ++ activationForward.g.alpha = activationArrAttr[1].cast(); ++ } ++ if (activationArrAttr.size() > 2) { ++ activationForward.h.alpha = activationArrAttr[2].cast(); ++ } ++ } ++ ++ // Reverse activations. ++ if (direction == REVERSE || direction == BIDIRECTIONAL) { ++ unsigned int startIndex = (direction == REVERSE) ? 0 : 3; ++ if (activationArrAttr.size() > startIndex) { ++ activationReverse.f.alpha = ++ activationArrAttr[startIndex].cast(); ++ } ++ if (activationArrAttr.size() > startIndex + 1) { ++ activationReverse.g.alpha = ++ activationArrAttr[startIndex + 1].cast(); ++ } ++ if (activationArrAttr.size() > startIndex + 2) { ++ activationReverse.h.alpha = ++ activationArrAttr[startIndex + 2].cast(); ++ } ++ } ++ } ++ ++ // Get beta attributes. ++ if (activationBeta) { ++ ArrayAttr activationArrAttr = activationBeta.value(); ++ if (direction == FORWARD || direction == BIDIRECTIONAL) { ++ // Forward activations. ++ if (activationArrAttr.size() > 0) { ++ activationForward.f.beta = activationArrAttr[0].cast(); ++ } ++ if (activationArrAttr.size() > 1) { ++ activationForward.g.beta = activationArrAttr[1].cast(); ++ } ++ if (activationArrAttr.size() > 2) { ++ activationForward.h.beta = activationArrAttr[2].cast(); ++ } ++ } ++ ++ // Reverse activations. ++ if (direction == REVERSE || direction == BIDIRECTIONAL) { ++ unsigned int startIndex = (direction == REVERSE) ? 0 : 3; ++ if (activationArrAttr.size() > startIndex) { ++ activationReverse.f.beta = ++ activationArrAttr[startIndex].cast(); ++ } ++ if (activationArrAttr.size() > startIndex + 1) { ++ activationReverse.g.beta = ++ activationArrAttr[startIndex + 1].cast(); ++ } ++ if (activationArrAttr.size() > startIndex + 2) { ++ activationReverse.h.beta = ++ activationArrAttr[startIndex + 2].cast(); ++ } ++ } ++ } ++ ++ return std::make_tuple(activationForward, activationReverse); ++} ++ ++template <> ++std::tuple ++getWeightPack( ++ ConversionPatternRewriter &rewriter, Location loc, ONNXLSTMOp *op) { ++ // Return values. ++ LstmWeightPack weightForward, weightReverse; ++ ++ // parameter weight: [direction, 4*hiddenSize, inputSize] ++ Value W = op->getW(); ++ // recurrence weight: [direction, 4*hiddenSize, hiddenSize] ++ Value R = op->getR(); ++ // direction ++ StringRef direction = op->getDirection(); ++ ++ ArrayRef wShape = W.getType().cast().getShape(); ++ Type elementType = W.getType().cast().getElementType(); ++ int64_t hiddenSize = wShape[1] / 4; ++ int64_t inputSize = wShape[2]; ++ ++ // RankedTensorType types for parameter weights. ++ auto w3DTy = ++ RankedTensorType::get({1, 4 * hiddenSize, inputSize}, elementType); ++ auto w2DTy = RankedTensorType::get({4 * hiddenSize, inputSize}, elementType); ++ auto wTranspose2DTy = ++ RankedTensorType::get({inputSize, 4 * hiddenSize}, elementType); ++ SmallVector w3D2Ty(2, w3DTy); ++ ++ // RankedTensorType types for recurrence weights. ++ auto r3DTy = ++ RankedTensorType::get({1, 4 * hiddenSize, hiddenSize}, elementType); ++ auto r2DTy = RankedTensorType::get({4 * hiddenSize, hiddenSize}, elementType); ++ auto rTranspose2DTy = ++ RankedTensorType::get({hiddenSize, 4 * hiddenSize}, elementType); ++ SmallVector r3D2Ty(2, r3DTy); ++ ++ // Squeeze the direction axis from W and R. ++ Value fW, bW, fR, bR; ++ if (direction == FORWARD) { ++ fW = foldOrEmitONNXSqueezeOpMhlo(rewriter, loc, w2DTy, W, /*axis=*/0); ++ fR = foldOrEmitONNXSqueezeOpMhlo(rewriter, loc, r2DTy, R, /*axis=*/0); ++ } else if (direction == REVERSE) { ++ bW = foldOrEmitONNXSqueezeOpMhlo(rewriter, loc, w2DTy, W, /*axis=*/0); ++ bR = foldOrEmitONNXSqueezeOpMhlo(rewriter, loc, r2DTy, R, /*axis=*/0); ++ } else { // BIDIRECTIONAL ++ // W ++ std::vector vals = ++ foldOrEmitONNXSplitOpMhlo(rewriter, loc, w3D2Ty, W, 0); ++ fW = foldOrEmitONNXSqueezeOpMhlo( ++ rewriter, loc, w2DTy, vals[0], /*axis=*/0); ++ bW = foldOrEmitONNXSqueezeOpMhlo( ++ rewriter, loc, w2DTy, vals[1], /*axis=*/0); ++ // R ++ vals.clear(); ++ vals = foldOrEmitONNXSplitOpMhlo(rewriter, loc, r3D2Ty, R, 0); ++ fR = foldOrEmitONNXSqueezeOpMhlo( ++ rewriter, loc, r2DTy, vals[0], /*axis=*/0); ++ bR = foldOrEmitONNXSqueezeOpMhlo( ++ rewriter, loc, r2DTy, vals[1], /*axis=*/0); ++ } ++ ++ // Transpose W and R. ++ ArrayAttr permAttr = rewriter.getI64ArrayAttr({1, 0}); ++ if (direction == FORWARD || direction == BIDIRECTIONAL) { ++ // W ++ weightForward.WT = foldOrEmitONNXTransposeOpMhlo( ++ rewriter, loc, wTranspose2DTy, fW, permAttr); ++ // R ++ weightForward.RT = foldOrEmitONNXTransposeOpMhlo( ++ rewriter, loc, rTranspose2DTy, fR, permAttr); ++ } ++ if (direction == REVERSE || direction == BIDIRECTIONAL) { ++ // W ++ weightReverse.WT = foldOrEmitONNXTransposeOpMhlo( ++ rewriter, loc, wTranspose2DTy, bW, permAttr); ++ // R ++ weightReverse.RT = foldOrEmitONNXTransposeOpMhlo( ++ rewriter, loc, rTranspose2DTy, bR, permAttr); ++ } ++ return std::make_tuple(weightForward, weightReverse); ++} ++ ++template <> ++std::tuple getBiasPack( ++ ConversionPatternRewriter &rewriter, Location loc, ONNXLSTMOp *op) { ++ // Return values. ++ LstmBiasPack biasForward, biasReverse; ++ ++ // bias: [direction, 8*hiddenSize] for both parameter and recurrence weights. ++ Value B = op->getB(); ++ // peephold: [direction, 3*hiddenSize] for input, output and forget gates. ++ Value P = op->getP(); ++ ++ // direction ++ StringRef direction = op->getDirection(); ++ ++ // Split B. ++ if (!isNoneValue(B)) { ++ ArrayRef bShape = B.getType().cast().getShape(); ++ Type elementType = B.getType().cast().getElementType(); ++ int64_t hiddenSize = bShape[1] / 8; ++ ++ // MemRef types. ++ auto bType2D = RankedTensorType::get({1, 8 * hiddenSize}, elementType); ++ auto bType1D = RankedTensorType::get({8 * hiddenSize}, elementType); ++ auto bSplitType1D = RankedTensorType::get({hiddenSize}, elementType); ++ SmallVector split1D8Ty(8, bSplitType1D); ++ SmallVector split2D2Ty(2, bType2D); ++ ++ // Squeeze the direction axis from B. ++ Value fB, bB; ++ if (direction == FORWARD) { ++ fB = ++ foldOrEmitONNXSqueezeOpMhlo(rewriter, loc, bType1D, B, /*axis=*/0); ++ } else if (direction == REVERSE) { ++ bB = ++ foldOrEmitONNXSqueezeOpMhlo(rewriter, loc, bType1D, B, /*axis=*/0); ++ } else { // BIDIRECTIONAL ++ std::vector vals; ++ vals = foldOrEmitONNXSplitOpMhlo(rewriter, loc, split2D2Ty, B, 0); ++ fB = foldOrEmitONNXSqueezeOpMhlo( ++ rewriter, loc, bType1D, vals[0], /*axis=*/0); ++ bB = foldOrEmitONNXSqueezeOpMhlo( ++ rewriter, loc, bType1D, vals[1], /*axis=*/0); ++ } ++ ++ // Split B into individual bias tensors. ++ if (direction == FORWARD || direction == BIDIRECTIONAL) { ++ std::vector vals = ++ foldOrEmitONNXSplitOpMhlo(rewriter, loc, split1D8Ty, fB, 0); ++ biasForward.Wbi = vals[0]; ++ biasForward.Wbo = vals[1]; ++ biasForward.Wbf = vals[2]; ++ biasForward.Wbc = vals[3]; ++ biasForward.Rbi = vals[4]; ++ biasForward.Rbo = vals[5]; ++ biasForward.Rbf = vals[6]; ++ biasForward.Rbc = vals[7]; ++ biasForward.hasBias = true; ++ } ++ if (direction == REVERSE || direction == BIDIRECTIONAL) { ++ std::vector vals = ++ foldOrEmitONNXSplitOpMhlo(rewriter, loc, split1D8Ty, bB, 0); ++ biasReverse.Wbi = vals[0]; ++ biasReverse.Wbo = vals[1]; ++ biasReverse.Wbf = vals[2]; ++ biasReverse.Wbc = vals[3]; ++ biasReverse.Rbi = vals[4]; ++ biasReverse.Rbo = vals[5]; ++ biasReverse.Rbf = vals[6]; ++ biasReverse.Rbc = vals[7]; ++ biasReverse.hasBias = true; ++ } ++ } ++ ++ // Split P. ++ if (!isNoneValue(P)) { ++ ArrayRef pShape = P.getType().cast().getShape(); ++ Type elementType = P.getType().cast().getElementType(); ++ int64_t hiddenSize = pShape[1] / 3; ++ ++ // MemRef types. ++ auto pType2D = RankedTensorType::get({1, 3 * hiddenSize}, elementType); ++ auto pType1D = RankedTensorType::get({3 * hiddenSize}, elementType); ++ auto pSplitType1D = RankedTensorType::get({hiddenSize}, elementType); ++ SmallVector split1D3Ty(3, pSplitType1D); ++ SmallVector split2D2Ty(2, pType2D); ++ ++ // Squeeze the direction axis from P. ++ Value fP, bP; ++ if (direction == FORWARD) { ++ fP = ++ foldOrEmitONNXSqueezeOpMhlo(rewriter, loc, pType1D, P, /*axis=*/0); ++ } else if (direction == REVERSE) { ++ bP = ++ foldOrEmitONNXSqueezeOpMhlo(rewriter, loc, pType1D, P, /*axis=*/0); ++ } else { // BIDIRECTIONAL ++ std::vector vals = ++ foldOrEmitONNXSplitOpMhlo(rewriter, loc, split2D2Ty, P, 0); ++ fP = foldOrEmitONNXSqueezeOpMhlo( ++ rewriter, loc, pType1D, vals[0], /*axis=*/0); ++ bP = foldOrEmitONNXSqueezeOpMhlo( ++ rewriter, loc, pType1D, vals[1], /*axis=*/0); ++ } ++ ++ // Split P into individual tensors. ++ if (direction == FORWARD || direction == BIDIRECTIONAL) { ++ std::vector vals = ++ foldOrEmitONNXSplitOpMhlo(rewriter, loc, split1D3Ty, fP, 0); ++ biasForward.Pi = vals[0]; ++ biasForward.Po = vals[1]; ++ biasForward.Pf = vals[2]; ++ biasForward.hasPeephole = true; ++ } ++ if (direction == REVERSE || direction == BIDIRECTIONAL) { ++ std::vector vals = ++ foldOrEmitONNXSplitOpMhlo(rewriter, loc, split1D3Ty, bP, 0); ++ biasReverse.Pi = vals[0]; ++ biasReverse.Po = vals[1]; ++ biasReverse.Pf = vals[2]; ++ biasReverse.hasPeephole = true; ++ } ++ } ++ ++ return std::make_tuple(biasForward, biasReverse); ++} ++ ++template <> ++LstmState allocAndInitializeStates( ++ ConversionPatternRewriter &rewriter, Location loc, ONNXLSTMOp *op, ++ typename ONNXLSTMOp::Adaptor operandAdaptor) { ++ LstmState state; ++ ++ // direction ++ StringRef direction = op->getDirection(); ++ ++ // Insert allocation and deallocation for the results of this operation. ++ // If the result is not returned, then no allocation happens. ++ // Y :: [seq_length, num_directions, batch_size, hidden_size] ++ // Y_h :: [num_directions, batch_size, hidden_size] ++ state.ht = allocHiddenOrCell(rewriter, loc, operandAdaptor.getX(), ++ operandAdaptor.getW(), operandAdaptor.getR()); ++ // Y_c :: [num_directions, batch_size, hidden_size] ++ state.ct = allocHiddenOrCell(rewriter, loc, operandAdaptor.getX(), ++ operandAdaptor.getW(), operandAdaptor.getR()); ++ ++ // Insert allocation and deallocation the intermediate Ht and Ct for the ++ // forward and reverse directions. ++ // Ht :: [batch_size, hidden_size] ++ // Ct :: [batch_size, hidden_size] ++ if (direction == FORWARD || direction == BIDIRECTIONAL) { ++ state.forwardHt = allocIntermediateState( ++ rewriter, loc, operandAdaptor.getX(), operandAdaptor.getR()); ++ state.forwardCt = allocIntermediateState( ++ rewriter, loc, operandAdaptor.getX(), operandAdaptor.getR()); ++ } ++ if (direction == REVERSE || direction == BIDIRECTIONAL) { ++ state.reverseHt = allocIntermediateState( ++ rewriter, loc, operandAdaptor.getX(), operandAdaptor.getR()); ++ state.reverseCt = allocIntermediateState( ++ rewriter, loc, operandAdaptor.getX(), operandAdaptor.getR()); ++ } ++ ++ // Initialize Ht and Ct. ++ initializeIntermediateStates(rewriter, loc, state.forwardHt, state.reverseHt, ++ state.forwardCt, state.reverseCt, operandAdaptor.getInitialH(), ++ operandAdaptor.getInitialC(), ++ operandAdaptor.getX().getType().cast().getElementType(), ++ direction, /*onlyHidden=*/false); ++ return state; ++} ++ ++template <> ++void calculateState(ConversionPatternRewriter &rewriter, Location loc, Value Xt, ++ LstmState &state, LstmActivationPack activationPack, ++ LstmWeightPack weightPack, LstmBiasPack biasPack, Value sequenceIV, ++ Value directionIV, Value sequenceLens, Value initialH, bool isForward) { ++ // Equations for LSTM. ++ // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) ++ // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) ++ // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) ++ // Ct = ft (.) Ct-1 + it (.) ct ++ // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) ++ // Ht = ot (.) h(Ct) ++ ++ MultiDialectBuilder create(rewriter, loc); ++ ++ ArrayRef xtShape = Xt.getType().cast().getShape(); ++ int64_t batchSize = xtShape[0]; ++ ++ // Get Ht, Ct. ++ Value Ht = (isForward) ? state.forwardHt : state.reverseHt; ++ Value Ct = (isForward) ? state.forwardCt : state.reverseCt; ++ ++ ArrayRef htShape = Ht.getType().cast().getShape(); ++ int64_t hiddenSize = htShape[1]; ++ ++ // Frequently used types. ++ RankedTensorType matrixType = Ht.getType().cast(); ++ Type elementType = matrixType.getElementType(); ++ RankedTensorType matrixAllGatesType = ++ RankedTensorType::get({batchSize, 4 * hiddenSize}, elementType); ++ ++ // Do matrix multiplications. ++ // Xt * (Wi^T ++ Wo^T ++ Wf^T ++ Wc^T) ++ // Ht * (Ri^T ++ Ro^T ++ Rf^T ++ Rc^T) ++ // where '++' is matrix concatenation. ++ // XtWT: [B, 4H], HtRT: [B, 4H] ++ Value XtWT = create.onnx.matmul(matrixAllGatesType, Xt, weightPack.WT); ++ Value HtRT = create.onnx.matmul(matrixAllGatesType, Ht, weightPack.RT); ++ Value commonSum = create.onnx.add(XtWT, HtRT); ++ RankedTensorType matrixSingleGateType = ++ RankedTensorType::get({batchSize, hiddenSize}, elementType); ++ Value zeroIndex = create.onnx.constantInt64({0}); ++ Value oneIndex = create.onnx.constantInt64({1}); ++ Value oneHiddenIndex = create.onnx.constantInt64({hiddenSize}); ++ Value twoHiddenIndex = create.onnx.constantInt64({2 * hiddenSize}); ++ Value threeHiddenIndex = create.onnx.constantInt64({3 * hiddenSize}); ++ Value fourHiddenIndex = create.onnx.constantInt64({4 * hiddenSize}); ++ Value it = create.onnx.slice(matrixSingleGateType, commonSum, zeroIndex, ++ oneHiddenIndex, oneIndex, oneIndex); ++ Value ot = create.onnx.slice(matrixSingleGateType, commonSum, oneHiddenIndex, ++ twoHiddenIndex, oneIndex, oneIndex); ++ Value ft = create.onnx.slice(matrixSingleGateType, commonSum, twoHiddenIndex, ++ threeHiddenIndex, oneIndex, oneIndex); ++ Value ct = create.onnx.slice(matrixSingleGateType, commonSum, ++ threeHiddenIndex, fourHiddenIndex, oneIndex, oneIndex); ++ if (biasPack.hasBias) { ++ it = create.onnx.add(it, biasPack.Wbi); ++ it = create.onnx.add(it, biasPack.Rbi); ++ } ++ if (biasPack.hasPeephole) { ++ Value PiCt = create.onnx.mul(biasPack.Pi, Ct); ++ it = create.onnx.add(it, PiCt); ++ } ++ it = applyActivation(rewriter, loc, activationPack.f, it); ++ if (biasPack.hasBias) { ++ ft = create.onnx.add(ft, biasPack.Wbf); ++ ft = create.onnx.add(ft, biasPack.Rbf); ++ } ++ if (biasPack.hasPeephole) { ++ Value PfCt = create.onnx.mul(biasPack.Pf, Ct); ++ ft = create.onnx.add(ft, PfCt); ++ } ++ ft = applyActivation(rewriter, loc, activationPack.f, ft); ++ if (biasPack.hasBias) { ++ ct = create.onnx.add(ct, biasPack.Wbc); ++ ct = create.onnx.add(ct, biasPack.Rbc); ++ } ++ ct = applyActivation(rewriter, loc, activationPack.g, ct); ++ ++ Value ftCt = create.onnx.mul(ft, Ct); ++ Value itct = create.onnx.mul(it, ct); ++ Value nextCt = create.onnx.add(ftCt, itct); ++ ++ if (biasPack.hasBias) { ++ ot = create.onnx.add(ot, biasPack.Wbo); ++ ot = create.onnx.add(ot, biasPack.Rbo); ++ } ++ if (biasPack.hasPeephole) { ++ Value PoCt = create.onnx.mul(biasPack.Po, nextCt); ++ ot = create.onnx.add(ot, PoCt); ++ } ++ ot = applyActivation(rewriter, loc, activationPack.f, ot); ++ // Ht = ot (.) h(Ct) ++ Value nextHt = applyActivation(rewriter, loc, activationPack.h, nextCt); ++ nextHt = create.onnx.mul(ot, nextHt); ++ if (isForward) { ++ state.forwardHt = nextHt; ++ state.forwardCt = nextCt; ++ } else { ++ state.reverseHt = nextHt; ++ state.reverseCt = nextCt; ++ } ++ RankedTensorType unsqueezedHtType = ++ RankedTensorType::get({1, 1, batchSize, hiddenSize}, elementType); ++ if (isForward) ++ state.forwardAllH.emplace_back(create.onnx.unsqueeze( ++ unsqueezedHtType, nextHt, create.onnx.constantInt64({0, 1}))); ++ else ++ state.reverseAllH.insert(state.reverseAllH.begin(), ++ create.onnx.unsqueeze(unsqueezedHtType, nextHt, ++ create.onnx.constantInt64({0, 1}))); ++} ++ ++template <> ++void stateToOutput(ConversionPatternRewriter &rewriter, ++ Location loc, ONNXLSTMOp *op, LstmState state, ++ std::vector &outputs) { ++ Value noneValue; ++ auto direction = op->getDirection(); ++ MultiDialectBuilder create(rewriter, loc); ++ // First output: all sequences. ++ if (isNoneValue(op->getY())) ++ outputs.emplace_back(noneValue); ++ else { ++ if (direction == FORWARD) { ++ outputs.emplace_back(create.onnx.concat( ++ op->getY().getType(), ValueRange(state.forwardAllH), 0)); ++ } else if (direction == REVERSE) { ++ outputs.emplace_back(create.onnx.concat( ++ op->getY().getType(), ValueRange(state.reverseAllH), 0)); ++ } else { ++ auto outputShape = op->getY().getType().cast().getShape(); ++ RankedTensorType singleDirectionType = RankedTensorType::get( ++ {outputShape[0], 1, outputShape[2], outputShape[3]}, ++ op->getY().getType().cast().getElementType()); ++ outputs.emplace_back(create.onnx.concat(op->getY().getType(), ++ {create.onnx.concat( ++ singleDirectionType, ValueRange(state.forwardAllH), 0), ++ create.onnx.concat( ++ singleDirectionType, ValueRange(state.reverseAllH), 0)}, ++ 1)); ++ } ++ } ++ // Second output: hidden. ++ if (isNoneValue(op->getYH())) ++ outputs.emplace_back(noneValue); ++ else { ++ stateToOutputForHiddenOrCell( ++ rewriter, loc, state.forwardHt, state.reverseHt, direction, state.ht); ++ outputs.emplace_back(state.ht); ++ } ++ // Third output: cell. ++ if (isNoneValue(op->getYC())) ++ outputs.emplace_back(noneValue); ++ else { ++ stateToOutputForHiddenOrCell( ++ rewriter, loc, state.forwardCt, state.reverseCt, direction, state.ct); ++ outputs.emplace_back(state.ct); ++ } ++} ++ ++} // namespace mhlo ++ ++void populateLoweringONNXLSTMOpToMhloPattern( ++ RewritePatternSet &patterns, MLIRContext *ctx) { ++ patterns.insert>( ++ ctx, UNROLL); ++} ++ ++} // namespace onnx_mlir +diff --git a/src/Conversion/ONNXToMhlo/RNN/RNNBase.cpp b/src/Conversion/ONNXToMhlo/RNN/RNNBase.cpp +new file mode 100644 +index 00000000..002d2673 +--- /dev/null ++++ b/src/Conversion/ONNXToMhlo/RNN/RNNBase.cpp +@@ -0,0 +1,221 @@ ++/* ++ * SPDX-License-Identifier: Apache-2.0 ++ */ ++ ++//===--------------- RNNBase.cpp - Lowering RNN Ops -----------------------===// ++// ++// Copyright 2023 ++// ++// ============================================================================= ++// ++// This file defines base functions for lowering the ONNX RNN Operators. ++// ++//===----------------------------------------------------------------------===// ++ ++#include "src/Conversion/ONNXToMhlo/RNN/RNNBase.hpp" ++#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" ++ ++#include "llvm/Support/Debug.h" ++ ++#define DEBUG_TYPE "lstm" ++ ++using namespace mlir; ++ ++namespace onnx_mlir { ++ ++namespace mhlo { ++ ++// Get a dimension of the tensor's shape. ++int64_t dimAt(Value val, int index) { ++ return val.getType().cast().getShape()[index]; ++} ++ ++/// Insert Allocate and Deallocate for the hidden or cell output. ++mlir::Value allocHiddenOrCell(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, mlir::Value X, mlir::Value W, mlir::Value R) { ++ LLVM_DEBUG(llvm::dbgs() << "allocHiddenOrCell\n"); ++ MultiDialectBuilder create(rewriter, loc); ++ RankedTensorType zeroType = RankedTensorType::get( ++ {/*num_directions=*/dimAt(W, 0), /*batch_size=*/dimAt(X, 1), ++ /*hidden_size=*/dimAt(R, 2)}, ++ X.getType().cast().getElementType()); ++ DenseElementsAttr zeroAttr = DenseElementsAttr::get(zeroType, 0.0f); ++ return create.onnx.constant(zeroAttr); ++} ++ ++/// Insert Allocate and Deallocate for the intermediate hidden or cell states. ++/// Shape :: [batch_size, hidden_size] ++Value allocIntermediateState( ++ ConversionPatternRewriter &rewriter, Location loc, Value X, Value R) { ++ LLVM_DEBUG(llvm::dbgs() << "allocIntermediateState\n"); ++ MultiDialectBuilder create(rewriter, loc); ++ RankedTensorType zeroType = ++ RankedTensorType::get({/*batch_size=*/dimAt(X, 1), ++ /*hidden_size=*/dimAt(R, 2)}, ++ X.getType().cast().getElementType()); ++ DenseElementsAttr zeroAttr = DenseElementsAttr::get(zeroType, 0.0f); ++ return create.onnx.constant(zeroAttr); ++} ++ ++/// Initialize the intermediate hidden and cell states. ++/// forward(reverse)Ht, forward(reverse)Ct ++void initializeIntermediateStates(ConversionPatternRewriter &rewriter, ++ Location loc, Value &forwardHt, Value &reverseHt, Value &forwardCt, ++ Value &reverseCt, Value initialH, Value initialC, Type elementType, ++ StringRef direction, bool onlyHidden) { ++ LLVM_DEBUG(llvm::dbgs() << "initializeIntermediateStates\n"); ++ MultiDialectBuilder create(rewriter, loc); ++ ++ Value zeroIndex = create.onnx.constantInt64({0}); ++ Value oneIndex = create.onnx.constantInt64({1}); ++ Value twoIndex = create.onnx.constantInt64({2}); ++ ++ Value boundVal = (direction == FORWARD || direction == BIDIRECTIONAL) ++ ? forwardHt ++ : reverseHt; ++ auto valShape = boundVal.getType().cast().getShape(); ++ RankedTensorType sliceType = ++ RankedTensorType::get({1, valShape[0], valShape[1]}, ++ boundVal.getType().cast().getElementType()); ++ RankedTensorType valType = boundVal.getType().cast(); ++ if (direction == FORWARD || direction == BIDIRECTIONAL) { ++ if (!isNoneValue(initialH)) { ++ forwardHt = create.onnx.slice( ++ sliceType, initialH, zeroIndex, oneIndex, zeroIndex, oneIndex); ++ forwardHt = create.onnx.squeeze(valType, forwardHt, zeroIndex); ++ } ++ if (!onlyHidden && !isNoneValue(initialC)) { ++ forwardCt = create.onnx.slice( ++ sliceType, initialC, zeroIndex, oneIndex, zeroIndex, oneIndex); ++ forwardCt = create.onnx.squeeze(valType, forwardCt, zeroIndex); ++ } ++ } ++ if (direction == REVERSE || direction == BIDIRECTIONAL) { ++ if (!isNoneValue(initialH)) { ++ if (direction == REVERSE) { ++ reverseHt = create.onnx.slice( ++ sliceType, initialH, zeroIndex, oneIndex, zeroIndex, oneIndex); ++ reverseHt = create.onnx.squeeze(valType, reverseHt, zeroIndex); ++ } else { ++ reverseHt = create.onnx.slice( ++ sliceType, initialH, oneIndex, twoIndex, zeroIndex, oneIndex); ++ reverseHt = create.onnx.squeeze(valType, reverseHt, zeroIndex); ++ } ++ } ++ if (!onlyHidden and !isNoneValue(initialC)) { ++ if (direction == REVERSE) { ++ reverseCt = create.onnx.slice( ++ sliceType, initialC, zeroIndex, oneIndex, zeroIndex, oneIndex); ++ reverseCt = create.onnx.squeeze(valType, reverseCt, zeroIndex); ++ } else { ++ reverseCt = create.onnx.slice( ++ sliceType, initialC, oneIndex, twoIndex, zeroIndex, oneIndex); ++ reverseCt = create.onnx.squeeze(valType, reverseCt, zeroIndex); ++ } ++ } ++ } ++} ++ ++/// Store a state into the output of the RNN op. ++/// The input state is 2D and the output state is 3D with '1' or '2' is ++/// pretended, depending on 'direction'. ++void stateToOutputForHiddenOrCell(ConversionPatternRewriter &rewriter, ++ Location loc, Value forwardVal, Value reverseVal, StringRef direction, ++ Value &output) { ++ MultiDialectBuilder create(rewriter, loc); ++ if (direction == FORWARD || direction == REVERSE) { ++ Value val = (direction == FORWARD) ? forwardVal : reverseVal; ++ output = val; ++ } else { // BIDIRECTIONAL ++ SmallVector bForwardValShape( ++ forwardVal.getType().cast().getShape()); ++ SmallVector bValShape( ++ forwardVal.getType().cast().getShape()); ++ SmallVector bReverseValShape( ++ reverseVal.getType().cast().getShape()); ++ bForwardValShape.insert(bForwardValShape.begin(), 1); ++ bReverseValShape.insert(bReverseValShape.begin(), 1); ++ bValShape.insert(bValShape.begin(), 2); ++ Type valElementType = ++ forwardVal.getType().cast().getElementType(); ++ Value zero = create.onnx.constantInt64({0}); ++ Value bForwardVal = create.onnx.unsqueeze( ++ RankedTensorType::get(bForwardValShape, valElementType), forwardVal, ++ zero); ++ Value bReverseVal = create.onnx.unsqueeze( ++ RankedTensorType::get(bReverseValShape, valElementType), reverseVal, ++ zero); ++ output = ++ create.onnx.concat(RankedTensorType::get(bValShape, valElementType), ++ {bForwardVal, bReverseVal}, 0); ++ } ++} ++ ++// Apply an activation function on a given scalar operand. ++Value applyActivation(OpBuilder &rewriter, Location loc, ++ RNNActivation activation, Value operand) { ++ Value res; ++ ++ std::vector attributes; ++ if (activation.alpha) { ++ attributes.emplace_back( ++ rewriter.getNamedAttr("alpha", activation.alpha.value())); ++ } ++ if (activation.beta) { ++ attributes.emplace_back( ++ rewriter.getNamedAttr("beta", activation.beta.value())); ++ } ++ Type resType = operand.getType(); ++ ++ // Change equality to be case insensitive. ++ if (activation.name.equals_insensitive("relu")) ++ res = rewriter.create(loc, resType, operand); ++ else if (activation.name.equals_insensitive("tanh")) ++ res = rewriter.create(loc, resType, operand); ++ else if (activation.name.equals_insensitive("sigmoid")) ++ res = rewriter.create(loc, resType, operand); ++ else if (activation.name.equals_insensitive("affine")) ++ llvm_unreachable("Unsupported activation"); ++ else if (activation.name.equals_insensitive("leakyrelu")) ++ res = rewriter.create(loc, resType, operand, attributes); ++ else if (activation.name.equals_insensitive("thresholdedrelu")) ++ res = rewriter.create( ++ loc, resType, operand, attributes); ++ else if (activation.name.equals_insensitive("scaledtanh")) ++ llvm_unreachable("Unsupported activation"); ++ else if (activation.name.equals_insensitive("hardsigmoid")) ++ res = rewriter.create(loc, resType, operand, attributes); ++ else if (activation.name.equals_insensitive("elu")) ++ res = rewriter.create(loc, resType, operand, attributes); ++ else if (activation.name.equals_insensitive("softsign")) ++ res = rewriter.create(loc, resType, operand); ++ else if (activation.name.equals_insensitive("softplus")) ++ res = rewriter.create(loc, resType, operand); ++ else ++ llvm_unreachable("Unsupported activation"); ++ ++ return res; ++} ++ ++/// Create a copy of a slice of X at a specific timestep. ++Value emitXSliceAt(ConversionPatternRewriter &rewriter, Location loc, Value X, ++ Value timestepIV) { ++ MultiDialectBuilder create(rewriter, loc); ++ int64_t batchSize = dimAt(X, 1); ++ int64_t inputSize = dimAt(X, 2); ++ Type elementType = X.getType().cast().getElementType(); ++ RankedTensorType sliceXType = ++ RankedTensorType::get({1, batchSize, inputSize}, elementType); ++ RankedTensorType squeezedXType = ++ RankedTensorType::get({batchSize, inputSize}, elementType); ++ Value sliceX = create.onnx.slice(sliceXType, X, timestepIV, ++ create.onnx.add(timestepIV, create.onnx.constantInt64({1})), ++ create.onnx.constantInt64({0}), create.onnx.constantInt64({1})); ++ sliceX = create.onnx.squeeze( ++ squeezedXType, sliceX, create.onnx.constantInt64({0})); ++ return sliceX; ++} ++ ++} // namespace mhlo ++ ++} // namespace onnx_mlir +diff --git a/src/Conversion/ONNXToMhlo/RNN/RNNBase.hpp b/src/Conversion/ONNXToMhlo/RNN/RNNBase.hpp +new file mode 100644 +index 00000000..8da9aedd +--- /dev/null ++++ b/src/Conversion/ONNXToMhlo/RNN/RNNBase.hpp +@@ -0,0 +1,211 @@ ++/* ++ * SPDX-License-Identifier: Apache-2.0 ++ */ ++ ++//===--------------- RNNBase.hpp - Lowering RNN Ops -----------------------===// ++// ++// Copyright 2023 ++// ++// ============================================================================= ++// ++// This file defines base functions for lowering the ONNX RNN Operators. ++// ++//===----------------------------------------------------------------------===// ++ ++#pragma once ++ ++#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" ++ ++static constexpr llvm::StringRef FORWARD = "forward"; ++static constexpr llvm::StringRef REVERSE = "reverse"; ++static constexpr llvm::StringRef BIDIRECTIONAL = "bidirectional"; ++ ++static constexpr llvm::StringRef USELOOP = "useloop"; ++static constexpr llvm::StringRef UNROLL = "unroll"; ++ ++namespace onnx_mlir { ++ ++namespace mhlo { ++ ++struct RNNActivation { ++ llvm::StringRef name; ++ std::optional alpha; ++ std::optional beta; ++}; ++ ++/// Get a dimension of the tensor's shape. ++int64_t dimAt(mlir::Value val, int index); ++ ++/// Insert Allocate and Deallocate for the hidden or cell output. ++mlir::Value allocHiddenOrCell(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, mlir::Value X, mlir::Value W, mlir::Value R); ++ ++/// Allocate the intermediate hidden or cell state. ++mlir::Value allocIntermediateState(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, mlir::Value X, mlir::Value R); ++ ++/// Initialize the intermediate hidden and cell states. ++void initializeIntermediateStates(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, mlir::Value &forwardHt, mlir::Value &reverseHt, ++ mlir::Value &forwardCt, mlir::Value &reverseCt, mlir::Value initialH, ++ mlir::Value initialC, mlir::Type elementType, llvm::StringRef direction, ++ bool onlyHidden); ++ ++/// Store a state into the output of the RNN op. ++/// The input state is 2D and the output state is 3D with '1' or '2' is ++/// pretended, depending on 'direction'. ++void stateToOutputForHiddenOrCell(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, mlir::Value forwardVal, mlir::Value reverseVal, ++ llvm::StringRef direction, mlir::Value &output); ++ ++/// Apply an activation function on a given operand. ++mlir::Value applyActivation(mlir::OpBuilder &rewriter, mlir::Location loc, ++ RNNActivation activation, mlir::Value operand); ++ ++/// Get a slice of X at a specific timestep. ++mlir::Value emitXSliceAt(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, mlir::Value X, mlir::Value timestep); ++ ++// Override the following methods when lowering an RNN operation: ++// - hasAllNoneOutput ++// - getActivationPack ++// - getWeightPack ++// - getBiasPack ++// - allocAndInitializeStates ++// - calculateState ++// - stateToOutput ++ ++// Check whether all outputs have NoneType or not. ++template ++bool hasAllNoneOutput(RNNOp *op); ++ ++// Obtain activations functions for a specific operation. ++template ++std::tuple getActivationPack(RNNOp *op); ++ ++/// Obtain weight tensors in 2D for each gate. ++/// In ONNX, weights for gates and directions are combined in a single tensor. ++/// This function splits them into 2D tensors. ++template ++std::tuple getWeightPack( ++ mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, RNNOp *op); ++ ++/// Obtain biases in 1D for each gate. ++/// In ONNX, biases for gates and directions are combined in a single tensor. ++/// This function splits them into 1D tensors. ++template ++std::tuple getBiasPack( ++ mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, RNNOp *op); ++ ++// Allocate memory for RNN states and initialize them. ++template ++S allocAndInitializeStates(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, RNNOp *op, typename RNNOp::Adaptor operandAdaptor); ++ ++// Calculate new states from the current input and states. ++template ++void calculateState(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, mlir::Value Xt, S &state, A activationSet, W weight, ++ B bias, mlir::Value sequenceIV, mlir::Value directionIV, ++ mlir::Value sequenceLens, mlir::Value initialH, bool isForward); ++ ++// Write states to the RNN's outputs. ++template ++void stateToOutput(mlir::ConversionPatternRewriter &rewriter, ++ mlir::Location loc, RNNOp *op, S state, std::vector &outputs); ++ ++// A common template for lowering an RNN operation. ++template ++struct ONNXRNNOpLowering : public mlir::OpConversionPattern { ++ using OpAdaptor = typename RNNOp::Adaptor; ++ StringRef loopExpansion; ++ ++ ONNXRNNOpLowering(mlir::MLIRContext *ctx, const StringRef &expansion) ++ : mlir::OpConversionPattern(ctx) { ++ loopExpansion = expansion; ++ } ++ ++ mlir::LogicalResult matchAndRewrite(RNNOp rnnOp, OpAdaptor adaptor, ++ mlir::ConversionPatternRewriter &rewriter) const final { ++ mlir::Operation *op = rnnOp.getOperation(); ++ mlir::Location loc = ONNXLoc(op); ++ mlir::Value X = adaptor.getX(); ++ mlir::Value sequenceLens = adaptor.getSequenceLens(); ++ mlir::Value initialH = adaptor.getInitialH(); ++ ++ if (hasAllNoneOutput(&rnnOp)) { ++ rewriter.eraseOp(op); ++ return mlir::success(); ++ } ++ ++ // Initialize output states. ++ S state = ++ allocAndInitializeStates(rewriter, loc, &rnnOp, adaptor); ++ ++ // Activation functions. ++ A activationForward, activationReverse; ++ std::tie(activationForward, activationReverse) = ++ getActivationPack(&rnnOp); ++ ++ // Prepare weights. ++ W weightForward, weightReverse; ++ std::tie(weightForward, weightReverse) = ++ getWeightPack(rewriter, loc, &rnnOp); ++ ++ // Prepare biases. ++ B biasForward, biasReverse; ++ std::tie(biasForward, biasReverse) = ++ getBiasPack(rewriter, loc, &rnnOp); ++ ++ int64_t sequenceDimSize = dimAt(rnnOp.getX(), 0); ++ auto direction = rnnOp.getDirection(); ++ ++ MultiDialectBuilder create(rewriter, loc); ++ ++ if (loopExpansion != UNROLL) { ++ rnnOp.emitError("only unroll is supported for now"); ++ return failure(); ++ } ++ assert(!mlir::ShapedType::isDynamic(sequenceDimSize) && ++ "Only static sequenceDimSize is supported for unroll"); ++ ++ if (direction == FORWARD || direction == BIDIRECTIONAL) { ++ for (int64_t i = 0; i < sequenceDimSize; i++) { ++ mlir::Value directionIV = create.onnx.constantInt64({0}); ++ mlir::Value sequenceIV = create.onnx.constantInt64({i}); ++ // Get a slice of X at the current timestep. ++ mlir::Value Xt = emitXSliceAt(rewriter, loc, X, sequenceIV); ++ // Emit calculation for one RNN step. ++ calculateState(rewriter, loc, Xt, state, activationForward, ++ weightForward, biasForward, sequenceIV, directionIV, sequenceLens, ++ initialH, ++ /*isForward=*/true); ++ } ++ } ++ ++ if (direction == REVERSE || direction == BIDIRECTIONAL) { ++ for (int64_t i = 0; i < sequenceDimSize; i++) { ++ mlir::Value directionIV = ++ create.onnx.constantInt64({(direction == REVERSE) ? 0 : 1}); ++ mlir::Value reverseSequenceIV = ++ create.onnx.constantInt64({sequenceDimSize - i - 1}); ++ // Get a slice of X at the current timestep. ++ mlir::Value Xt = emitXSliceAt(rewriter, loc, X, reverseSequenceIV); ++ // Emit calculation for one RNN step. ++ calculateState(rewriter, loc, Xt, state, activationReverse, ++ weightReverse, biasReverse, reverseSequenceIV, directionIV, ++ sequenceLens, initialH, ++ /*isForward=*/false); ++ } ++ } ++ ++ std::vector outputs; ++ stateToOutput(rewriter, loc, &rnnOp, state, outputs); ++ rewriter.replaceOp(op, outputs); ++ return mlir::success(); ++ } ++}; ++ ++} // namespace mhlo ++ ++} // namespace onnx_mlir +diff --git a/test/mlir/conversion/onnx_to_mhlo/RNN/LSTM.mlir b/test/mlir/conversion/onnx_to_mhlo/RNN/LSTM.mlir +new file mode 100644 +index 00000000..14a38607 +--- /dev/null ++++ b/test/mlir/conversion/onnx_to_mhlo/RNN/LSTM.mlir +@@ -0,0 +1,185 @@ ++// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-mhlo --canonicalize -split-input-file %s | FileCheck %s ++func.func @test_lstm(%arg0 : tensor<2x16x512xf32>, %arg1 : tensor<2x2048xf32>, %arg2 : tensor<2x1024x512xf32>, %arg3 : tensor<2x1024x256xf32>) -> tensor<2x2x16x256xf32> { ++ %0 = onnx.Constant dense<0.000000e+00> : tensor<2x16x256xf32> ++ %1 = "onnx.NoValue"() {value} : () -> none ++ %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg2, %arg3, %arg1, %1, %0, %0, %1) {direction = "bidirectional", hidden_size = 256 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor<2x16x512xf32>, tensor<2x1024x512xf32>, tensor<2x1024x256xf32>, tensor<2x2048xf32>, none, tensor<2x16x256xf32>, tensor<2x16x256xf32>, none) -> (tensor<2x2x16x256xf32>, tensor<2x16x256xf32>, tensor<2x16x256xf32>) ++ return %Y : tensor<2x2x16x256xf32> ++// CHECK: func.func @test_lstm(%arg0: tensor<2x16x512xf32>, %arg1: tensor<2x2048xf32>, %arg2: tensor<2x1024x512xf32>, %arg3: tensor<2x1024x256xf32>) -> tensor<2x2x16x256xf32> { ++// CHECK: %0 = mhlo.constant dense<0.000000e+00> : tensor<16x256xf32> ++// CHECK: %1 = "mhlo.slice"(%arg2) {limit_indices = dense<[1, 1024, 512]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<2x1024x512xf32>) -> tensor<1x1024x512xf32> ++// CHECK: %2 = "mhlo.slice"(%arg2) {limit_indices = dense<[2, 1024, 512]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<2x1024x512xf32>) -> tensor<1x1024x512xf32> ++// CHECK: %3 = mhlo.reshape %1 : (tensor<1x1024x512xf32>) -> tensor<1024x512xf32> ++// CHECK: %4 = mhlo.reshape %2 : (tensor<1x1024x512xf32>) -> tensor<1024x512xf32> ++// CHECK: %5 = "mhlo.slice"(%arg3) {limit_indices = dense<[1, 1024, 256]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<2x1024x256xf32>) -> tensor<1x1024x256xf32> ++// CHECK: %6 = "mhlo.slice"(%arg3) {limit_indices = dense<[2, 1024, 256]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<2x1024x256xf32>) -> tensor<1x1024x256xf32> ++// CHECK: %7 = mhlo.reshape %5 : (tensor<1x1024x256xf32>) -> tensor<1024x256xf32> ++// CHECK: %8 = mhlo.reshape %6 : (tensor<1x1024x256xf32>) -> tensor<1024x256xf32> ++// CHECK: %9 = "mhlo.transpose"(%3) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1024x512xf32>) -> tensor<512x1024xf32> ++// CHECK: %10 = "mhlo.transpose"(%7) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1024x256xf32>) -> tensor<256x1024xf32> ++// CHECK: %11 = "mhlo.transpose"(%4) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1024x512xf32>) -> tensor<512x1024xf32> ++// CHECK: %12 = "mhlo.transpose"(%8) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1024x256xf32>) -> tensor<256x1024xf32> ++// CHECK: %13 = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 2048]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2048xf32>) -> tensor<1x2048xf32> ++// CHECK: %14 = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 2048]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2048xf32>) -> tensor<1x2048xf32> ++// CHECK: %15 = mhlo.reshape %13 : (tensor<1x2048xf32>) -> tensor<2048xf32> ++// CHECK: %16 = mhlo.reshape %14 : (tensor<1x2048xf32>) -> tensor<2048xf32> ++// CHECK: %17 = "mhlo.slice"(%15) {limit_indices = dense<256> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %18 = "mhlo.slice"(%15) {limit_indices = dense<512> : tensor<1xi64>, start_indices = dense<256> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %19 = "mhlo.slice"(%15) {limit_indices = dense<768> : tensor<1xi64>, start_indices = dense<512> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %20 = "mhlo.slice"(%15) {limit_indices = dense<1024> : tensor<1xi64>, start_indices = dense<768> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %21 = "mhlo.slice"(%15) {limit_indices = dense<1280> : tensor<1xi64>, start_indices = dense<1024> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %22 = "mhlo.slice"(%15) {limit_indices = dense<1536> : tensor<1xi64>, start_indices = dense<1280> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %23 = "mhlo.slice"(%15) {limit_indices = dense<1792> : tensor<1xi64>, start_indices = dense<1536> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %24 = "mhlo.slice"(%15) {limit_indices = dense<2048> : tensor<1xi64>, start_indices = dense<1792> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %25 = "mhlo.slice"(%16) {limit_indices = dense<256> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %26 = "mhlo.slice"(%16) {limit_indices = dense<512> : tensor<1xi64>, start_indices = dense<256> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %27 = "mhlo.slice"(%16) {limit_indices = dense<768> : tensor<1xi64>, start_indices = dense<512> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %28 = "mhlo.slice"(%16) {limit_indices = dense<1024> : tensor<1xi64>, start_indices = dense<768> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %29 = "mhlo.slice"(%16) {limit_indices = dense<1280> : tensor<1xi64>, start_indices = dense<1024> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %30 = "mhlo.slice"(%16) {limit_indices = dense<1536> : tensor<1xi64>, start_indices = dense<1280> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %31 = "mhlo.slice"(%16) {limit_indices = dense<1792> : tensor<1xi64>, start_indices = dense<1536> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %32 = "mhlo.slice"(%16) {limit_indices = dense<2048> : tensor<1xi64>, start_indices = dense<1792> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2048xf32>) -> tensor<256xf32> ++// CHECK: %33 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 16, 512]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<2x16x512xf32>) -> tensor<1x16x512xf32> ++// CHECK: %34 = mhlo.reshape %33 : (tensor<1x16x512xf32>) -> tensor<16x512xf32> ++// CHECK: %35 = "mhlo.dot"(%34, %9) : (tensor<16x512xf32>, tensor<512x1024xf32>) -> tensor<16x1024xf32> ++// CHECK: %36 = "mhlo.dot"(%0, %10) : (tensor<16x256xf32>, tensor<256x1024xf32>) -> tensor<16x1024xf32> ++// CHECK: %37 = mhlo.add %35, %36 : tensor<16x1024xf32> ++// CHECK: %38 = "mhlo.slice"(%37) {limit_indices = dense<[16, 256]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %39 = "mhlo.slice"(%37) {limit_indices = dense<[16, 512]> : tensor<2xi64>, start_indices = dense<[0, 256]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %40 = "mhlo.slice"(%37) {limit_indices = dense<[16, 768]> : tensor<2xi64>, start_indices = dense<[0, 512]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %41 = "mhlo.slice"(%37) {limit_indices = dense<[16, 1024]> : tensor<2xi64>, start_indices = dense<[0, 768]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %42 = "mhlo.broadcast_in_dim"(%17) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %43 = mhlo.add %38, %42 : tensor<16x256xf32> ++// CHECK: %44 = "mhlo.broadcast_in_dim"(%21) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %45 = mhlo.add %43, %44 : tensor<16x256xf32> ++// CHECK: %46 = mhlo.logistic %45 : tensor<16x256xf32> ++// CHECK: %47 = "mhlo.broadcast_in_dim"(%19) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %48 = mhlo.add %40, %47 : tensor<16x256xf32> ++// CHECK: %49 = "mhlo.broadcast_in_dim"(%23) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %50 = mhlo.add %48, %49 : tensor<16x256xf32> ++// CHECK: %51 = mhlo.logistic %50 : tensor<16x256xf32> ++// CHECK: %52 = "mhlo.broadcast_in_dim"(%20) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %53 = mhlo.add %41, %52 : tensor<16x256xf32> ++// CHECK: %54 = "mhlo.broadcast_in_dim"(%24) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %55 = mhlo.add %53, %54 : tensor<16x256xf32> ++// CHECK: %56 = mhlo.tanh %55 : tensor<16x256xf32> ++// CHECK: %57 = mhlo.multiply %51, %0 : tensor<16x256xf32> ++// CHECK: %58 = mhlo.multiply %46, %56 : tensor<16x256xf32> ++// CHECK: %59 = mhlo.add %57, %58 : tensor<16x256xf32> ++// CHECK: %60 = "mhlo.broadcast_in_dim"(%18) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %61 = mhlo.add %39, %60 : tensor<16x256xf32> ++// CHECK: %62 = "mhlo.broadcast_in_dim"(%22) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %63 = mhlo.add %61, %62 : tensor<16x256xf32> ++// CHECK: %64 = mhlo.logistic %63 : tensor<16x256xf32> ++// CHECK: %65 = mhlo.tanh %59 : tensor<16x256xf32> ++// CHECK: %66 = mhlo.multiply %64, %65 : tensor<16x256xf32> ++// CHECK: %67 = mhlo.reshape %66 : (tensor<16x256xf32>) -> tensor<1x1x16x256xf32> ++// CHECK: %68 = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 16, 512]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<2x16x512xf32>) -> tensor<1x16x512xf32> ++// CHECK: %69 = mhlo.reshape %68 : (tensor<1x16x512xf32>) -> tensor<16x512xf32> ++// CHECK: %70 = "mhlo.dot"(%69, %9) : (tensor<16x512xf32>, tensor<512x1024xf32>) -> tensor<16x1024xf32> ++// CHECK: %71 = "mhlo.dot"(%66, %10) : (tensor<16x256xf32>, tensor<256x1024xf32>) -> tensor<16x1024xf32> ++// CHECK: %72 = mhlo.add %70, %71 : tensor<16x1024xf32> ++// CHECK: %73 = "mhlo.slice"(%72) {limit_indices = dense<[16, 256]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %74 = "mhlo.slice"(%72) {limit_indices = dense<[16, 512]> : tensor<2xi64>, start_indices = dense<[0, 256]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %75 = "mhlo.slice"(%72) {limit_indices = dense<[16, 768]> : tensor<2xi64>, start_indices = dense<[0, 512]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %76 = "mhlo.slice"(%72) {limit_indices = dense<[16, 1024]> : tensor<2xi64>, start_indices = dense<[0, 768]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %77 = "mhlo.broadcast_in_dim"(%17) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %78 = mhlo.add %73, %77 : tensor<16x256xf32> ++// CHECK: %79 = "mhlo.broadcast_in_dim"(%21) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %80 = mhlo.add %78, %79 : tensor<16x256xf32> ++// CHECK: %81 = mhlo.logistic %80 : tensor<16x256xf32> ++// CHECK: %82 = "mhlo.broadcast_in_dim"(%19) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %83 = mhlo.add %75, %82 : tensor<16x256xf32> ++// CHECK: %84 = "mhlo.broadcast_in_dim"(%23) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %85 = mhlo.add %83, %84 : tensor<16x256xf32> ++// CHECK: %86 = mhlo.logistic %85 : tensor<16x256xf32> ++// CHECK: %87 = "mhlo.broadcast_in_dim"(%20) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %88 = mhlo.add %76, %87 : tensor<16x256xf32> ++// CHECK: %89 = "mhlo.broadcast_in_dim"(%24) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %90 = mhlo.add %88, %89 : tensor<16x256xf32> ++// CHECK: %91 = mhlo.tanh %90 : tensor<16x256xf32> ++// CHECK: %92 = mhlo.multiply %86, %59 : tensor<16x256xf32> ++// CHECK: %93 = mhlo.multiply %81, %91 : tensor<16x256xf32> ++// CHECK: %94 = mhlo.add %92, %93 : tensor<16x256xf32> ++// CHECK: %95 = "mhlo.broadcast_in_dim"(%18) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %96 = mhlo.add %74, %95 : tensor<16x256xf32> ++// CHECK: %97 = "mhlo.broadcast_in_dim"(%22) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %98 = mhlo.add %96, %97 : tensor<16x256xf32> ++// CHECK: %99 = mhlo.logistic %98 : tensor<16x256xf32> ++// CHECK: %100 = mhlo.tanh %94 : tensor<16x256xf32> ++// CHECK: %101 = mhlo.multiply %99, %100 : tensor<16x256xf32> ++// CHECK: %102 = mhlo.reshape %101 : (tensor<16x256xf32>) -> tensor<1x1x16x256xf32> ++// CHECK: %103 = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 16, 512]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<2x16x512xf32>) -> tensor<1x16x512xf32> ++// CHECK: %104 = mhlo.reshape %103 : (tensor<1x16x512xf32>) -> tensor<16x512xf32> ++// CHECK: %105 = "mhlo.dot"(%104, %11) : (tensor<16x512xf32>, tensor<512x1024xf32>) -> tensor<16x1024xf32> ++// CHECK: %106 = "mhlo.dot"(%0, %12) : (tensor<16x256xf32>, tensor<256x1024xf32>) -> tensor<16x1024xf32> ++// CHECK: %107 = mhlo.add %105, %106 : tensor<16x1024xf32> ++// CHECK: %108 = "mhlo.slice"(%107) {limit_indices = dense<[16, 256]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %109 = "mhlo.slice"(%107) {limit_indices = dense<[16, 512]> : tensor<2xi64>, start_indices = dense<[0, 256]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %110 = "mhlo.slice"(%107) {limit_indices = dense<[16, 768]> : tensor<2xi64>, start_indices = dense<[0, 512]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %111 = "mhlo.slice"(%107) {limit_indices = dense<[16, 1024]> : tensor<2xi64>, start_indices = dense<[0, 768]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %112 = "mhlo.broadcast_in_dim"(%25) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %113 = mhlo.add %108, %112 : tensor<16x256xf32> ++// CHECK: %114 = "mhlo.broadcast_in_dim"(%29) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %115 = mhlo.add %113, %114 : tensor<16x256xf32> ++// CHECK: %116 = mhlo.logistic %115 : tensor<16x256xf32> ++// CHECK: %117 = "mhlo.broadcast_in_dim"(%27) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %118 = mhlo.add %110, %117 : tensor<16x256xf32> ++// CHECK: %119 = "mhlo.broadcast_in_dim"(%31) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %120 = mhlo.add %118, %119 : tensor<16x256xf32> ++// CHECK: %121 = mhlo.logistic %120 : tensor<16x256xf32> ++// CHECK: %122 = "mhlo.broadcast_in_dim"(%28) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %123 = mhlo.add %111, %122 : tensor<16x256xf32> ++// CHECK: %124 = "mhlo.broadcast_in_dim"(%32) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %125 = mhlo.add %123, %124 : tensor<16x256xf32> ++// CHECK: %126 = mhlo.tanh %125 : tensor<16x256xf32> ++// CHECK: %127 = mhlo.multiply %121, %0 : tensor<16x256xf32> ++// CHECK: %128 = mhlo.multiply %116, %126 : tensor<16x256xf32> ++// CHECK: %129 = mhlo.add %127, %128 : tensor<16x256xf32> ++// CHECK: %130 = "mhlo.broadcast_in_dim"(%26) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %131 = mhlo.add %109, %130 : tensor<16x256xf32> ++// CHECK: %132 = "mhlo.broadcast_in_dim"(%30) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %133 = mhlo.add %131, %132 : tensor<16x256xf32> ++// CHECK: %134 = mhlo.logistic %133 : tensor<16x256xf32> ++// CHECK: %135 = mhlo.tanh %129 : tensor<16x256xf32> ++// CHECK: %136 = mhlo.multiply %134, %135 : tensor<16x256xf32> ++// CHECK: %137 = mhlo.reshape %136 : (tensor<16x256xf32>) -> tensor<1x1x16x256xf32> ++// CHECK: %138 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 16, 512]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<2x16x512xf32>) -> tensor<1x16x512xf32> ++// CHECK: %139 = mhlo.reshape %138 : (tensor<1x16x512xf32>) -> tensor<16x512xf32> ++// CHECK: %140 = "mhlo.dot"(%139, %11) : (tensor<16x512xf32>, tensor<512x1024xf32>) -> tensor<16x1024xf32> ++// CHECK: %141 = "mhlo.dot"(%136, %12) : (tensor<16x256xf32>, tensor<256x1024xf32>) -> tensor<16x1024xf32> ++// CHECK: %142 = mhlo.add %140, %141 : tensor<16x1024xf32> ++// CHECK: %143 = "mhlo.slice"(%142) {limit_indices = dense<[16, 256]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %144 = "mhlo.slice"(%142) {limit_indices = dense<[16, 512]> : tensor<2xi64>, start_indices = dense<[0, 256]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %145 = "mhlo.slice"(%142) {limit_indices = dense<[16, 768]> : tensor<2xi64>, start_indices = dense<[0, 512]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %146 = "mhlo.slice"(%142) {limit_indices = dense<[16, 1024]> : tensor<2xi64>, start_indices = dense<[0, 768]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<16x1024xf32>) -> tensor<16x256xf32> ++// CHECK: %147 = "mhlo.broadcast_in_dim"(%25) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %148 = mhlo.add %143, %147 : tensor<16x256xf32> ++// CHECK: %149 = "mhlo.broadcast_in_dim"(%29) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %150 = mhlo.add %148, %149 : tensor<16x256xf32> ++// CHECK: %151 = mhlo.logistic %150 : tensor<16x256xf32> ++// CHECK: %152 = "mhlo.broadcast_in_dim"(%27) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %153 = mhlo.add %145, %152 : tensor<16x256xf32> ++// CHECK: %154 = "mhlo.broadcast_in_dim"(%31) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %155 = mhlo.add %153, %154 : tensor<16x256xf32> ++// CHECK: %156 = mhlo.logistic %155 : tensor<16x256xf32> ++// CHECK: %157 = "mhlo.broadcast_in_dim"(%28) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %158 = mhlo.add %146, %157 : tensor<16x256xf32> ++// CHECK: %159 = "mhlo.broadcast_in_dim"(%32) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %160 = mhlo.add %158, %159 : tensor<16x256xf32> ++// CHECK: %161 = mhlo.tanh %160 : tensor<16x256xf32> ++// CHECK: %162 = mhlo.multiply %156, %129 : tensor<16x256xf32> ++// CHECK: %163 = mhlo.multiply %151, %161 : tensor<16x256xf32> ++// CHECK: %164 = mhlo.add %162, %163 : tensor<16x256xf32> ++// CHECK: %165 = "mhlo.broadcast_in_dim"(%26) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %166 = mhlo.add %144, %165 : tensor<16x256xf32> ++// CHECK: %167 = "mhlo.broadcast_in_dim"(%30) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<16x256xf32> ++// CHECK: %168 = mhlo.add %166, %167 : tensor<16x256xf32> ++// CHECK: %169 = mhlo.logistic %168 : tensor<16x256xf32> ++// CHECK: %170 = mhlo.tanh %164 : tensor<16x256xf32> ++// CHECK: %171 = mhlo.multiply %169, %170 : tensor<16x256xf32> ++// CHECK: %172 = mhlo.reshape %171 : (tensor<16x256xf32>) -> tensor<1x1x16x256xf32> ++// CHECK: %173 = "mhlo.concatenate"(%67, %102) {dimension = 0 : i64} : (tensor<1x1x16x256xf32>, tensor<1x1x16x256xf32>) -> tensor<2x1x16x256xf32> ++// CHECK: %174 = "mhlo.concatenate"(%172, %137) {dimension = 0 : i64} : (tensor<1x1x16x256xf32>, tensor<1x1x16x256xf32>) -> tensor<2x1x16x256xf32> ++// CHECK: %175 = "mhlo.concatenate"(%173, %174) {dimension = 1 : i64} : (tensor<2x1x16x256xf32>, tensor<2x1x16x256xf32>) -> tensor<2x2x16x256xf32> ++// CHECK: return %175 : tensor<2x2x16x256xf32> ++} +diff --git a/src/Conversion/ONNXToMhlo/Tensor/Expand.cpp b/src/Conversion/ONNXToMhlo/Tensor/Expand.cpp +index 0089fee1..8d5dec5b 100644 +--- a/src/Conversion/ONNXToMhlo/Tensor/Expand.cpp ++++ b/src/Conversion/ONNXToMhlo/Tensor/Expand.cpp +@@ -69,11 +69,8 @@ struct ONNXExpandOpLoweringToMhlo : public ConversionPattern { + RankedTensorType onesType = RankedTensorType::get(onesShape, elementType); + broadcastedOnes = rewriter.create( + loc, onesType, ones, shape, rewriter.getI64TensorAttr({})); +- } else if (ONNXConstantOp shapeOp = +- dyn_cast_or_null(shapeDefOp)) { ++ } else if (mlir::ElementsAttr constShape = getElementAttributeFromConstValue(shape)) { + llvm::SmallVector shapeValues; +- mlir::ElementsAttr constShape = +- shapeOp.getValueAttr().cast(); + for (mlir::IntegerAttr element : constShape.getValues()) + shapeValues.push_back(element.getInt()); + RankedTensorType broadcastedType = +@@ -84,7 +81,7 @@ struct ONNXExpandOpLoweringToMhlo : public ConversionPattern { + assert( + false && + "Shape argument of Expand is the output of an unexpected operation. " +- "Supported operations are: onnx.Constant and onnx.Shape"); ++ "Supported operations are: Constant and onnx.Shape"); + } + llvm::SmallVector newOperands = {input, broadcastedOnes}; + llvm::SmallVector broadcastedOperands = getBroadcastedOperands( diff --git a/src/Conversion/ONNXToMhlo/Tensor/GatherElements.cpp b/src/Conversion/ONNXToMhlo/Tensor/GatherElements.cpp new file mode 100644 -index 00000000..b7133e4f +index 00000000..09d7a6c2 --- /dev/null +++ b/src/Conversion/ONNXToMhlo/Tensor/GatherElements.cpp -@@ -0,0 +1,134 @@ +@@ -0,0 +1,139 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ @@ -132,7 +1911,8 @@ index 00000000..b7133e4f + +struct ONNXGatherElementsOpLoweringToMhlo : public ConversionPattern { + ONNXGatherElementsOpLoweringToMhlo(MLIRContext *ctx) -+ : ConversionPattern(mlir::ONNXGatherElementsOp::getOperationName(), 1, ctx) {} ++ : ConversionPattern( ++ mlir::ONNXGatherElementsOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { @@ -153,7 +1933,7 @@ index 00000000..b7133e4f + int64_t axisLit = gatherOp.getAxis(); + + ShapedType inputType = data.getType().cast(); -+ int64_t rank = inputType.getRank(); // indices has the same rank ++ int64_t rank = inputType.getRank(); // indices has the same rank + ShapedType indicesType = indices.getType().cast(); + Type indexElemType = indicesType.getElementType(); + // Negative value means counting dimensions from the back. @@ -165,18 +1945,19 @@ index 00000000..b7133e4f + Value indicesShape = rewriter.create(loc, indices); + Value axisDimSize = + rewriter.create(loc, inputShape, axisLit); -+ axisDimSize = rewriter.create( -+ loc, indexElemType, axisDimSize); ++ axisDimSize = ++ rewriter.create(loc, indexElemType, axisDimSize); + axisDimSize = rewriter.create(loc, axisDimSize); -+ axisDimSize = rewriter.create( -+ loc, RankedTensorType::get(SmallVector{}, indexElemType), axisDimSize); ++ axisDimSize = rewriter.create(loc, ++ RankedTensorType::get(SmallVector{}, indexElemType), ++ axisDimSize); + Value broadcastedAxisDimSize = + rewriter.create(loc, indicesType, + axisDimSize, indicesShape, rewriter.getI64TensorAttr({})); + Value isNegative = rewriter.create( + loc, indices, zero, mhlo::ComparisonDirection::LT); -+ Value positiveIndices = -+ rewriter.create(loc, indicesType, indices, broadcastedAxisDimSize); ++ Value positiveIndices = rewriter.create( ++ loc, indicesType, indices, broadcastedAxisDimSize); + indices = rewriter.create( + loc, indicesType, isNegative, positiveIndices, indices); + @@ -184,15 +1965,20 @@ index 00000000..b7133e4f + Value toConcatIndexShape; + SmallVector toConcatIndexShapeValueVec; + for (size_t i = 0; i < rank; i++) { -+ toConcatIndexShapeValueVec.push_back(rewriter.create(loc, indicesShape, i)); ++ toConcatIndexShapeValueVec.push_back( ++ rewriter.create(loc, indicesShape, i)); + } -+ toConcatIndexShapeValueVec.push_back(rewriter.create(loc, 1)); -+ toConcatIndexShape = rewriter.create(loc, toConcatIndexShapeValueVec); ++ toConcatIndexShapeValueVec.push_back( ++ rewriter.create(loc, 1)); ++ toConcatIndexShape = rewriter.create( ++ loc, toConcatIndexShapeValueVec); + + ArrayRef indicesShapeVec = indicesType.getShape(); -+ SmallVector toConcatIndexShapeVec(indicesShapeVec.begin(), indicesShapeVec.end()); ++ SmallVector toConcatIndexShapeVec( ++ indicesShapeVec.begin(), indicesShapeVec.end()); + toConcatIndexShapeVec.push_back(1); -+ RankedTensorType toConcatIndexType = RankedTensorType::get(toConcatIndexShapeVec, indexElemType); ++ RankedTensorType toConcatIndexType = ++ RankedTensorType::get(toConcatIndexShapeVec, indexElemType); + + SmallVector toConcat; + for (int64_t i = 0; i < inputType.getRank(); ++i) { @@ -200,9 +1986,9 @@ index 00000000..b7133e4f + toConcat.push_back(rewriter.create( + loc, toConcatIndexType, indices, toConcatIndexShape)); + } else { -+ toConcat.push_back(rewriter.create( -+ loc, toConcatIndexType, toConcatIndexShape, -+ rewriter.getI64IntegerAttr(i))); ++ toConcat.push_back( ++ rewriter.create(loc, toConcatIndexType, ++ toConcatIndexShape, rewriter.getI64IntegerAttr(i))); + } + } + auto gatherIndicies = rewriter.create( @@ -215,17 +2001,15 @@ index 00000000..b7133e4f + collapsedDims.push_back(i); + startIndexMap.push_back(i); + } -+ auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( -+ rewriter.getContext(), -+ /*offsetDims=*/{}, -+ /*collapsedSliceDims=*/collapsedDims, -+ /*startIndexMap=*/startIndexMap, -+ /*indexVecDim=*/rank); ++ auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(rewriter.getContext(), ++ /*offsetDims=*/{}, ++ /*collapsedSliceDims=*/collapsedDims, ++ /*startIndexMap=*/startIndexMap, ++ /*indexVecDim=*/rank); + SmallVector sliceSizes(inputType.getRank(), 1); + -+ Value gatherValue = rewriter.create(loc, -+ outputType, data, gatherIndicies, dimsAttr, -+ rewriter.getI64TensorAttr(sliceSizes)); ++ Value gatherValue = rewriter.create(loc, outputType, data, ++ gatherIndicies, dimsAttr, rewriter.getI64TensorAttr(sliceSizes)); + rewriter.replaceOp(op, gatherValue); + return success(); + } @@ -239,9 +2023,142 @@ index 00000000..b7133e4f +} + +} // namespace onnx_mlir +diff --git a/src/Conversion/ONNXToMhlo/Tensor/OneHot.cpp b/src/Conversion/ONNXToMhlo/Tensor/OneHot.cpp +new file mode 100644 +index 00000000..f4537916 +--- /dev/null ++++ b/src/Conversion/ONNXToMhlo/Tensor/OneHot.cpp +@@ -0,0 +1,127 @@ ++/* ++ * SPDX-License-Identifier: Apache-2.0 ++ */ ++ ++//===---------------- OneHot.cpp - Lowering OneHot Op -------------------===// ++// ++// Copyright 2023 ++// ++// ============================================================================= ++// ++// This file lowers the ONNX OneHot Operator to Mhlo dialect. ++// ++//===----------------------------------------------------------------------===// ++ ++#include "src/Conversion/ONNXToMhlo/DialectBuilder.hpp" ++#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" ++#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" ++ ++#include ++ ++using namespace mlir; ++ ++namespace onnx_mlir { ++ ++struct ONNXOneHotOpLoweringToMhlo : public OpConversionPattern { ++ ONNXOneHotOpLoweringToMhlo(MLIRContext *ctx) : OpConversionPattern(ctx) {} ++ ++ LogicalResult matchAndRewrite(ONNXOneHotOp onehotOp, ++ ONNXOneHotOpAdaptor adaptor, ++ ConversionPatternRewriter &rewriter) const final { ++ Operation *op = onehotOp.getOperation(); ++ Location loc = ONNXLoc(op); ++ ValueRange operands = adaptor.getOperands(); ++ Value indices = adaptor.getIndices(); ++ Value depthValue = adaptor.getDepth(); ++ Value values = adaptor.getValues(); ++ Type outputType = *op->result_type_begin(); ++ ++ IndexExprBuilderForMhlo createIE(rewriter, loc); ++ ONNXOneHotOpShapeHelper shapeHelper(op, operands, &createIE); ++ shapeHelper.computeShapeAndAssertOnFailure(); ++ int64_t axis = shapeHelper.axis; ++ ++ RankedTensorType indicesType = ++ indices.getType().dyn_cast(); ++ if (!indicesType || !indicesType.hasStaticShape()) ++ return failure(); ++ ArrayRef indicesShape = indicesType.getShape(); ++ Type indicesElementType = indicesType.getElementType(); ++ ++ DenseIntElementsAttr depthAttr; ++ if (!matchPattern(depthValue, m_Constant(&depthAttr))) { ++ return failure(); ++ } ++ ++ int64_t depth = depthAttr.getValues()[0].getSExtValue(); ++ ++ llvm::SmallVector broadcastDims(indicesShape.size()); ++ std::iota(broadcastDims.begin(), broadcastDims.begin() + axis, 0); ++ std::iota(broadcastDims.begin() + axis, broadcastDims.end(), axis + 1); ++ ++ llvm::SmallVector outputDims = llvm::to_vector<4>(indicesShape); ++ outputDims.insert(outputDims.begin() + axis, depth); ++ ++ RankedTensorType indexType = ++ RankedTensorType::get(llvm::ArrayRef(outputDims), indicesElementType); ++ ++ Value iota = rewriter.create( ++ loc, indexType, IntegerAttr::get(rewriter.getIntegerType(64), axis)); ++ Value broadcastIndices = rewriter.create( ++ loc, indexType, indices, GetI64ElementsAttr(broadcastDims, &rewriter)); ++ Value zero = rewriter.create(loc, ++ DenseIntElementsAttr::get(RankedTensorType::get({}, indicesElementType), ++ ArrayRef{0})); ++ Value broadcastZero = rewriter.create( ++ loc, indexType, zero, rewriter.getI64TensorAttr({})); ++ Value broadcastDepth; ++ int64_t depthRank = depthValue.getType().cast().getRank(); ++ if (depthRank == 1) ++ broadcastDepth = rewriter.create( ++ loc, indexType, depthValue, rewriter.getI64TensorAttr({0})); ++ else ++ broadcastDepth = rewriter.create( ++ loc, indexType, depthValue, rewriter.getI64TensorAttr({})); ++ Value compareGeZero = rewriter.create( ++ loc, broadcastIndices, broadcastZero, mhlo::ComparisonDirection::GE); ++ Value positiveIndices = ++ rewriter.create(loc, broadcastIndices, broadcastDepth); ++ Value normalizedIndices = rewriter.create( ++ loc, indexType, compareGeZero, broadcastIndices, positiveIndices); ++ Value compare = rewriter.create( ++ loc, normalizedIndices, iota, mhlo::ComparisonDirection::EQ); ++ Type indexElementType = rewriter.getI64Type(); ++ Type valueType = values.getType().cast().getElementType(); ++ Value offValue = rewriter.create(loc, ++ RankedTensorType::get({1}, valueType), values, ++ DenseIntElementsAttr::get( ++ RankedTensorType::get({1}, indexElementType), ArrayRef{0}), ++ DenseIntElementsAttr::get( ++ RankedTensorType::get({1}, indexElementType), ArrayRef{1}), ++ DenseIntElementsAttr::get(RankedTensorType::get({1}, indexElementType), ++ ArrayRef{1})); ++ Value onValue = rewriter.create(loc, ++ RankedTensorType::get({1}, valueType), values, ++ DenseIntElementsAttr::get( ++ RankedTensorType::get({1}, indexElementType), ArrayRef{1}), ++ DenseIntElementsAttr::get( ++ RankedTensorType::get({1}, indexElementType), ArrayRef{2}), ++ DenseIntElementsAttr::get(RankedTensorType::get({1}, indexElementType), ++ ArrayRef{1})); ++ Value offValueBroadcast = rewriter.create( ++ loc, outputType, offValue, rewriter.getI64TensorAttr({0})); ++ Value onValueBroadcast = rewriter.create( ++ loc, outputType, onValue, rewriter.getI64TensorAttr({0})); ++ Value result = rewriter.create( ++ loc, outputType, compare, onValueBroadcast, offValueBroadcast); ++ rewriter.replaceOp(op, {result}); ++ return success(); ++ } ++}; ++ ++void populateLoweringONNXOneHotOpToMhloPattern( ++ RewritePatternSet &patterns, MLIRContext *ctx) { ++ patterns.insert(ctx); ++} ++ ++} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToMhlo/Tensor/Pad.cpp b/src/Conversion/ONNXToMhlo/Tensor/Pad.cpp new file mode 100644 -index 00000000..1482e9ee +index 00000000..40006578 --- /dev/null +++ b/src/Conversion/ONNXToMhlo/Tensor/Pad.cpp @@ -0,0 +1,103 @@ @@ -309,7 +2226,7 @@ index 00000000..1482e9ee + SmallVector edgePaddingLowVec(rank, 0); + SmallVector edgePaddingHighVec(rank, 0); + SmallVector interiorPaddingVec(rank, 0); -+ if (auto valueAttribute = getElementAttributeFromMhloValue(pads)) { ++ if (auto valueAttribute = getElementAttributeFromConstValue(pads)) { + // If `pads` are constants, read them." + int64_t idx = 0; + for (IntegerAttr value : valueAttribute.getValues()) { @@ -350,7 +2267,7 @@ index 00000000..1482e9ee +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToMhlo/Tensor/ScatterND.cpp b/src/Conversion/ONNXToMhlo/Tensor/ScatterND.cpp new file mode 100644 -index 00000000..50045514 +index 00000000..acb95fe9 --- /dev/null +++ b/src/Conversion/ONNXToMhlo/Tensor/ScatterND.cpp @@ -0,0 +1,93 @@ @@ -448,13 +2365,13 @@ index 00000000..50045514 + +} // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp -index beb50392..2fd7ecbf 100644 +index beb50392..b821ec34 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp @@ -544,6 +544,16 @@ struct ONNXPadOpShapeHelper : public ONNXOpShapeHelper { llvm::SmallVector pads; }; - + +struct ONNXPadV13OpShapeHelper : public ONNXOpShapeHelper { + ONNXPadV13OpShapeHelper(mlir::Operation *op, mlir::ValueRange operands, + IndexExprBuilder *ieBuilder = nullptr, IndexExprScope *scope = nullptr) @@ -475,7 +2392,7 @@ index b00edc4a..f33bc18a 100644 @@ -67,6 +67,49 @@ LogicalResult ONNXPadOpShapeHelper::computeShape() { return success(); } - + +LogicalResult ONNXPadV13OpShapeHelper::computeShape() { + ONNXPadOpAdaptor operandAdaptor(operands); + Value dataOperand = operandAdaptor.getData(); @@ -520,7 +2437,7 @@ index b00edc4a..f33bc18a 100644 +} + } // namespace onnx_mlir - + //===----------------------------------------------------------------------===// @@ -108,3 +151,15 @@ LogicalResult ONNXPadOp::inferShapes( ONNXPadOpShapeHelper shapeHelper(getOperation(), {}); @@ -539,10 +2456,10 @@ index b00edc4a..f33bc18a 100644 + return shapeHelper.computeShapeAndUpdateType(elementType); +} diff --git a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp -index 758635e4..2078716a 100644 +index 758635e4..f5c07c76 100644 --- a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp +++ b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp -@@ -55,7 +55,6 @@ UNSUPPORTED_OPS(ONNXMomentumOp) +@@ -56,7 +56,6 @@ UNSUPPORTED_OPS(ONNXMomentumOp) UNSUPPORTED_OPS(ONNXMultinomialOp) UNSUPPORTED_OPS(ONNXNegativeLogLikelihoodLossOp) UNSUPPORTED_OPS(ONNXNormalizerOp) @@ -550,9 +2467,82 @@ index 758635e4..2078716a 100644 UNSUPPORTED_OPS(ONNXPadV11Op) UNSUPPORTED_OPS(ONNXPadV2Op) UNSUPPORTED_OPS(ONNXRandomUniformLikeOp) +diff --git a/test/mlir/conversion/onnx_to_mhlo/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_mhlo/Math/Elementwise.mlir +index 834471e3..ff14f8d8 100644 +--- a/test/mlir/conversion/onnx_to_mhlo/Math/Elementwise.mlir ++++ b/test/mlir/conversion/onnx_to_mhlo/Math/Elementwise.mlir +@@ -256,6 +256,15 @@ func.func @test_max(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> ten + // CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> + } + ++func.func @test_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<10x10xf32> { ++ %0 = "onnx.Min"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> ++ "func.return"(%0) : (tensor<10x10xf32>) -> () ++// CHECK-LABEL: func @test_min ++// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>, [[PARAM_1_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { ++// CHECK-NEXT: [[VAR_0_:%.+]] = mhlo.minimum [[PARAM_0_]], [[PARAM_1_]] : tensor<10x10xf32> ++// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> ++} ++ + func.func @test_leakyrelu_dynamic(%arg0 : tensor) -> tensor { + %0 = "onnx.LeakyRelu"(%arg0) {alpha=0.5:f32} : (tensor) -> tensor + "func.return"(%0) : (tensor) -> () +@@ -275,6 +284,16 @@ func.func @test_leakyrelu_dynamic(%arg0 : tensor) -> tensor + + // ----- + ++func.func @test_prelu_dynamic(%arg0 : tensor, %arg1: tensor<10x1x1xf32>) -> tensor { ++ %0 = "onnx.PRelu"(%arg0, %arg1) : (tensor, tensor<10x1x1xf32>) -> tensor ++ "func.return"(%0) : (tensor) -> () ++// CHECK-LABEL: func.func @test_prelu_dynamic ++// CHECK-SAME: (%arg0: tensor, %arg1: tensor<10x1x1xf32>) -> tensor { ++// CHECK: [[VAR_0_:%.+]] = mhlo.multiply [[INP:%.+]], [[SLOPE:%.+]] : tensor ++// CHECK: [[VAR_1_:%.+]] = mhlo.compare GT, [[INP]], [[ZEROS:%.+]], NOTYPE : (tensor, tensor) -> tensor ++// CHECK: [[VAR_2_:%.+]] = mhlo.select [[VAR_1_]], [[INP]], [[VAR_0_]] : tensor, tensor ++} ++ + func.func @test_neg(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Neg"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +@@ -290,3 +309,10 @@ func.func @test_sin(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + // CHECK-LABEL: func @test_sin + // CHECK: %0 = mhlo.sine %arg0 : tensor<10x10xf32> + } ++ ++func.func @test_where(%arg0 : tensor<16x24x36xi1>, %arg1 : tensor<16x24x36xi64>, %arg2 : tensor<16x24x36xi64>) -> tensor<16x24x36xi64> { ++ %0 = "onnx.Where"(%arg0, %arg1, %arg2) : (tensor<16x24x36xi1>, tensor<16x24x36xi64>, tensor<16x24x36xi64>) -> tensor<16x24x36xi64> ++ "func.return"(%0) : (tensor<16x24x36xi64>) -> () ++// CHECK-LABEL: func.func @test_where ++// CHECK: %0 = mhlo.select %arg0, %arg1, %arg2 : tensor<16x24x36xi1>, tensor<16x24x36xi64> ++} +diff --git a/test/mlir/conversion/onnx_to_mhlo/Tensor/OneHot.mlir b/test/mlir/conversion/onnx_to_mhlo/Tensor/OneHot.mlir +new file mode 100644 +index 00000000..a4648439 +--- /dev/null ++++ b/test/mlir/conversion/onnx_to_mhlo/Tensor/OneHot.mlir +@@ -0,0 +1,19 @@ ++// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-mhlo %s --canonicalize -split-input-file | FileCheck %s ++ ++func.func @test_onehot(%arg0 : tensor<2x3x4xi64>) -> tensor<*xi64> { ++ %0 = onnx.Constant dense<64> : tensor<1xi64> ++ %1 = onnx.Constant dense<[0, 1]> : tensor<2xi64> ++ %2 = "onnx.OneHot"(%arg0, %0, %1) {axis = -1 : si64} : (tensor<2x3x4xi64>, tensor<1xi64>, tensor<2xi64>) -> tensor<*xi64> ++ "func.return"(%2) : (tensor<*xi64>) -> () ++// CHECK-LABEL: func.func @test_onehot ++// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x3x4xi64>) -> tensor<2x3x4x64xi64> { ++// CHECK: %[[IOTA:.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64xi64> ++// CHECK: %[[BCAST_IOTA:.+]] = "mhlo.broadcast_in_dim"(%[[IOTA]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<64xi64>) -> tensor<2x3x4x64xi64> ++// CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%[[ARG0]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xi64>) -> tensor<2x3x4x64xi64> ++// CHECK: %[[GE_ZERO:.+]] = mhlo.compare GE, %[[BCAST_ARG0]], %[[BCAST_ZERO:.+]], NOTYPE : (tensor<2x3x4x64xi64>, tensor<2x3x4x64xi64>) -> tensor<2x3x4x64xi1> ++// CHECK: %[[POS_ARG:.+]] = mhlo.add %[[BCAST_ARG0]], %[[BCAST_DEPTH:.+]] : tensor<2x3x4x64xi64> ++// CHECK: %[[NORM_ARG:.+]] = mhlo.select %[[GE_ZERO]], %[[BCAST_ARG0]], %[[POS_ARG]] : tensor<2x3x4x64xi1>, tensor<2x3x4x64xi64> ++// CHECK: %[[COMPARE:.+]] = mhlo.compare EQ, %[[NORM_ARG]], %[[BCAST_IOTA]], NOTYPE : (tensor<2x3x4x64xi64>, tensor<2x3x4x64xi64>) -> tensor<2x3x4x64xi1> ++// CHECK: %[[RESULT:.+]] = mhlo.select %[[COMPARE]], %[[ON_VALUE:.+]], %[[OFF_VALUE:.+]] : tensor<2x3x4x64xi1>, tensor<2x3x4x64xi64> ++// CHECK: return %[[RESULT]] : tensor<2x3x4x64xi64> ++} diff --git a/test/mlir/conversion/onnx_to_mhlo/Tensor/ScatterND.mlir b/test/mlir/conversion/onnx_to_mhlo/Tensor/ScatterND.mlir new file mode 100644 -index 00000000..cb4a5c8d +index 00000000..bffb2b21 --- /dev/null +++ b/test/mlir/conversion/onnx_to_mhlo/Tensor/ScatterND.mlir @@ -0,0 +1,23 @@ diff --git a/frontends/onnx-frontend/third_party/patches/OnnxMlirPRelu.patch b/frontends/onnx-frontend/third_party/patches/OnnxMlirPRelu.patch deleted file mode 100644 index c0fad1b43..000000000 --- a/frontends/onnx-frontend/third_party/patches/OnnxMlirPRelu.patch +++ /dev/null @@ -1,73 +0,0 @@ -diff --git a/src/Conversion/ONNXToMhlo/Math/Elementwise.cpp b/src/Conversion/ONNXToMhlo/Math/Elementwise.cpp -index 26c392b8..3eb8c45d 100644 ---- a/src/Conversion/ONNXToMhlo/Math/Elementwise.cpp -+++ b/src/Conversion/ONNXToMhlo/Math/Elementwise.cpp -@@ -293,6 +293,39 @@ struct ONNXElementwiseBinaryOpLoweringToMhlo : public ConversionPattern { - } - }; - -+// ONNXPReluOp(x) = alpha * x if x < 0 else x. -+template <> -+struct ONNXElementwiseBinaryOpLoweringToMhlo -+ : public ConversionPattern { -+ ONNXElementwiseBinaryOpLoweringToMhlo(MLIRContext *ctx) -+ : ConversionPattern(ONNXPReluOp::getOperationName(), 1, ctx) {} -+ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, -+ ConversionPatternRewriter &rewriter) const final { -+ Location loc = op->getLoc(); -+ // Prior code here used the "analysis" version that did not generate code. -+ // Since code is actually not needed here at this time, one could use -+ // IndexExprBuilderForAnalysis createIE(loc) instead. -+ IndexExprBuilderForMhlo createShapeIE(rewriter, loc); -+ ONNXBroadcastOpShapeHelper shapeHelper(op, operands, &createShapeIE); -+ shapeHelper.computeShapeAndAssertOnFailure(); -+ -+ int64_t outputRank = shapeHelper.outputRank; -+ llvm::SmallVector broadcastedOperands = -+ getBroadcastedOperands(op, rewriter, loc, outputRank); -+ Value inp = broadcastedOperands[0]; -+ Value broadcastedSlope = broadcastedOperands[1]; -+ Type resultType = *op->result_type_begin(); -+ Value PReluActivationVal = rewriter.create(loc, inp, broadcastedSlope); -+ Value broadcastedZero = getShapedZero(loc, rewriter, inp); -+ Value compareGtZero = rewriter.create( -+ loc, inp, broadcastedZero, mhlo::ComparisonDirection::GT); -+ Value resultOp = rewriter.create( -+ loc, resultType, compareGtZero, inp, PReluActivationVal); -+ rewriter.replaceOp(op, resultOp); -+ return success(); -+ } -+}; -+ - // Element-wise variadic ops lowering to Mhlo dialect. - //===----------------------------------------------------------------------===// - template -@@ -343,6 +376,7 @@ void populateLoweringONNXElementwiseOpToMhloPattern( - ONNXElementwiseCompareBinaryOpLoweringToMhlo, - ONNXElementwiseCompareBinaryOpLoweringToMhlo, - ONNXElementwiseBinaryOpLoweringToMhlo, -+ ONNXElementwiseBinaryOpLoweringToMhlo, - ONNXElementwiseVariadicOpLoweringToMhlo, - ONNXElementwiseVariadicOpLoweringToMhlo, - ONNXElementwiseVariadicOpLoweringToMhlo, -diff --git a/test/mlir/conversion/onnx_to_mhlo/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_mhlo/Math/Elementwise.mlir -index 834471e3..261aa444 100644 ---- a/test/mlir/conversion/onnx_to_mhlo/Math/Elementwise.mlir -+++ b/test/mlir/conversion/onnx_to_mhlo/Math/Elementwise.mlir -@@ -275,6 +275,16 @@ func.func @test_leakyrelu_dynamic(%arg0 : tensor) -> tensor - - // ----- - -+func.func @test_prelu_dynamic(%arg0 : tensor, %arg1: tensor<10x1x1xf32>) -> tensor { -+ %0 = "onnx.PRelu"(%arg0, %arg1) : (tensor, tensor<10x1x1xf32>) -> tensor -+ "func.return"(%0) : (tensor) -> () -+// CHECK-LABEL: func.func @test_prelu_dynamic -+// CHECK-SAME: (%arg0: tensor, %arg1: tensor<10x1x1xf32>) -> tensor { -+// CHECK: [[VAR_0_:%.+]] = mhlo.multiply [[INP:%.+]], [[SLOPE:%.+]] : tensor -+// CHECK: [[VAR_1_:%.+]] = mhlo.compare GT, [[INP]], [[ZEROS:%.+]], NOTYPE : (tensor, tensor) -> tensor -+// CHECK: [[VAR_2_:%.+]] = mhlo.select [[VAR_1_]], [[INP]], [[VAR_0_]] : tensor, tensor -+} -+ - func.func @test_neg(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { - %0 = "onnx.Neg"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> - "func.return"(%0) : (tensor<10x10xf32>) -> () diff --git a/frontends/onnx-frontend/third_party/patches/OnnxMlirRegisterLibrary.patch b/frontends/onnx-frontend/third_party/patches/OnnxMlirRegisterLibrary.patch index dce5688a4..f315684fc 100644 --- a/frontends/onnx-frontend/third_party/patches/OnnxMlirRegisterLibrary.patch +++ b/frontends/onnx-frontend/third_party/patches/OnnxMlirRegisterLibrary.patch @@ -2,7 +2,7 @@ diff --git a/src/Tools/onnx-mlir-opt/CMakeLists.txt b/src/Tools/onnx-mlir-opt/CM index a90a670a..0a80c88b 100644 --- a/src/Tools/onnx-mlir-opt/CMakeLists.txt +++ b/src/Tools/onnx-mlir-opt/CMakeLists.txt -@@ -20,3 +20,17 @@ add_onnx_mlir_executable(onnx-mlir-opt +@@ -20,3 +20,16 @@ add_onnx_mlir_executable(onnx-mlir-opt MLIROptLib MLIRSCFToOpenMP ) @@ -19,4 +19,3 @@ index a90a670a..0a80c88b 100644 + MLIRLinalgTransforms + MLIRMemRefTransforms +) -+ diff --git a/frontends/tf-frontend/tf_mlir_ext/numerical/numerical_test.py b/frontends/tf-frontend/tf_mlir_ext/numerical/numerical_test.py index c0d852716..5c89a769a 100755 --- a/frontends/tf-frontend/tf_mlir_ext/numerical/numerical_test.py +++ b/frontends/tf-frontend/tf_mlir_ext/numerical/numerical_test.py @@ -51,6 +51,7 @@ def get_config(config: str): "log_softmax_case0": 6, "erf_case0": 6, "gelu_erf_case0": 6, + "gelu_erf_case1": 6, "gelu_tanh_case0": 6, "gelu_tanh_case1": 2, "gelu_tanh_case2": 6, @@ -65,10 +66,12 @@ def get_config(config: str): "layer_norm_V4": 2, "layer_norm_V4_swap_squarediff": 2, "layer_norm_with_cast": 2, + "layer_norm_with_cast_v2": 2, "layer_norm_with_cast_disable_minimize_broadcast": 2, "l2_norm_V1": 6, "l2_norm_V1_swap_mul": 6, "l2_norm_V2": 3, + "l2_norm_V2_swap_mul": 3, "l2_norm_V3": 6, "onehot_case0": 6, }, diff --git a/frontends/tf-frontend/tf_mlir_ext/numerical/rewrite_to_custom_call.mlir b/frontends/tf-frontend/tf_mlir_ext/numerical/rewrite_to_custom_call.mlir index 117758e6e..c43d27187 100755 --- a/frontends/tf-frontend/tf_mlir_ext/numerical/rewrite_to_custom_call.mlir +++ b/frontends/tf-frontend/tf_mlir_ext/numerical/rewrite_to_custom_call.mlir @@ -45,6 +45,22 @@ func.func @gelu_erf_case0(%arg0: tensor<100x?x?xf32>) -> tensor<100x?x?xf32> { // CHECK-SAME: @byteir.gelu // CHECK-SAME: byteir_attrs = {approximate = "erf"} +func.func @gelu_erf_case1(%arg0: tensor<100x?x?xf32>) -> tensor<100x?x?xf32> { + %cst = "tf.Const"() {value = dense<0.707106769> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %cst) : (tensor<100x?x?xf32>, tensor) -> tensor<100x?x?xf32> + %2 = "tf.Erf"(%1) : (tensor<100x?x?xf32>) -> tensor<100x?x?xf32> + %3 = "tf.AddV2"(%2, %cst_0) : (tensor<100x?x?xf32>, tensor) -> tensor<100x?x?xf32> + %4 = "tf.Mul"(%3, %cst_1) : (tensor<100x?x?xf32>, tensor) -> tensor<100x?x?xf32> + %5 = "tf.Mul"(%arg0, %4) : (tensor<100x?x?xf32>, tensor<100x?x?xf32>) -> tensor<100x?x?xf32> + func.return %5 : tensor<100x?x?xf32> +} +// CHECK-LABEL: func.func @gelu_erf_case1(%arg0: tensor<100x?x?xf32>) -> tensor<100x?x?xf32> { +// CHECK: mhlo.custom_call +// CHECK-SAME: @byteir.gelu +// CHECK-SAME: byteir_attrs = {approximate = "erf"} + func.func @gelu_tanh_case0(%arg0: tensor<100x?x?xf32>) -> tensor<100x?x?xf32> { %cst = "tf.Const"() {value = dense<4.471500e-02> : tensor} : () -> tensor %cst_0 = "tf.Const"() {value = dense<3.000000e+00> : tensor} : () -> tensor @@ -388,6 +404,33 @@ func.func @layer_norm_with_cast(%79: tensor<150x3xf16>) -> tensor<150x3xf16> { // CHECK-SAME: @byteir.layer_norm // CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 1.0132789611816406E-6 : f64} +func.func @layer_norm_with_cast_v2(%79: tensor<150x3xf16>) -> tensor<150x3xf16> { + %cst_61 = "tf.Const"() {value = dense<[0.0401659757, -0.11370486, 0.432680517]> : tensor<3xf16>} : () -> tensor<3xf16> + %cst_62 = "tf.Const"() {value = dense<[0.445568085, 0.45303449, 3.227140e-01]> : tensor<3xf16>} : () -> tensor<3xf16> + %cst_157 = "tf.Const"() {value = dense<1.013280e-06> : tensor} : () -> tensor + %cst_158 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %80 = "tf.Cast"(%79) {Truncate = false, device = ""} : (tensor<150x3xf16>) -> tensor<150x3xf32> + %81 = "tf.Mean"(%80, %cst_158) {device = "", keep_dims = true} : (tensor<150x3xf32>, tensor) -> tensor<150x1xf32> + %82 = "tf.Cast"(%81) {Truncate = false, device = ""} : (tensor<150x1xf32>) -> tensor<150x1xf16> + %83 = "tf.SquaredDifference"(%80, %81) {device = ""} : (tensor<150x3xf32>, tensor<150x1xf32>) -> tensor<150x3xf32> + %84 = "tf.Mean"(%83, %cst_158) {device = "", keep_dims = true} : (tensor<150x3xf32>, tensor) -> tensor<150x1xf32> + %85 = "tf.Cast"(%84) {Truncate = false, device = ""} : (tensor<150x1xf32>) -> tensor<150x1xf16> + %86 = "tf.AddV2"(%85, %cst_157) {device = ""} : (tensor<150x1xf16>, tensor) -> tensor<150x1xf16> + %87 = "tf.Rsqrt"(%86) {device = ""} : (tensor<150x1xf16>) -> tensor<150x1xf16> + %88 = "tf.Mul"(%87, %cst_61) {_grappler_ArithmeticOptimizer_MinimizeBroadcasts = true, device = ""} : (tensor<150x1xf16>, tensor<3xf16>) -> tensor<150x3xf16> + %89 = "tf.Mul"(%79, %88) {_grappler_ArithmeticOptimizer_MinimizeBroadcasts = true, device = ""} : (tensor<150x3xf16>, tensor<150x3xf16>) -> tensor<150x3xf16> + %90 = "tf.Mul"(%88, %82) {_grappler_ArithmeticOptimizer_MinimizeBroadcasts = true, device = ""} : (tensor<150x3xf16>, tensor<150x1xf16>) -> tensor<150x3xf16> + %91 = "tf.Sub"(%cst_62, %90) {device = ""} : (tensor<3xf16>, tensor<150x3xf16>) -> tensor<150x3xf16> + %92 = "tf.AddV2"(%89, %91) {device = ""} : (tensor<150x3xf16>, tensor<150x3xf16>) -> tensor<150x3xf16> + return %92 : tensor<150x3xf16> +} +// CHECK-LABEL: func.func @layer_norm_with_cast_v2(%arg0: tensor<150x3xf16>) -> tensor<150x3xf16> { +// CHECK-NEXT: %cst = "tf.Const"() {value = dense<[4.016110e-02, -1.137080e-01, 4.326170e-01]> : tensor<3xf16>} : () -> tensor<3xf16> +// CHECK-NEXT: %cst_0 = "tf.Const"() {value = dense<[4.455570e-01, 4.531250e-01, 3.227540e-01]> : tensor<3xf16>} : () -> tensor<3xf16> +// CHECK: mhlo.custom_call +// CHECK-SAME: @byteir.layer_norm +// CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 1.0132789611816406E-6 : f64} + func.func @layer_norm_with_cast_disable_minimize_broadcast(%46: tensor<1024x4xf16>) -> tensor<1024x4xf16> { %cst_103 = "tf.Const"() {value = dense<[0.0401659757, -0.11370486, 0.432680517, 0.4000000]> : tensor<4xf16>} : () -> tensor<4xf16> %cst_104 = "tf.Const"() {value = dense<[0.445568085, 0.45303449, 3.227140e-01, 0.4000000]> : tensor<4xf16>} : () -> tensor<4xf16> @@ -454,6 +497,20 @@ func.func @l2_norm_V2(%1871: tensor<1x128xf16>) -> tensor<1x128xf16> { // CHECK-SAME: @byteir.l2_norm // CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 0.000000e+00 : f64} +func.func @l2_norm_V2_swap_mul(%1871: tensor<1x128xf16>) -> tensor<1x128xf16> { + %cst_5 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %1872 = "tf.Square"(%1871) {device = ""} : (tensor<1x128xf16>) -> tensor<1x128xf16> + %1873 = "tf.Sum"(%1872, %cst_5) {device = "", keep_dims = true} : (tensor<1x128xf16>, tensor) -> tensor<1x1xf16> + %1874 = "tf.Relu"(%1873) : (tensor<1x1xf16>) -> tensor<1x1xf16> + %1875 = "tf.Rsqrt"(%1874) {device = ""} : (tensor<1x1xf16>) -> tensor<1x1xf16> + %1876 = "tf.Mul"(%1871, %1875) {device = ""} : (tensor<1x128xf16>, tensor<1x1xf16>) -> tensor<1x128xf16> + return %1876 : tensor<1x128xf16> +} +// CHECK-LABEL: @l2_norm_V2_swap_mul +// CHECK: mhlo.custom_call +// CHECK-SAME: @byteir.l2_norm +// CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 0.000000e+00 : f64} + func.func @l2_norm_V3(%15: tensor<1x100x512xf32>) -> tensor<1x100x512xf32> { %cst_96 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> %16 = "tf.Square"(%15) {device = ""} : (tensor<1x100x512xf32>) -> tensor<1x100x512xf32> diff --git a/frontends/tf-frontend/tf_mlir_ext/tests/fallback_to_custom_call.mlir b/frontends/tf-frontend/tf_mlir_ext/tests/fallback_to_custom_call.mlir index b79dc03a9..484c74956 100644 --- a/frontends/tf-frontend/tf_mlir_ext/tests/fallback_to_custom_call.mlir +++ b/frontends/tf-frontend/tf_mlir_ext/tests/fallback_to_custom_call.mlir @@ -6,6 +6,12 @@ func.func @test_tf_const_string() -> tensor { } // CHECK: %0 = "ace.constant"() {value = dense<"fork_active_pay"> : tensor} : () -> tensor +func.func @test_tf_squeeze_string(%arg0: tensor<512x1x!tf_type.string>) -> tensor<512x!tf_type.string> { + %0 = "tf.Squeeze"(%arg0) {squeeze_dims = [-1]} : (tensor<512x1x!tf_type.string>) -> tensor<512x!tf_type.string> + func.return %0 : tensor<512x!tf_type.string> +} +// CHECK: ace.reshape + func.func @test_to_mhlo_custom_call(%arg0 : tensor) -> tensor { %0 = "tf.Where"(%arg0) {_XlaCompile = false, _XlaScope = "jit_scope_0", _XlaSeparateCompiledGradients = false, device = "/device:CPU:0"} : (tensor) -> tensor func.return %0 : tensor diff --git a/frontends/tf-frontend/tf_mlir_ext/tests/rewrite_to_custom_call.mlir b/frontends/tf-frontend/tf_mlir_ext/tests/rewrite_to_custom_call.mlir index 4b070673d..fd6030bdd 100755 --- a/frontends/tf-frontend/tf_mlir_ext/tests/rewrite_to_custom_call.mlir +++ b/frontends/tf-frontend/tf_mlir_ext/tests/rewrite_to_custom_call.mlir @@ -43,6 +43,22 @@ func.func @gelu_erf_case0(%arg0: tensor<100x?x?xf32>) -> tensor<100x?x?xf32> { // CHECK-SAME: @byteir.gelu // CHECK-SAME: byteir_attrs = {approximate = "erf"} +func.func @gelu_erf_case1(%arg0: tensor<100x?x?xf32>) -> tensor<100x?x?xf32> { + %cst = "tf.Const"() {value = dense<0.707106769> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %cst) : (tensor<100x?x?xf32>, tensor) -> tensor<100x?x?xf32> + %2 = "tf.Erf"(%1) : (tensor<100x?x?xf32>) -> tensor<100x?x?xf32> + %3 = "tf.AddV2"(%2, %cst_0) : (tensor<100x?x?xf32>, tensor) -> tensor<100x?x?xf32> + %4 = "tf.Mul"(%3, %cst_1) : (tensor<100x?x?xf32>, tensor) -> tensor<100x?x?xf32> + %5 = "tf.Mul"(%arg0, %4) : (tensor<100x?x?xf32>, tensor<100x?x?xf32>) -> tensor<100x?x?xf32> + func.return %5 : tensor<100x?x?xf32> +} +// CHECK-LABEL: func.func @gelu_erf_case1(%arg0: tensor<100x?x?xf32>) -> tensor<100x?x?xf32> { +// CHECK: mhlo.custom_call +// CHECK-SAME: @byteir.gelu +// CHECK-SAME: byteir_attrs = {approximate = "erf"} + func.func @gelu_tanh_case0(%arg0: tensor<100x?x?xf32>) -> tensor<100x?x?xf32> { %cst = "tf.Const"() {value = dense<4.471500e-02> : tensor} : () -> tensor %cst_0 = "tf.Const"() {value = dense<3.000000e+00> : tensor} : () -> tensor @@ -386,6 +402,33 @@ func.func @layer_norm_with_cast(%79: tensor<150x3xf16>) -> tensor<150x3xf16> { // CHECK-SAME: @byteir.layer_norm // CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 1.0132789611816406E-6 : f64} +func.func @layer_norm_with_cast_v2(%79: tensor<150x3xf16>) -> tensor<150x3xf16> { + %cst_61 = "tf.Const"() {value = dense<[0.0401659757, -0.11370486, 0.432680517]> : tensor<3xf16>} : () -> tensor<3xf16> + %cst_62 = "tf.Const"() {value = dense<[0.445568085, 0.45303449, 3.227140e-01]> : tensor<3xf16>} : () -> tensor<3xf16> + %cst_157 = "tf.Const"() {value = dense<1.013280e-06> : tensor} : () -> tensor + %cst_158 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %80 = "tf.Cast"(%79) {Truncate = false, device = ""} : (tensor<150x3xf16>) -> tensor<150x3xf32> + %81 = "tf.Mean"(%80, %cst_158) {device = "", keep_dims = true} : (tensor<150x3xf32>, tensor) -> tensor<150x1xf32> + %82 = "tf.Cast"(%81) {Truncate = false, device = ""} : (tensor<150x1xf32>) -> tensor<150x1xf16> + %83 = "tf.SquaredDifference"(%80, %81) {device = ""} : (tensor<150x3xf32>, tensor<150x1xf32>) -> tensor<150x3xf32> + %84 = "tf.Mean"(%83, %cst_158) {device = "", keep_dims = true} : (tensor<150x3xf32>, tensor) -> tensor<150x1xf32> + %85 = "tf.Cast"(%84) {Truncate = false, device = ""} : (tensor<150x1xf32>) -> tensor<150x1xf16> + %86 = "tf.AddV2"(%85, %cst_157) {device = ""} : (tensor<150x1xf16>, tensor) -> tensor<150x1xf16> + %87 = "tf.Rsqrt"(%86) {device = ""} : (tensor<150x1xf16>) -> tensor<150x1xf16> + %88 = "tf.Mul"(%87, %cst_61) {_grappler_ArithmeticOptimizer_MinimizeBroadcasts = true, device = ""} : (tensor<150x1xf16>, tensor<3xf16>) -> tensor<150x3xf16> + %89 = "tf.Mul"(%79, %88) {_grappler_ArithmeticOptimizer_MinimizeBroadcasts = true, device = ""} : (tensor<150x3xf16>, tensor<150x3xf16>) -> tensor<150x3xf16> + %90 = "tf.Mul"(%88, %82) {_grappler_ArithmeticOptimizer_MinimizeBroadcasts = true, device = ""} : (tensor<150x3xf16>, tensor<150x1xf16>) -> tensor<150x3xf16> + %91 = "tf.Sub"(%cst_62, %90) {device = ""} : (tensor<3xf16>, tensor<150x3xf16>) -> tensor<150x3xf16> + %92 = "tf.AddV2"(%89, %91) {device = ""} : (tensor<150x3xf16>, tensor<150x3xf16>) -> tensor<150x3xf16> + return %92 : tensor<150x3xf16> +} +// CHECK-LABEL: func.func @layer_norm_with_cast_v2(%arg0: tensor<150x3xf16>) -> tensor<150x3xf16> { +// CHECK-NEXT: %cst = "tf.Const"() {value = dense<[4.016110e-02, -1.137080e-01, 4.326170e-01]> : tensor<3xf16>} : () -> tensor<3xf16> +// CHECK-NEXT: %cst_0 = "tf.Const"() {value = dense<[4.455570e-01, 4.531250e-01, 3.227540e-01]> : tensor<3xf16>} : () -> tensor<3xf16> +// CHECK: mhlo.custom_call +// CHECK-SAME: @byteir.layer_norm +// CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 1.0132789611816406E-6 : f64} + func.func @layer_norm_with_cast_disable_minimize_broadcast(%46: tensor<1024x4xf16>) -> tensor<1024x4xf16> { %cst_103 = "tf.Const"() {value = dense<[0.0401659757, -0.11370486, 0.432680517, 0.4000000]> : tensor<4xf16>} : () -> tensor<4xf16> %cst_104 = "tf.Const"() {value = dense<[0.445568085, 0.45303449, 3.227140e-01, 0.4000000]> : tensor<4xf16>} : () -> tensor<4xf16> @@ -452,6 +495,20 @@ func.func @l2_norm_V2(%1871: tensor<1x128xf16>) -> tensor<1x128xf16> { // CHECK-SAME: @byteir.l2_norm // CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 0.000000e+00 : f64} +func.func @l2_norm_V2_swap_mul(%1871: tensor<1x128xf16>) -> tensor<1x128xf16> { + %cst_5 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %1872 = "tf.Square"(%1871) {device = ""} : (tensor<1x128xf16>) -> tensor<1x128xf16> + %1873 = "tf.Sum"(%1872, %cst_5) {device = "", keep_dims = true} : (tensor<1x128xf16>, tensor) -> tensor<1x1xf16> + %1874 = "tf.Relu"(%1873) : (tensor<1x1xf16>) -> tensor<1x1xf16> + %1875 = "tf.Rsqrt"(%1874) {device = ""} : (tensor<1x1xf16>) -> tensor<1x1xf16> + %1876 = "tf.Mul"(%1871, %1875) {device = ""} : (tensor<1x128xf16>, tensor<1x1xf16>) -> tensor<1x128xf16> + return %1876 : tensor<1x128xf16> +} +// CHECK-LABEL: @l2_norm_V2_swap_mul +// CHECK: mhlo.custom_call +// CHECK-SAME: @byteir.l2_norm +// CHECK-SAME: byteir_attrs = {axis = [1], epsilon = 0.000000e+00 : f64} + func.func @l2_norm_V3(%15: tensor<1x100x512xf32>) -> tensor<1x100x512xf32> { %cst_96 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> %16 = "tf.Square"(%15) {device = ""} : (tensor<1x100x512xf32>) -> tensor<1x100x512xf32> diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.cc b/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.cc index 5294f3007..6b897d0ee 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.cc +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.cc @@ -60,7 +60,8 @@ namespace { cb(layer_norm, LayerNorm, CALL_TARGET_NAME_PREFIX) \ cb(l2_norm, L2Norm, CALL_TARGET_NAME_PREFIX) \ cb(addn, AddN, CALL_TARGET_NAME_PREFIX) \ - cb(one_hot, OneHot, CALL_TARGET_NAME_PREFIX) \ + cb(one_hot, OneHot, CALL_TARGET_NAME_PREFIX) \ + cb(repeat, Repeat, CALL_TARGET_NAME_PREFIX) \ cb(DynamicMaskStitch, DynamicMaskStitch, CALL_TF_TARGET_NAME_PREFIX) \ cb(DynamicPartition, DynamicPartition, CALL_TF_TARGET_NAME_PREFIX) \ cb(DynamicStitch, DynamicStitch, CALL_TF_TARGET_NAME_PREFIX) @@ -445,6 +446,30 @@ struct RewriteOneHot : public OpRewritePattern { } }; +//===----------------------------------------------------------------------===// +// Repeat Pattern +//===----------------------------------------------------------------------===// +struct RewriteRepeat : public RewritePattern { + RewriteRepeat(MLIRContext *context, PatternBenefit benefits = 1) + : RewritePattern("tf.Repeat", benefits, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // llvm::outs() << op->getName().getStringRef(); + assert(op->getName().getStringRef() == "tf.Repeat"); + mhlo::CustomCallOp customCallOp = rewriter.create( + op->getLoc(), op->getResults().getTypes(), op->getOperands(), + getRepeatNameWithPrefix(), false, rewriter.getStringAttr(""), + mhlo::CustomCallApiVersion{ + mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL}, + rewriter.getArrayAttr(ArrayRef{}), + mhlo::CustomCallSchedule{mhlo::CustomCallSchedule::NONE}, nullptr, + nullptr, rewriter.getArrayAttr(ArrayRef{})); + customCallOp->setAttr(getByteIRAttrs(), getCleanAttr(op)); + rewriter.replaceOp(op, customCallOp->getResults()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // SimpleReplace Pattern //===----------------------------------------------------------------------===// @@ -616,6 +641,8 @@ struct RewriteToCustomCallOpsPass std::make_unique(context)); validCustomCallOpSet[getGeLUName()].emplace_back( std::make_unique(context)); + validCustomCallOpSet[getGeLUName()].emplace_back( + std::make_unique(context)); if (keepBody) { validCustomCallOpSet[getLayerNormName()].emplace_back( @@ -642,6 +669,8 @@ struct RewriteToCustomCallOpsPass std::make_unique(context)); validCustomCallOpSet[getLayerNormName()].emplace_back( std::make_unique(context)); + validCustomCallOpSet[getLayerNormName()].emplace_back( + std::make_unique(context)); validCustomCallOpSet[getLayerNormName()].emplace_back( std::make_unique( context)); @@ -652,6 +681,8 @@ struct RewriteToCustomCallOpsPass std::make_unique(context)); validCustomCallOpSet[getL2NormName()].emplace_back( std::make_unique(context)); + validCustomCallOpSet[getL2NormName()].emplace_back( + std::make_unique(context)); validCustomCallOpSet[getL2NormName()].emplace_back( std::make_unique(context)); @@ -681,6 +712,8 @@ struct RewriteToCustomCallOpsPass std::make_unique(context, 1)); validCustomCallOpSet[getDynamicStitchName()].emplace_back( std::make_unique(context, 1)); + validCustomCallOpSet[getRepeatName()].emplace_back( + std::make_unique(context, 1)); RewritePatternSet patterns(context); for (auto op : opsSet) { diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.td b/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.td index b4502b150..0102233dc 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.td +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.td @@ -450,6 +450,54 @@ def RewriteLayerNormWithCast : Pat< (SameTwoValuesOrAttrs $input_cast_mean, $input_cast_mean_1), (OneRank $gama_attr), (OneRank $beta_attr)]>; +def RewriteLayerNormWithCastV2 : Pat< + (TF_AddV2Op + (TF_MulOp + $input, + (TF_MulOp:$factor + (TF_RsqrtOp + (TF_AddV2Op + (TF_CastOp + (TF_MeanOp + (TF_SquaredDifferenceOp + $input_cast_1, + (TF_MeanOp:$input_cast_mean + (TF_CastOp:$input_cast $input, $_), + (TF_ConstOp:$axis $axis_attr), + $keep_dims_0 + ) + ), + $axis_1, + $keep_dims_1 + ), + $_ + ), + (TF_ConstOp $epsilon_attr) + ) + ), + (TF_ConstOp:$gama $gama_attr) + ) + ), + (TF_SubOp + (TF_ConstOp:$beta $beta_attr), + (TF_MulOp + $factor_1, + (TF_CastOp + $input_cast_mean_1, + $_ + ) + ) + ) + ), + (NativeCodeCall<"createLayerNorm($_builder, $_loc, $0, $1, $2, $3, $4)"> $input, $gama, $beta, $epsilon_attr, $axis_attr), + [(OneSize $epsilon_attr), (OneSize $axis_attr), + (TrueBoolAttr $keep_dims_0), (TrueBoolAttr $keep_dims_1), + (SameTwoValuesOrAttrs $axis, $axis_1), + (SameTwoValuesOrAttrs $input_cast, $input_cast_1), + (SameTwoValuesOrAttrs $input_cast_mean, $input_cast_mean_1), + (SameTwoValuesOrAttrs $factor, $factor_1), + (OneRank $gama_attr), (OneRank $beta_attr)]>; + def RewriteLayerNormWithCastDisableMinimizeBroadcast : Pat< (TF_AddV2Op (TF_MulOp @@ -557,6 +605,24 @@ def RewriteL2NormV2 : Pat< (NativeCodeCall<"createL2NormV2($_builder, $_loc, $0, $1)"> $input, $axis_attr), [(OneSize $axis_attr), (TrueBoolAttr $keep_dims)]>; +def RewriteL2NormV2SwapMul : Pat< + (TF_MulOp + $input, + (TF_RsqrtOp + (TF_ReluOp + (TF_SumOp + (TF_SquareOp + $input + ), + (TF_ConstOp:$axis $axis_attr), + $keep_dims + ) + ) + ) + ), + (NativeCodeCall<"createL2NormV2($_builder, $_loc, $0, $1)"> $input, $axis_attr), + [(OneSize $axis_attr), (TrueBoolAttr $keep_dims)]>; + // note: x^2 guarantee >= 0, so set epsilon to 0.0 def RewriteL2NormV3 : Pat< (TF_MulOp @@ -602,6 +668,25 @@ def RewriteGELUerf : Pat< (NativeCodeCall<"createGELU($_builder, $_loc, $0, \"erf\")"> $input), [(GeluValue0 $cst1), (GeluValue1 $cst0), (GeluValue5 $cst)]>; +def RewriteGELUerfV2 : Pat< + (TF_MulOp + $input, + (TF_MulOp + (TF_AddV2Op + (TF_ErfOp + (TF_MulOp + $input, + (TF_ConstOp $cst) + ) + ), + (TF_ConstOp $cst0) + ), + (TF_ConstOp $cst1) + ) + ), + (NativeCodeCall<"createGELU($_builder, $_loc, $0, \"erf\")"> $input), + [(GeluValue0 $cst1), (GeluValue1 $cst0), (GeluValue5 $cst)]>; + def RewriteGELUtanh : Pat< (TF_MulOp (TF_AddV2Op diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/tf_fallback_to_custom_call.cc b/frontends/tf-frontend/tf_mlir_ext/transforms/tf_fallback_to_custom_call.cc index 4893e7420..9db2b2b14 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/tf_fallback_to_custom_call.cc +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/tf_fallback_to_custom_call.cc @@ -120,6 +120,22 @@ void LowerToAceConstant(TF::ConstOp op) { ReplaceOp(op, ace_const_op); } +void LowerToAceReshape(TF::SqueezeOp op) { + ShapedType ty = op.getOutput().getType().dyn_cast(); + // TODO(lyq): handle resource type + if (!ty || !ty.getElementType().isa() || + !ty.hasStaticShape()) { + return; + } + OpBuilder builder(op); + + auto new_result_ty = + ty.cloneWith(std::nullopt, ace::StringType::get(op->getContext())); + ace::ReshapeOp ace_reshape_op = builder.create( + op->getLoc(), new_result_ty, op->getOperand(0)); + ReplaceOp(op, ace_reshape_op); +} + void RewriteTFPrint(Operation *op) { auto operands = llvm::to_vector(op->getOperands()); auto results = op->getResults(); @@ -163,6 +179,10 @@ struct TfFallbackToCustomCallPass LowerToAceConstant(llvm::cast(op)); return; } + if (llvm::isa(op)) { + LowerToAceReshape(llvm::cast(op)); + return; + } if (llvm::any_of(op->getOperandTypes(), check_is_string_or_resource) || llvm::any_of(op->getResultTypes(), check_is_string_or_resource)) { LowerToAceCustomCall(op); diff --git a/frontends/torch-frontend/examples/demo/backend.py b/frontends/torch-frontend/examples/demo/backend.py index a15d4e215..8ad4cb92a 100644 --- a/frontends/torch-frontend/examples/demo/backend.py +++ b/frontends/torch-frontend/examples/demo/backend.py @@ -12,11 +12,10 @@ import torch_frontend from torch_frontend import list_decomposed_ops, preprocess_fx_graph, fx_replace_attn_pattern, replace_flash_attn, get_none_indices - +from context import FxGraphCache TRACE = False -submodule_cnt = 0 MODEL_NAME = '' FLASH = False @@ -151,50 +150,70 @@ def byteir_runner(*inputs): print("\n\n============") print(f"{category} Part") print("============\n\n") - none_indices = get_none_indices(graph) - fx_graph = preprocess_fx_graph(graph) + graph_kwargs = { + #"cudagraphs": cudagraphs, + #"num_fixed": num_fixed, + "is_backward": is_backward, + #"graph_id": graph_id, + #"cpp_wrapper": cpp_wrapper, + #"aot_mode": aot_mode, + #"is_inference": is_inference, + #"user_visible_outputs": user_visible_outputs, + #"layout_opt": layout_opt, + #"extern_node_serializer": extern_node_serializer, + } compile_type = 'stablehlo' - backend_legal_ops = [ - "aten._softmax", - "aten.softmax.int", - "aten.log_softmax.int", - "aten._log_softmax", - # "aten.native_layer_norm", - # "aten.layer_norm", - "aten.gelu", - "aten.argmax", - "aten.max.dim", - "aten.one_hot", - "aten.topk", - "byteir.flash_attn_fwd", - "byteir.flash_attn_bwd", - ] - with maybe_disable_fake_tensor_mode(): - compiled_graph = torch_frontend.compile(fx_graph, inputs, compile_type, backend_legal_ops=backend_legal_ops) - model_name = MODEL_NAME - global submodule_cnt TEMP_FOLDER="./temp" + RT_FOLDER=TEMP_FOLDER + f"/{model_name}_{category}" os.makedirs(TEMP_FOLDER, exist_ok=True) - os.makedirs(TEMP_FOLDER + f"/{model_name}_{category}_{submodule_cnt}", exist_ok=True) - mlir_file_name = f'{TEMP_FOLDER}/{model_name}_{category}_{submodule_cnt}.{compile_type}.mlir' - output_mlir_file_name = f'{TEMP_FOLDER}/{model_name}_{category}_{submodule_cnt}/{model_name}_{category}_{submodule_cnt}.rt.mlir' - submodule_cnt = submodule_cnt + 1 - with open(mlir_file_name, "w+") as fout: - compiled_graph.operation.print(file=fout, - large_elements_limit=None) - - with maybe_disable_fake_tensor_mode(): - byteir.compile(mlir_file_name, output_mlir_file_name, entry_func='forward', target='cuda_with_ait') + os.makedirs(RT_FOLDER, exist_ok=True) + mlir_file_name = f'{TEMP_FOLDER}/{model_name}_{category}.{compile_type}.mlir' + output_mlir_file_name = f'{TEMP_FOLDER}/{model_name}_{category}/{model_name}_{category}.rt.mlir' + + # load FxCache + key = FxGraphCache.get_hash_key(graph, inputs, graph_kwargs) + print(f"fx graph hash key: {key}") + cache_hit = FxGraphCache.try_load(key, RT_FOLDER) + + if not cache_hit: + fx_graph = preprocess_fx_graph(graph) + backend_legal_ops = [ + "aten._softmax", + "aten.softmax.int", + "aten.log_softmax.int", + "aten._log_softmax", + # "aten.native_layer_norm", + # "aten.layer_norm", + "aten.gelu", + "aten.argmax", + "aten.max.dim", + "aten.one_hot", + "aten.topk", + "byteir.flash_attn_fwd", + "byteir.flash_attn_bwd", + ] + with maybe_disable_fake_tensor_mode(): + compiled_graph = torch_frontend.compile(fx_graph, inputs, compile_type, backend_legal_ops=backend_legal_ops) + + with open(mlir_file_name, "w+") as fout: + compiled_graph.operation.print(file=fout, + large_elements_limit=None) + + with maybe_disable_fake_tensor_mode(): + byteir.compile(mlir_file_name, output_mlir_file_name, entry_func='forward', target='cuda_with_ait') + + # save to cache + FxGraphCache.save_to_cache(RT_FOLDER, key) + none_indices = get_none_indices(graph) outputs = FakeTensorProp(graph).propagate(*inputs) if isinstance(outputs, torch.Tensor): outputs = [outputs] mhlo_ret_dtypes = [t.dtype for t in outputs] mhlo_ret_shapes = [t.shape for t in outputs] - print(output_mlir_file_name) runner = ByteIRFunction(output_mlir_file_name, mhlo_ret_shapes, mhlo_ret_dtypes, none_indices, device=outputs[0].device) return runner(*inputs) return byteir_runner diff --git a/frontends/torch-frontend/examples/demo/context.py b/frontends/torch-frontend/examples/demo/context.py new file mode 100644 index 000000000..6c23dfdbf --- /dev/null +++ b/frontends/torch-frontend/examples/demo/context.py @@ -0,0 +1,350 @@ +# Copyright (c) 2016-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +import base64 +import copyreg +import dataclasses +import io +import hashlib +import os +import torch +import pickle +import logging + +from shutil import copytree, rmtree +from torch._prims_common import suggest_memory_format +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from typing import Any, Dict, List, Set, Optional, Tuple + +log = logging.getLogger(__name__) + +class ByteirContext: + def __init__(self): + HOME_DIR = os.getenv("HOME") + if HOME_DIR == None: + HOME_DIR = "/tmp/" + self.CACHE_HOME_DIR = os.path.join(HOME_DIR, ".byteir_cache/compile_cache/") + + def __enter__(self): + os.environ["ByteirCacheDir"] = self.CACHE_HOME_DIR + + def __exit__(self, exception_type, exception_value, traceback): + del os.environ["ByteirCacheDir"] + + +def sha256_hash(data: bytes) -> str: + # [:51] to strip off the "Q====" suffix common to every hash value. + return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() + +@dataclasses.dataclass +class TensorMetadata: + """ + The Tensor metadata relevant when hashing FxGraph cache keys. + """ + dtype: torch.dtype + shape: torch.Size + stride: Tuple[Any, ...] + device: torch.device + layout: torch.layout + memory_format: Optional[torch.memory_format] + storage_offset: int + requires_grad: bool + is_quantized: bool + is_conj: bool + is_neg: bool + is_coalesced: bool + dense_dim: int + sparse_dim: int + +def extract_tensor_metadata(t: torch.Tensor) -> TensorMetadata: + """ + Extract the TensorMetadata of a tensor. + """ + memory_format: Optional[torch.memory_format] = suggest_memory_format(t) + if not t.is_contiguous(memory_format=memory_format): + memory_format = None + + return TensorMetadata( + dtype=t.dtype, + shape=t.shape, + stride=t.stride() if t.layout == torch.strided else (), + device=t.device, + layout=t.layout, + memory_format=memory_format, + storage_offset=t.storage_offset(), + requires_grad=t.requires_grad, + is_quantized=t.is_quantized, + is_conj=t.is_conj(), + is_neg=t.is_neg(), + is_coalesced=t.is_coalesced() if t.is_sparse else False, + dense_dim=t.dense_dim() if t.is_sparse else False, + sparse_dim=t.sparse_dim() if t.is_sparse else False, + ) + +def _ident(x: Any) -> Any: + return x + +def _reduce_fake_tensor(t): + """ + See FxGraphCachePickler. Custom reducer to pickle FakeTensors. + """ + metadata = extract_tensor_metadata(t) + return (_ident, (metadata,)) + +def _reduce_tensor(t): + """ + See FxGraphCachePickler. Custom reducer to pickle Tensors. + """ + # If we see tensors, we know they're contstants stored as attributes on + # the GraphModule. See tensor lowering; small constants are inlined. If + # we see a small tensor, therefore, no reference will ultimately remain + # in the generated code. So we need to include its value in the cache key. + # Large constannts are effectively treated as inputs and we consider only + # their metadata. + metadata = extract_tensor_metadata(t) + return (_ident, (metadata,)) + +def _reduce_symint(s): + """ + See FxGraphCachePickler. Custom reducer to pickle SymInts. + """ + # For hashing purposes, we only care about the name of the symbol and + # not the backed value. We evaluate guards stored with a cached graph + # to ensure a cached entity with SymInt args is safe to reuse. + return (_ident, (str(s),)) + + +class FxGraphCachePickler(pickle.Pickler): + """ + Custom pickler to customize the pickling of some objects (Tensors), only for the + purpose of computing a hash for keying into the FxGraphCache. Tensors contain + objects that don't pickle and/or vary between runs, and we want to capture the + data that allow us to compute a stable, but safe hash. + """ + + dispatch_table = copyreg.dispatch_table.copy() + dispatch_table[torch._subclasses.fake_tensor.FakeTensor] = _reduce_fake_tensor + dispatch_table[torch.Tensor] = _reduce_tensor + dispatch_table[torch.SymInt] = _reduce_symint + + @staticmethod + def dumps(obj) -> bytes: + """ + Pickle an object using the FxGraphCachePickler. + """ + with io.BytesIO() as stream: + pickler = FxGraphCachePickler(stream) + pickler.dump(obj) + return stream.getvalue() + + @staticmethod + def get_hash(obj: Any) -> str: + """ + Serialize an object using the FxGraphCachePickler and return a hash + of the pickled object. + """ + serialized_data = FxGraphCachePickler.dumps(obj) + return sha256_hash(serialized_data) + +@dataclasses.dataclass +class OrderedSetHolder: + """ + See FxGraphHashDetails. Holds a sorted list to support stable hashing + of set kwargs. + """ + + items: List[Any] + +class FxGraphHashDetails: + """ + Object to capture all the details for a compiled FX graph relevant to computing + a safe and stable cache key. + """ + + # Excluded kwargs param that are not stable between runs + EXCLUDED_KWARGS = ["graph_id"] + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + ): + self.gm = gm + self.example_inputs = example_inputs + + # Order kwargs so hashing is stable to changes in kwarg order. + self.fx_kwargs = {} + for k in sorted(fx_kwargs): + if k not in self.EXCLUDED_KWARGS: + if type(fx_kwargs[k]) is set: + # Special case to handle set params. Python sets can't be + # ordered, so sort the elements and store them in a proxy. + self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k])) + else: + self.fx_kwargs[k] = fx_kwargs[k] + + def debug_str(self) -> str: + """ + Get a printable string describing in more detail all the attributes + comprising this object. Useful for debugging when one graph hashes + to a different value than another. + """ + + def get_str(obj) -> str: + if isinstance(obj, torch.Tensor): + return str(extract_tensor_metadata(obj)) + elif isinstance(obj, bytes): + return "" + else: + return str(obj) + + lines = [] + for attr, obj in vars(self).items(): + if isinstance(obj, list): + for ii in range(len(obj)): + h = FxGraphCachePickler.get_hash(obj[ii]) + lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}") + elif isinstance(obj, dict): + for k, v in obj.items(): + h = FxGraphCachePickler.get_hash(v) + lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}") + else: + h = FxGraphCachePickler.get_hash(obj) + lines.append(f"[{h}] {attr}: {get_str(obj)}") + return "\n".join(lines) + + +def compiled_fx_graph_hash( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], +) -> str: + """ + Generate a unique hash of the FX graph for caching. + """ + details = FxGraphHashDetails(gm, example_inputs, fx_kwargs) + # The prefix distinguishes among the other kinds of objects we + # cache in this module. + key = "f" + FxGraphCachePickler.get_hash(details) + log.debug("[byteir] FX graph cache hash details for key %s:\n%s", key, details.debug_str()) + return key + + +class FxGraphCache: + """ + Supports caching and reusing compiled Fx graphs. + + The overall strategy is as follows: + - This cache stores entries on disk. When saving an entry, we can't + serialize callables (that could be C++, Triton, etc.), so we serialize + their own disk cache location. We then recreate the compiled artifact + after fetching from disk. + - For indexing the cache, we gather the fields relevant to identifying an + FxGraph (the graph module, graph inputs, system settings etc.) into an + FxGraphCacheDetails object, pickle it, and compute a hash for the key. + See FxGraphCachePickler. + - Among the metadata we store, we also include a guards expression that's + appropriate for validating any symbols for Tensor arguments that have + symbolic bounds. On cache lookup then, we evaluate those guards in the + current context to validate that a cached entry can be served. + - A given graph could have multiple compiled versions, corresponding to + different sets of guards. Therefore, we store cache entries in the form: + // + - On lookup, we compute the key from the graph details, iterate over all + leaf files in the corresponding subdirectory, deserialize the entry, and + evaluate its guards expression. If the evaluation succeeds, we have a + cache hit. If it fails, we compile the graph and store a new entry. + - Finally, on a cache hit, we need to make sure any guards that would + have been created during compilation are added to the current context. + """ + + # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs + # in an in-memory cache after loading from disk. + @staticmethod + def _get_tmp_dir() -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(os.environ["ByteirCacheDir"]) + + @staticmethod + def _get_tmp_dir_for_key(key: str) -> str: + """ + Return the disk location for a given cache key. + """ + return os.path.join(FxGraphCache._get_tmp_dir(), key) + + @staticmethod + def _filter_symints(inputs: List[Any]) -> List[torch.SymInt]: + """ + Get the SymInt objects from the input list. + """ + return [s for s in inputs if isinstance(s, torch.SymInt)] + + @staticmethod + def _get_shape_env() -> ShapeEnv: + """ + Helper to get the shape env from the tracing context. + """ + return torch._guards.TracingContext.get().fake_mode.shape_env + + @staticmethod + def get_hash_key( + graph: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + ) -> str: + """ + Get hash key for a given graph + """ + return compiled_fx_graph_hash(graph, example_inputs, fx_kwargs) + + @staticmethod + def save_to_cache( + compiled_rt_folder: str, key: str + ): + """ + Move compiled temp runtime folder to cache + """ + cache_dir = FxGraphCache._get_tmp_dir_for_key(key) + copytree(compiled_rt_folder, cache_dir, dirs_exist_ok=True) + + @staticmethod + def try_load( + key: str, rt_folder: str + ): + """ + Load a compiled graph from the cache, return True on cache hit + """ + #from filelock import FileLock + cache_dir = FxGraphCache._get_tmp_dir_for_key(key) + #lock_dir = get_lock_dir() + #lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + #with lock: + if os.path.exists(cache_dir): + log.debug("[byteir] fx graph cache hit for key %s", key) + copytree(cache_dir, rt_folder, dirs_exist_ok=True) + return True + else: + log.debug("[byteir] fx graph cache miss for key %s", key) + return False + + @staticmethod + def clear(): + """ + Clear out the on-disk cache. + """ + rmtree(FxGraphCache._get_tmp_dir()) diff --git a/frontends/torch-frontend/examples/demo/main.py b/frontends/torch-frontend/examples/demo/main.py index a91afb050..2712da0fa 100644 --- a/frontends/torch-frontend/examples/demo/main.py +++ b/frontends/torch-frontend/examples/demo/main.py @@ -9,11 +9,14 @@ import transformers import argparse +from context import ByteirContext MODEL_LIST = ["gpt2", "bloom-560m", "llama", "llama-2", "opt-1.3b", "nanogpt", "chatglm"] AUTH_TOKEN="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" +torch.manual_seed(4) + class InferLLAMAModule(torch.nn.Module): def __init__(self): super().__init__() @@ -224,16 +227,17 @@ def train_model(args): backend.MODEL_NAME = model_name backend.FLASH = use_flash_attn - optimized_model = torch.compile(model, backend=fuse_aware_byteir_compile_fx, fullgraph=True) + with ByteirContext(): + optimized_model = torch.compile(model, backend=fuse_aware_byteir_compile_fx, fullgraph=True) - data = make_data(optimized_model, model_name, device) - model.zero_grad(set_to_none=True) - with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16): - loss = compute_loss(optimized_model, data, model_name) - torch_loss = compute_loss(model, data, model_name) - print("loss:", loss) - print("torch_loss:", torch_loss) - loss.backward() + data = make_data(optimized_model, model_name, device) + model.zero_grad(set_to_none=True) + with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16): + loss = compute_loss(optimized_model, data, model_name) + torch_loss = compute_loss(model, data, model_name) + print("loss:", loss) + print("torch_loss:", torch_loss) + loss.backward() if __name__ == "__main__": diff --git a/frontends/torch-frontend/third_party/patches/backend_contract.patch b/frontends/torch-frontend/third_party/patches/backend_contract.patch new file mode 100644 index 000000000..c54fa4704 --- /dev/null +++ b/frontends/torch-frontend/third_party/patches/backend_contract.patch @@ -0,0 +1,79 @@ +diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +index 38198a91..fd4a40df 100644 +--- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp ++++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +@@ -10,6 +10,7 @@ + #include "PassDetail.h" + + #include "mlir/IR/BuiltinOps.h" ++#include "mlir/IR/OpDefinition.h" + #include "mlir/Pass/PassManager.h" + #include "mlir/Transforms/DialectConversion.h" + #include "mlir/Transforms/Passes.h" +@@ -30,6 +31,50 @@ using namespace mlir::torch::Torch; + // Checking the backend contract. + //===----------------------------------------------------------------------===// + ++static void markDynamicShapeOpAsIllegal(ConversionTarget &target) { ++ auto isPrimConstantValue = [](Value v) -> bool { ++ Operation *op = v.getDefiningOp(); ++ return op->hasTrait() || ++ llvm::isa(op); ++ }; ++ auto isListOfConstantIntValue = [](Value v) -> bool { ++ SmallVector values; ++ if (!matchPattern(v, m_TorchListOfConstantInts(values))) { ++ return false; ++ } ++ return true; ++ }; ++ auto isOpResultLegal = [&](Operation *op) -> bool { ++ bool staticShapeConstraint = true; ++ for (auto operand : op->getOperands()) { ++ if (auto ty = operand.getType().dyn_cast()) { ++ if (!ty.areAllSizesKnown()) { ++ staticShapeConstraint = false; ++ } ++ } else if (!isPrimConstantValue(operand) && ++ !isListOfConstantIntValue(operand)) { ++ staticShapeConstraint = false; ++ } ++ } ++ if (staticShapeConstraint == false) { ++ return true; ++ } ++ ++ for (auto result : op->getResults()) { ++ if (auto ty = result.getType().dyn_cast()) { ++ if (ty.areAllSizesKnown()) ++ continue; ++ else ++ return false; ++ } ++ } ++ return true; ++ }; ++ target.addDynamicallyLegalOp(isOpResultLegal); ++ target.addDynamicallyLegalOp(isOpResultLegal); ++ target.addDynamicallyLegalOp(isOpResultLegal); ++} ++ + static void markDecomposedOpsAsIllegal(MLIRContext *context, + ConversionTarget &target, + llvm::StringSet<> backendLegalOps); +@@ -251,6 +296,7 @@ getBackendContractTarget(MLIRContext *context, bool decompose, + llvm::StringSet<> backendLegalOpsSet) { + ConversionTarget target(*context); + target.addLegalDialect(); ++ markDynamicShapeOpAsIllegal(target); + if (decompose) + markDecomposedOpsAsIllegal(context, target, backendLegalOpsSet); + return target; +@@ -386,6 +432,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); ++ target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); diff --git a/frontends/torch-frontend/third_party/patches/einsum.patch b/frontends/torch-frontend/third_party/patches/einsum.patch index f5ea6959d..a19b08852 100644 --- a/frontends/torch-frontend/third_party/patches/einsum.patch +++ b/frontends/torch-frontend/third_party/patches/einsum.patch @@ -469,15 +469,3 @@ index 1a61cf23..e0efd293 100644 addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); -diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp -index 4b823e51..3e70a02f 100644 ---- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp -+++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp -@@ -386,6 +386,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); -+ target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); diff --git a/frontends/torch-frontend/third_party/patches/tuple.patch b/frontends/torch-frontend/third_party/patches/tuple.patch new file mode 100644 index 000000000..3fb5ecb9f --- /dev/null +++ b/frontends/torch-frontend/third_party/patches/tuple.patch @@ -0,0 +1,16 @@ +diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +index 30cc4db4..96d04fd1 100644 +--- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp ++++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +@@ -194,9 +194,8 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( +- [](Torch::TupleType type, +- SmallVectorImpl &types) -> LogicalResult { +- llvm::append_range(types, type.getContainedTypes()); ++ [](Torch::TupleType type, SmallVectorImpl &types) -> LogicalResult { ++ // llvm::append_range(types, type.getContainedTypes()); + return success(); + }); + typeConverter.addConversion( diff --git a/frontends/torch-frontend/third_party/torch-mlir b/frontends/torch-frontend/third_party/torch-mlir index 4b9db995b..f7a92d346 160000 --- a/frontends/torch-frontend/third_party/torch-mlir +++ b/frontends/torch-frontend/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit 4b9db995b55c55f1548c3514f849b34341b8581f +Subproject commit f7a92d346ee8f498eb35d764d81cfd00bcdbf119 diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_ops.py b/frontends/torch-frontend/torch-frontend/python/test/test_ops.py index e26f1f33b..18fd1098e 100644 --- a/frontends/torch-frontend/torch-frontend/python/test/test_ops.py +++ b/frontends/torch-frontend/torch-frontend/python/test/test_ops.py @@ -133,3 +133,28 @@ def test_clamp_derefine(): inputs = [tu.randn(3, 4)] module = convert_to_mhlo_via_torch_mlir(ClampDerefineModule(), inputs) print(module.operation.get_asm()) + +# ============================================================================== +# tuple cases + +class Tuple1Module(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + return (x, ) + +def test_tuple_one_tensor(): + inputs = [tu.randn(3, 4)] + module = convert_to_mhlo_via_torch_mlir(Tuple1Module(), inputs) + print(module.operation.get_asm()) + +class Tuple2Module(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + return (x, x) + +def test_tuple_one_tensor(): + inputs = [tu.randn(3, 4)] + module = convert_to_mhlo_via_torch_mlir(Tuple2Module(), inputs) + print(module.operation.get_asm()) diff --git a/frontends/torch-frontend/torch-frontend/python/version.txt b/frontends/torch-frontend/torch-frontend/python/version.txt index 7f207341d..e4c0d46e5 100644 --- a/frontends/torch-frontend/torch-frontend/python/version.txt +++ b/frontends/torch-frontend/torch-frontend/python/version.txt @@ -1 +1 @@ -1.0.1 \ No newline at end of file +1.0.3 \ No newline at end of file diff --git a/runtime/VERSION_NUMBER b/runtime/VERSION_NUMBER index bc80560fa..bd8bf882d 100644 --- a/runtime/VERSION_NUMBER +++ b/runtime/VERSION_NUMBER @@ -1 +1 @@ -1.5.0 +1.7.0 diff --git a/runtime/cmake/CMakeLists.txt b/runtime/cmake/CMakeLists.txt index 91cb0b4a0..af719dd09 100644 --- a/runtime/cmake/CMakeLists.txt +++ b/runtime/cmake/CMakeLists.txt @@ -32,6 +32,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) # Options option(brt_BUILD_SHARED_LIB "Build using shared libraries" ON) option(brt_USE_CUDA "Build with CUDA support" OFF) +option(brt_USE_NCCL "Build with NCCL support" OFF) option(brt_BUILD_UNIT_TESTS "Build Runtime unit tests" ON) option(brt_CROSS_COMPILING "Cross compiling for another platform" OFF) @@ -40,10 +41,14 @@ set(LLVM_INSTALL_PATH "" CACHE STRING "The path to the installed LLVM library") set(FLASH_ATTN_INSTALL_PATH "" CACHE STRING "The path to the installed flash attn library") if(FLASH_ATTN_INSTALL_PATH AND NOT brt_USE_CUDA) - message(FATAL_ERROR "config FLASH_ATTN_INSTALL_PATH=... must with brt_USE_CUDA=ON") endif() +if(brt_USE_NCCL AND NOT brt_USE_CUDA) + message(FATAL_ERROR "brt_USE_NCCL=ON must with brt_USE_CUDA=ON") +endif() + + set(brt_ENABLE_FLASH_ATTENTION false) if(FLASH_ATTN_INSTALL_PATH) set(brt_ENABLE_FLASH_ATTENTION true) @@ -173,6 +178,8 @@ message("LIB_ROOT = ${LIB_ROOT}") set(CUTLASS_ROOT ${REPO_ROOT}/../external/cutlass) message("CUTLASS_ROOT = ${CUTLASS_ROOT}") +list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/Modules) + file (STRINGS "${REPO_ROOT}/VERSION_NUMBER" BRT_VERSION) # TODO: check if x86 @@ -249,6 +256,13 @@ if(brt_USE_CUDA) endif() endif() +if(brt_USE_NCCL) + message(STATUS "brt_USE_NCCL On") + find_package(NCCL REQUIRED) + include_directories("${NCCL_INCLUDE_DIRS}") + message("NCCL Include Dirs = ${NCCL_INCLUDE_DIRS}") +endif() + FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} BRT_BINARY_DIR) FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} BRT_SOURCE_DIR) @@ -272,6 +286,12 @@ if (brt_USE_CUDA) endif() endif() +if (brt_USE_NCCL) + list(APPEND BRT_PROVIDER_FLAGS -DBRT_USE_NCCL-1) + list(APPEND BRT_PROVIDER_CMAKE_FLAGS -Dbrt_USE_NCCL=1) + list(APPEND BRT_PROVIDER_NAMES nccl) +endif() + # utility functions function(brt_set_compile_flags target_name) if (MSVC) @@ -485,6 +505,12 @@ if(brt_USE_CUDA) list(APPEND BRT_TARGETS brt_provider_cuda) endif() +if(brt_USE_NCCL) + list(APPEND BRT_TARGETS brt_device_nccl) + list(APPEND BRT_TARGETS brt_provider_nccl) +endif() + + if(brt_ENABLE_PYTHON_BINDINGS) list(APPEND BRT_TARGETS brt_python_bindings) endif() @@ -553,6 +579,14 @@ if(brt_USE_CUDA) brt_device_cuda brt_provider_cuda) endif() +if(brt_USE_NCCL) + target_sources(brt.objs INTERFACE + $ + $) + target_link_libraries(brt.objs INTERFACE + brt_device_nccl + brt_provider_nccl) +endif() #The following files may use the 'brt_libs' and 'brt_EXTERNAL_LIBRARIES' vars if (brt_BUILD_SHARED_LIB OR brt_BUILD_APPLE_FRAMEWORK OR brt_ENABLE_PYTHON_BINDINGS) diff --git a/runtime/cmake/Modules/FindNCCL.cmake b/runtime/cmake/Modules/FindNCCL.cmake new file mode 100644 index 000000000..3af90b866 --- /dev/null +++ b/runtime/cmake/Modules/FindNCCL.cmake @@ -0,0 +1,66 @@ +## Copyright (c) Facebook, Inc. +## Licensed under BSD License +## =========================================================================== +## Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. + +# Try to find NCCL +# +# The following variables are optionally searched for defaults +# NCCL_ROOT_DIR: Base directory where all NCCL components are found +# NCCL_INCLUDE_DIR: Directory where NCCL header is found +# NCCL_LIB_DIR: Directory where NCCL library is found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIRS +# NCCL_LIBRARIES +# +# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks +# install NCCL in the same location as the CUDA toolkit. +# See https://github.com/caffe2/caffe2/issues/1601 + +set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR} CACHE PATH "Folder contains NVIDIA NCCL") + +find_path(NCCL_INCLUDE_DIR + NAMES nccl.h + HINTS + ${NCCL_INCLUDE_DIR} + ${NCCL_ROOT_DIR} + ${NCCL_ROOT_DIR}/include + ${CUDA_TOOLKIT_ROOT_DIR}/include) + +if ($ENV{USE_STATIC_NCCL}) + message(STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library") + set(NCCL_LIBNAME "libnccl_static.a") +else() + set(NCCL_LIBNAME "nccl") +endif() + +find_library(NCCL_LIBRARY + NAMES ${NCCL_LIBNAME} + HINTS + ${NCCL_LIB_DIR} + ${NCCL_ROOT_DIR} + ${NCCL_ROOT_DIR}/lib + ${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu + ${NCCL_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib64) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY) + +if (NCCL_FOUND) + set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIR}/nccl.h") + message(STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}") + file (STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED + REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$" LIMIT_COUNT 1) + if (NCCL_MAJOR_VERSION_DEFINED) + string (REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" "" + NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED}) + message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}") + endif() + set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) + set(NCCL_LIBRARIES ${NCCL_LIBRARY}) + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/runtime/cmake/brt_device_nccl.cmake b/runtime/cmake/brt_device_nccl.cmake new file mode 100644 index 000000000..956b1161a --- /dev/null +++ b/runtime/cmake/brt_device_nccl.cmake @@ -0,0 +1,24 @@ +file(GLOB_RECURSE brt_device_nccl_srcs CONFIGURE_DEPENDS + "${BRT_INCLUDE_DIR}/brt/backends/nccl/device/*.h" + "${LIB_ROOT}/backends/nccl/device/*.cc" +) + +source_group(TREE ${REPO_ROOT} FILES ${brt_device_nccl_srcs}) + +brt_add_object_library(brt_device_nccl ${brt_device_nccl_srcs}) +target_link_libraries(brt_device_nccl brt_device_cuda) +target_link_libraries(brt_device_nccl ${NCCL_LIBRARIES}) +brt_add_include_to_target(brt_device_nccl brt_framework brt_common) +set_target_properties(brt_device_nccl PROPERTIES FOLDER "Brt") + +# In order to find the shared provider libraries we need to add the origin to the rpath for all executables we build +# For the shared brt library, this is set in brt.cmake through CMAKE_SHARED_LINKER_FLAGS +# But our test files don't use the shared library so this must be set for them. +# For Win32 it generates an absolute path for shared providers based on the location of the executable/brt.dll +if (UNIX AND NOT APPLE) + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-rpath='$ORIGIN'") +endif() + +install( + DIRECTORY "${BRT_INCLUDE_DIR}/brt/backends/nccl/device" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/brt/backends/nccl") diff --git a/runtime/cmake/brt_framework.cmake b/runtime/cmake/brt_framework.cmake index a2fa61506..007ffc3e2 100644 --- a/runtime/cmake/brt_framework.cmake +++ b/runtime/cmake/brt_framework.cmake @@ -1,9 +1,12 @@ file(GLOB_RECURSE brt_framework_srcs CONFIGURE_DEPENDS "${BRT_INCLUDE_DIR}/brt/core/context/*.h" + "${BRT_INCLUDE_DIR}/brt/core/distributed/*.h" "${BRT_INCLUDE_DIR}/brt/core/framework/*.h" "${BRT_INCLUDE_DIR}/brt/core/session/*.h" "${LIB_ROOT}/core/context/*.h" "${LIB_ROOT}/core/context/*.cc" + "${LIB_ROOT}/core/distributed/*.h" + "${LIB_ROOT}/core/distributed/*.cc" "${LIB_ROOT}/core/framework/*.h" "${LIB_ROOT}/core/framework/*.cc" "${LIB_ROOT}/core/session/*.h" diff --git a/runtime/cmake/brt_provider_nccl.cmake b/runtime/cmake/brt_provider_nccl.cmake new file mode 100644 index 000000000..882d337fb --- /dev/null +++ b/runtime/cmake/brt_provider_nccl.cmake @@ -0,0 +1,27 @@ +set(brt_all_includes brt_common brt_framework brt_provider_cuda) + +file(GLOB_RECURSE brt_nccl_provider_srcs CONFIGURE_DEPENDS + "${BRT_INCLUDE_DIR}/brt/backends/nccl/providers/*.h" + "${LIB_ROOT}/backends/nccl/providers/*.h" + "${LIB_ROOT}/backends/nccl/providers/*.cc" +) + + +list(APPEND brt_all_providers_srcs ${brt_nccl_provider_srcs}) +list(APPEND brt_all_includes brt_device_nccl) + +source_group(TREE ${REPO_ROOT} FILES ${brt_nccl_provider_srcs}) + +brt_add_object_library(brt_provider_nccl ${brt_nccl_provider_srcs}) + +target_link_libraries(brt_provider_nccl brt_provider_cuda) +target_link_libraries(brt_provider_nccl brt_device_nccl) +target_link_libraries(brt_provider_nccl ${NCCL_LIBRARIES}) +brt_add_include_to_target(brt_provider_nccl ${brt_all_includes}) +set_target_properties(brt_provider_nccl PROPERTIES LINKER_LANGUAGE CXX) +set_target_properties(brt_provider_nccl PROPERTIES FOLDER "Brt") + + +install( + DIRECTORY "${BRT_INCLUDE_DIR}/brt/backends/nccl/providers" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/brt/backends/nccl") diff --git a/runtime/cmake/brt_unittests.cmake b/runtime/cmake/brt_unittests.cmake index 2456b1e96..252fd5d8b 100644 --- a/runtime/cmake/brt_unittests.cmake +++ b/runtime/cmake/brt_unittests.cmake @@ -16,6 +16,10 @@ if(brt_USE_CUDA) list(APPEND brt_test_common_src_patterns "${TEST_SRC_DIR}/include/brt/test/common/cuda/*.h") endif() +if(brt_USE_NCCL) + list(APPEND brt_test_common_src_patterns "${TEST_SRC_DIR}/include/brt/test/common/nccl/*.h") +endif() + file(GLOB brt_test_common_src CONFIGURE_DEPENDS ${brt_test_common_src_patterns} ) @@ -53,6 +57,16 @@ file(GLOB brt_test_session_src CONFIGURE_DEPENDS ${brt_test_session_src_patterns} ) +## test distributed +set(brt_test_distributed_src_patterns + "${TEST_SRC_DIR}/distributed/*.cc" + "${TEST_SRC_DIR}/distributed/*.h" +) + +file(GLOB brt_test_distributed_src CONFIGURE_DEPENDS + ${brt_test_distributed_src_patterns} +) + ## test providers set(brt_test_providers_src "") @@ -86,6 +100,20 @@ if(brt_USE_CUDA) list(APPEND brt_test_providers_src ${brt_test_cuda_provider_src}) endif() +### test nccl providers +if(brt_USE_NCCL) + set(brt_test_nccl_provider_src_patterns + "${TEST_SRC_DIR}/backends/nccl/providers/*.cc" + "${TEST_SRC_DIR}/backends/nccl/providers/*.h" + ) + + file(GLOB brt_test_nccl_provider_src CONFIGURE_DEPENDS + ${brt_test_nccl_provider_src_patterns} + ) + + list(APPEND brt_test_providers_src ${brt_test_nccl_provider_src}) +endif() + ## test devices set(brt_test_devices_src "") @@ -118,6 +146,20 @@ if(brt_USE_CUDA) list(APPEND brt_test_devices_src ${brt_test_cuda_device_src}) endif() +### test nccl device +if(brt_USE_NCCL) + set(brt_test_nccl_device_src_patterns + "${TEST_SRC_DIR}/backends/nccl/device/*.cc" + "${TEST_SRC_DIR}/backends/nccl/device/*.h" + ) + + file(GLOB brt_test_nccl_device_src CONFIGURE_DEPENDS + ${brt_test_nccl_device_src_patterns} + ) + + list(APPEND brt_test_devices_src ${brt_test_nccl_device_src}) +endif() + ## include all src's set(all_test ${brt_test_common_src} @@ -125,6 +167,7 @@ set(all_test ${brt_test_framework_src} ${brt_test_ir_src} ${brt_test_session_src} + ${brt_test_distributed_src} ${brt_test_providers_src} ${brt_unittest_main_src} ) diff --git a/runtime/include/brt/backends/cuda/device/common/util.h b/runtime/include/brt/backends/cuda/device/common/util.h index 3ea5955cb..0441bdb73 100644 --- a/runtime/include/brt/backends/cuda/device/common/util.h +++ b/runtime/include/brt/backends/cuda/device/common/util.h @@ -19,7 +19,6 @@ #include "brt/backends/cuda/device/common/cuda_call.h" #include "brt/backends/cuda/device/cuda_work_queue.h" -#include "brt/backends/cuda/providers/default/tensor_generate/rng_state_context.h" #include "brt/core/common/status.h" #include "brt/core/context/execution_context.h" #include "brt/core/context/execution_frame.h" @@ -33,7 +32,6 @@ #define BRT_CUBLAS_HANDLE_NAME "cublasHandle" #define BRT_CUDNN_HANDLE_NAME "cudnnHandle" #define BRT_CURAND_GENERATOR_NAME "curandGenerator" -#define BRT_RNG_STATE_HANDLE_NAME "rngStateHandle" namespace brt { namespace cuda { @@ -170,47 +168,5 @@ inline common::Status DeleteCurandGenerator(const brt::ExecutionContext &ctx) { return brt::common::Status::OK(); } -//===----------------------------------------------------------------------===// -// RNGStateHandle Util -// TODO : move to common utility. -//===----------------------------------------------------------------------===// - -inline rngStateHandle_t GetRNGStateHandle(const brt::ExecutionContext &ctx) { - brt::ExecutionFrame::StateInfo &state_info = ctx.frame_state_info; - size_t offset = state_info.GetStateOffset(BRT_RNG_STATE_HANDLE_NAME); - return static_cast(ctx.exec_frame->GetState(offset)); -} - -inline common::Status CreateRNGStateHandle(const brt::ExecutionContext &ctx) { - brt::ExecutionFrame::StateInfo &state_info = ctx.frame_state_info; - return state_info.CreateStateIfNotExist( - BRT_RNG_STATE_HANDLE_NAME, ctx.exec_frame, []() { - rngStateHandle_t handle = new rngStateContext(); - return handle; - }); -} - -inline rngStateHandle_t -GetOrCreateRNGStateHandle(const brt::ExecutionContext &ctx) { - brt::ExecutionFrame::StateInfo &state_info = ctx.frame_state_info; - if (!state_info.HasState(BRT_RNG_STATE_HANDLE_NAME)) { - BRT_ENFORCE(CreateRNGStateHandle(ctx) == common::Status::OK()); - } - return GetRNGStateHandle(ctx); -} - -inline common::Status DeleteRNGStateHandle(const brt::ExecutionContext &ctx) { - brt::ExecutionFrame::StateInfo &state_info = ctx.frame_state_info; - size_t offset = state_info.GetStateOffset(BRT_RNG_STATE_HANDLE_NAME); - void *ptr = ctx.exec_frame->GetAndResetState(offset); - if (ptr != nullptr) { - rngStateHandle_t handle = static_cast(ptr); - if (handle != nullptr) { - delete handle; - } - } - return brt::common::Status::OK(); -} - } // namespace cuda } // namespace brt diff --git a/runtime/include/brt/backends/device_api.h b/runtime/include/brt/backends/device_api.h index 8998fb3d4..6c76149ca 100644 --- a/runtime/include/brt/backends/device_api.h +++ b/runtime/include/brt/backends/device_api.h @@ -1,5 +1,4 @@ -//===- device_api.h -----------------------------------------------*--- C++ -//-*-===// +//===- device_api.h -------------------------------------------*--- C++ -*-===// // // Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +19,8 @@ #include +namespace brt { + enum class DeviceType { CPU = 1, CUDA }; struct Device { @@ -43,3 +44,5 @@ struct DeviceAPI { MemcpyD2DFunc MemcpyD2D; SetDeviceFunc SetDevice; }; + +} // namespace brt \ No newline at end of file diff --git a/runtime/include/brt/backends/nccl/device/d_context_nccl.h b/runtime/include/brt/backends/nccl/device/d_context_nccl.h new file mode 100644 index 000000000..9b778bae1 --- /dev/null +++ b/runtime/include/brt/backends/nccl/device/d_context_nccl.h @@ -0,0 +1,31 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#pragma once + +#include "brt/core/distributed/d_context.h" + +#include + +#include +#include +#include + +namespace brt { + +class CudaContext : public DContext { +public: + CudaContext(cudaStream_t stream) : m_stream{stream} {} + static std::shared_ptr make(cudaStream_t stream) { + return std::make_shared(stream); + } + std::string type() const override { return "BRT_CTX_CUDA"; } + cudaStream_t get_stream() { return m_stream; } + +private: + cudaStream_t m_stream; +}; + +} // namespace brt diff --git a/runtime/include/brt/backends/nccl/device/distributed_backend_nccl.h b/runtime/include/brt/backends/nccl/device/distributed_backend_nccl.h new file mode 100644 index 000000000..4a439b789 --- /dev/null +++ b/runtime/include/brt/backends/nccl/device/distributed_backend_nccl.h @@ -0,0 +1,72 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#pragma once + +#include "brt/core/common/status.h" +#include "brt/core/distributed/distributed_backend.h" +#include "brt/core/framework/dtype.h" + +namespace brt { + +class DistributedBackendNCCLPrivate; + +// Distributed Backend implemented by nccl +// collective communications are performed asynchronously +class DistributedBackendNCCL : public DistributedBackend { +public: + DistributedBackendNCCL(int nranks, int rank); + + ~DistributedBackendNCCL(); + + common::Status do_init() override; + common::Status do_init(BcastCallback cb) override; + + common::Status _send(const void *sendbuff, size_t size, uint32_t rank, + std::shared_ptr ctx) override; + + common::Status _recv(void *recvbuff, size_t size, uint32_t rank, + std::shared_ptr ctx) override; + + common::Status scatter(const void *sendbuff, void *recvbuff, size_t recvlen, + DTypeEnum dtype, uint32_t root, + std::shared_ptr ctx) override; + + common::Status gather(const void *sendbuff, void *recvbuff, size_t sendlen, + DTypeEnum dtype, uint32_t root, + std::shared_ptr ctx) override; + + common::Status all_to_all(const void *sendbuff, void *recvbuff, size_t len, + DTypeEnum dtype, + std::shared_ptr ctx) override; + + common::Status all_gather(const void *sendbuff, void *recvbuff, + size_t sendlen, DTypeEnum dtype, + std::shared_ptr ctx) override; + + common::Status all_reduce(const void *sendbuff, void *recvbuff, size_t len, + DTypeEnum dtype, ReduceOp op, + std::shared_ptr ctx) override; + + common::Status reduce_scatter(const void *sendbuff, void *recvbuff, + size_t recvlen, DTypeEnum dtype, ReduceOp op, + std::shared_ptr ctx) override; + + common::Status broadcast(const void *sendbuff, void *recvbuff, size_t len, + DTypeEnum dtype, uint32_t root, + std::shared_ptr ctx) override; + + common::Status reduce(const void *sendbuff, void *recvbuff, size_t len, + DTypeEnum dtype, ReduceOp op, uint32_t root, + std::shared_ptr ctx) override; + + common::Status group_start() override; + common::Status group_end() override; + +private: + std::unique_ptr m_nccl; +}; + +} // namespace brt diff --git a/runtime/include/brt/backends/nccl/device/utils.h b/runtime/include/brt/backends/nccl/device/utils.h new file mode 100644 index 000000000..7f47dd85f --- /dev/null +++ b/runtime/include/brt/backends/nccl/device/utils.h @@ -0,0 +1,32 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#pragma once + +#include "brt/core/common/common.h" +#include "brt/core/common/enums.h" +#include "brt/core/framework/dtype.h" + +#include + +#include "nccl.h" + +namespace brt { + +#define NCCL_ASSERT(expr) \ + do { \ + ncclResult_t result = (expr); \ + if (result != ncclSuccess) { \ + BRT_LOGS_DEFAULT(ERROR) \ + << "nccl error [" << result << "]: " << ncclGetErrorString(result); \ + BRT_THROW("nccl error"); \ + } \ + } while (0); + +ncclDataType_t get_nccl_dtype(const DTypeEnum dtype); + +ncclRedOp_t get_nccl_reduce_op(const ReduceOp red_op); + +} // namespace brt diff --git a/runtime/include/brt/backends/nccl/providers/nccl_provider.h b/runtime/include/brt/backends/nccl/providers/nccl_provider.h new file mode 100644 index 000000000..6e9dda5c0 --- /dev/null +++ b/runtime/include/brt/backends/nccl/providers/nccl_provider.h @@ -0,0 +1,46 @@ +//===- nccl_provider.h ----------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "brt/backends/common.h" +#include "brt/backends/nccl/device/distributed_backend_nccl.h" +#include "brt/core/common/status.h" +#include "brt/core/distributed/distributed_session.h" +#include "brt/core/framework/execution_provider.h" + +#include + +namespace brt { + +class NCCLExecutionProvider : public ExecutionProvider { +public: + explicit NCCLExecutionProvider(const std::string &name, int nranks, int rank, + const std::string &ip, int port); + + DistributedBackendNCCL *GetDistributedBackend() { + return nccl_backend_.get(); + } + +protected: + std::unique_ptr nccl_backend_; +}; + +common::Status DefaultNCCLExecutionProviderFactory(DistributedSession *session, + int local_rank); + +} // namespace brt \ No newline at end of file diff --git a/runtime/include/brt/backends/nccl/providers/op_registration.h b/runtime/include/brt/backends/nccl/providers/op_registration.h new file mode 100644 index 000000000..11a7b7c9f --- /dev/null +++ b/runtime/include/brt/backends/nccl/providers/op_registration.h @@ -0,0 +1,27 @@ +//===- op_registration.h --------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#pragma once + +namespace brt { +class KernelRegistry; +namespace cuda { + +void RegisterNCCLOps(KernelRegistry *registry); + +} // namespace cuda +} // namespace brt diff --git a/runtime/include/brt/backends/rng_state_context.h b/runtime/include/brt/backends/rng_state_context.h new file mode 100644 index 000000000..69f1972fb --- /dev/null +++ b/runtime/include/brt/backends/rng_state_context.h @@ -0,0 +1,86 @@ +//===- rng_state_context.h ------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "brt/core/common/status.h" +#include "brt/core/context/execution_context.h" +#include "brt/core/context/execution_frame.h" + +#define BRT_RNG_STATE_HANDLE_NAME "rngStateHandle" + +namespace brt { + +class RNGStateContext { +private: + int64_t seed; + int64_t offset; + +public: + explicit RNGStateContext() : seed(0), offset(0) {} + + int64_t getSeed() { return seed; } + + int64_t nextOffset() { return offset++; } + + void setSeed(int64_t seed) { this->seed = seed; } +}; + +using rngStateHandle_t = RNGStateContext *; + +//===----------------------------------------------------------------------===// +// RNGStateHandle Util +//===----------------------------------------------------------------------===// + +inline rngStateHandle_t GetRNGStateHandle(const brt::ExecutionContext &ctx) { + brt::ExecutionFrame::StateInfo &state_info = ctx.frame_state_info; + size_t offset = state_info.GetStateOffset(BRT_RNG_STATE_HANDLE_NAME); + return static_cast(ctx.exec_frame->GetState(offset)); +} + +inline common::Status CreateRNGStateHandle(const brt::ExecutionContext &ctx) { + brt::ExecutionFrame::StateInfo &state_info = ctx.frame_state_info; + return state_info.CreateStateIfNotExist( + BRT_RNG_STATE_HANDLE_NAME, ctx.exec_frame, []() { + rngStateHandle_t handle = new RNGStateContext(); + return handle; + }); +} + +inline rngStateHandle_t +GetOrCreateRNGStateHandle(const brt::ExecutionContext &ctx) { + brt::ExecutionFrame::StateInfo &state_info = ctx.frame_state_info; + if (!state_info.HasState(BRT_RNG_STATE_HANDLE_NAME)) { + BRT_ENFORCE(CreateRNGStateHandle(ctx) == common::Status::OK()); + } + return GetRNGStateHandle(ctx); +} + +inline common::Status DeleteRNGStateHandle(const brt::ExecutionContext &ctx) { + brt::ExecutionFrame::StateInfo &state_info = ctx.frame_state_info; + size_t offset = state_info.GetStateOffset(BRT_RNG_STATE_HANDLE_NAME); + void *ptr = ctx.exec_frame->GetAndResetState(offset); + if (ptr != nullptr) { + rngStateHandle_t handle = static_cast(ptr); + if (handle != nullptr) { + delete handle; + } + } + return brt::common::Status::OK(); +} + +} // namespace brt \ No newline at end of file diff --git a/runtime/include/brt/core/common/enums.h b/runtime/include/brt/core/common/enums.h new file mode 100644 index 000000000..adba5feba --- /dev/null +++ b/runtime/include/brt/core/common/enums.h @@ -0,0 +1,31 @@ +//===- enums.h ------------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace brt { + +typedef enum { + BRT_SUM = 0, + BRT_MAX = 1, + BRT_MIN = 2, + BRT_REDUCEOP_COUNT = 3, +} ReduceOp; + +} // namespace brt \ No newline at end of file diff --git a/runtime/include/brt/core/context/execution_context.h b/runtime/include/brt/core/context/execution_context.h index b6b5a2f3d..d16226a78 100644 --- a/runtime/include/brt/core/context/execution_context.h +++ b/runtime/include/brt/core/context/execution_context.h @@ -18,6 +18,7 @@ #pragma once #include "brt/core/context/execution_frame.h" +#include "brt/core/distributed/distributed_backend.h" namespace brt { @@ -38,12 +39,14 @@ struct ExecutionContext { WorkQueue *work_queue; ExecutionFrame::StateInfo &frame_state_info; EventListenerManager *event_listener_manager; + DistributedBackend *distributed_backend; ExecutionContext(ExecutionFrame *frame, WorkQueue *wq, ExecutionFrame::StateInfo &fs_info, - EventListenerManager *event_mgr) + EventListenerManager *event_mgr, + DistributedBackend *backend = nullptr) : exec_frame(frame), work_queue(wq), frame_state_info(fs_info), - event_listener_manager(event_mgr) {} + event_listener_manager(event_mgr), distributed_backend(backend) {} }; } // namespace brt diff --git a/runtime/include/brt/core/distributed/d_context.h b/runtime/include/brt/core/distributed/d_context.h new file mode 100644 index 000000000..6698ea843 --- /dev/null +++ b/runtime/include/brt/core/distributed/d_context.h @@ -0,0 +1,21 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#pragma once + +#include + +namespace brt { + +// DContext is an abstraction of communication contexts (e.g. cuda stream) +// on different platforms, a context should be passed as a parameter when +// a communicator operation is called +class DContext { +public: + virtual std::string type() const = 0; + virtual ~DContext() = default; +}; + +} // namespace brt \ No newline at end of file diff --git a/runtime/include/brt/core/distributed/distributed_backend.h b/runtime/include/brt/core/distributed/distributed_backend.h new file mode 100644 index 000000000..d219e12b1 --- /dev/null +++ b/runtime/include/brt/core/distributed/distributed_backend.h @@ -0,0 +1,123 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#pragma once + +#include +#include +#include +#include + +#include "brt/core/common/enums.h" +#include "brt/core/distributed/d_context.h" +#include "brt/core/distributed/rendezvous_socket.h" +#include "brt/core/framework/dtype.h" + +namespace brt { + +using BcastCallback = std::function; + +class DistributedBackend { +public: + DistributedBackend(uint32_t nranks, uint32_t rank) + : m_nranks(nranks), m_rank(rank) {} + + // get the number of all ranks + uint32_t nranks() { return m_nranks; } + + // get the rank of this process + uint32_t rank() { return m_rank; } + + // establish connection with server + common::Status init(const char *master_ip, int port); + common::Status init(BcastCallback cb); + + // send data to another communicator in the group + // implemented in the subclass _send() + common::Status send(const void *sendbuff, size_t len, DTypeEnum dtype, + uint32_t rank, std::shared_ptr ctx); + + // receive data from another communicator in the group + // implemented in the subclass _recv() + common::Status recv(void *recvbuf, size_t len, DTypeEnum dtype, uint32_t rank, + std::shared_ptr ctx); + + // implemented in the subclass and called in init() + virtual common::Status do_init() = 0; + virtual common::Status do_init(BcastCallback cb) = 0; + + // TODO: enhance api to support arbitary device groups + // the length of sendbuff = recvlen * m_nranks + // the length of recvbuff = recvlen + virtual common::Status scatter(const void *sendbuff, void *recvbuff, + size_t recvlen, DTypeEnum dtype, uint32_t root, + std::shared_ptr ctx) = 0; + + // TODO: enhance api to support arbitary device groups + // the length of sendbuff = sendlen + // the length of recvbuff = sendlen * m_nranks + virtual common::Status gather(const void *sendbuff, void *recvbuff, + size_t sendlen, DTypeEnum dtype, uint32_t root, + std::shared_ptr ctx) = 0; + + // TODO: enhance api to support arbitary device groups + // the length of sendbuff = the length of recvbuff = len * m_nranks + virtual common::Status all_to_all(const void *sendbuff, void *recvbuff, + size_t len, DTypeEnum dtype, + std::shared_ptr ctx) = 0; + + // TODO: enhance api to support arbitary device groups + // the length of sendbuff = sendlen + // the length of recvbuff = sendlen * m_nranks + virtual common::Status all_gather(const void *sendbuff, void *recvbuff, + size_t sendlen, DTypeEnum dtype, + std::shared_ptr ctx) = 0; + // TODO: enhance api to support arbitary device groups + // the length of sendbuff = the length of recvbuff = len + virtual common::Status all_reduce(const void *sendbuff, void *recvbuff, + size_t len, DTypeEnum dtype, ReduceOp op, + std::shared_ptr ctx) = 0; + + // TODO: enhance api to support arbitary device groups + // the length of sendbuff = recvlen * m_nranks + // the length of recvbuff = recvlen + virtual common::Status reduce_scatter(const void *sendbuff, void *recvbuff, + size_t recvlen, DTypeEnum dtype, + ReduceOp op, + std::shared_ptr ctx) = 0; + + // TODO: enhance api to support arbitary device groups + // the length of sendbuff = the length of recvbuff = len + virtual common::Status broadcast(const void *sendbuff, void *recvbuff, + size_t len, DTypeEnum dtype, uint32_t root, + std::shared_ptr ctx) = 0; + + // TODO: enhance api to support arbitary device groups + // the length of sendbuff = the length of recvbuff = len + virtual common::Status reduce(const void *sendbuff, void *recvbuff, + size_t len, DTypeEnum dtype, ReduceOp op, + uint32_t root, + std::shared_ptr ctx) = 0; + + // mark the begin of a series of (send recv) + virtual common::Status group_start() = 0; + // mark the end of a series of (send recv) + virtual common::Status group_end() = 0; + +protected: + uint32_t m_nranks; + uint32_t m_rank; + std::shared_ptr m_client; + + // send data to another communicator in the group + virtual common::Status _send(const void *sendbuff, size_t size, uint32_t rank, + std::shared_ptr ctx) = 0; + + // receive data from another communicator in the group + virtual common::Status _recv(void *recvbuf, size_t size, uint32_t rank, + std::shared_ptr ctx) = 0; +}; + +} // namespace brt diff --git a/runtime/include/brt/core/distributed/distributed_session.h b/runtime/include/brt/core/distributed/distributed_session.h new file mode 100644 index 000000000..cf9ff4b54 --- /dev/null +++ b/runtime/include/brt/core/distributed/distributed_session.h @@ -0,0 +1,80 @@ +//===- distributed_session.h ----------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "brt/backends/device_api.h" +#include "brt/core/common/status.h" +#include "brt/core/distributed/distributed_backend.h" +#include "brt/core/framework/dtype.h" +#include "brt/core/session/session.h" +#include +#include +#include +#include +#include +#include + +namespace brt { + +// forward decl +class ExecutionPlan; +class ExecutionProvider; +class IAllocator; +class OpKernelInfo; +class RequestContext; +class WorkQueue; + +namespace ir { +class IRHandle; +} + +class DistributedSession : public Session { +public: + DistributedSession(int rank, int nranks, const std::string &host, int port); + + virtual ~DistributedSession(); + + common::Status LoadConfig(const std::vector &config, + std::string &ir_url); + + void SetDistributedBackend(DistributedBackend *backend) { + distributed_backend_ = backend; + } + + common::Status Run(RequestContext &request); + + common::Status NewRequestContext(std::unique_ptr *request, + WorkQueue *work_queue = nullptr); + + int GetRank() const { return rank_; } + + int GetNRanks() const { return nranks_; } + + const std::string &GetHost() const { return host_; } + + int GetPort() const { return port_; } + +protected: + DistributedBackend *distributed_backend_; + int rank_; + int nranks_; + std::string host_; + int port_; +}; + +} // namespace brt diff --git a/runtime/include/brt/core/distributed/rendezvous_socket.h b/runtime/include/brt/core/distributed/rendezvous_socket.h new file mode 100644 index 000000000..632acd6d1 --- /dev/null +++ b/runtime/include/brt/core/distributed/rendezvous_socket.h @@ -0,0 +1,46 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#pragma once + +#include "brt/core/common/status.h" + +#include + +namespace brt { + +int GetFreePort(); + +common::Status CreateServer(uint32_t nranks, int port); + +class RendezvousSocket { +public: + RendezvousSocket(unsigned nranks, unsigned rank); + + ~RendezvousSocket(); + + common::Status connect(const char *master_ip, int port); + + // block until all ranks reach this barrier + common::Status barrier(); + + // the length of send_buff = the length of recv_buff = len + common::Status broadcast(const void *send_buff, void *recv_buff, + size_t send_len, unsigned root); + + // the length of send_buff = sendlen + // the length of recv_buff = send_len * m_nranks + common::Status allgather(const void *send_buff, void *recv_buff, + size_t send_len); + +private: + uint32_t nranks_; + uint32_t rank_; + bool connected_ = false; + int conn_; + std::mutex mutex_; +}; + +} // namespace brt \ No newline at end of file diff --git a/runtime/include/brt/core/framework/execution_plan.h b/runtime/include/brt/core/framework/execution_plan.h index b31ed7a93..886b42fa3 100644 --- a/runtime/include/brt/core/framework/execution_plan.h +++ b/runtime/include/brt/core/framework/execution_plan.h @@ -57,7 +57,8 @@ class ExecutionPlan { virtual common::Status EpiloguePerSession() = 0; - virtual void CreateWorkQueue(std::unique_ptr *wq) = 0; + virtual void CreateWorkQueue(std::unique_ptr *wq, + int rank = 0) = 0; virtual void CreateExecutinFrame(std::unique_ptr *frame) = 0; @@ -116,7 +117,7 @@ class StaticBRTExecutionPlan final : public ExecutionPlan { common::Status EpiloguePerSession() override; - void CreateWorkQueue(std::unique_ptr *wq) override; + void CreateWorkQueue(std::unique_ptr *wq, int rank = 0) override; void CreateExecutinFrame(std::unique_ptr *frame) override; diff --git a/runtime/include/brt/core/session/request_context.h b/runtime/include/brt/core/session/request_context.h index 5da7865cc..5ed19b045 100644 --- a/runtime/include/brt/core/session/request_context.h +++ b/runtime/include/brt/core/session/request_context.h @@ -18,6 +18,7 @@ #pragma once #include "brt/core/common/status.h" +#include "brt/core/distributed/distributed_session.h" #include "brt/core/framework/event.h" #include "brt/core/session/session.h" #include @@ -65,6 +66,7 @@ class RequestContext { private: friend Session; + friend DistributedSession; /** * Private RequestContext constructor diff --git a/runtime/lib/backends/cpu/providers/default/cpu_provider.cc b/runtime/lib/backends/cpu/providers/default/cpu_provider.cc index 8a3733e66..edf7a4ff9 100644 --- a/runtime/lib/backends/cpu/providers/default/cpu_provider.cc +++ b/runtime/lib/backends/cpu/providers/default/cpu_provider.cc @@ -19,15 +19,16 @@ #include "./custom_call/tf_equal.h" #include "./custom_call/tf_select.h" -#include "./custom_call/tf_stringToNumber.h" +#include "./custom_call/tf_string_to_number.h" #include "./custom_call/tf_where.h" #include "./custom_call/topk.h" #include "./llvm/jit.h" +#include "./math/elementwise_ops.h" #include "./shape/shape_compute.h" #include "./tensor_generate/fill.h" +#include "./tensor_generate/rng_state.h" #include "./typecvt/typecvt.h" #include "brt/backends/common.h" -#include "brt/backends/cpu/providers/default/math/elementwise_ops.h" // TODO move to another header #include "brt/core/framework/execution_provider.h" #include "brt/core/session/session.h" #include "half/half.hpp" @@ -111,6 +112,16 @@ BRT_STATIC_KERNEL_REGISTRATION( [](const brt::OpKernelInfo &info) -> std::shared_ptr { return std::make_shared(info); }); + registry->Register( + "GetSeed", + [](const brt::OpKernelInfo &info) -> std::shared_ptr { + return std::make_shared(info); + }); + registry->Register( + "NextOffset", + [](const brt::OpKernelInfo &info) -> std::shared_ptr { + return std::make_shared(info); + }); RegisterCommonBuiltinOps(registry); }); diff --git a/runtime/lib/backends/cpu/providers/default/custom_call/tf_stringToNumber.cc b/runtime/lib/backends/cpu/providers/default/custom_call/tf_string_to_number.cc similarity index 99% rename from runtime/lib/backends/cpu/providers/default/custom_call/tf_stringToNumber.cc rename to runtime/lib/backends/cpu/providers/default/custom_call/tf_string_to_number.cc index c3d9865f2..9adc7e9ee 100644 --- a/runtime/lib/backends/cpu/providers/default/custom_call/tf_stringToNumber.cc +++ b/runtime/lib/backends/cpu/providers/default/custom_call/tf_string_to_number.cc @@ -15,7 +15,7 @@ // //===----------------------------------------------------------------------===// -#include "./tf_stringToNumber.h" +#include "./tf_string_to_number.h" #include "brt/core/framework/op_accessor.h" #include #include diff --git a/runtime/lib/backends/cpu/providers/default/custom_call/tf_stringToNumber.h b/runtime/lib/backends/cpu/providers/default/custom_call/tf_string_to_number.h similarity index 100% rename from runtime/lib/backends/cpu/providers/default/custom_call/tf_stringToNumber.h rename to runtime/lib/backends/cpu/providers/default/custom_call/tf_string_to_number.h diff --git a/runtime/lib/backends/cpu/providers/default/math/elementwise_ops.cc b/runtime/lib/backends/cpu/providers/default/math/elementwise_ops.cc index aa7b56eaf..39a3d726c 100644 --- a/runtime/lib/backends/cpu/providers/default/math/elementwise_ops.cc +++ b/runtime/lib/backends/cpu/providers/default/math/elementwise_ops.cc @@ -15,7 +15,7 @@ // //===----------------------------------------------------------------------===// -#include "brt/backends/cpu/providers/default/math/elementwise_ops.h" +#include "./elementwise_ops.h" #include "brt/core/common/common.h" #include "brt/core/context/execution_context.h" diff --git a/runtime/include/brt/backends/cpu/providers/default/math/elementwise_ops.h b/runtime/lib/backends/cpu/providers/default/math/elementwise_ops.h similarity index 100% rename from runtime/include/brt/backends/cpu/providers/default/math/elementwise_ops.h rename to runtime/lib/backends/cpu/providers/default/math/elementwise_ops.h diff --git a/runtime/lib/backends/cpu/providers/default/tensor_generate/rng_state.cc b/runtime/lib/backends/cpu/providers/default/tensor_generate/rng_state.cc new file mode 100644 index 000000000..fca4f4b54 --- /dev/null +++ b/runtime/lib/backends/cpu/providers/default/tensor_generate/rng_state.cc @@ -0,0 +1,98 @@ +//===- rng_state.cc -------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "./rng_state.h" + +#include "brt/backends/rng_state_context.h" +#include "brt/core/context/work_queue.h" +#include "brt/core/framework/dtype.h" +#include "brt/core/framework/op_accessor.h" + +namespace brt { +namespace cpu { + +//===----------------------------------------------------------------------===// +// GetSeed Op Kernel +//===----------------------------------------------------------------------===// + +GetSeedOpKernel::GetSeedOpKernel(const OpKernelInfo &info) + : OpKernel(info, false, false, true, true) {} + +common::Status GetSeedOpKernel::RunImpl(const ExecutionContext &ctx) { + rngStateHandle_t rngStateHandle = GetOrCreateRNGStateHandle(ctx); + int64_t rngSeed = rngStateHandle->getSeed(); + OpAccessor accessor(info_, ctx.exec_frame); + DTypeEnum dtype = accessor.GetArgDTypeEnum(0); + void *device_p = accessor.GetArgAsyncValueRef(0); +#define CASE(D) \ + case DTypeEnum::D: { \ + using ctype = DTypeTraits::type_t; \ + DispatchHostTask(ctx.work_queue, { \ + *reinterpret_cast(device_p) = static_cast(rngSeed); \ + }); \ + return common::Status::OK(); \ + } + BRT_DISPATCH_NUMBER_TYPES(dtype, CASE) +#undef CASE +} + +common::Status GetSeedOpKernel::ProloguePerFrame(const ExecutionContext &) { + return common::Status::OK(); +} + +common::Status GetSeedOpKernel::EpiloguePerFrame(const ExecutionContext &ctx) { + DeleteRNGStateHandle(ctx); + return common::Status::OK(); +} + +//===----------------------------------------------------------------------===// +// NextOffset Op Kernel +//===----------------------------------------------------------------------===// + +NextOffsetOpKernel::NextOffsetOpKernel(const OpKernelInfo &info) + : OpKernel(info, false, false, true, true) {} + +common::Status NextOffsetOpKernel::RunImpl(const ExecutionContext &ctx) { + rngStateHandle_t rngStateHandle = GetOrCreateRNGStateHandle(ctx); + int64_t rngOffset = rngStateHandle->nextOffset(); + OpAccessor accessor(info_, ctx.exec_frame); + DTypeEnum dtype = accessor.GetArgDTypeEnum(0); + void *device_p = accessor.GetArgAsyncValueRef(0); +#define CASE(D) \ + case DTypeEnum::D: { \ + using ctype = DTypeTraits::type_t; \ + DispatchHostTask(ctx.work_queue, { \ + *reinterpret_cast(device_p) = static_cast(rngOffset); \ + }); \ + return common::Status::OK(); \ + } + BRT_DISPATCH_NUMBER_TYPES(dtype, CASE) +#undef CASE +} + +common::Status NextOffsetOpKernel::ProloguePerFrame(const ExecutionContext &) { + return common::Status::OK(); +} + +common::Status +NextOffsetOpKernel::EpiloguePerFrame(const ExecutionContext &ctx) { + DeleteRNGStateHandle(ctx); + return common::Status::OK(); +} + +} // namespace cpu +} // namespace brt \ No newline at end of file diff --git a/runtime/lib/backends/cpu/providers/default/tensor_generate/rng_state.h b/runtime/lib/backends/cpu/providers/default/tensor_generate/rng_state.h new file mode 100644 index 000000000..4ea628de3 --- /dev/null +++ b/runtime/lib/backends/cpu/providers/default/tensor_generate/rng_state.h @@ -0,0 +1,42 @@ +//===- rng_state.h --------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "brt/core/framework/dtype.h" +#include "brt/core/framework/op_kernel.h" + +namespace brt { +namespace cpu { +class GetSeedOpKernel final : public OpKernel { +public: + explicit GetSeedOpKernel(const OpKernelInfo &info); + common::Status RunImpl(const ExecutionContext &) override; + common::Status ProloguePerFrame(const ExecutionContext &) override; + common::Status EpiloguePerFrame(const ExecutionContext &) override; +}; + +class NextOffsetOpKernel final : public OpKernel { +public: + explicit NextOffsetOpKernel(const OpKernelInfo &info); + common::Status RunImpl(const ExecutionContext &) override; + common::Status ProloguePerFrame(const ExecutionContext &) override; + common::Status EpiloguePerFrame(const ExecutionContext &) override; +}; + +} // namespace cpu +} // namespace brt \ No newline at end of file diff --git a/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.cc b/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.cc index 408c680e0..3d2c15bbd 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.cc +++ b/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.cc @@ -1,5 +1,4 @@ -//===- rng_state.cc -------------------------------------------------*--- C++ -//-*-===// +//===- rng_state.cc -------------------------------------------*--- C++ -*-===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,12 +16,14 @@ //===----------------------------------------------------------------------===// #include "./rng_state.h" -#include "brt/backends/cuda/device/common/util.h" + #include "brt/backends/cuda/device/cuda_allocator.h" #include "brt/backends/cuda/device/cuda_work_queue.h" +#include "brt/backends/rng_state_context.h" #include "brt/core/common/common.h" #include "brt/core/framework/op_accessor.h" #include +#include namespace brt { namespace cuda { diff --git a/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.h b/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.h index fe4c14a50..d77ff0cc5 100644 --- a/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.h +++ b/runtime/lib/backends/cuda/providers/default/tensor_generate/rng_state.h @@ -1,5 +1,4 @@ -//===- rng_state.h -------------------------------------------------*--- C++ -//-*-===// +//===- rng_state.h --------------------------------------------*--- C++ -*-===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +17,6 @@ #pragma once -#include "brt/backends/cuda/device/common/util.h" #include "brt/core/framework/dtype.h" #include "brt/core/framework/op_kernel.h" diff --git a/runtime/lib/backends/nccl/device/distributed_backend_nccl.cc b/runtime/lib/backends/nccl/device/distributed_backend_nccl.cc new file mode 100644 index 000000000..7f14723ff --- /dev/null +++ b/runtime/lib/backends/nccl/device/distributed_backend_nccl.cc @@ -0,0 +1,232 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#include "brt/backends/nccl/device/distributed_backend_nccl.h" +#include "brt/backends/nccl/device/d_context_nccl.h" +#include "brt/backends/nccl/device/utils.h" +#include "brt/core/common/logging/logging.h" + +#include "nccl.h" +#include + +#define CHECK_LAUNCH_MODE \ + do { \ + const char *str = getenv("NCCL_LAUNCH_MODE"); \ + if (!str or strcmp(str, "PARALLEL") != 0) { \ + BRT_LOGS_DEFAULT(ERROR) \ + << "please set NCCL_LAUNCH_MODE to \"PARALLEL\"\n"; \ + BRT_THROW("nccl error"); \ + } \ + } while (0) + +namespace brt { + +class DistributedBackendNCCLPrivate { +public: + ncclComm_t m_comm; + ~DistributedBackendNCCLPrivate() { ncclCommDestroy(m_comm); } +}; + +DistributedBackendNCCL::DistributedBackendNCCL(int nranks, int rank) + : DistributedBackend(nranks, rank) {} + +DistributedBackendNCCL::~DistributedBackendNCCL() {} + +Status DistributedBackendNCCL::do_init() { + uint32_t root = 0; + ncclUniqueId uid; + if (m_rank == root) { + ncclGetUniqueId(&uid); + } + auto status = m_client->broadcast(&uid, &uid, NCCL_UNIQUE_ID_BYTES, root); + if (status != Status::OK()) + return status; + m_nccl = std::make_unique(); + NCCL_ASSERT(ncclCommInitRank(&m_nccl->m_comm, m_nranks, uid, m_rank)); + return Status::OK(); +} + +Status DistributedBackendNCCL::do_init(BcastCallback cb) { + uint32_t root = 0; + ncclUniqueId uid; + if (m_rank == root) { + ncclGetUniqueId(&uid); + } + cb(uid.internal, NCCL_UNIQUE_ID_BYTES); + m_nccl = std::make_unique(); + NCCL_ASSERT(ncclCommInitRank(&m_nccl->m_comm, m_nranks, uid, m_rank)); + return Status::OK(); +} + +Status DistributedBackendNCCL::_send(const void *sendbuff, size_t size, + uint32_t rank, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + // perform nccl send synchronously + NCCL_ASSERT(ncclSend(sendbuff, size, ncclChar, rank, m_nccl->m_comm, stream)); + return Status::OK(); +} + +Status DistributedBackendNCCL::_recv(void *recvbuff, size_t size, uint32_t rank, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + // perform nccl send synchronously + NCCL_ASSERT(ncclRecv(recvbuff, size, ncclChar, rank, m_nccl->m_comm, stream)); + return Status::OK(); +} + +Status DistributedBackendNCCL::scatter(const void *sendbuff, void *recvbuff, + size_t recvlen, DTypeEnum dtype, + uint32_t root, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + ncclDataType_t nccl_dtype = get_nccl_dtype(dtype); + CHECK_LAUNCH_MODE; + // perform nccl send/recv in a group + ncclGroupStart(); + if (m_rank == root) { + for (size_t r = 0; r < m_nranks; r++) { + const char *p = + (const char *)sendbuff + r * recvlen * GetDTypeByte(dtype); + NCCL_ASSERT(ncclSend((const void *)p, recvlen, nccl_dtype, r, + m_nccl->m_comm, stream)); + } + } + NCCL_ASSERT( + ncclRecv(recvbuff, recvlen, nccl_dtype, root, m_nccl->m_comm, stream)); + ncclGroupEnd(); + return Status::OK(); +} + +Status DistributedBackendNCCL::gather(const void *sendbuff, void *recvbuff, + size_t sendlen, DTypeEnum dtype, + uint32_t root, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + ncclDataType_t nccl_dtype = get_nccl_dtype(dtype); + CHECK_LAUNCH_MODE; + // perform nccl send/recv in a group + ncclGroupStart(); + if (m_rank == root) { + for (size_t r = 0; r < m_nranks; r++) { + char *p = (char *)recvbuff + r * sendlen * GetDTypeByte(dtype); + NCCL_ASSERT( + ncclRecv((void *)p, sendlen, nccl_dtype, r, m_nccl->m_comm, stream)); + } + } + NCCL_ASSERT( + ncclSend(sendbuff, sendlen, nccl_dtype, root, m_nccl->m_comm, stream)); + ncclGroupEnd(); + return Status::OK(); +} + +Status DistributedBackendNCCL::all_to_all(const void *sendbuff, void *recvbuff, + size_t len, DTypeEnum dtype, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + ncclDataType_t nccl_dtype = get_nccl_dtype(dtype); + CHECK_LAUNCH_MODE; + // perform nccl send/recv in a group + ncclGroupStart(); + for (size_t r = 0; r < m_nranks; r++) { + const char *p = (const char *)sendbuff + r * len * GetDTypeByte(dtype); + char *q = (char *)recvbuff + r * len * GetDTypeByte(dtype); + NCCL_ASSERT( + ncclSend((const void *)p, len, nccl_dtype, r, m_nccl->m_comm, stream)); + NCCL_ASSERT( + ncclRecv((void *)q, len, nccl_dtype, r, m_nccl->m_comm, stream)); + } + ncclGroupEnd(); + return Status::OK(); +} + +Status DistributedBackendNCCL::all_gather(const void *sendbuff, void *recvbuff, + size_t sendlen, DTypeEnum dtype, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + // perform all gather synchronously + NCCL_ASSERT(ncclAllGather(sendbuff, recvbuff, sendlen, get_nccl_dtype(dtype), + m_nccl->m_comm, stream)); + return Status::OK(); +} + +Status DistributedBackendNCCL::all_reduce(const void *sendbuff, void *recvbuff, + size_t len, DTypeEnum dtype, + ReduceOp op, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + // perform all reduce synchronously + NCCL_ASSERT(ncclAllReduce(sendbuff, recvbuff, len, get_nccl_dtype(dtype), + get_nccl_reduce_op(op), m_nccl->m_comm, stream)); + return Status::OK(); +} + +Status DistributedBackendNCCL::reduce_scatter(const void *sendbuff, + void *recvbuff, size_t recvlen, + DTypeEnum dtype, ReduceOp op, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + // perform reduce scatter synchronously + NCCL_ASSERT(ncclReduceScatter(sendbuff, recvbuff, recvlen, + get_nccl_dtype(dtype), get_nccl_reduce_op(op), + m_nccl->m_comm, stream)); + return Status::OK(); +} + +Status DistributedBackendNCCL::broadcast(const void *sendbuff, void *recvbuff, + size_t len, DTypeEnum dtype, + uint32_t root, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + // perform broadcast synchronously + NCCL_ASSERT(ncclBroadcast(sendbuff, recvbuff, len, get_nccl_dtype(dtype), + root, m_nccl->m_comm, stream)); + return Status::OK(); +} + +Status DistributedBackendNCCL::reduce(const void *sendbuff, void *recvbuff, + size_t len, DTypeEnum dtype, ReduceOp op, + uint32_t root, + std::shared_ptr ctx) { + // check context type and get cuda stream + assert(ctx->type() == "BRT_CTX_CUDA" && "only cuda context supported"); + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + // perform reduce synchronously + NCCL_ASSERT(ncclReduce(sendbuff, recvbuff, len, get_nccl_dtype(dtype), + get_nccl_reduce_op(op), root, m_nccl->m_comm, stream)); + return Status::OK(); +} + +Status DistributedBackendNCCL::group_start() { + CHECK_LAUNCH_MODE; + ncclGroupStart(); + return Status::OK(); +} + +Status DistributedBackendNCCL::group_end() { + CHECK_LAUNCH_MODE; + ncclGroupEnd(); + return Status::OK(); +} + +} // namespace brt \ No newline at end of file diff --git a/runtime/lib/backends/nccl/device/utils.cc b/runtime/lib/backends/nccl/device/utils.cc new file mode 100644 index 000000000..c8aa9a655 --- /dev/null +++ b/runtime/lib/backends/nccl/device/utils.cc @@ -0,0 +1,51 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#include "brt/backends/nccl/device/utils.h" +#include "brt/backends/nccl/device/d_context_nccl.h" +#include "brt/core/common/enums.h" +#include "brt/core/framework/dtype.h" + +namespace brt { + +ncclDataType_t get_nccl_dtype(const DTypeEnum dtype) { + switch (dtype) { + case DTypeEnum::Int8: + return ncclInt8; + case DTypeEnum::UInt8: + return ncclUint8; + case DTypeEnum::Int32: + return ncclInt32; + case DTypeEnum::UInt32: + return ncclUint32; + case DTypeEnum::Int64: + return ncclInt64; + case DTypeEnum::UInt64: + return ncclUint64; + case DTypeEnum::Float16: + return ncclFloat16; + case DTypeEnum::Float32: + return ncclFloat32; + case DTypeEnum::Float64: + return ncclFloat64; + default: + BRT_THROW("unknown dtype"); + } +} + +ncclRedOp_t get_nccl_reduce_op(const ReduceOp red_op) { + switch (red_op) { + case BRT_SUM: + return ncclSum; + case BRT_MAX: + return ncclMax; + case BRT_MIN: + return ncclMin; + default: + BRT_THROW("unknown reduce op"); + } +} + +} // namespace brt diff --git a/runtime/lib/backends/nccl/providers/nccl_provider.cc b/runtime/lib/backends/nccl/providers/nccl_provider.cc new file mode 100644 index 000000000..0b14fa729 --- /dev/null +++ b/runtime/lib/backends/nccl/providers/nccl_provider.cc @@ -0,0 +1,67 @@ +//===- nccl_provider.cc ---------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "brt/backends/nccl/providers/nccl_provider.h" + +#include "brt/backends/common.h" +#include "brt/backends/cuda/device/common/cuda_call.h" +#include "brt/backends/cuda/device/cuda_allocator.h" +#include "brt/backends/nccl/providers/op_registration.h" +#include "brt/core/framework/kernel_registry.h" +#include +#include + +using namespace brt; +using namespace brt::common; + +namespace brt { + +namespace { + +// clang-format off +BRT_STATIC_KERNEL_REGISTRATION( + DeviceKind::CUDA, ProviderType::BRT, [](KernelRegistry *registry) { + cuda::RegisterNCCLOps(registry); + }); +// clang-format on + +} // namespace + +NCCLExecutionProvider::NCCLExecutionProvider(const std::string &name, + int nranks, int rank, + const std::string &ip, int port) + : ExecutionProvider(DeviceKind::CUDA, name) { + nccl_backend_ = std::make_unique(nranks, rank); + nccl_backend_->init(ip.c_str(), port); +} + +common::Status DefaultNCCLExecutionProviderFactory(DistributedSession *session, + int local_rank) { + BRT_CUDA_CHECK(cudaSetDevice(local_rank)); + // create a NCCL provider + int rank = session->GetRank(); + int nranks = session->GetNRanks(); + const std::string &host = session->GetHost(); + int port = session->GetPort(); + auto provider = std::make_unique( + ProviderType::BRT, nranks, rank, host, port); + session->SetDistributedBackend(provider->GetDistributedBackend()); + // give ownership to the session + return session->AddExecutionProvider(std::move(provider)); +} + +} // namespace brt diff --git a/runtime/lib/backends/nccl/providers/op_registration.cc b/runtime/lib/backends/nccl/providers/op_registration.cc new file mode 100644 index 000000000..88b562bca --- /dev/null +++ b/runtime/lib/backends/nccl/providers/op_registration.cc @@ -0,0 +1,43 @@ +//===- op_registration.cc -------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "brt/backends/nccl/providers/op_registration.h" + +#include "./recv.h" +#include "./send.h" +#include "brt/core/framework/kernel_registry.h" + +namespace brt { +namespace cuda { + +void RegisterNCCLOps(KernelRegistry *registry) { + registry->Register( + "NCCLRecv_f32", + [](const brt::OpKernelInfo &info) -> std::shared_ptr { + auto kernel = std::shared_ptr(new cuda::Recv(info)); + return kernel; + }); + + registry->Register( + "NCCLSend_f32", + [](const brt::OpKernelInfo &info) -> std::shared_ptr { + auto kernel = std::shared_ptr(new cuda::Send(info)); + return kernel; + }); +} +} // namespace cuda +} // namespace brt diff --git a/runtime/lib/backends/nccl/providers/recv.cc b/runtime/lib/backends/nccl/providers/recv.cc new file mode 100644 index 000000000..4e55ec516 --- /dev/null +++ b/runtime/lib/backends/nccl/providers/recv.cc @@ -0,0 +1,71 @@ +//===- recv.cc ------------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "./recv.h" +#include "brt/backends/cuda/device/common/util.h" +#include "brt/backends/cuda/device/cuda_work_queue.h" +#include "brt/backends/nccl/device/d_context_nccl.h" +#include "brt/backends/nccl/device/distributed_backend_nccl.h" +#include "brt/core/context/execution_context.h" +#include "brt/core/context/execution_frame.h" +#include "brt/core/context/work_queue.h" +#include "brt/core/framework/dtype.h" +#include "brt/core/framework/op_accessor.h" +#include "brt/core/ir/ir.h" +#include "brt/core/ir/util.h" +#include "byteir/Dialect/Byre/ByreDialect.h" +#include +#include +#include +#include + +using namespace brt; +using namespace brt::common; +using namespace brt::ir; +using namespace mlir; + +namespace brt { +namespace cuda { + +// TODO: refine code and support various dtypes +template +common::Status Recv::RunImpl(const ExecutionContext &ctx) { + DistributedBackend *backend = ctx.distributed_backend; + assert(backend != nullptr); + DistributedBackendNCCL *nccl_backend = + static_cast(backend); + + OpAccessor accessor(info_, ctx.exec_frame); + const auto src_shape = accessor.GetArgShape(0); + auto elem_num = std::accumulate(src_shape.begin(), src_shape.end(), 1, + std::multiplies()); + T *src = reinterpret_cast(accessor.GetArgAsyncValueRef(0)); + int64_t rank = accessor.GetAttrAsInt("rank"); + + cudaStream_t stream = + static_cast(ctx.work_queue)->GetComputeStream(); + std::shared_ptr d_context = std::make_shared(stream); + nccl_backend->recv(src, elem_num, DTypeEnum::Float32, rank, d_context); + + return Status::OK(); +} + +// instantiate +template class Recv; + +} // namespace cuda +} // namespace brt diff --git a/runtime/include/brt/backends/cuda/providers/default/tensor_generate/rng_state_context.h b/runtime/lib/backends/nccl/providers/recv.h similarity index 66% rename from runtime/include/brt/backends/cuda/providers/default/tensor_generate/rng_state_context.h rename to runtime/lib/backends/nccl/providers/recv.h index 1f23f4377..bbcf0b4f8 100644 --- a/runtime/include/brt/backends/cuda/providers/default/tensor_generate/rng_state_context.h +++ b/runtime/lib/backends/nccl/providers/recv.h @@ -1,5 +1,4 @@ -//===- rng_state_context.h --------------------------------------*--- C++ -//-*-===// +//===- recv.h -------------------------------------------------*--- C++ -*-===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,25 +17,18 @@ #pragma once +#include "brt/core/framework/op_kernel.h" + namespace brt { namespace cuda { -class rngStateContext { -private: - int64_t seed; - int64_t offset; - +// recv synchronously +template class Recv final : public OpKernel { public: - explicit rngStateContext() : seed(0), offset(0) {} - - int64_t getSeed() { return seed; } + explicit Recv(const OpKernelInfo &info) : OpKernel(info) {} - int64_t nextOffset() { return offset++; } - - void setSeed(int64_t seed_) { seed = seed_; } + common::Status RunImpl(const ExecutionContext &) override; }; -using rngStateHandle_t = rngStateContext *; - } // namespace cuda -} // namespace brt \ No newline at end of file +} // namespace brt diff --git a/runtime/lib/backends/nccl/providers/send.cc b/runtime/lib/backends/nccl/providers/send.cc new file mode 100644 index 000000000..990adf56e --- /dev/null +++ b/runtime/lib/backends/nccl/providers/send.cc @@ -0,0 +1,71 @@ +//===- send.cc ------------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "./send.h" +#include "brt/backends/cuda/device/common/util.h" +#include "brt/backends/cuda/device/cuda_work_queue.h" +#include "brt/backends/nccl/device/d_context_nccl.h" +#include "brt/backends/nccl/device/distributed_backend_nccl.h" +#include "brt/core/context/execution_context.h" +#include "brt/core/context/execution_frame.h" +#include "brt/core/context/work_queue.h" +#include "brt/core/framework/dtype.h" +#include "brt/core/framework/op_accessor.h" +#include "brt/core/ir/ir.h" +#include "brt/core/ir/util.h" +#include "byteir/Dialect/Byre/ByreDialect.h" +#include +#include +#include +#include + +using namespace brt; +using namespace brt::common; +using namespace brt::ir; +using namespace mlir; + +namespace brt { +namespace cuda { + +// TODO: refine code and support various dtypes +template +common::Status Send::RunImpl(const ExecutionContext &ctx) { + DistributedBackend *backend = ctx.distributed_backend; + assert(backend != nullptr); + DistributedBackendNCCL *nccl_backend = + static_cast(backend); + + OpAccessor accessor(info_, ctx.exec_frame); + const auto src_shape = accessor.GetArgShape(0); + auto elem_num = std::accumulate(src_shape.begin(), src_shape.end(), 1, + std::multiplies()); + T *src = reinterpret_cast(accessor.GetArgAsyncValueRef(0)); + int64_t rank = accessor.GetAttrAsInt("rank"); + + cudaStream_t stream = + static_cast(ctx.work_queue)->GetComputeStream(); + std::shared_ptr d_context = std::make_shared(stream); + nccl_backend->send(src, elem_num, DTypeEnum::Float32, rank, d_context); + + return Status::OK(); +} + +// instantiate +template class Send; + +} // namespace cuda +} // namespace brt diff --git a/runtime/lib/backends/nccl/providers/send.h b/runtime/lib/backends/nccl/providers/send.h new file mode 100644 index 000000000..b8c730c5f --- /dev/null +++ b/runtime/lib/backends/nccl/providers/send.h @@ -0,0 +1,34 @@ +//===- send.h -------------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "brt/core/framework/op_kernel.h" + +namespace brt { +namespace cuda { + +// send synchronously +template class Send final : public OpKernel { +public: + explicit Send(const OpKernelInfo &info) : OpKernel(info) {} + + common::Status RunImpl(const ExecutionContext &) override; +}; + +} // namespace cuda +} // namespace brt diff --git a/runtime/lib/core/distributed/distributed_backend.cc b/runtime/lib/core/distributed/distributed_backend.cc new file mode 100644 index 000000000..7f8249db1 --- /dev/null +++ b/runtime/lib/core/distributed/distributed_backend.cc @@ -0,0 +1,35 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#include "brt/core/distributed/distributed_backend.h" + +using namespace brt::common; + +namespace brt { + +Status DistributedBackend::init(const char *master_ip, int port) { + m_client = std::make_shared(m_nranks, m_rank); + auto status = m_client->connect(master_ip, port); + if (status != Status::OK()) + return status; + return do_init(); +} + +Status DistributedBackend::init(BcastCallback cb) { return do_init(cb); } + +Status DistributedBackend::recv(void *recvbuf, size_t len, DTypeEnum dtype, + uint32_t rank, std::shared_ptr ctx) { + size_t type_size = GetDTypeByte(dtype); + return _recv(recvbuf, len * type_size, rank, ctx); +} + +Status DistributedBackend::send(const void *sendbuff, size_t len, + DTypeEnum dtype, uint32_t rank, + std::shared_ptr ctx) { + size_t type_size = GetDTypeByte(dtype); + return _send(sendbuff, len * type_size, rank, ctx); +} + +} // namespace brt diff --git a/runtime/lib/core/distributed/distributed_session.cc b/runtime/lib/core/distributed/distributed_session.cc new file mode 100644 index 000000000..c72716ccb --- /dev/null +++ b/runtime/lib/core/distributed/distributed_session.cc @@ -0,0 +1,92 @@ +//===- distributed_session.cc ---------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "brt/core/distributed/distributed_session.h" + +#include "brt/core/context/execution_context.h" +#include "brt/core/context/execution_frame.h" +#include "brt/core/context/work_queue.h" +#include "brt/core/framework/execution_plan.h" +#include "brt/core/framework/execution_provider.h" +#include "brt/core/ir/ir.h" +#include "brt/core/session/request_context.h" +#include +#include + +using namespace brt; +using namespace brt::common; +using namespace brt::ir; + +namespace brt { + +DistributedSession::DistributedSession(int rank, int nranks, + const std::string &host, int port) + : rank_(rank), nranks_(nranks), host_(host), port_(port), Session() {} + +DistributedSession::~DistributedSession() {} + +common::Status DistributedSession::Run(RequestContext &request) { + // Create ExecutionContext + ExecutionContext ctx(request.frame_.get(), request.wq_.get(), + execution_plan_->GetFrameStateInfo(), + request.events_.get(), distributed_backend_); + + using State = ExecutionFrame::InternalState; + Status status = + request.frame_->GetIStateTransition() + .Edge(State::BeforePrologue, State::MainLoop, + [&] { return execution_plan_->ProloguePerFrame(ctx); }) + .Invariant(State::MainLoop) + .Apply(); + + if (!status.IsOK()) { + return status; + } + + return request.frame_->GetIStateTransition() + .Edge(State::MainLoop, State::MainLoop, + [&] { return execution_plan_->Run(ctx); }) + .Apply(); +} + +common::Status +DistributedSession::NewRequestContext(std::unique_ptr *request, + WorkQueue *work_queue) { + *request = std::unique_ptr(new RequestContext(*this)); + // allocate Frame but not allocate Intermediate + BRT_ENFORCE(execution_plan_ != nullptr); + + execution_plan_->CreateExecutinFrame(&((*request)->frame_)); + + if (work_queue) { + (*request)->SetWorkQueue(work_queue); + } else { + execution_plan_->CreateWorkQueue(&((*request)->wq_), rank_); + } + + return Status::OK(); +} + +common::Status +DistributedSession::LoadConfig(const std::vector &config, + std::string &ir_url) { + assert(config.size() > rank_); + ir_url = config[rank_]; + return Status::OK(); +} + +} // namespace brt \ No newline at end of file diff --git a/runtime/lib/core/distributed/rendezvous_socket.cc b/runtime/lib/core/distributed/rendezvous_socket.cc new file mode 100644 index 000000000..9c6d8412c --- /dev/null +++ b/runtime/lib/core/distributed/rendezvous_socket.cc @@ -0,0 +1,337 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#include "brt/core/distributed/rendezvous_socket.h" +#include "brt/core/common/logging/logging.h" +#include "brt/core/common/logging/macros.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace brt::logging; +using namespace brt::common; + +namespace brt { + +//===----------------------------------------------------------------------===// +// GetFreePort +//===----------------------------------------------------------------------===// + +int GetFreePort() { + // create socket + int sock = socket(AF_INET, SOCK_STREAM, 0); + assert(sock != -1); + + // set address + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_ANY); + addr.sin_port = htons(0); + + // bind + assert(bind(sock, (struct sockaddr *)&addr, sizeof(addr)) != -1); + + // get port + socklen_t len = sizeof(addr); + assert(getsockname(sock, (struct sockaddr *)&addr, &len) != -1); + int port = ntohs(addr.sin_port); + + // close + assert(close(sock) != -1); + + return port; +} + +//===----------------------------------------------------------------------===// +// CreateServer +//===----------------------------------------------------------------------===// + +namespace { + +void serve_barrier(uint32_t nranks, int *conns) { + uint32_t request_id; + + // recv other requests + for (uint32_t rank = 1; rank < nranks; rank++) { + assert(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL) != -1); + assert(request_id == 1); + } + + // send ack + uint32_t ack = 0; + for (uint32_t rank = 0; rank < nranks; rank++) { + assert(send(conns[rank], &ack, sizeof(uint32_t), 0) != -1); + } +} + +void serve_broadcast(uint32_t nranks, int *conns) { + uint32_t request_id, root, root0; + uint64_t len, len0; + + // recv request 0 + assert(recv(conns[0], &root0, sizeof(uint32_t), MSG_WAITALL) != -1); + assert(recv(conns[0], &len0, sizeof(uint64_t), MSG_WAITALL) != -1); + + // recv other requests + for (uint32_t rank = 1; rank < nranks; rank++) { + assert(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL) != -1); + assert(request_id == 2 && "inconsistent request_id from rank"); + + assert(recv(conns[rank], &root, sizeof(uint32_t), MSG_WAITALL) != -1); + assert(root == root0 && "inconsistent root from rank"); + + assert(recv(conns[rank], &len, sizeof(uint64_t), MSG_WAITALL) != -1); + assert(len == len0 && "inconsistent len from rank"); + } + + root = root0; + len = len0; + + // recv data from root + void *data = malloc(len); + assert(recv(conns[root], data, len, MSG_WAITALL) != -1); + + // send data to clients + for (uint32_t rank = 0; rank < nranks; rank++) { + assert(send(conns[rank], data, len, 0) != -1); + } + + free(data); +} + +void serve_allgather(uint32_t nranks, int *conns) { + uint32_t request_id; + uint64_t len, len0; + + // recv request 0 + assert(recv(conns[0], &len0, sizeof(uint64_t), MSG_WAITALL) != -1); + + // recv other requests + for (uint32_t rank = 1; rank < nranks; rank++) { + assert(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL) != -1); + assert(request_id == 3 && "inconsistent request_id from rank"); + + assert(recv(conns[rank], &len, sizeof(uint64_t), MSG_WAITALL) != -1); + assert(len == len0 && "inconsistent len from rank"); + } + + // recv data + void *data = malloc(len * nranks); + for (uint32_t rank = 0; rank < nranks; rank++) { + char *ptr = (char *)data + rank * len; + assert(recv(conns[rank], ptr, len, MSG_WAITALL) != -1); + } + + // send data to clients + for (uint32_t rank = 0; rank < nranks; rank++) { + assert(send(conns[rank], data, len * nranks, 0) != -1); + } + + free(data); +} + +void server_thread(int listenfd, uint32_t nranks) { + int conns[nranks]; + + for (uint32_t i = 0; i < nranks; i++) { + // establish connection + int conn = accept(listenfd, (struct sockaddr *)NULL, NULL); + assert(conn != -1); + + // recv rank and save into conns + uint32_t rank; + assert(recv(conn, &rank, sizeof(uint32_t), MSG_WAITALL) != -1); + conns[rank] = conn; + } + + // send ack to clients + uint32_t ack = 0; + for (uint32_t i = 0; i < nranks; i++) { + assert(send(conns[i], &ack, sizeof(uint32_t), 0) != -1); + } + + while (true) { + // receive a request from rank 0 + uint32_t request_id; + auto ret = recv(conns[0], &request_id, sizeof(uint32_t), MSG_WAITALL); + // recv 0 btyes means socket close + if (ret == 0) + break; + assert(ret != -1 && "socket recv msg error"); + + if (request_id == 1) { + serve_barrier(nranks, conns); + } else if (request_id == 2) { + serve_broadcast(nranks, conns); + } else if (request_id == 3) { + serve_allgather(nranks, conns); + } else { + BRT_LOGS_DEFAULT(ERROR) << "unexpected request id:" << request_id; + BRT_THROW("unexpected error"); + } + } +} + +} // namespace + +Status CreateServer(uint32_t nranks, int port) { + // create socket + int listenfd = socket(AF_INET, SOCK_STREAM, 0); + assert(listenfd != -1); + + // set server_addr + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = htonl(INADDR_ANY); + server_addr.sin_port = htons(port); + + // bind and listen + int opt = 1; + assert(setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(int)) != + -1); + assert(bind(listenfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) != + -1); + assert(listen(listenfd, nranks) != -1); + + // start server thread + std::thread th(server_thread, listenfd, nranks); + th.detach(); + + return Status::OK(); +} + +//===----------------------------------------------------------------------===// +// RendezvousSocket +//===----------------------------------------------------------------------===// + +RendezvousSocket::RendezvousSocket(uint32_t nranks, uint32_t rank) + : nranks_(nranks), rank_(rank), connected_(false) {} + +RendezvousSocket::~RendezvousSocket() {} + +Status RendezvousSocket::connect(const char *master_ip, int port) { + std::unique_lock lock(mutex_); + + if (connected_) { + return Status(common::StatusCategory::BRT, common::StatusCode::FAIL, + "Client already connected"); + } + + // create socket + conn_ = socket(AF_INET, SOCK_STREAM, 0); + assert(conn_ != -1); + + // set server_addr + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(port); + assert(inet_pton(AF_INET, master_ip, &server_addr.sin_addr) != -1); + + // connect + int ret = + ::connect(conn_, (struct sockaddr *)&server_addr, sizeof(server_addr)); + while (ret == -1) { + usleep(100000); // 100ms + ret = + ::connect(conn_, (struct sockaddr *)&server_addr, sizeof(server_addr)); + } + + // send client rank + assert(send(conn_, &rank_, sizeof(uint32_t), 0) != -1); + + // recv ack from server + uint32_t ack; + assert(recv(conn_, &ack, sizeof(uint32_t), MSG_WAITALL) != -1); + + connected_ = true; + return Status::OK(); +} + +Status RendezvousSocket::barrier() { + std::unique_lock lock(mutex_); + + if (!connected_) { + return Status(common::StatusCategory::BRT, common::StatusCode::FAIL, + "Client not connected"); + } + + // send request_id + uint32_t request_id = 1; + assert(send(conn_, &request_id, sizeof(uint32_t), 0) != -1); + + // recv ack + uint32_t ack; + assert(recv(conn_, &ack, sizeof(uint32_t), MSG_WAITALL) != -1); + + return Status::OK(); +} + +Status RendezvousSocket::broadcast(const void *sendbuff, void *recvbuff, + size_t len, uint32_t root) { + std::unique_lock lock(mutex_); + + if (!connected_) { + return Status(common::StatusCategory::BRT, common::StatusCode::FAIL, + "Client not connected"); + } + + // send request_id + uint32_t request_id = 2; + assert(send(conn_, &request_id, sizeof(uint32_t), 0) != -1); + + // send root + assert(send(conn_, &root, sizeof(uint32_t), 0) != -1); + + // send len + uint64_t len64 = len; + assert(send(conn_, &len64, sizeof(uint64_t), 0) != -1); + + // send data + if (rank_ == root) { + assert(send(conn_, sendbuff, len, 0) != -1); + } + + // recv data + assert(recv(conn_, recvbuff, len, MSG_WAITALL) != -1); + + return Status::OK(); +} + +Status RendezvousSocket::allgather(const void *sendbuff, void *recvbuff, + size_t sendlen) { + std::unique_lock lock(mutex_); + + if (!connected_) { + return Status(common::StatusCategory::BRT, common::StatusCode::FAIL, + "Client not connected"); + } + + // send request_id + uint32_t request_id = 3; + assert(send(conn_, &request_id, sizeof(uint32_t), 0) != -1); + + // send sendlen + uint64_t sendlen64 = sendlen; + assert(send(conn_, &sendlen64, sizeof(uint64_t), 0) != -1); + + // send data + assert(send(conn_, sendbuff, sendlen, 0) != -1); + + // recv data + assert(recv(conn_, recvbuff, sendlen * nranks_, MSG_WAITALL) != -1); + + return Status::OK(); +} + +} // namespace brt \ No newline at end of file diff --git a/runtime/lib/core/framework/execution_plan.cc b/runtime/lib/core/framework/execution_plan.cc index 4901a6350..5cbe86131 100644 --- a/runtime/lib/core/framework/execution_plan.cc +++ b/runtime/lib/core/framework/execution_plan.cc @@ -506,13 +506,14 @@ common::Status StaticBRTExecutionPlan::EpiloguePerSession() { return common::Status::OK(); } -void StaticBRTExecutionPlan::CreateWorkQueue(std::unique_ptr *wq) { +void StaticBRTExecutionPlan::CreateWorkQueue(std::unique_ptr *wq, + int rank) { // create WQ // TODO remove this // TODO avoid using BRT_USE_CUDA #if BRT_USE_CUDA // wq_ = std::unique_ptr(new CUDAWorkQueue()); - *wq = std::unique_ptr(new CUDASingleStreamWorkQueue(0)); + *wq = std::unique_ptr(new CUDASingleStreamWorkQueue(rank)); #endif } diff --git a/runtime/test/backends/cpu/providers/default/kernel/rng_state_test.cc b/runtime/test/backends/cpu/providers/default/kernel/rng_state_test.cc new file mode 100644 index 000000000..72f10875f --- /dev/null +++ b/runtime/test/backends/cpu/providers/default/kernel/rng_state_test.cc @@ -0,0 +1,58 @@ +//===- rng_state_test.cc --------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "brt/backends/cpu/device/cpu_work_queue.h" +#include "brt/backends/cpu/providers/default/cpu_provider.h" +#include "brt/core/common/status.h" +#include "brt/core/session/request_context.h" +#include "brt/core/session/session.h" +#include "brt/test/common/util.h" +#include "gtest/gtest.h" + +static std::string test_file_fill = "test/test_files/rng_state_cpu.mlir"; + +using namespace brt; +using namespace brt::test; + +TEST(CPUTestRngState, Basic) { + constexpr size_t length = 1; + + Session session; + auto status_allocator = CPUAllocatorFactory(&session); + BRT_TEST_CHECK_STATUS(status_allocator); + auto status_cpu = NaiveCPUExecutionProviderFactory(&session); + BRT_TEST_CHECK_STATUS(status_cpu); + + auto status_load = session.Load(test_file_fill, "byre"); + BRT_TEST_CHECK_STATUS(status_load); + + std::unique_ptr request; + auto status_request = session.NewRequestContext(&request); + BRT_TEST_CHECK_STATUS(status_request); + + request->FinishIOBinding(); + + auto status_run = session.Run(*request); + BRT_TEST_CHECK_STATUS(status_run); + auto status_sync = request->Sync(); + BRT_TEST_CHECK_STATUS(status_sync); + + CheckValues(static_cast(request->GetArg(0)), length, 0); + CheckValues(static_cast(request->GetArg(1)), length, 0); + CheckValues(static_cast(request->GetArg(2)), length, 1); + CheckValues(static_cast(request->GetArg(3)), length, 2); +} diff --git a/runtime/test/backends/cuda/providers/default/kernel/rng_state_test.cc b/runtime/test/backends/cuda/providers/default/kernel/rng_state_test.cc index 4e090b7d4..e288d5f5b 100644 --- a/runtime/test/backends/cuda/providers/default/kernel/rng_state_test.cc +++ b/runtime/test/backends/cuda/providers/default/kernel/rng_state_test.cc @@ -1,5 +1,4 @@ -//===- rng_state_test.cc -------------------------------------------*--- C++ -//-*-===// +//===- rng_state_test.cc --------------------------------------*--- C++ -*-===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,7 +25,7 @@ #include #include -static std::string test_file_fill = "test/test_files/rng_state.mlir"; +static std::string test_file_fill = "test/test_files/rng_state_cuda.mlir"; using namespace brt; using namespace brt::cuda; diff --git a/runtime/test/backends/nccl/device/test_distributed_backend.cc b/runtime/test/backends/nccl/device/test_distributed_backend.cc new file mode 100644 index 000000000..505e73d76 --- /dev/null +++ b/runtime/test/backends/nccl/device/test_distributed_backend.cc @@ -0,0 +1,78 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#include +#include +#include +#include +#include +#include +#include + +#include "brt/backends/nccl/device/distributed_backend_nccl.h" +#include "brt/core/common/common.h" +#include "brt/core/framework/dtype.h" +#include "brt/test/common/nccl/test_base.h" +#include "brt/test/common/nccl/test_utils.h" + +using namespace brt; + +TEST(TestDistributedBackendNCCL, Init) { + auto type = "BRT_CTX_CUDA"; + const int nranks = 2; + const int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + auto run = [&](int rank) { + get_context_trait(get_preferred_context(BackendType::BRT_NCCL)) + .set_device(rank); + auto backend = std::make_shared(nranks, rank); + ASSERT_EQ(Status::OK(), backend->init("localhost", port)); + }; + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestDistributedBackendNCCL, SendRecv) { + std::string msg("test_message"); + const int nranks = 2; + const size_t len = msg.size(); + + std::vector> inputs(nranks); + std::vector> expected_outputs(nranks); + + for (size_t i = 0; i < len; i++) { + inputs[0].push_back(msg[i]); + expected_outputs[1].push_back(msg[i]); + } + + auto run = [len](std::shared_ptr comm, ContextTrait trait, + int port, int rank, std::vector &input, + std::vector &output) -> void { + trait.set_device(rank); + comm->init("localhost", port); + + auto context = trait.make_context(); + + void *ptr = trait.alloc(len); + + if (rank == 0) { // send + trait.memcpy_h2d(ptr, input.data(), len, context); + comm->send(ptr, len * 1, DTypeEnum::UInt8, 1, context); + trait.sync_context(context); + } else { // recv + comm->recv(ptr, len * 1, DTypeEnum::UInt8, 0, context); + trait.memcpy_d2h(output.data(), ptr, len, context); + trait.sync_context(context); + } + }; + + run_test_for_all(nranks, inputs, expected_outputs, run); +} diff --git a/runtime/test/backends/nccl/device/test_utils.cc b/runtime/test/backends/nccl/device/test_utils.cc new file mode 100644 index 000000000..80be8b90e --- /dev/null +++ b/runtime/test/backends/nccl/device/test_utils.cc @@ -0,0 +1,54 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#include "brt/test/common/nccl/test_utils.h" +#include "brt/backends/cuda/device/common/cuda_call.h" +#include "brt/backends/nccl/device/d_context_nccl.h" +#include +#include +#include + +namespace brt { + +void *alloc_cuda(size_t size) { + void *result; + BRT_CUDA_CHECK(cudaMalloc(&result, size)); + return result; +} + +void free_cuda(void *ptr) { BRT_CUDA_CHECK(cudaFree(ptr)); } + +void set_device_cuda(size_t device) { BRT_CUDA_CHECK(cudaSetDevice(device)); } + +std::shared_ptr make_context_cuda() { + cudaStream_t stream; + BRT_CUDA_CHECK(cudaStreamCreate(&stream)); + auto context = std::make_shared(stream); + return context; +} + +void sync_context_cuda(std::shared_ptr context) { + assert(context->type() == "BRT_CTX_CUDA" && "not a cuda context"); + BRT_CUDA_CHECK(cudaStreamSynchronize( + static_cast(context.get())->get_stream())); +} + +void memcpy_h2d_cuda(void *dst, void *src, size_t len, + std::shared_ptr ctx) { + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + BRT_CUDA_CHECK( + cudaMemcpyAsync(dst, src, len, cudaMemcpyHostToDevice, stream)); + BRT_CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +void memcpy_d2h_cuda(void *dst, void *src, size_t len, + std::shared_ptr ctx) { + cudaStream_t stream = static_cast(ctx.get())->get_stream(); + BRT_CUDA_CHECK( + cudaMemcpyAsync(dst, src, len, cudaMemcpyDeviceToHost, stream)); + BRT_CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +} // namespace brt diff --git a/runtime/test/backends/nccl/providers/test_distributed_session.cc b/runtime/test/backends/nccl/providers/test_distributed_session.cc new file mode 100644 index 000000000..071776a9f --- /dev/null +++ b/runtime/test/backends/nccl/providers/test_distributed_session.cc @@ -0,0 +1,194 @@ +//===- test_distributed_session.cc ----------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include "brt/backends/cuda/device/cuda_allocator.h" +#include "brt/backends/nccl/providers/nccl_provider.h" +#include "brt/core/common/status.h" +#include "brt/core/distributed/distributed_session.h" +#include "brt/core/session/request_context.h" +#include "brt/test/common/cuda/util.h" +#include "brt/test/common/nccl/test_base.h" +#include "brt/test/common/nccl/test_utils.h" +#include "brt/test/common/util.h" + +using namespace brt; +using namespace brt::common; +using namespace brt::ir; +using namespace brt::test; + +namespace { + +static void CheckResult(float *d_ptr, size_t size, float val) { + CheckCUDABuffer((float *)d_ptr, size, [&](float *h_ptr) { + for (size_t i = 0; i < size; ++i) { + EXPECT_EQ(h_ptr[i], val); + } + }); +} + +} // namespace + +TEST(TestDistributedSession, NCCLProvider) { + const int nranks = 2; + const std::string host = "localhost"; + int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + + auto run = [nranks, host, port](int rank) { + int local_rank = rank; + DistributedSession d_session(rank, nranks, host, port); + auto status_cuda = + DefaultNCCLExecutionProviderFactory(&d_session, local_rank); + BRT_TEST_CHECK_STATUS(status_cuda); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestDistributedSession, NCCLSendRecv) { + const int nranks = 2; + const std::string host = "localhost"; + int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + + auto run = [nranks, host, port](int rank) { + int local_rank = rank; + DistributedSession d_session(rank, nranks, host, port); + auto status_allocator = CUDAAllocatorFactory(&d_session, local_rank); + BRT_TEST_CHECK_STATUS(status_allocator); + auto status_cuda = + DefaultNCCLExecutionProviderFactory(&d_session, local_rank); + BRT_TEST_CHECK_STATUS(status_cuda); + + std::vector config = {"test/test_files/Distributed/send.mlir", + "test/test_files/Distributed/recv.mlir"}; + std::string ir_url; + d_session.LoadConfig(config, ir_url); + auto status_load = d_session.Load(ir_url, "byre"); + BRT_TEST_CHECK_STATUS(status_load); + + std::unique_ptr request; + auto status_request = d_session.NewRequestContext(&request); + BRT_TEST_CHECK_STATUS(status_request); + + float *d_src = (float *)request->GetArg(0); + auto shape = d_session.GetStaticShape(0); + int64_t linearized_shape = LinearizedShape(shape); + EXPECT_GT(linearized_shape, 0); + size_t len = static_cast(linearized_shape); + if (rank == 0) + AssignCUDABuffer(d_src, len, 12345.0f); + if (rank == 1) + AssignCUDABuffer(d_src, len, 2.0f); + request->FinishIOBinding(); + + auto status_run = d_session.Run(*request); + BRT_TEST_CHECK_STATUS(status_run); + auto status_sync = request->Sync(); + BRT_TEST_CHECK_STATUS(status_sync); + + CheckResult(d_src, len, 12345.0f); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestDistributedSession, NCCLAddSendRecvAdd) { + const int nranks = 2; + const std::string host = "localhost"; + int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + + auto run = [nranks, host, port](int rank) { + int local_rank = rank; + DistributedSession d_session(rank, nranks, host, port); + auto status_allocator = CUDAAllocatorFactory(&d_session, local_rank); + BRT_TEST_CHECK_STATUS(status_allocator); + auto status_cuda = + DefaultNCCLExecutionProviderFactory(&d_session, local_rank); + BRT_TEST_CHECK_STATUS(status_cuda); + + std::vector config = { + "test/test_files/Distributed/add_send.mlir", + "test/test_files/Distributed/recv_add.mlir"}; + std::string ir_url; + d_session.LoadConfig(config, ir_url); + auto status_load = d_session.Load(ir_url, "byre"); + BRT_TEST_CHECK_STATUS(status_load); + + std::unique_ptr request; + auto status_request = d_session.NewRequestContext(&request); + BRT_TEST_CHECK_STATUS(status_request); + + auto shape = d_session.GetStaticShape(0); + int64_t linearized_shape = LinearizedShape(shape); + EXPECT_GT(linearized_shape, 0); + size_t len = static_cast(linearized_shape); + if (rank == 0) { + float *d_in0 = (float *)request->GetArg(0); + float *d_in1 = (float *)request->GetArg(1); + AssignCUDABuffer(d_in0, len, 1.0f); + AssignCUDABuffer(d_in1, len, 2.0f); + } else if (rank == 1) { + float *d_in0 = (float *)request->GetArg(0); + AssignCUDABuffer(d_in0, len, 3.0f); + } + request->FinishIOBinding(); + + auto status_run = d_session.Run(*request); + BRT_TEST_CHECK_STATUS(status_run); + auto status_sync = request->Sync(); + BRT_TEST_CHECK_STATUS(status_sync); + + if (rank == 1) { + float *d_out0 = (float *)request->GetArg(1); + float *d_out1 = (float *)request->GetArg(2); + CheckResult(d_out0, len, 3.0f); + CheckResult(d_out1, len, 6.0f); + } + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} diff --git a/runtime/test/distributed/test_rendezvous_socket.cc b/runtime/test/distributed/test_rendezvous_socket.cc new file mode 100644 index 000000000..3893dc60b --- /dev/null +++ b/runtime/test/distributed/test_rendezvous_socket.cc @@ -0,0 +1,200 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#include +#include +#include + +#include "brt/core/common/status.h" +#include "brt/core/distributed/rendezvous_socket.h" + +using namespace brt::common; + +TEST(TestRendezvousSocket, GetFreePort) { + int port = brt::GetFreePort(); + ASSERT_TRUE(port > 0); +} + +TEST(TestRendezvousSocket, Connect) { + const int nranks = 3; + + const int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + + auto run = [nranks, port](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(Status::OK(), ret); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestRendezvousSocket, Barrier) { + const int nranks = 3; + + const int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + + int counter = 0; + + auto run = [nranks, port, &counter](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(Status::OK(), ret); + + ret = client->barrier(); + ASSERT_EQ(Status::OK(), ret); + + sleep(rank); + ++counter; + + ret = client->barrier(); + ASSERT_EQ(Status::OK(), ret); + + // if the barrier is not working correctly, threads that sleep + // less seconds will arrive here earlier and counter might be + // less than nranks + ASSERT_EQ(nranks, counter); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestRendezvousSocket, Broadcast) { + const int nranks = 3; + const int root = 1; + const int chunk_size = 10; + + const int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + + std::string str(chunk_size * nranks, '\0'); + for (size_t i = 0; i < str.size(); i++) { + str[i] = 'a' + i % 26; + } + auto expected = str.substr(root * chunk_size, chunk_size); + + auto run = [nranks, port, &str, &expected](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(Status::OK(), ret); + + const char *input = str.data() + rank * chunk_size; + char *output = (char *)malloc(chunk_size); + ret = client->broadcast(input, output, chunk_size, root); + ASSERT_EQ(Status::OK(), ret); + + ASSERT_EQ(expected, std::string(output, chunk_size)); + free(output); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestServerClient, AllGather) { + const int nranks = 3; + const int chunk_size = 10; + + const int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + + std::string str(chunk_size * nranks, '\0'); + for (size_t i = 0; i < str.size(); i++) { + str[i] = 'a' + i % 26; + } + + auto run = [nranks, port, &str](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(Status::OK(), ret); + + const char *input = str.data() + rank * chunk_size; + char *output = (char *)malloc(str.size()); + ret = client->allgather(input, output, chunk_size); + ASSERT_EQ(Status::OK(), ret); + + ASSERT_EQ(str, std::string(output, str.size())); + free(output); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestServerClient, Sequence) { + const int nranks = 3; + const int chunk_size = 10; + + const int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + + std::string str(chunk_size * nranks, '\0'); + for (size_t i = 0; i < str.size(); i++) { + str[i] = 'a' + i % 26; + } + + auto run = [nranks, port, &str](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(Status::OK(), ret); + + const char *input = str.data() + rank * chunk_size; + char *output = (char *)malloc(str.size()); + + // send a sequence of requets without checking output + ASSERT_EQ(Status::OK(), client->barrier()); + ASSERT_EQ(Status::OK(), client->broadcast(input, output, chunk_size, 1)); + ASSERT_EQ(Status::OK(), client->allgather(input, output, chunk_size)); + ASSERT_EQ(Status::OK(), client->barrier()); + ASSERT_EQ(Status::OK(), client->allgather(input, output, chunk_size)); + ASSERT_EQ(Status::OK(), client->broadcast(input, output, chunk_size, 2)); + ASSERT_EQ(Status::OK(), client->allgather(input, output, chunk_size)); + ASSERT_EQ(Status::OK(), client->barrier()); + + free(output); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} diff --git a/runtime/test/include/brt/test/common/nccl/test_base.h b/runtime/test/include/brt/test/common/nccl/test_base.h new file mode 100644 index 000000000..b60b8b11a --- /dev/null +++ b/runtime/test/include/brt/test/common/nccl/test_base.h @@ -0,0 +1,65 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#include +#include +#include +#include + +#include "brt/backends/nccl/device/distributed_backend_nccl.h" +#include "brt/core/common/common.h" +#include "test_utils.h" + +using namespace brt; + +template +void run_test( + int nranks, BackendType backend, std::vector> &inputs, + std::vector> &expect_outputs, + std::function, ContextTrait, int, + int, std::vector &, std::vector &)> + main_func) { + auto trait = get_context_trait(get_preferred_context(backend)); + std::vector> comms(nranks); + std::vector> outputs(nranks); + + int port = brt::GetFreePort(); + auto ret = brt::CreateServer(nranks, port); + ASSERT_EQ(Status::OK(), ret); + + for (int i = 0; i < nranks; i++) { + comms[i] = std::make_shared(nranks, i); + outputs[i].resize(expect_outputs[i].size()); + } + + std::vector threads; + for (int i = 0; i < nranks; i++) { + threads.push_back(std::thread(main_func, comms[i], trait, port, i, + std::ref(inputs[i]), std::ref(outputs[i]))); + } + + for (int i = 0; i < nranks; i++) { + threads[i].join(); + } + + for (int i = 0; i < nranks; i++) { + for (size_t j = 0; j < expect_outputs[i].size(); j++) { + ASSERT_FLOAT_EQ(expect_outputs[i][j], outputs[i][j]); + } + } +} + +template +void run_test_for_all( + int nranks, std::vector> &inputs, + std::vector> &expect_outputs, + std::function, ContextTrait, int, + int, std::vector &, std::vector &)> + main_func) { + std::vector backends = {BackendType::BRT_NCCL}; + for (auto &&backend : backends) { + run_test(nranks, backend, inputs, expect_outputs, main_func); + } +} \ No newline at end of file diff --git a/runtime/test/include/brt/test/common/nccl/test_utils.h b/runtime/test/include/brt/test/common/nccl/test_utils.h new file mode 100644 index 000000000..ff4e8d301 --- /dev/null +++ b/runtime/test/include/brt/test/common/nccl/test_utils.h @@ -0,0 +1,57 @@ +// Copyright (c) Megvii Inc. +// Licensed under Apache License, Version 2.0 +// =========================================================================== +// Modification Copyright 2022 ByteDance Ltd. and/or its affiliates. + +#pragma once + +#include "brt/backends/nccl/device/distributed_backend_nccl.h" +#include + +namespace brt { + +struct ContextTrait { + void *(*alloc)(size_t size); + void (*set_device)(size_t device); + void (*free)(void *ptr); + std::shared_ptr (*make_context)(); + void (*sync_context)(std::shared_ptr context); + void (*memcpy_h2d)(void *dst, void *src, size_t len, + std::shared_ptr context); + void (*memcpy_d2h)(void *dst, void *src, size_t len, + std::shared_ptr context); +}; + +void *alloc_cuda(size_t size); +void set_device_cuda(size_t device); +void free_cuda(void *ptr); +std::shared_ptr make_context_cuda(); +void sync_context_cuda(std::shared_ptr context); +void memcpy_h2d_cuda(void *dst, void *src, size_t len, + std::shared_ptr context); +void memcpy_d2h_cuda(void *dst, void *src, size_t len, + std::shared_ptr context); + +static std::map context_trait_map = { + {"BRT_CTX_CUDA", + {&alloc_cuda, &set_device_cuda, &free_cuda, &make_context_cuda, + &sync_context_cuda, &memcpy_h2d_cuda, &memcpy_d2h_cuda}}}; + +typedef enum { + BRT_NCCL = 0, +} BackendType; + +static std::string get_preferred_context(BackendType backend) { + switch (backend) { + case BRT_NCCL: + return "BRT_CTX_CUDA"; + default: + return ""; + } +} + +static ContextTrait get_context_trait(std::string type) { + return context_trait_map[type]; +} + +} // namespace brt diff --git a/runtime/test/test_files/Distributed/add_send.mlir b/runtime/test/test_files/Distributed/add_send.mlir new file mode 100644 index 000000000..24749deb6 --- /dev/null +++ b/runtime/test/test_files/Distributed/add_send.mlir @@ -0,0 +1,9 @@ +module attributes {byre.container_module} { + func.func @test_add_send(%arg0 : memref<4xf32, "cuda"> {byre.argname = "in0", byre.argtype = 1: i32}, + %arg1 : memref<4xf32, "cuda"> {byre.argname = "in1", byre.argtype = 1: i32}, + %arg2 : memref<4xf32, "cuda"> {byre.argname = "out", byre.argtype = 2: i32}) attributes {byre.entry_point} { + byre.compute @AddOp_f32f32_f32(%arg0, %arg1, %arg2) : memref<4xf32, "cuda">, memref<4xf32, "cuda">, memref<4xf32, "cuda"> + byre.compute @NCCLSend_f32(%arg2) {rank = 1 : i64} : memref<4xf32, "cuda"> + return + } +} \ No newline at end of file diff --git a/runtime/test/test_files/Distributed/recv.mlir b/runtime/test/test_files/Distributed/recv.mlir new file mode 100644 index 000000000..3e74e1958 --- /dev/null +++ b/runtime/test/test_files/Distributed/recv.mlir @@ -0,0 +1,6 @@ +module attributes {byre.container_module} { + func.func @test_recv(%arg0 : memref<4xf32, "cuda"> {byre.argname = "src", byre.argtype = 2: i32}) attributes {byre.entry_point} { + byre.compute @NCCLRecv_f32(%arg0) {rank = 0 : i64} : memref<4xf32, "cuda"> + return + } +} \ No newline at end of file diff --git a/runtime/test/test_files/Distributed/recv_add.mlir b/runtime/test/test_files/Distributed/recv_add.mlir new file mode 100644 index 000000000..d2947e698 --- /dev/null +++ b/runtime/test/test_files/Distributed/recv_add.mlir @@ -0,0 +1,9 @@ +module attributes {byre.container_module} { + func.func @test_recv_add(%arg0 : memref<4xf32, "cuda"> {byre.argname = "in", byre.argtype = 1: i32}, + %arg1 : memref<4xf32, "cuda"> {byre.argname = "out0", byre.argtype = 2: i32}, + %arg2 : memref<4xf32, "cuda"> {byre.argname = "out1", byre.argtype = 2: i32}) attributes {byre.entry_point} { + byre.compute @NCCLRecv_f32(%arg1) {rank = 0 : i64} : memref<4xf32, "cuda"> + byre.compute @AddOp_f32f32_f32(%arg0, %arg1, %arg2) : memref<4xf32, "cuda">, memref<4xf32, "cuda">, memref<4xf32, "cuda"> + return + } +} \ No newline at end of file diff --git a/runtime/test/test_files/Distributed/send.mlir b/runtime/test/test_files/Distributed/send.mlir new file mode 100644 index 000000000..85176209c --- /dev/null +++ b/runtime/test/test_files/Distributed/send.mlir @@ -0,0 +1,6 @@ +module attributes {byre.container_module} { + func.func @test_send(%arg0 : memref<4xf32, "cuda"> {byre.argname = "src", byre.argtype = 1: i32}) attributes {byre.entry_point} { + byre.compute @NCCLSend_f32(%arg0) {rank = 1 : i64} : memref<4xf32, "cuda"> + return + } +} \ No newline at end of file diff --git a/runtime/test/test_files/rng_state_cpu.mlir b/runtime/test/test_files/rng_state_cpu.mlir new file mode 100644 index 000000000..c55160601 --- /dev/null +++ b/runtime/test/test_files/rng_state_cpu.mlir @@ -0,0 +1,12 @@ +module attributes {byre.container_module} { + func.func @test_rng_state(%arg0 : memref {byre.argname = "seed", byre.argtype = 2: i32}, + %arg1 : memref {byre.argname = "offset0", byre.argtype = 2: i32}, + %arg2 : memref {byre.argname = "offset1", byre.argtype = 2: i32}, + %arg3 : memref {byre.argname = "offset2", byre.argtype = 2: i32}) attributes {byre.entry_point} { + byre.compute @GetSeed(%arg0) {device = "cpu", memory_effects = [2 : i32]} : memref + byre.compute @NextOffset(%arg1) {device = "cpu", memory_effects = [2 : i32]} : memref + byre.compute @NextOffset(%arg2) {device = "cpu", memory_effects = [2 : i32]} : memref + byre.compute @NextOffset(%arg3) {device = "cpu", memory_effects = [2 : i32]} : memref + return + } +} \ No newline at end of file diff --git a/runtime/test/test_files/rng_state.mlir b/runtime/test/test_files/rng_state_cuda.mlir similarity index 100% rename from runtime/test/test_files/rng_state.mlir rename to runtime/test/test_files/rng_state_cuda.mlir diff --git a/scripts/runtime/build_and_test.sh b/scripts/runtime/build_and_test.sh index 1cd9cc52a..3867fa9ca 100755 --- a/scripts/runtime/build_and_test.sh +++ b/scripts/runtime/build_and_test.sh @@ -8,6 +8,10 @@ while [[ $# -gt 1 ]]; do BRT_USE_CUDA=ON shift ;; + --nccl) + BRT_USE_NCCL=ON + shift + ;; --asan) BRT_ENABLE_ASAN=ON CMAKE_BUILD_TYPE=Debug @@ -65,6 +69,7 @@ cmake -GNinja \ -DLLVM_INSTALL_PATH="$LLVM_INSTALL_DIR" \ -DCMAKE_INSTALL_PREFIX="$BUILD_DIR/install" \ -Dbrt_USE_CUDA=${BRT_USE_CUDA} \ + -Dbrt_USE_NCCL=${BRT_USE_NCCL} \ -Dbrt_BUILD_FLASH_ATTN=${brt_BUILD_FLASH_ATTN} \ -Dbrt_ENABLE_ASAN=${BRT_ENABLE_ASAN} \ -Dbrt_ENABLE_PYTHON_BINDINGS=${BRT_ENABLE_PYTHON_BINDINGS}