From 1371fde84f3c049a3f0bcaaf9c7b98c3d76526bf Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 12 Dec 2024 16:24:30 +0000 Subject: [PATCH] refactored functions for loading ds from nwb --- examples/nwb_conversion.py | 15 ++-- movement/io/nwb.py | 169 ++++++++++++++++++++++-------------- tests/test_unit/test_nwb.py | 26 +++--- 3 files changed, 130 insertions(+), 80 deletions(-) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 505036c7..bc0c46d4 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -7,11 +7,12 @@ # %% Load the sample data import datetime +import xarray as xr from pynwb import NWBHDF5IO, NWBFile from movement import sample_data from movement.io.nwb import ( - convert_nwb_to_movement, + ds_from_nwb_file, ds_to_nwb, ) @@ -48,8 +49,10 @@ # This will create a movement dataset with the same data as # the original dataset from the NWBFiles -# Convert the NWBFiles to a movement dataset -ds_from_nwb = convert_nwb_to_movement( - nwb_filepaths=["individual1.nwb", "individual2.nwb"] -) -ds_from_nwb +# Convert each NWB file to a single-individual movement dataset +datasets = [ + ds_from_nwb_file(f) for f in ["individual1.nwb", "individual2.nwb"] +] + +# Combine into a multi-individual dataset +ds = xr.merge(datasets) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 8a01632a..c55b4fa2 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -13,6 +13,7 @@ from movement.io.load_poses import from_numpy from movement.utils.logging import log_error +from movement.validators.files import ValidFile # Default keyword arguments for Skeletons, # PoseEstimation and PoseEstimationSeries objects @@ -146,7 +147,7 @@ def ds_to_nwb( Parameters ---------- movement_dataset : xr.Dataset - ``movement`` poses dataset containing the data to be converted to NWB. + ``movement`` poses dataset containing the data to be converted to NWB nwb_files : list[pynwb.NWBFile] | pynwb.NWBFile An NWBFile object or a list of such objects to which the data will be added. @@ -207,32 +208,119 @@ def ds_to_nwb( print("PoseEstimation already exists. Skipping...") -def _convert_pose_estimation_series( +def ds_from_nwb_file( + file: str | Path | pynwb.NWBFile, + key_name: str = "PoseEstimation", +) -> xr.Dataset: + """Create a ``movement`` poses dataset from an NWB file. + + The input can be a path to an NWB file on disk or an open NWB file object. + The data will be extracted from the NWB file's behavior processing module, + which is assumed to contain a ``PoseEstimation`` object formatted according + to the ``ndx-pose`` NWB extension [1]_. + + Parameters + ---------- + file : str | Path | NWBFile + Path to the NWB file on disk (ending in ".nwb"), + or an open NWBFile object. + key_name: str, optional + Name of the ``PoseEstimation`` object in the NWB "behavior" + processing module, by default "PoseEstimation". + + Returns + ------- + movement_ds : xr.Dataset + A single-individual ``movement`` poses dataset + + References + ---------- + .. [1] https://github.com/rly/ndx-pose + + """ + if isinstance(file, str | Path): + valid_file = ValidFile( + file, expected_permission="r", expected_suffix=[".nwb"] + ) + with pynwb.NWBHDF5IO(valid_file.path, mode="r") as io: + nwb_file = io.read() + ds = _ds_from_nwb_object(nwb_file, key_name=key_name) + ds.attrs["source_file"] = valid_file.path + elif isinstance(file, pynwb.NWBFile): + ds = _ds_from_nwb_object(file, key_name=key_name) + + ds.attrs["source_file"] = getattr(file, "path", None) + else: + raise log_error( + TypeError, + "Expected file to be one of following types: str, Path, " + f"pynwb.NWBFile. Got {type(file)} instead.", + ) + return ds + + +def _ds_from_nwb_object( + nwb_file: pynwb.NWBFile, + key_name: str = "PoseEstimation", +) -> xr.Dataset: + """Extract a ``movement`` poses dataset from an open NWB file object. + + Parameters + ---------- + nwb_file : pynwb.NWBFile + An open NWB file object. + key_name: str + Name of the ``PoseEstimation`` object in the NWB "behavior" + processing module, by default "PoseEstimation". + + Returns + ------- + movement_ds : xr.Dataset + A single-individual ``movement`` poses dataset + + """ + pose_estimation = nwb_file.processing["behavior"][key_name] + source_software = pose_estimation.fields["source_software"] + pose_estimation_series = pose_estimation.fields["pose_estimation_series"] + single_keypoint_datasets = [ + _ds_from_pose_estimation_series( + pes, + keypoint, + subject_name=nwb_file.identifier, + source_software=source_software, + ) + for keypoint, pes in pose_estimation_series.items() + ] + + return xr.merge(single_keypoint_datasets) + + +def _ds_from_pose_estimation_series( pose_estimation_series: ndx_pose.PoseEstimationSeries, keypoint: str, subject_name: str, source_software: str, - source_file: str | None = None, ) -> xr.Dataset: - """Convert to single-keypoint, single-individual ``movement`` dataset. + """Create a ``movement`` poses dataset from a pose estimation series. + + Note that the ``PoseEstimationSeries`` object only contains data for a + single point over time. Parameters ---------- pose_estimation_series : ndx_pose.PoseEstimationSeries - PoseEstimationSeries NWB object to be converted. + PoseEstimationSeries NWB object to be converted keypoint : str - Name of the keypoint - body part. + Name of the keypoint (body part) subject_name : str - Name of the subject (individual). + Name of the subject (individual) source_software : str - Name of the software used to estimate the pose. - source_file : Optional[str], optional - File from which the data was extracted, by default None + Name of the software used to estimate the pose Returns ------- movement_dataset : xr.Dataset - ``movement`` compatible dataset containing the pose estimation data. + A single-individual ``movement`` poses dataset with one keypoint """ # extract pose_estimation series data (n_time, n_space) @@ -244,63 +332,16 @@ def _convert_pose_estimation_series( else: confidence_array = np.asarray(pose_estimation_series.confidence) + # Compute fps from the time differences between timestamps + fps = np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)) + # create movement dataset with 1 keypoint and 1 individual ds = from_numpy( - # reshape to (n_time, n_space, 1, 1) - position_array[:, :, np.newaxis, np.newaxis], - # reshape to (n_time, 1, 1) - confidence_array[:, np.newaxis, np.newaxis], + position_array[:, :, np.newaxis, np.newaxis], # n_time, n_space, 1, 1 + confidence_array[:, np.newaxis, np.newaxis], # n_time, 1, 1 individual_names=[subject_name], keypoint_names=[keypoint], - fps=np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)), + fps=round(fps, 6), source_software=source_software, ) - ds.attrs["source_file"] = source_file return ds - - -def convert_nwb_to_movement( - nwb_filepaths: str | list[str] | list[Path], - key_name: str = "PoseEstimation", -) -> xr.Dataset: - """Convert a list of NWB files to a single ``movement`` dataset. - - Parameters - ---------- - nwb_filepaths : str | Path | list[str] | list[Path] - List of paths to NWB files to be converted. - key_name: str, optional - Name of the PoseEstimation object in the NWB "behavior" - processing module, by default "PoseEstimation". - - Returns - ------- - movement_ds : xr.Dataset - ``movement`` dataset containing the pose estimation data. - - """ - if isinstance(nwb_filepaths, str | Path): - nwb_filepaths = [nwb_filepaths] - - datasets = [] - for path in nwb_filepaths: - with pynwb.NWBHDF5IO(path, mode="r") as io: - nwbfile = io.read() - pose_estimation = nwbfile.processing["behavior"][key_name] - source_software = pose_estimation.fields["source_software"] - pose_estimation_series = pose_estimation.fields[ - "pose_estimation_series" - ] - - for keypoint, pes in pose_estimation_series.items(): - datasets.append( - _convert_pose_estimation_series( - pes, - keypoint, - subject_name=nwbfile.identifier, - source_software=source_software, - source_file=None, - ) - ) - - return xr.merge(datasets) diff --git a/tests/test_unit/test_nwb.py b/tests/test_unit/test_nwb.py index 4d95de42..e9e09f1b 100644 --- a/tests/test_unit/test_nwb.py +++ b/tests/test_unit/test_nwb.py @@ -9,9 +9,9 @@ from movement import sample_data from movement.io.nwb import ( - _convert_pose_estimation_series, _create_pose_and_skeleton_objects, - convert_nwb_to_movement, + _ds_from_pose_estimation_series, + ds_from_nwb_file, ds_to_nwb, ) @@ -50,7 +50,8 @@ def create_test_pose_estimation_series( ): rng = np.random.default_rng(42) data = rng.random((n_time, n_dims)) # num_frames x n_space_dims (2 or 3) - timestamps = np.linspace(0, 10, num=n_time) # a timestamp for every frame + # Create an array of timestamps in seconds, assuming fps=10.0 + timestamps = np.arange(n_time) / 10.0 confidence = np.ones((n_time,)) # a confidence value for every frame reference_frame = "(0,0,0) corresponds to ..." confidence_definition = "Softmax output of the deep neural network." @@ -74,19 +75,18 @@ def create_test_pose_estimation_series( (50, 3), # 3D data ], ) -def test_convert_pose_estimation_series(n_time, n_dims): +def test_ds_from_pose_estimation_series(n_time, n_dims): # Create a sample PoseEstimationSeries object pose_estimation_series = create_test_pose_estimation_series( n_time=n_time, n_dims=n_dims, keypoint="leftear" ) # Call the function - movement_dataset = _convert_pose_estimation_series( + movement_dataset = _ds_from_pose_estimation_series( pose_estimation_series, keypoint="leftear", subject_name="individual1", source_software="software1", - source_file="file1", ) # Assert the dimensions of the movement dataset @@ -118,10 +118,10 @@ def test_convert_pose_estimation_series(n_time, n_dims): print(movement_dataset.attrs) assert movement_dataset.attrs == { "ds_type": "poses", - "fps": np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)), + "fps": 10.0, "time_unit": pose_estimation_series.timestamps_unit, "source_software": "software1", - "source_file": "file1", + "source_file": None, } @@ -242,8 +242,7 @@ def test_convert_nwb_to_movement(tmp_path): with NWBHDF5IO(file_path, mode="w") as io: io.write(nwb_file) - nwb_filepaths = [file_path.as_posix()] - movement_dataset = convert_nwb_to_movement(nwb_filepaths) + movement_dataset = ds_from_nwb_file(file_path) assert movement_dataset.sizes == { "time": 100, @@ -251,3 +250,10 @@ def test_convert_nwb_to_movement(tmp_path): "keypoints": 3, "space": 2, } + assert movement_dataset.attrs == { + "ds_type": "poses", + "fps": 10.0, + "time_unit": "seconds", + "source_software": "DeepLabCut", + "source_file": file_path, + }