Skip to content

Commit

Permalink
Fix :top_k limit in classification servings
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Sep 18, 2023
1 parent 7430d64 commit 355ad10
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion lib/bumblebee/text/text_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ defmodule Bumblebee.Text.TextClassification do
scores_fun = fn params, input ->
outputs = predict_fun.(params, input)
scores = Shared.logits_to_scores(outputs.logits, scores_function)
k = min(top_k, Nx.size(scores))
k = min(top_k, Nx.axis_size(scores, 1))
{top_scores, top_indices} = Nx.top_k(scores, k: k)
{top_scores, top_indices}
end
Expand Down
2 changes: 1 addition & 1 deletion lib/bumblebee/vision/image_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ defmodule Bumblebee.Vision.ImageClassification do
input = Bumblebee.Featurizer.process_batch(featurizer, input)
outputs = predict_fun.(params, input)
scores = Shared.logits_to_scores(outputs.logits, scores_function)
k = min(top_k, Nx.size(scores))
k = min(top_k, Nx.axis_size(scores, 1))
{top_scores, top_indices} = Nx.top_k(scores, k: k)
{top_scores, top_indices}
end
Expand Down

0 comments on commit 355ad10

Please sign in to comment.