diff --git a/newsfragments/612.bugfix b/newsfragments/612.bugfix new file mode 100644 index 000000000..6f4e30c86 --- /dev/null +++ b/newsfragments/612.bugfix @@ -0,0 +1 @@ +NXmx files with multidimensional arrays (images, modules, or both) are now handled. diff --git a/src/dxtbx/nexus/__init__.py b/src/dxtbx/nexus/__init__.py index ccd84a7ec..6fc0d20c6 100644 --- a/src/dxtbx/nexus/__init__.py +++ b/src/dxtbx/nexus/__init__.py @@ -2,7 +2,7 @@ import itertools import logging -from typing import Literal, Optional, Tuple, cast +from typing import Literal, Optional import h5py import numpy as np @@ -404,9 +404,12 @@ def equipment_component_key(dependency): origin -= nxdetector.beam_center_y.magnitude * pixel_size[1] * slow_axis # dxtbx requires image size in the order fast, slow - which is the reverse of what - # is stored in module.data_size - image_size = cast(Tuple[int, int], tuple(map(int, module.data_size[::-1]))) - assert len(image_size) == 2 + # is stored in module.data_size. Additionally, data_size can have more than 2 + # dimensions, for multi-module detectors. So take the last two dimensions and reverse + # them. Examples: + # [1,2,3] --> (3, 2) + # [1,2] --> (2, 1) + image_size = (int(module.data_size[-1]), int(module.data_size[-2])) underload = ( float(nxdetector.underload_value) if nxdetector.underload_value is not None @@ -475,13 +478,17 @@ def get_static_mask(nxdetector: nxmx.NXdetector) -> tuple[flex.bool, ...] | None pixel_mask = nxdetector.pixel_mask except KeyError: return None - if pixel_mask is None or not pixel_mask.size or pixel_mask.ndim != 2: + if pixel_mask is None or not pixel_mask.size: return None all_slices = get_detector_module_slices(nxdetector) - return tuple( - flumpy.from_numpy(np.ascontiguousarray(pixel_mask[slices])) == 0 - for slices in all_slices - ) + all_mask_slices = [] + for slices in all_slices: + mask_slice = flumpy.from_numpy(np.ascontiguousarray(pixel_mask[slices])) == 0 + mask_slice.reshape( + flex.grid(mask_slice.all()[-2:]) + ) # handle 3 or 4 dimension arrays + all_mask_slices.append(mask_slice) + return tuple(all_mask_slices) def _dataset_as_flex( @@ -562,5 +569,8 @@ def get_raw_data( data_as_flex = _dataset_as_flex( sliced_outer, tuple(module_slices), bit_depth=bit_depth ) + data_as_flex.reshape( + flex.grid(data_as_flex.all()[-2:]) + ) # handle 3 or 4 dimension arrays all_data.append(data_as_flex) return tuple(all_data) diff --git a/tests/nexus/test_build_dxtbx_models.py b/tests/nexus/test_build_dxtbx_models.py index 8b3b684ac..82c33ca86 100644 --- a/tests/nexus/test_build_dxtbx_models.py +++ b/tests/nexus/test_build_dxtbx_models.py @@ -306,6 +306,109 @@ def test_get_dxtbx_detector_beam_center_fallback(nxmx_example): ) +@pytest.fixture +def detector_with_multiple_modules(): + + with h5py.File(" ", "w", **pytest.h5_in_memory) as f: + + detector = f.create_group("/entry/instrument/detector") + detector.attrs["NX_class"] = "NXdetector" + detector["beam_center_x"] = 2079.79727597266 + detector["beam_center_y"] = 2225.38773853771 + detector["count_time"] = 0.00285260857097799 + detector["depends_on"] = "/entry/instrument/detector/transformations/det_z" + detector["description"] = "Eiger 16M" + detector["distance"] = 0.237015940260233 + detector.create_dataset("data", data=np.zeros((100, 100))) + detector["sensor_material"] = "Silicon" + detector["sensor_thickness"] = 0.00045 + detector["sensor_thickness"].attrs["units"] = b"m" + detector["x_pixel_size"] = 7.5e-05 + detector["y_pixel_size"] = 7.5e-05 + detector["underload_value"] = 0 + detector["saturation_value"] = 9266 + detector["frame_time"] = 0.1 + detector["frame_time"].attrs["units"] = "s" + detector["bit_depth_readout"] = np.array(32) + mask = np.zeros((2, 100, 200), dtype="i8") + detector.create_dataset("pixel_mask", data=mask) + + detector_transformations = detector.create_group("transformations") + detector_transformations.attrs["NX_class"] = "NXtransformations" + det_z = detector_transformations.create_dataset("det_z", data=np.array([289.3])) + det_z.attrs["depends_on"] = b"." + det_z.attrs["transformation_type"] = b"translation" + det_z.attrs["units"] = b"mm" + det_z.attrs["vector"] = np.array([0.0, 0.0, 1.0]) + + def make_module(name, depends_on, data_origin, fast_direction, slow_direction): + module = detector.create_group(name) + module.attrs["NX_class"] = "NXdetector_module" + module.create_dataset("data_size", data=np.array([1, 100, 200])) + module.create_dataset("data_origin", data=np.array(data_origin)) + fast = module.create_dataset("fast_pixel_direction", data=0.075) + fast.attrs["transformation_type"] = "translation" + fast.attrs["depends_on"] = depends_on + fast.attrs["vector"] = np.array(fast_direction) + fast.attrs["units"] = "mm" + slow = module.create_dataset("slow_pixel_direction", data=0.075) + slow.attrs["transformation_type"] = "translation" + slow.attrs["depends_on"] = depends_on + slow.attrs["vector"] = np.array(slow_direction) + slow.attrs["units"] = "mm" + + make_module( + name="m0", + depends_on="/entry/instrument/detector/transformations/det_z", + data_origin=[0, 0, 0], + fast_direction=[-0.999998, -0.001781, 0], + slow_direction=[-0.001781, 0.999998, 0], + ) + make_module( + name="m1", + depends_on="/entry/instrument/detector/transformations/det_z", + data_origin=[1, 0, 0], + fast_direction=[-0.999998, -0.001781, 0], + slow_direction=[-0.001781, 0.999998, 0], + ) + + nxdata = f.create_group("/entry/data") + nxdata.attrs["NX_class"] = "NXdata" + nxdata.create_dataset( + "data", + data=np.array( + [np.full((2, 100, 200), i, dtype=np.int32) for i in range(3)] + ), + ) + nxdata.attrs["signal"] = "/entry/data/data" + + yield f + + +def test_get_dxtbx_detector_with_multiple_modules(detector_with_multiple_modules): + det = nxmx.NXdetector(detector_with_multiple_modules["/entry/instrument/detector"]) + wavelength = 1 + + detector = dxtbx.nexus.get_dxtbx_detector(det, wavelength) + assert len(detector) == 2 + for panel in detector: + assert panel.get_image_size() == (200, 100) + + nxdata = nxmx.NXdata(detector_with_multiple_modules["/entry/data"]) + for i in range(3): + raw_data = dxtbx.nexus.get_raw_data(nxdata, det, i) + assert len(raw_data) == 2 + for module_data in raw_data: + assert module_data.all() == (100, 200) + assert module_data.all_eq(i) + + mask = dxtbx.nexus.get_static_mask(det) + assert len(mask) == 2 + for module_mask in mask: + assert isinstance(module_mask, flex.bool) + assert module_mask.all() == (100, 200) + + @pytest.fixture def detector_with_two_theta(): with h5py.File(" ", "w", **pytest.h5_in_memory) as f: