diff --git a/model_api/cpp/models/src/classification_model.cpp b/model_api/cpp/models/src/classification_model.cpp index c0beda0b..6009d84d 100644 --- a/model_api/cpp/models/src/classification_model.cpp +++ b/model_api/cpp/models/src/classification_model.cpp @@ -158,6 +158,21 @@ std::vector get_non_xai_names(const std::vector get_non_xai_output_indices(const std::vector>& outputs) { + std::vector outputIndices; + outputIndices.reserve(std::max(1, int(outputs.size()) - 2)); + size_t idx = 0; + for (const ov::Output& output : outputs) { + bool is_xai = output.get_names().count(saliency_map_name) > 0 + || output.get_names().count(feature_vector_name) > 0; + if (!is_xai) { + outputIndices.push_back(idx); + } + ++idx; + } + return outputIndices; +} + std::vector get_non_xai_names(const std::vector& outputs) { std::vector outputNames; outputNames.reserve(std::max(1, int(outputs.size()) - 2)); @@ -469,8 +484,9 @@ void ClassificationModel::prepareInputsOutputs(std::shared_ptr& model throw std::logic_error("Classification model wrapper supports topologies with up to 4 outputs"); } - if (model->outputs().size() == 1) { - const ov::Shape& outputShape = model->output().get_partial_shape().get_max_shape(); + auto non_xai_idx = get_non_xai_output_indices(model->outputs()); + if (non_xai_idx.size() == 1) { + const ov::Shape& outputShape = model->outputs()[non_xai_idx[0]].get_partial_shape().get_max_shape(); if (outputShape.size() != 2 && outputShape.size() != 4) { throw std::logic_error("Classification model wrapper supports topologies only with" " 2-dimensional or 4-dimensional output"); @@ -488,10 +504,7 @@ void ClassificationModel::prepareInputsOutputs(std::shared_ptr& model throw std::logic_error("The model provides " + std::to_string(classesNum) + " classes, but " + std::to_string(topk) + " labels are requested to be predicted"); } - if (classesNum == labels.size() + 1) { - labels.insert(labels.begin(), "other"); - slog::warn << "Inserted 'other' label as first." << slog::endl; - } else if (classesNum != labels.size()) { + if (classesNum != labels.size()) { throw std::logic_error("Model's number of classes and parsed labels must match (" + std::to_string(outputShape[1]) + " and " + std::to_string(labels.size()) + ')'); }