diff --git a/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx b/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx index fd73765020be6..ee660f888ef43 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx @@ -16,7 +16,7 @@ namespace TMVA{ namespace Experimental{ namespace SOFIE{ -enum EReduceOpMode { ReduceMean, ReduceSum, ReduceSumsquare, ReduceProd, InvalidReduceOp }; +enum EReduceOpMode { ReduceMean, ReduceSum, ReduceSumSquare, ReduceProd, InvalidReduceOp }; template class ROperator_Reduce final : public ROperator @@ -38,7 +38,7 @@ public: std::string Name() { if (fReduceOpMode == ReduceMean) return "ReduceMean"; - else if (fReduceOpMode == ReduceSumsquare ) return "ReduceSumsquare"; + else if (fReduceOpMode == ReduceSumSquare ) return "ReduceSumSquare"; else if (fReduceOpMode == ReduceProd ) return "ReduceProd"; else if (fReduceOpMode == ReduceSum) return "ReduceSum"; return "Invalid"; @@ -175,7 +175,7 @@ public: out << SP << SP << SP << "tensor_" << fNY << "[i] *= tensor_" << fNX << "[i * " << reducedLength << " + j];\n"; else if (fReduceOpMode == ReduceSum || fReduceOpMode == ReduceMean) out << SP << SP << SP << "tensor_" << fNY << "[i] += tensor_" << fNX << "[i * " << reducedLength << " + j];\n"; - else if(fReduceOpMode == ReduceSumsquare) + else if(fReduceOpMode == ReduceSumSquare) out << SP << SP << SP << "tensor_" << fNY << "[i] += tensor_" << fNX << "[i * " << reducedLength << " + j] * tensor_" << fNX << "[i * " << reducedLength << " + j];\n"; out << SP << SP << "}\n"; // end j loop @@ -199,7 +199,7 @@ public: out << SP << SP << SP << "tensor_" << fNY << "[j] *= tensor_" << fNX << "[i * " << outputLength << " + j];\n"; else if (fReduceOpMode == ReduceSum || fReduceOpMode == ReduceMean) out << SP << SP << SP << "tensor_" << fNY << "[j] += tensor_" << fNX << "[i * " << outputLength << " + j];\n"; - else if(fReduceOpMode == ReduceSumsquare) + else if(fReduceOpMode == ReduceSumSquare) out << SP << SP << SP << "tensor_" << fNY << "[j] += tensor_" << fNX << "[i * " << outputLength << " + j] * tensor_" << fNX << "[i * " << outputLength << " + j];\n"; out << SP << SP << "}\n"; // end j loop @@ -238,7 +238,7 @@ public: out << SP << SP << "tensor_" << fNY << "[outputIndex] *= tensor_" << fNX << "[i];\n"; else if (fReduceOpMode == ReduceSum || fReduceOpMode == ReduceMean) out << SP << SP << "tensor_" << fNY << "[outputIndex] += tensor_" << fNX << "[i];\n"; - else if (fReduceOpMode == ReduceSumsquare) { + else if (fReduceOpMode == ReduceSumSquare) { out << SP << SP << "tensor_" << fNY << "[outputIndex] += tensor_" << fNX << "[i] * tensor_" << fNX << "[i];\n"; } diff --git a/tmva/sofie/test/TestCustomModelsFromONNX.cxx b/tmva/sofie/test/TestCustomModelsFromONNX.cxx index df9f9da916a17..8be28073c5ce0 100644 --- a/tmva/sofie/test/TestCustomModelsFromONNX.cxx +++ b/tmva/sofie/test/TestCustomModelsFromONNX.cxx @@ -33,7 +33,10 @@ #include "ReduceProd_FromONNX.hxx" #include "input_models/references/ReduceProd.ref.hxx" -#include "ReduceSum_FromONNX.hxx" // hardcode reference +// hardcode reference +#include "ReduceSum_FromONNX.hxx" + +#include "ReduceSumSquare_FromONNX.hxx" #include "Shape_FromONNX.hxx" #include "input_models/references/Shape.ref.hxx" @@ -1208,6 +1211,33 @@ TEST(ONNX, ReduceSum){ } } +TEST(ONNX, ReduceSumSquare){ + constexpr float TOLERANCE = DEFAULT_TOLERANCE; + + + // Preparing the standard input + std::vector input({ + 5, 2, 3, + 5, 5, 4 + }); + + // reduce on last axis and do not keep dimension + // output should be [1,2] and [25+4+9, 25+25+16] + + + TMVA_SOFIE_ReduceSumSquare::Session s("ReduceSumSquare_FromONNX.dat"); + std::vector output = s.infer(input.data()); + // Checking output size + EXPECT_EQ(output.size(), 2); + + float correct[] = {38, 66}; + + // Checking every output value, one by one + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE); + } +} + TEST(ONNX, Max) { constexpr float TOLERANCE = DEFAULT_TOLERANCE; diff --git a/tmva/sofie/test/input_models/ReduceSumSquare.onnx b/tmva/sofie/test/input_models/ReduceSumSquare.onnx new file mode 100644 index 0000000000000..40ef9494f06c8 Binary files /dev/null and b/tmva/sofie/test/input_models/ReduceSumSquare.onnx differ diff --git a/tmva/sofie_parsers/src/ParseReduce.cxx b/tmva/sofie_parsers/src/ParseReduce.cxx index c9182176c1347..6c18a4371c342 100644 --- a/tmva/sofie_parsers/src/ParseReduce.cxx +++ b/tmva/sofie_parsers/src/ParseReduce.cxx @@ -16,8 +16,8 @@ std::unique_ptr ParseReduce(RModelParser_ONNX &parser, const onnx::No if (nodeproto.op_type() == "ReduceMean") op_mode = ReduceMean; - else if (nodeproto.op_type() == "ReduceSumsquare") - op_mode = ReduceSumsquare; + else if (nodeproto.op_type() == "ReduceSumSquare") + op_mode = ReduceSumSquare; else if (nodeproto.op_type() == "ReduceProd") op_mode = ReduceProd; else if (nodeproto.op_type() == "ReduceSum") @@ -77,9 +77,9 @@ ParserFuncSignature ParseReduceMean = [](RModelParser_ONNX &parser, const onnx:: return ParseReduce(parser, nodeproto); }; -// Parse ReduceSumsquare -ParserFuncSignature ParseReduceSumsquare = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { - return ParseReduce(parser, nodeproto); +// Parse ReduceSumSquare +ParserFuncSignature ParseReduceSumSquare = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + return ParseReduce(parser, nodeproto); }; // Parse ReduceProd diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index d6d472426d769..81a4b23697898 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -42,7 +42,7 @@ extern ParserFuncSignature ParseGreaterEq; // Reduce operators extern ParserFuncSignature ParseReduceMean; extern ParserFuncSignature ParseReduceSum; -extern ParserFuncSignature ParseReduceSumsquare; +extern ParserFuncSignature ParseReduceSumSquare; extern ParserFuncSignature ParseReduceProd; // Others extern ParserFuncSignature ParseBatchNormalization; @@ -170,7 +170,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un // Reduce operators RegisterOperator("ReduceMean", ParseReduceMean); RegisterOperator("ReduceSum", ParseReduceSum); - RegisterOperator("ReduceSumsquare", ParseReduceSumsquare); + RegisterOperator("ReduceSumSquare", ParseReduceSumSquare); RegisterOperator("ReduceProd", ParseReduceProd); // Others RegisterOperator("BatchNormalization", ParseBatchNormalization);