Skip to content

Commit

Permalink
Merge pull request #590 from cbegeman/add-bsf-function
Browse files Browse the repository at this point in the history
Port barotropic streamfunction from MPAS-Analysis
  • Loading branch information
xylar authored Nov 13, 2024
2 parents 3e1cab5 + 3b69401 commit 2dbd9de
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 15 deletions.
2 changes: 2 additions & 0 deletions conda_package/docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ Ocean Tools
depth.compute_depth
depth.compute_zmid

compute_barotropic_streamfunction

.. currentmodule:: mpas_tools.ocean.inject_bathymetry

.. autosummary::
Expand Down
1 change: 1 addition & 0 deletions conda_package/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ analyzing simulations, and in other MPAS-related workflows.
ocean/coastline_alteration
ocean/moc
ocean/depth
ocean/streamfunction
ocean/visualization

.. toctree::
Expand Down
14 changes: 14 additions & 0 deletions conda_package/docs/ocean/streamfunction.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. _ocean_streamfunction:

Computing streamfunctions
=========================

Computing the barotropic streamfunction
---------------------------------------

The function :py:func:`mpas_tools.ocean.compute_barotropic_streamfunction()`
computes the barotproic streamfunction at vertices on the MPAS-Ocean grid.
The function takes a dataset containing an MPAS-Ocean mesh and another with
``normalVelocity`` and ``layerThickness`` variables (possibly with a
``timeMonthly_avg_`` prefix). The streamfunction is computed only over the
range of (positive-down) depths provided and at the given time index.
21 changes: 15 additions & 6 deletions conda_package/mpas_tools/ocean/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from mpas_tools.ocean.build_mesh import build_spherical_mesh, \
build_planar_mesh
from mpas_tools.ocean.build_mesh import (
build_spherical_mesh,
build_planar_mesh,
)
from mpas_tools.ocean.barotropic_streamfunction import (
compute_barotropic_streamfunction,
)
from mpas_tools.ocean.inject_bathymetry import inject_bathymetry
from mpas_tools.ocean.inject_meshDensity import inject_meshDensity_from_file, \
inject_spherical_meshDensity, inject_planar_meshDensity
from mpas_tools.ocean.inject_preserve_floodplain import \
inject_preserve_floodplain
from mpas_tools.ocean.inject_meshDensity import (
inject_meshDensity_from_file,
inject_spherical_meshDensity,
inject_planar_meshDensity,
)
from mpas_tools.ocean.inject_preserve_floodplain import (
inject_preserve_floodplain,
)
158 changes: 158 additions & 0 deletions conda_package/mpas_tools/ocean/barotropic_streamfunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import xarray as xr
import numpy as np
import scipy.sparse
import scipy.sparse.linalg

import logging
import sys
from mpas_tools.ocean.depth import compute_zmid


def compute_barotropic_streamfunction(ds_mesh, ds, logger=None,
min_depth=-5., max_depth=1.e4,
prefix='timeMonthly_avg_',
time_index=0):
"""
Compute barotropic streamfunction. Returns BSF in Sv on vertices.
Parameters
----------
ds_mesh : ``xarray.Dataset``
A dataset containing MPAS mesh variables
ds : ``xarray.Dataset``
A dataset containing MPAS output variables ``normalVelocity`` and
``layerThickness`` (possibly with a ``prefix``)
logger : ``logging.Logger``, optional
A logger for the output if not stdout
min_depth : float, optional
The minimum depth (positive down) to compute transport over
max_depth : float, optional
The maximum depth (positive down) to compute transport over
prefix : str, optional
The prefix on the ``normalVelocity`` and ``layerThickness`` variables
time_index : int, optional
The time at which to index ``ds`` (if it has ``Time`` as a dimension)
"""

useStdout = logger is None
if useStdout:
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler(sys.stdout))
logger.setLevel(logging.INFO)

inner_edges, transport = _compute_transport(
ds_mesh, ds, min_depth=min_depth, max_depth=max_depth, prefix=prefix,
time_index=time_index)
logger.info('transport computed.')

nvertices = ds_mesh.sizes['nVertices']

cells_on_vertex = ds_mesh.cellsOnVertex - 1
vertices_on_edge = ds_mesh.verticesOnEdge - 1
is_boundary_cov = cells_on_vertex == -1
boundary_vertices = np.logical_or(is_boundary_cov.isel(vertexDegree=0),
is_boundary_cov.isel(vertexDegree=1))
boundary_vertices = np.logical_or(boundary_vertices,
is_boundary_cov.isel(vertexDegree=2))

# convert from boolean mask to indices
boundary_vertices = np.flatnonzero(boundary_vertices.values)

n_boundary_vertices = len(boundary_vertices)
n_inner_edges = len(inner_edges)

indices = np.zeros((2, 2 * n_inner_edges + n_boundary_vertices), dtype=int)
data = np.zeros(2 * n_inner_edges + n_boundary_vertices, dtype=float)

# The difference between the streamfunction at vertices on an inner
# edge should be equal to the transport
v0 = vertices_on_edge.isel(nEdges=inner_edges, TWO=0).values
v1 = vertices_on_edge.isel(nEdges=inner_edges, TWO=1).values

ind = np.arange(n_inner_edges)
indices[0, 2 * ind] = ind
indices[1, 2 * ind] = v1
data[2 * ind] = 1.

indices[0, 2 * ind + 1] = ind
indices[1, 2 * ind + 1] = v0
data[2 * ind + 1] = -1.

# the streamfunction should be zero at all boundary vertices
ind = np.arange(n_boundary_vertices)
indices[0, 2 * n_inner_edges + ind] = n_inner_edges + ind
indices[1, 2 * n_inner_edges + ind] = boundary_vertices
data[2 * n_inner_edges + ind] = 1.

rhs = np.zeros(n_inner_edges + n_boundary_vertices, dtype=float)

# convert to Sv
ind = np.arange(n_inner_edges)
rhs[ind] = 1e-6 * transport

ind = np.arange(n_boundary_vertices)
rhs[n_inner_edges + ind] = 0.

matrix = scipy.sparse.csr_matrix(
(data, indices),
shape=(n_inner_edges + n_boundary_vertices, nvertices))

solution = scipy.sparse.linalg.lsqr(matrix, rhs)
bsf_vertex = xr.DataArray(-solution[0],
dims=('nVertices',))

return bsf_vertex

def _compute_transport(ds_mesh, ds, min_depth, max_depth, prefix,
time_index):

cells_on_edge = ds_mesh.cellsOnEdge - 1
inner_edges = np.logical_and(cells_on_edge.isel(TWO=0) >= 0,
cells_on_edge.isel(TWO=1) >= 0)

if 'Time' in ds.dims:
ds = ds.isel(Time=time_index)

# convert from boolean mask to indices
inner_edges = np.flatnonzero(inner_edges.values)

cell0 = cells_on_edge.isel(nEdges=inner_edges, TWO=0)
cell1 = cells_on_edge.isel(nEdges=inner_edges, TWO=1)

normal_velocity = \
ds[f'{prefix}normalVelocity'].isel(nEdges=inner_edges)
layer_thickness = ds[f'{prefix}layerThickness']
layer_thickness_edge = 0.5 * (layer_thickness.isel(nCells=cell0) +
layer_thickness.isel(nCells=cell1))

n_vert_levels = ds.sizes['nVertLevels']

vert_index = xr.DataArray.from_dict(
{'dims': ('nVertLevels',), 'data': np.arange(n_vert_levels)})
mask_bottom = (vert_index < ds_mesh.maxLevelCell).T
mask_bottom_edge = 0.5 * (mask_bottom.isel(nCells=cell0) +
mask_bottom.isel(nCells=cell1))

if 'zMid' not in ds.keys():
z_mid = compute_zmid(ds_mesh.bottomDepth, ds_mesh.maxLevelCell,
ds_mesh.layerThickness)
else:
z_mid = ds.zMid
z_mid_edge = 0.5 * (z_mid.isel(nCells=cell0) +
z_mid.isel(nCells=cell1))

mask = np.logical_and(np.logical_and(z_mid_edge >= -max_depth,
z_mid_edge <= -min_depth),
mask_bottom_edge)
normal_velocity = normal_velocity.where(mask)
layer_thickness_edge = layer_thickness_edge.where(mask)
transport = ds_mesh.dvEdge[inner_edges] * \
(layer_thickness_edge * normal_velocity).sum(dim='nVertLevels')

return inner_edges, transport
16 changes: 7 additions & 9 deletions conda_package/mpas_tools/ocean/depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def compute_depth(refBottomDepth):
depth_bnds[0, 0] = 0.
depth_bnds[1:, 0] = refBottomDepth[0:-1]
depth_bnds[:, 1] = refBottomDepth
depth = 0.5*(depth_bnds[:, 0] + depth_bnds[:, 1])
depth = 0.5 * (depth_bnds[:, 0] + depth_bnds[:, 1])

return depth, depth_bnds

Expand Down Expand Up @@ -82,11 +82,11 @@ def compute_zmid(bottomDepth, maxLevelCell, layerThickness,

thicknessSum = layerThickness.sum(dim=depth_dim)
thicknessCumSum = layerThickness.cumsum(dim=depth_dim)
zSurface = -bottomDepth+thicknessSum
zSurface = -bottomDepth + thicknessSum

zLayerBot = zSurface - thicknessCumSum

zMid = zLayerBot + 0.5*layerThickness
zMid = zLayerBot + 0.5 * layerThickness

zMid = zMid.where(vertIndex < maxLevelCell)
if 'Time' in zMid.dims:
Expand Down Expand Up @@ -150,8 +150,7 @@ def add_depth(inFileName, outFileName, coordFileName=None):
history = '{}: {}'.format(time, ' '.join(sys.argv))

if 'history' in ds.attrs:
ds.attrs['history'] = '{}\n{}'.format(history,
ds.attrs['history'])
ds.attrs['history'] = f'{history}\n{ds.attrs["history"]}'
else:
ds.attrs['history'] = history

Expand Down Expand Up @@ -229,8 +228,7 @@ def add_zmid(inFileName, outFileName, coordFileName=None):
history = '{}: {}'.format(time, ' '.join(sys.argv))

if 'history' in ds.attrs:
ds.attrs['history'] = '{}\n{}'.format(history,
ds.attrs['history'])
ds.attrs['history'] = f'{history}\n{ds.attrs["history"]}'
else:
ds.attrs['history'] = history

Expand Down Expand Up @@ -288,8 +286,8 @@ def write_time_varying_zmid(inFileName, outFileName, coordFileName=None,

dsIn = xarray.open_dataset(inFileName)
dsIn = dsIn.rename({'nVertLevels': 'depth'})
inVarName = '{}layerThickness'.format(prefix)
outVarName = '{}zMid'.format(prefix)
inVarName = f'{prefix}layerThickness'
outVarName = f'{prefix}zMid'
layerThickness = dsIn[inVarName]

zMid = compute_zmid(dsCoord.bottomDepth, dsCoord.maxLevelCell,
Expand Down

0 comments on commit 2dbd9de

Please sign in to comment.