-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f12d5b6
commit 1e6b659
Showing
6 changed files
with
357 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import cv2 | ||
from PIL import Image | ||
from ultralytics import YOLO | ||
import numpy as np | ||
from yolox.tracker.byte_tracker import BYTETracker, STrack | ||
from supervision.draw.color import ColorPalette | ||
from supervision.geometry.dataclasses import Point | ||
from supervision.video.dataclasses import VideoInfo | ||
from supervision.video.source import get_video_frames_generator | ||
from supervision.video.sink import VideoSink | ||
from supervision.notebook.utils import show_frame_in_notebook | ||
from supervision.tools.detections import Detections, BoxAnnotator | ||
from supervision.tools.line_counter import LineCounter, LineCounterAnnotator | ||
from typing import List | ||
from onemetric.cv.utils.iou import box_iou_batch | ||
|
||
class BYTETrackerArgs: | ||
track_thresh: float = 0.50 | ||
track_buffer: int = 1000 | ||
match_thresh: float = 0.9 | ||
aspect_ratio_thresh: float = 3.0 | ||
min_box_area: float = 1.0 | ||
mot20: bool = False | ||
|
||
class videoTest: | ||
def __init__(self, video_path, model_path): | ||
self.cap = cv2.VideoCapture(str(video_path)) | ||
self.model = YOLO(str(model_path)) | ||
self.byte_tracker = BYTETracker(BYTETrackerArgs()) | ||
LINE_START = Point(561, 517) | ||
LINE_END = Point(906, 510) | ||
self.line_counter = LineCounter(start=LINE_START, end=LINE_END) | ||
|
||
def detections2boxes(self, detections: Detections) -> np.ndarray: | ||
return np.hstack(( | ||
detections.xyxy, | ||
detections.confidence[:, np.newaxis] | ||
)) | ||
# converts List[STrack] into format that can be consumed by match_detections_with_tracks function | ||
def tracks2boxes(self, tracks: List[STrack]) -> np.ndarray: | ||
return np.array([ | ||
track.tlbr | ||
for track | ||
in tracks | ||
], dtype=float) | ||
|
||
# matches our bounding boxes with predictions | ||
def match_detections_with_tracks(self, detections: Detections, tracks: List[STrack]) -> Detections: | ||
if not np.any(detections.xyxy) or len(tracks) == 0: | ||
return np.empty((0,)) | ||
|
||
tracks_boxes = self.tracks2boxes(tracks=tracks) | ||
iou = box_iou_batch(tracks_boxes, detections.xyxy) | ||
track2detection = np.argmax(iou, axis=1) | ||
|
||
tracker_ids = [None] * len(detections) | ||
|
||
for tracker_index, detection_index in enumerate(track2detection): | ||
if iou[tracker_index, detection_index] != 0: | ||
tracker_ids[detection_index] = tracks[tracker_index].track_id | ||
|
||
return tracker_ids | ||
|
||
def processVideo(self): | ||
if (self.cap.isOpened()== False): | ||
print("Error opening video stream or file") | ||
# Read until video is completed | ||
while(self.cap.isOpened()): | ||
# Capture frame-by-frame | ||
ret, frame = self.cap.read() | ||
if ret == True: | ||
|
||
# Display the resulting frame | ||
#print(self.deployModel(frame)) | ||
cv2.imshow("FRAME", self.deployModel(frame)) | ||
# Press Q on keyboard to exit | ||
if cv2.waitKey(25) & 0xFF == ord('q'): | ||
break | ||
|
||
# Break the loop | ||
else: | ||
break | ||
|
||
# When everything done, release the video capture object | ||
self.cap.release() | ||
|
||
# Closes all the frames | ||
cv2.destroyAllWindows() | ||
|
||
def deployModel(self, img): | ||
LINE_START = Point(50, 1500) | ||
LINE_END = Point(3840-50, 1500) | ||
CLASS_NAMES_DICT = self.model.names | ||
CLASS_ID = [0] | ||
results = self.model.predict(img) | ||
detections = Detections( | ||
xyxy=results[0].boxes.xyxy.cpu().numpy(), | ||
confidence=results[0].boxes.conf.cpu().numpy(), | ||
class_id=results[0].boxes.cls.cpu().numpy().astype(int)) | ||
|
||
tracks = self.byte_tracker.update( | ||
output_results=self.detections2boxes(detections=detections), | ||
img_info=img.shape, | ||
img_size=img.shape | ||
) | ||
|
||
|
||
tracker_id = self.match_detections_with_tracks(detections=detections, tracks=tracks) | ||
|
||
detections.tracker_id = np.array(tracker_id) | ||
|
||
# format custom labels | ||
labels = [ | ||
f"#{tracker_id} {CLASS_NAMES_DICT[class_id]} {confidence:0.2f}" | ||
for _, confidence, class_id, tracker_id | ||
in detections | ||
] | ||
|
||
self.line_counter.update(detections=detections) | ||
|
||
|
||
# annotate and display frame | ||
box_annotator = BoxAnnotator(color=ColorPalette(), thickness=4, text_thickness=4, text_scale=2) | ||
frame = box_annotator.annotate(frame=img, detections=detections, labels=labels) | ||
|
||
line_annotator = LineCounterAnnotator(thickness=4, text_thickness=4, text_scale=2) | ||
line_annotator.annotate(frame=frame, line_counter=self.line_counter) | ||
|
||
return frame | ||
|
||
|
||
|
||
def main(): | ||
test = videoTest("/home/hari/cement/datasets/172.20.6.226_Truck Loading PP - 1_main_20230722112747.mp4", "/home/hari/cement/runs/detect/train/weights/best.pt") | ||
test.processVideo() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import cv2 | ||
import os | ||
|
||
class frameExtractor: | ||
def __init__(self, video_path, write_path): | ||
self.cap = cv2.VideoCapture(str(video_path)) | ||
self.write_path = str(write_path) | ||
|
||
def showVideo(self, frame_interval=50): | ||
frame_count = 29500 | ||
while(self.cap.isOpened()): | ||
ret, frame = self.cap.read() | ||
if ret == True: | ||
frame_count += 1 | ||
if frame_count % frame_interval == 0: | ||
image_name = "{:05d}.jpg".format(int(frame_count / frame_interval)) | ||
print(os.path.join(self.write_path, image_name)) | ||
cv2.imwrite(os.path.join(self.write_path, image_name), frame) | ||
frame = cv2.resize(frame, (0, 0), None, fx = 0.5, fy = 0.4) | ||
cv2.imshow("Frame", frame) | ||
if cv2.waitKey(1) & 0xFF == ord('q'): | ||
break | ||
else: | ||
break | ||
self.cap.release() | ||
cv2.destroyAllWindows() | ||
|
||
def main(): | ||
extract = frameExtractor("datasets/bag_upright.mp4", "/home/hari/cement/datasets/images") | ||
extract.showVideo() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import cv2 | ||
from PIL import Image | ||
from ultralytics import YOLO | ||
import numpy as np | ||
from yolox.tracker.byte_tracker import BYTETracker, STrack | ||
from supervision.draw.color import ColorPalette | ||
from supervision.geometry.dataclasses import Point | ||
from supervision.video.dataclasses import VideoInfo | ||
from supervision.video.source import get_video_frames_generator | ||
from supervision.video.sink import VideoSink | ||
from supervision.notebook.utils import show_frame_in_notebook | ||
from supervision.tools.detections import Detections, BoxAnnotator | ||
from supervision.tools.line_counter import LineCounter, LineCounterAnnotator | ||
from typing import List | ||
|
||
class videoTest: | ||
def __init__(self, video_path, model_path): | ||
self.cap = cv2.VideoCapture(str(video_path)) | ||
self.model = YOLO(str(model_path)) | ||
|
||
def detections2boxes(self, detections: Detections) -> np.ndarray: | ||
return np.hstack(( | ||
detections.xyxy, | ||
detections.confidence[:, np.newaxis] | ||
)) | ||
# converts List[STrack] into format that can be consumed by match_detections_with_tracks function | ||
def tracks2boxes(self, tracks: List[STrack]) -> np.ndarray: | ||
return np.array([ | ||
track.tlbr | ||
for track | ||
in tracks | ||
], dtype=float) | ||
|
||
# matches our bounding boxes with predictions | ||
def match_detections_with_tracks(self, detections: Detections, tracks: List[STrack]) -> Detections: | ||
if not np.any(detections.xyxy) or len(tracks) == 0: | ||
return np.empty((0,)) | ||
|
||
tracks_boxes = self.tracks2boxes(tracks=tracks) | ||
iou = self.box_iou_batch(tracks_boxes, detections.xyxy) | ||
track2detection = np.argmax(iou, axis=1) | ||
|
||
tracker_ids = [None] * len(detections) | ||
|
||
for tracker_index, detection_index in enumerate(track2detection): | ||
if iou[tracker_index, detection_index] != 0: | ||
tracker_ids[detection_index] = tracks[tracker_index].track_id | ||
|
||
return tracker_ids | ||
|
||
def processVideo(self): | ||
if (self.cap.isOpened()== False): | ||
print("Error opening video stream or file") | ||
# Read until video is completed | ||
while(self.cap.isOpened()): | ||
# Capture frame-by-frame | ||
ret, frame = self.cap.read() | ||
if ret == True: | ||
|
||
# Display the resulting frame | ||
cv2.imshow('Result', self.deployModel(frame)) | ||
|
||
# Press Q on keyboard to exit | ||
if cv2.waitKey(1) & 0xFF == ord('q'): | ||
break | ||
|
||
# Break the loop | ||
else: | ||
break | ||
|
||
# When everything done, release the video capture object | ||
self.cap.release() | ||
|
||
# Closes all the frames | ||
cv2.destroyAllWindows() | ||
|
||
def deployModel(self, img): | ||
result = self.model.predict(img) | ||
result = result[0] | ||
img = Image.fromarray(result.plot()[:, :, : : -1]).convert('RGB') | ||
open_cv_image = np.array(img) | ||
open_cv_image = open_cv_image[:, :, ::-1].copy() | ||
return open_cv_image | ||
|
||
|
||
|
||
def main(): | ||
test = videoTest("/home/hari/cement/datasets/172.20.6.226_Truck Loading PP - 1_main_20230722095635.mp4", "/home/hari/cement/runs/detect/train/weights/best.pt") | ||
test.processVideo() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
import random | ||
import shutil | ||
|
||
class splitter: | ||
def __init__(self, dir_path, train_path, val_path, test_path, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15): | ||
self.dirPath = str(dir_path) | ||
self.trainPath = str(train_path) | ||
self.valPath = str(val_path) | ||
self.testPath = str(test_path) | ||
self.trainRatio = train_ratio | ||
self.valRatio = val_ratio | ||
self.testRatio = test_ratio | ||
|
||
def splitData(self): | ||
image_files = os.listdir(self.dirPath) | ||
num_images = len(image_files) | ||
random.shuffle(image_files) | ||
|
||
num_train = int(num_images * self.trainRatio) | ||
num_val = int(num_images * self.valRatio) | ||
num_test = num_images - num_train - num_val | ||
|
||
train_files = image_files[:num_train] | ||
val_files = image_files[num_train:num_train + num_val] | ||
test_files = image_files[num_train + num_val:] | ||
|
||
os.makedirs(self.trainPath, exist_ok=True) | ||
os.makedirs(self.valPath, exist_ok=True) | ||
os.makedirs(self.testPath, exist_ok=True) | ||
|
||
for filename in train_files: | ||
src_path = os.path.join(self.dirPath, filename) | ||
dst_path = os.path.join(self.trainPath, filename) | ||
shutil.copy(src_path, dst_path) # Use shutil.copy if you want to copy instead of move | ||
|
||
for filename in val_files: | ||
src_path = os.path.join(self.dirPath, filename) | ||
dst_path = os.path.join(self.valPath, filename) | ||
shutil.copy(src_path, dst_path) # Use shutil.copy if you want to copy instead of move | ||
|
||
for filename in test_files: | ||
src_path = os.path.join(self.dirPath, filename) | ||
dst_path = os.path.join(self.testPath, filename) | ||
shutil.copy(src_path, dst_path) # Use shutil.copy if you want to copy instead of move | ||
|
||
def main(): | ||
split = splitter(dir_path = "/home/hari/cement/datasets/images", train_path = "/home/hari/cement/datasets/images/train", val_path = "/home/hari/cement/datasets/images/val", test_path = "/home/hari/cement/datasets/images/test") | ||
split.splitData() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from ultralytics import YOLO | ||
from PIL import Image | ||
import cv2 | ||
import numpy as np | ||
|
||
class testData: | ||
def __init__(self): | ||
self.model = YOLO('/home/hari/cement/runs/detect/train/weights/best.pt') | ||
|
||
def predict(self, img): | ||
result = self.model.predict(img) | ||
result = result[0] | ||
img = Image.fromarray(result.plot()[:, :, : : -1]).convert('RGB') | ||
open_cv_image = np.array(img) | ||
open_cv_image = open_cv_image[:, :, ::-1].copy() | ||
cv2.imshow("RESULT", open_cv_image) | ||
cv2.waitKey(0) | ||
cv2.destroyAllWindows() | ||
|
||
def main(): | ||
test = testData() | ||
test.predict("/home/hari/cement/datasets/images/00118.jpg") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from ultralytics import YOLO | ||
|
||
class modelTraining: | ||
def __init__(self): | ||
self.model = YOLO('yolov8n.yaml') | ||
self.model = YOLO('yolov8n.pt') # load a pretrained model (recommended for training) | ||
self.model = YOLO('yolov8n.yaml').load('yolov8n.pt') | ||
|
||
def train(self): | ||
self.model.train(data='/home/hari/cement/data.yaml', epochs=1000, imgsz=640) | ||
|
||
def main(): | ||
trainer = modelTraining() | ||
trainer.train() | ||
|
||
if __name__ == "__main__": | ||
main() |