diff --git a/cellpose/metrics.py b/cellpose/metrics.py index a33aabf5..ad233eac 100644 --- a/cellpose/metrics.py +++ b/cellpose/metrics.py @@ -118,8 +118,8 @@ def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]): tp = np.zeros((len(masks_true), len(threshold)), np.float32) fp = np.zeros((len(masks_true), len(threshold)), np.float32) fn = np.zeros((len(masks_true), len(threshold)), np.float32) - n_true = np.array(list(map(np.max, masks_true))) - n_pred = np.array(list(map(np.max, masks_pred))) + n_true = np.array([len(np.unique(mt)) - 1 for mt in masks_true]) + n_pred = np.array([len(np.unique(mp)) - 1 for mp in masks_pred]) for n in range(len(masks_true)): #_,mt = np.reshape(np.unique(masks_true[n], return_index=True), masks_pred[n].shape)