Skip to content

Commit

Permalink
Wrap GMT's standard data type GMT_CUBE for cubes
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Mar 31, 2024
1 parent 8a2a6e4 commit fe976e4
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 4 deletions.
11 changes: 7 additions & 4 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
vectors_to_arrays,
)
from pygmt.clib.loading import load_libgmt
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID
from pygmt.datatypes import _GMT_CUBE, _GMT_DATASET, _GMT_GRID
from pygmt.exceptions import (
GMTCLibError,
GMTCLibNoSessionError,
Expand Down Expand Up @@ -1648,7 +1648,9 @@ def virtualfile_in( # noqa: PLR0912

@contextlib.contextmanager
def virtualfile_out(
self, kind: Literal["dataset", "grid"] = "dataset", fname: str | None = None
self,
kind: Literal["dataset", "grid", "cube"] = "dataset",
fname: str | None = None,
):
r"""
Create a virtual file or an actual file for storing output data.
Expand Down Expand Up @@ -1705,12 +1707,13 @@ def virtualfile_out(
family, geometry = {
"dataset": ("GMT_IS_DATASET", "GMT_IS_PLP"),
"grid": ("GMT_IS_GRID", "GMT_IS_SURFACE"),
"cube": ("GMT_IS_CUBE", "GMT_IS_VOLUME"),
}[kind]
with self.open_virtualfile(family, geometry, "GMT_OUT", None) as vfile:
yield vfile

def read_virtualfile(
self, vfname: str, kind: Literal["dataset", "grid", None] = None
self, vfname: str, kind: Literal["dataset", "grid", "cube", None] = None
):
"""
Read data from a virtual file and optionally cast into a GMT data container.
Expand Down Expand Up @@ -1769,7 +1772,7 @@ def read_virtualfile(
# _GMT_DATASET).
if kind is None: # Return the ctypes void pointer
return pointer
dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID}[kind]
dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID, "cube": _GMT_CUBE}[kind]
return ctp.cast(pointer, ctp.POINTER(dtype))

def virtualfile_to_dataset(
Expand Down
1 change: 1 addition & 0 deletions pygmt/datatypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
Wrappers for GMT data types.
"""

from pygmt.datatypes.cube import _GMT_CUBE
from pygmt.datatypes.dataset import _GMT_DATASET
from pygmt.datatypes.grid import _GMT_GRID
93 changes: 93 additions & 0 deletions pygmt/datatypes/cube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Wrapper for the GMT_CUBE data type.
"""

import ctypes as ctp
from typing import ClassVar

import numpy as np
import xarray as xr
from pygmt.datatypes.header import (
_GMT_GRID_HEADER,
GMT_GRID_UNIT_LEN80,
GMT_GRID_VARNAME_LEN80,
_parse_nameunits,
gmt_grdfloat,
)


class _GMT_CUBE(ctp.structure): # noqa: N801
"""
GMT cube data structure for 3D data.
"""

_fields_: ClassVar = [
# Pointer to full GMT 2-D header for a layer (common to all layers)
("header", ctp.POINTER(_GMT_GRID_HEADER)),
# Pointer to the gmt_grdfloat 3-D cube - a stack of 2-D padded grids
("data", ctp.POINTER(gmt_grdfloat)),
# Vector of x coordinates common to all layers
("x", ctp.POINTER(ctp.c_double)),
# Vector of y coordinates common to all layers
("y", ctp.POINTER(ctp.c_double)),
# Low-level information for GMT use only
("hidden", ctp.c_void_p),
# GMT_CUBE_IS_STACK if input dataset was a list of 2-D grids rather than a
# single cube
("mode", ctp.c_uint),
# Minimum/max z values (complements header->wesn[4])
("z_range", ctp.c_double * 2),
# z increment (complements inc[2]) (0 if variable z spacing)
("z_inc", ctp.c_double),
# Array of z values (complements x, y)
("z", ctp.POINTER(ctp.c_double)),
# Name of the 3-D variable, if read from file (or empty if just one)
("name", ctp.c_char * GMT_GRID_VARNAME_LEN80),
# Units in 3rd direction (complements x_units, y_units, z_units)
("units", ctp.c_char * GMT_GRID_UNIT_LEN80),
]

def to_dataarray(self):
"""
Convert the GMT_CUBE to an xarray.DataArray.
Returns
-------
xarray.DataArray: The data array representation of the GMT_CUBE.
"""
# The grid header
header = self.header.contents

name = "cube"
attrs = header.get_data_attrs()
dims = header.get_dims()
dim_attrs = header.get_dim_attrs()

# Patch for the 3rd dimension
dims.append("z")
z_attrs = {"actual_range": np.array(self.z_range[:]), "axis": "Z"}
long_name, units = _parse_nameunits(self.units.decode())
if long_name:
z_attrs["long_name"] = long_name
if units:
z_attrs["units"] = units
dim_attrs.append(z_attrs)

# The coordinates, given as a tuple of the form (dims, data, attrs)
coords = [
(dims[0], self.y[: header.n_rows], dim_attrs[0]),
(dims[1], self.x[: header.n_columns], dim_attrs[1]),
# header->n_bands is used for the number of layers for 3-D cubes
(dims[2], self.z[: header.n_bands], dim_attrs[1]),
]

# The data array without paddings
pad = header.pad[:]
data = np.reshape(
self.data[: header.mx * header.my * header.n_bands],
(header.my, header.mx, header.n_bands),
)[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1], :]

# Create the xarray.DataArray object
grid = xr.DataArray(data, coords=coords, name=name, attrs=attrs)
return grid

0 comments on commit fe976e4

Please sign in to comment.