Skip to content

Commit

Permalink
refactored functions for loading ds from nwb
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Dec 12, 2024
1 parent 70efdd0 commit 1371fde
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 80 deletions.
15 changes: 9 additions & 6 deletions examples/nwb_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
# %% Load the sample data
import datetime

import xarray as xr
from pynwb import NWBHDF5IO, NWBFile

from movement import sample_data
from movement.io.nwb import (
convert_nwb_to_movement,
ds_from_nwb_file,
ds_to_nwb,
)

Expand Down Expand Up @@ -48,8 +49,10 @@
# This will create a movement dataset with the same data as
# the original dataset from the NWBFiles

# Convert the NWBFiles to a movement dataset
ds_from_nwb = convert_nwb_to_movement(
nwb_filepaths=["individual1.nwb", "individual2.nwb"]
)
ds_from_nwb
# Convert each NWB file to a single-individual movement dataset
datasets = [
ds_from_nwb_file(f) for f in ["individual1.nwb", "individual2.nwb"]
]

# Combine into a multi-individual dataset
ds = xr.merge(datasets)
169 changes: 105 additions & 64 deletions movement/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from movement.io.load_poses import from_numpy
from movement.utils.logging import log_error
from movement.validators.files import ValidFile

# Default keyword arguments for Skeletons,
# PoseEstimation and PoseEstimationSeries objects
Expand Down Expand Up @@ -146,7 +147,7 @@ def ds_to_nwb(
Parameters
----------
movement_dataset : xr.Dataset
``movement`` poses dataset containing the data to be converted to NWB.
``movement`` poses dataset containing the data to be converted to NWB
nwb_files : list[pynwb.NWBFile] | pynwb.NWBFile
An NWBFile object or a list of such objects to which the data
will be added.
Expand Down Expand Up @@ -207,32 +208,119 @@ def ds_to_nwb(
print("PoseEstimation already exists. Skipping...")

Check warning on line 208 in movement/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

movement/io/nwb.py#L207-L208

Added lines #L207 - L208 were not covered by tests


def _convert_pose_estimation_series(
def ds_from_nwb_file(
file: str | Path | pynwb.NWBFile,
key_name: str = "PoseEstimation",
) -> xr.Dataset:
"""Create a ``movement`` poses dataset from an NWB file.
The input can be a path to an NWB file on disk or an open NWB file object.
The data will be extracted from the NWB file's behavior processing module,
which is assumed to contain a ``PoseEstimation`` object formatted according
to the ``ndx-pose`` NWB extension [1]_.
Parameters
----------
file : str | Path | NWBFile
Path to the NWB file on disk (ending in ".nwb"),
or an open NWBFile object.
key_name: str, optional
Name of the ``PoseEstimation`` object in the NWB "behavior"
processing module, by default "PoseEstimation".
Returns
-------
movement_ds : xr.Dataset
A single-individual ``movement`` poses dataset
References
----------
.. [1] https://github.com/rly/ndx-pose
"""
if isinstance(file, str | Path):
valid_file = ValidFile(
file, expected_permission="r", expected_suffix=[".nwb"]
)
with pynwb.NWBHDF5IO(valid_file.path, mode="r") as io:
nwb_file = io.read()
ds = _ds_from_nwb_object(nwb_file, key_name=key_name)
ds.attrs["source_file"] = valid_file.path
elif isinstance(file, pynwb.NWBFile):
ds = _ds_from_nwb_object(file, key_name=key_name)

Check warning on line 250 in movement/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

movement/io/nwb.py#L249-L250

Added lines #L249 - L250 were not covered by tests

ds.attrs["source_file"] = getattr(file, "path", None)

Check warning on line 252 in movement/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

movement/io/nwb.py#L252

Added line #L252 was not covered by tests
else:
raise log_error(

Check warning on line 254 in movement/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

movement/io/nwb.py#L254

Added line #L254 was not covered by tests
TypeError,
"Expected file to be one of following types: str, Path, "
f"pynwb.NWBFile. Got {type(file)} instead.",
)
return ds


def _ds_from_nwb_object(
nwb_file: pynwb.NWBFile,
key_name: str = "PoseEstimation",
) -> xr.Dataset:
"""Extract a ``movement`` poses dataset from an open NWB file object.
Parameters
----------
nwb_file : pynwb.NWBFile
An open NWB file object.
key_name: str
Name of the ``PoseEstimation`` object in the NWB "behavior"
processing module, by default "PoseEstimation".
Returns
-------
movement_ds : xr.Dataset
A single-individual ``movement`` poses dataset
"""
pose_estimation = nwb_file.processing["behavior"][key_name]
source_software = pose_estimation.fields["source_software"]
pose_estimation_series = pose_estimation.fields["pose_estimation_series"]
single_keypoint_datasets = [
_ds_from_pose_estimation_series(
pes,
keypoint,
subject_name=nwb_file.identifier,
source_software=source_software,
)
for keypoint, pes in pose_estimation_series.items()
]

return xr.merge(single_keypoint_datasets)


def _ds_from_pose_estimation_series(
pose_estimation_series: ndx_pose.PoseEstimationSeries,
keypoint: str,
subject_name: str,
source_software: str,
source_file: str | None = None,
) -> xr.Dataset:
"""Convert to single-keypoint, single-individual ``movement`` dataset.
"""Create a ``movement`` poses dataset from a pose estimation series.
Note that the ``PoseEstimationSeries`` object only contains data for a
single point over time.
Parameters
----------
pose_estimation_series : ndx_pose.PoseEstimationSeries
PoseEstimationSeries NWB object to be converted.
PoseEstimationSeries NWB object to be converted
keypoint : str
Name of the keypoint - body part.
Name of the keypoint (body part)
subject_name : str
Name of the subject (individual).
Name of the subject (individual)
source_software : str
Name of the software used to estimate the pose.
source_file : Optional[str], optional
File from which the data was extracted, by default None
Name of the software used to estimate the pose
Returns
-------
movement_dataset : xr.Dataset
``movement`` compatible dataset containing the pose estimation data.
A single-individual ``movement`` poses dataset with one keypoint
"""
# extract pose_estimation series data (n_time, n_space)
Expand All @@ -244,63 +332,16 @@ def _convert_pose_estimation_series(
else:
confidence_array = np.asarray(pose_estimation_series.confidence)

# Compute fps from the time differences between timestamps
fps = np.nanmedian(1 / np.diff(pose_estimation_series.timestamps))

# create movement dataset with 1 keypoint and 1 individual
ds = from_numpy(
# reshape to (n_time, n_space, 1, 1)
position_array[:, :, np.newaxis, np.newaxis],
# reshape to (n_time, 1, 1)
confidence_array[:, np.newaxis, np.newaxis],
position_array[:, :, np.newaxis, np.newaxis], # n_time, n_space, 1, 1
confidence_array[:, np.newaxis, np.newaxis], # n_time, 1, 1
individual_names=[subject_name],
keypoint_names=[keypoint],
fps=np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)),
fps=round(fps, 6),
source_software=source_software,
)
ds.attrs["source_file"] = source_file
return ds


def convert_nwb_to_movement(
nwb_filepaths: str | list[str] | list[Path],
key_name: str = "PoseEstimation",
) -> xr.Dataset:
"""Convert a list of NWB files to a single ``movement`` dataset.
Parameters
----------
nwb_filepaths : str | Path | list[str] | list[Path]
List of paths to NWB files to be converted.
key_name: str, optional
Name of the PoseEstimation object in the NWB "behavior"
processing module, by default "PoseEstimation".
Returns
-------
movement_ds : xr.Dataset
``movement`` dataset containing the pose estimation data.
"""
if isinstance(nwb_filepaths, str | Path):
nwb_filepaths = [nwb_filepaths]

datasets = []
for path in nwb_filepaths:
with pynwb.NWBHDF5IO(path, mode="r") as io:
nwbfile = io.read()
pose_estimation = nwbfile.processing["behavior"][key_name]
source_software = pose_estimation.fields["source_software"]
pose_estimation_series = pose_estimation.fields[
"pose_estimation_series"
]

for keypoint, pes in pose_estimation_series.items():
datasets.append(
_convert_pose_estimation_series(
pes,
keypoint,
subject_name=nwbfile.identifier,
source_software=source_software,
source_file=None,
)
)

return xr.merge(datasets)
26 changes: 16 additions & 10 deletions tests/test_unit/test_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from movement import sample_data
from movement.io.nwb import (
_convert_pose_estimation_series,
_create_pose_and_skeleton_objects,
convert_nwb_to_movement,
_ds_from_pose_estimation_series,
ds_from_nwb_file,
ds_to_nwb,
)

Expand Down Expand Up @@ -50,7 +50,8 @@ def create_test_pose_estimation_series(
):
rng = np.random.default_rng(42)
data = rng.random((n_time, n_dims)) # num_frames x n_space_dims (2 or 3)
timestamps = np.linspace(0, 10, num=n_time) # a timestamp for every frame
# Create an array of timestamps in seconds, assuming fps=10.0
timestamps = np.arange(n_time) / 10.0
confidence = np.ones((n_time,)) # a confidence value for every frame
reference_frame = "(0,0,0) corresponds to ..."
confidence_definition = "Softmax output of the deep neural network."
Expand All @@ -74,19 +75,18 @@ def create_test_pose_estimation_series(
(50, 3), # 3D data
],
)
def test_convert_pose_estimation_series(n_time, n_dims):
def test_ds_from_pose_estimation_series(n_time, n_dims):
# Create a sample PoseEstimationSeries object
pose_estimation_series = create_test_pose_estimation_series(
n_time=n_time, n_dims=n_dims, keypoint="leftear"
)

# Call the function
movement_dataset = _convert_pose_estimation_series(
movement_dataset = _ds_from_pose_estimation_series(
pose_estimation_series,
keypoint="leftear",
subject_name="individual1",
source_software="software1",
source_file="file1",
)

# Assert the dimensions of the movement dataset
Expand Down Expand Up @@ -118,10 +118,10 @@ def test_convert_pose_estimation_series(n_time, n_dims):
print(movement_dataset.attrs)
assert movement_dataset.attrs == {
"ds_type": "poses",
"fps": np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)),
"fps": 10.0,
"time_unit": pose_estimation_series.timestamps_unit,
"source_software": "software1",
"source_file": "file1",
"source_file": None,
}


Expand Down Expand Up @@ -242,12 +242,18 @@ def test_convert_nwb_to_movement(tmp_path):
with NWBHDF5IO(file_path, mode="w") as io:
io.write(nwb_file)

nwb_filepaths = [file_path.as_posix()]
movement_dataset = convert_nwb_to_movement(nwb_filepaths)
movement_dataset = ds_from_nwb_file(file_path)

assert movement_dataset.sizes == {
"time": 100,
"individuals": 1,
"keypoints": 3,
"space": 2,
}
assert movement_dataset.attrs == {
"ds_type": "poses",
"fps": 10.0,
"time_unit": "seconds",
"source_software": "DeepLabCut",
"source_file": file_path,
}

0 comments on commit 1371fde

Please sign in to comment.