From 5010b8ff1492853c89a7ee0320f49cc22162f654 Mon Sep 17 00:00:00 2001 From: Lars Briem Date: Fri, 20 Sep 2024 13:23:37 +0200 Subject: [PATCH] Use pyav to fix frame shifts in OTAnalytics --- .pre-commit-config.yaml | 2 -- OTVision/detect/yolo.py | 27 ++++++++++++++++----------- requirements.txt | 1 + 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 29830646..ef709b61 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,8 +47,6 @@ repos: - types-PyYAML - types-flake8 - types-jsonschema - - types-psutil - - types-seaborn - types-setuptools - types-tqdm - types-ujson diff --git a/OTVision/detect/yolo.py b/OTVision/detect/yolo.py index 3744187c..2dd41ad2 100644 --- a/OTVision/detect/yolo.py +++ b/OTVision/detect/yolo.py @@ -25,6 +25,7 @@ from time import perf_counter from typing import Generator +import av import torch from tqdm import tqdm from ultralytics import YOLO as YOLOv8 @@ -145,17 +146,21 @@ def _load_model(self) -> YOLOv8: return model def _predict(self, video: Path) -> Generator[Results, None, None]: - return self.model.predict( - source=video, - conf=self.confidence, - iou=self.iou, - half=self.half_precision, - imgsz=self.img_size, - device=0 if torch.cuda.is_available() else "cpu", - stream=True, - verbose=False, - agnostic_nms=True, - ) + with av.open(str(video.absolute())) as container: + for frame in container.decode(video=0): + results = self.model.predict( + source=frame.to_ndarray(format="rgb24"), + conf=self.confidence, + iou=self.iou, + half=self.half_precision, + imgsz=self.img_size, + device=0 if torch.cuda.is_available() else "cpu", + stream=False, + verbose=False, + agnostic_nms=True, + ) + for result in results: + yield result def _parse_detections(self, detection_result: Boxes) -> list[Detection]: bboxes = detection_result.xywhn if self.normalized else detection_result.xywh diff --git a/requirements.txt b/requirements.txt index 21d6360d..ed2f5c8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +av==13.0.0 geopandas==1.0.1 ijson==3.3.0 moviepy==1.0.3