Skip to content

Commit

Permalink
[tmva][sofie] Fix parsing of ReduceSumSquare
Browse files Browse the repository at this point in the history
AN invalid operator name was used. This is now fixed and a test is also added for ReduceSumSquare
  • Loading branch information
lmoneta committed Nov 15, 2024
1 parent f812f98 commit 5027f07
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 13 deletions.
10 changes: 5 additions & 5 deletions tmva/sofie/inc/TMVA/ROperator_Reduce.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace TMVA{
namespace Experimental{
namespace SOFIE{

enum EReduceOpMode { ReduceMean, ReduceSum, ReduceSumsquare, ReduceProd, InvalidReduceOp };
enum EReduceOpMode { ReduceMean, ReduceSum, ReduceSumSquare, ReduceProd, InvalidReduceOp };

template <typename T, EReduceOpMode Op>
class ROperator_Reduce final : public ROperator
Expand All @@ -38,7 +38,7 @@ public:

std::string Name() {
if (fReduceOpMode == ReduceMean) return "ReduceMean";
else if (fReduceOpMode == ReduceSumsquare ) return "ReduceSumsquare";
else if (fReduceOpMode == ReduceSumSquare ) return "ReduceSumSquare";
else if (fReduceOpMode == ReduceProd ) return "ReduceProd";
else if (fReduceOpMode == ReduceSum) return "ReduceSum";
return "Invalid";
Expand Down Expand Up @@ -175,7 +175,7 @@ public:
out << SP << SP << SP << "tensor_" << fNY << "[i] *= tensor_" << fNX << "[i * " << reducedLength << " + j];\n";
else if (fReduceOpMode == ReduceSum || fReduceOpMode == ReduceMean)
out << SP << SP << SP << "tensor_" << fNY << "[i] += tensor_" << fNX << "[i * " << reducedLength << " + j];\n";
else if(fReduceOpMode == ReduceSumsquare)
else if(fReduceOpMode == ReduceSumSquare)
out << SP << SP << SP << "tensor_" << fNY << "[i] += tensor_" << fNX << "[i * " << reducedLength << " + j] * tensor_"
<< fNX << "[i * " << reducedLength << " + j];\n";
out << SP << SP << "}\n"; // end j loop
Expand All @@ -199,7 +199,7 @@ public:
out << SP << SP << SP << "tensor_" << fNY << "[j] *= tensor_" << fNX << "[i * " << outputLength << " + j];\n";
else if (fReduceOpMode == ReduceSum || fReduceOpMode == ReduceMean)
out << SP << SP << SP << "tensor_" << fNY << "[j] += tensor_" << fNX << "[i * " << outputLength << " + j];\n";
else if(fReduceOpMode == ReduceSumsquare)
else if(fReduceOpMode == ReduceSumSquare)
out << SP << SP << SP << "tensor_" << fNY << "[j] += tensor_" << fNX << "[i * " << outputLength << " + j] * tensor_"
<< fNX << "[i * " << outputLength << " + j];\n";
out << SP << SP << "}\n"; // end j loop
Expand Down Expand Up @@ -238,7 +238,7 @@ public:
out << SP << SP << "tensor_" << fNY << "[outputIndex] *= tensor_" << fNX << "[i];\n";
else if (fReduceOpMode == ReduceSum || fReduceOpMode == ReduceMean)
out << SP << SP << "tensor_" << fNY << "[outputIndex] += tensor_" << fNX << "[i];\n";
else if (fReduceOpMode == ReduceSumsquare) {
else if (fReduceOpMode == ReduceSumSquare) {
out << SP << SP << "tensor_" << fNY << "[outputIndex] += tensor_" << fNX << "[i] * tensor_" << fNX
<< "[i];\n";
}
Expand Down
32 changes: 31 additions & 1 deletion tmva/sofie/test/TestCustomModelsFromONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
#include "ReduceProd_FromONNX.hxx"
#include "input_models/references/ReduceProd.ref.hxx"

#include "ReduceSum_FromONNX.hxx" // hardcode reference
// hardcode reference
#include "ReduceSum_FromONNX.hxx"

#include "ReduceSumSquare_FromONNX.hxx"

#include "Shape_FromONNX.hxx"
#include "input_models/references/Shape.ref.hxx"
Expand Down Expand Up @@ -1208,6 +1211,33 @@ TEST(ONNX, ReduceSum){
}
}

TEST(ONNX, ReduceSumSquare){
constexpr float TOLERANCE = DEFAULT_TOLERANCE;


// Preparing the standard input
std::vector<float> input({
5, 2, 3,
5, 5, 4
});

// reduce on last axis and do not keep dimension
// output should be [1,2] and [25+4+9, 25+25+16]


TMVA_SOFIE_ReduceSumSquare::Session s("ReduceSumSquare_FromONNX.dat");
std::vector<float> output = s.infer(input.data());
// Checking output size
EXPECT_EQ(output.size(), 2);

float correct[] = {38, 66};

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - correct[i]), TOLERANCE);
}
}

TEST(ONNX, Max)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;
Expand Down
Binary file added tmva/sofie/test/input_models/ReduceSumSquare.onnx
Binary file not shown.
10 changes: 5 additions & 5 deletions tmva/sofie_parsers/src/ParseReduce.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ std::unique_ptr<ROperator> ParseReduce(RModelParser_ONNX &parser, const onnx::No

if (nodeproto.op_type() == "ReduceMean")
op_mode = ReduceMean;
else if (nodeproto.op_type() == "ReduceSumsquare")
op_mode = ReduceSumsquare;
else if (nodeproto.op_type() == "ReduceSumSquare")
op_mode = ReduceSumSquare;
else if (nodeproto.op_type() == "ReduceProd")
op_mode = ReduceProd;
else if (nodeproto.op_type() == "ReduceSum")
Expand Down Expand Up @@ -77,9 +77,9 @@ ParserFuncSignature ParseReduceMean = [](RModelParser_ONNX &parser, const onnx::
return ParseReduce<EReduceOpMode::ReduceMean>(parser, nodeproto);
};

// Parse ReduceSumsquare
ParserFuncSignature ParseReduceSumsquare = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseReduce<EReduceOpMode::ReduceSumsquare>(parser, nodeproto);
// Parse ReduceSumSquare
ParserFuncSignature ParseReduceSumSquare = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) {
return ParseReduce<EReduceOpMode::ReduceSumSquare>(parser, nodeproto);
};

// Parse ReduceProd
Expand Down
4 changes: 2 additions & 2 deletions tmva/sofie_parsers/src/RModelParser_ONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ extern ParserFuncSignature ParseGreaterEq;
// Reduce operators
extern ParserFuncSignature ParseReduceMean;
extern ParserFuncSignature ParseReduceSum;
extern ParserFuncSignature ParseReduceSumsquare;
extern ParserFuncSignature ParseReduceSumSquare;
extern ParserFuncSignature ParseReduceProd;
// Others
extern ParserFuncSignature ParseBatchNormalization;
Expand Down Expand Up @@ -170,7 +170,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un
// Reduce operators
RegisterOperator("ReduceMean", ParseReduceMean);
RegisterOperator("ReduceSum", ParseReduceSum);
RegisterOperator("ReduceSumsquare", ParseReduceSumsquare);
RegisterOperator("ReduceSumSquare", ParseReduceSumSquare);
RegisterOperator("ReduceProd", ParseReduceProd);
// Others
RegisterOperator("BatchNormalization", ParseBatchNormalization);
Expand Down

0 comments on commit 5027f07

Please sign in to comment.