Skip to content

Commit

Permalink
Allow for giving a single score for the whole object
Browse files Browse the repository at this point in the history
  • Loading branch information
Agustín Castro committed Feb 9, 2024
1 parent 009a1b1 commit 7514f3d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
11 changes: 9 additions & 2 deletions norfair/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ class Detection:
Parameters
----------
points : np.ndarray
Points detected. Must be a rank 2 array with shape `(n_points, n_dimensions)` where n_dimensions is 2 or 3.
Points detected. Must be a rank 2 array with shape `(n_points, n_dimensions)`.
scores : np.ndarray, optional
An array of length `n_points` which assigns a score to each of the points defined in `points`.
Expand All @@ -770,12 +770,19 @@ class Detection:
def __init__(
self,
points: np.ndarray,
scores: np.ndarray = None,
scores: Union[float, int, np.ndarray] = None,
data: Any = None,
label: Hashable = None,
embedding=None,
):
self.points = validate_points(points)

if isinstance(scores, np.ndarray):
assert len(scores) == len(
self.points
), "scores should be a np.ndarray with it's length being equal to the amount of points."
elif scores is not None:
scores = np.zeros((len(points),)) + scores
self.scores = scores
self.data = data
self.label = label
Expand Down
2 changes: 1 addition & 1 deletion norfair/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def validate_points(points: np.ndarray) -> np.array:

def raise_detection_error_message(points):
message = "\n[red]INPUT ERROR:[/red]\n"
message += f"Each `Detection` object should have a property `points` of shape (num_of_points_to_track, 2), not {points.shape}. Check your `Detection` list creation code.\n"
message += f"Each `Detection` object should have a property `points` of shape (n_points, n_dimensions), not {points.shape}. Check your `Detection` list creation code.\n"
message += "You can read the documentation for the `Detection` class here:\n"
message += "https://tryolabs.github.io/norfair/reference/tracker/#norfair.tracker.Detection\n"
raise ValueError(message)
Expand Down

0 comments on commit 7514f3d

Please sign in to comment.