diff --git a/lib/bumblebee/text/text_classification.ex b/lib/bumblebee/text/text_classification.ex index 3647a330..a31e2c8d 100644 --- a/lib/bumblebee/text/text_classification.ex +++ b/lib/bumblebee/text/text_classification.ex @@ -36,7 +36,7 @@ defmodule Bumblebee.Text.TextClassification do scores_fun = fn params, input -> outputs = predict_fun.(params, input) scores = Shared.logits_to_scores(outputs.logits, scores_function) - k = min(top_k, Nx.size(scores)) + k = min(top_k, Nx.axis_size(scores, 1)) {top_scores, top_indices} = Nx.top_k(scores, k: k) {top_scores, top_indices} end diff --git a/lib/bumblebee/vision/image_classification.ex b/lib/bumblebee/vision/image_classification.ex index b2c45006..9a3be721 100644 --- a/lib/bumblebee/vision/image_classification.ex +++ b/lib/bumblebee/vision/image_classification.ex @@ -40,7 +40,7 @@ defmodule Bumblebee.Vision.ImageClassification do input = Bumblebee.Featurizer.process_batch(featurizer, input) outputs = predict_fun.(params, input) scores = Shared.logits_to_scores(outputs.logits, scores_function) - k = min(top_k, Nx.size(scores)) + k = min(top_k, Nx.axis_size(scores, 1)) {top_scores, top_indices} = Nx.top_k(scores, k: k) {top_scores, top_indices} end