diff --git a/requirements.txt b/requirements.txt index e3404cb53..4a5424b60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,7 +36,7 @@ tqdm == 4.67.0; python_version>="3.9" xgboost == 0.90 yellowbrick==0.9.1; python_version=="3.6" yellowbrick==1.5.0; python_version>="3.9" -kaleido == 0.1.0.post1; python_version=="3.6" +kaleido == 0.1.0; python_version=="3.6" kaleido; python_version>="3.9" psutil == 5.9.8; python_version=="3.6" psutil; python_version>="3.9" diff --git a/simba/mixins/image_mixin.py b/simba/mixins/image_mixin.py index 6033af21e..bcd71ac46 100644 --- a/simba/mixins/image_mixin.py +++ b/simba/mixins/image_mixin.py @@ -948,9 +948,8 @@ def img_sliding_mse(imgs: np.ndarray, return results.astype(int64) @staticmethod - def _read_img_batch_from_video_helper( - frm_idx: np.ndarray, video_path: Union[str, os.PathLike], verbose: bool, - ): + def _read_img_batch_from_video_helper(frm_idx: np.ndarray, video_path: Union[str, os.PathLike], greyscale: bool, verbose: bool): + """Multiprocess helper used by read_img_batch_from_video to read in images from video file.""" start_idx, end_frm, current_frm = frm_idx[0], frm_idx[-1] + 1, frm_idx[0] results = {} @@ -959,18 +958,20 @@ def _read_img_batch_from_video_helper( while current_frm < end_frm: if verbose: print(f'Reading frame idx {current_frm}...') - results[current_frm] = cap.read()[1] + img = cap.read()[1] + if greyscale: + img = ImageMixin.img_to_greyscale(img=img) + results[current_frm] = img current_frm += 1 return results @staticmethod - def read_img_batch_from_video( - video_path: Union[str, os.PathLike], - start_frm: int, - end_frm: int, - core_cnt: Optional[int] = -1, - verbose: Optional[bool] = False, - ) -> Dict[int, np.ndarray]: + def read_img_batch_from_video(video_path: Union[str, os.PathLike], + start_frm: int, + end_frm: int, + greyscale: Optional[bool] = False, + core_cnt: Optional[int] = -1, + verbose: Optional[bool] = False) -> Dict[int, np.ndarray]: """ Read a batch of frames from a video file. This method reads frames from a specified range of frames within a video file using multiprocessing. @@ -981,7 +982,7 @@ def read_img_batch_from_video( :param int start_frm: Starting frame index. :param int end_frm: Ending frame index. :param Optionalint] core_cnt: Number of CPU cores to use for parallel processing. Default is -1, indicating using all available cores. - :param Optional[bool] greyscale: If True, reads the images as greyscale. If False, then as original color scale. Default: True. + :param Optional[bool] greyscale: If True, reads the images as greyscale. If False, then as original color scale. Default: False. :returns: A dictionary containing frame indices as keys and corresponding frame arrays as values. :rtype: Dict[int, np.ndarray] @@ -990,34 +991,18 @@ def read_img_batch_from_video( """ check_file_exist_and_readable(file_path=video_path) video_meta_data = get_video_meta_data(video_path=video_path) - check_int( - name=ImageMixin().__class__.__name__, - value=start_frm, - min_value=0, - max_value=video_meta_data["frame_count"], - ) - check_int( - name=ImageMixin().__class__.__name__, - value=end_frm, - min_value=0, - max_value=video_meta_data["frame_count"], - ) + check_int(name=ImageMixin().__class__.__name__,value=start_frm, min_value=0,max_value=video_meta_data["frame_count"]) + check_int(name=ImageMixin().__class__.__name__, value=end_frm, min_value=start_frm+1, max_value=video_meta_data["frame_count"]) check_int(name=ImageMixin().__class__.__name__, value=core_cnt, min_value=-1) + check_valid_boolean(value=[greyscale], source=f'{ImageMixin().__class__.__name__} greyscale') if core_cnt < 0: core_cnt = multiprocessing.cpu_count() if end_frm <= start_frm: - FrameRangeError( - msg=f"Start frame ({start_frm}) has to be before end frame ({end_frm})", - source=ImageMixin().__class__.__name__, - ) + FrameRangeError(msg=f"Start frame ({start_frm}) has to be before end frame ({end_frm})", source=ImageMixin().__class__.__name__) frm_lst = np.array_split(np.arange(start_frm, end_frm + 1), core_cnt) results = {} - with multiprocessing.Pool( - core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value - ) as pool: - constants = functools.partial( - ImageMixin()._read_img_batch_from_video_helper, video_path=video_path, verbose=verbose - ) + with multiprocessing.Pool(core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value) as pool: + constants = functools.partial(ImageMixin()._read_img_batch_from_video_helper, video_path=video_path, greyscale=greyscale, verbose=verbose) for cnt, result in enumerate(pool.imap(constants, frm_lst, chunksize=1)): results.update(result) diff --git a/simba/utils/enums.py b/simba/utils/enums.py index 480c0312f..bbd95dc58 100644 --- a/simba/utils/enums.py +++ b/simba/utils/enums.py @@ -129,7 +129,7 @@ class Formats(Enum): MP4_CODEC = "mp4v" AVI_CODEC = "XVID" BATCH_CODEC = "libx264" - NUMERIC_DTYPES = (np.float32, np.float64, np.int64, np.int32, np.int8, int, float) + NUMERIC_DTYPES = (np.float32, np.float64, np.int64, np.int32, np.int8, np.uint8, int, float) LABELFRAME_HEADER_FORMAT = ("Helvetica", 12, "bold") LABELFRAME_HEADER_CLICKABLE_FORMAT = ("Helvetica", 12, "bold", "underline") LABELFRAME_HEADER_CLICKABLE_COLOR = f"#{5:02x}{99:02x}{193:02x}"