Skip to content

Commit

Permalink
Merge branch 'master' into vis_prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov authored Jun 25, 2024
2 parents a8a083a + eb9fcfb commit f303e8e
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions model_api/cpp/models/src/classification_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,21 @@ std::vector<std::string> get_non_xai_names(const std::vector<ov::Output<ov::Node
return outputNames;
}

std::vector<size_t> get_non_xai_output_indices(const std::vector<ov::Output<ov::Node>>& outputs) {
std::vector<size_t> outputIndices;
outputIndices.reserve(std::max(1, int(outputs.size()) - 2));
size_t idx = 0;
for (const ov::Output<ov::Node>& output : outputs) {
bool is_xai = output.get_names().count(saliency_map_name) > 0
|| output.get_names().count(feature_vector_name) > 0;
if (!is_xai) {
outputIndices.push_back(idx);
}
++idx;
}
return outputIndices;
}

std::vector<std::string> get_non_xai_names(const std::vector<std::string>& outputs) {
std::vector<std::string> outputNames;
outputNames.reserve(std::max(1, int(outputs.size()) - 2));
Expand Down Expand Up @@ -469,8 +484,9 @@ void ClassificationModel::prepareInputsOutputs(std::shared_ptr<ov::Model>& model
throw std::logic_error("Classification model wrapper supports topologies with up to 4 outputs");
}

if (model->outputs().size() == 1) {
const ov::Shape& outputShape = model->output().get_partial_shape().get_max_shape();
auto non_xai_idx = get_non_xai_output_indices(model->outputs());
if (non_xai_idx.size() == 1) {
const ov::Shape& outputShape = model->outputs()[non_xai_idx[0]].get_partial_shape().get_max_shape();
if (outputShape.size() != 2 && outputShape.size() != 4) {
throw std::logic_error("Classification model wrapper supports topologies only with"
" 2-dimensional or 4-dimensional output");
Expand All @@ -488,10 +504,7 @@ void ClassificationModel::prepareInputsOutputs(std::shared_ptr<ov::Model>& model
throw std::logic_error("The model provides " + std::to_string(classesNum) + " classes, but " +
std::to_string(topk) + " labels are requested to be predicted");
}
if (classesNum == labels.size() + 1) {
labels.insert(labels.begin(), "other");
slog::warn << "Inserted 'other' label as first." << slog::endl;
} else if (classesNum != labels.size()) {
if (classesNum != labels.size()) {
throw std::logic_error("Model's number of classes and parsed labels must match (" +
std::to_string(outputShape[1]) + " and " + std::to_string(labels.size()) + ')');
}
Expand Down

0 comments on commit f303e8e

Please sign in to comment.