From 52ec7ca1265930b894ee3b93a0e932f4c6850405 Mon Sep 17 00:00:00 2001 From: Bobby Jackson Date: Mon, 22 Apr 2024 16:39:55 -0500 Subject: [PATCH 1/6] FIX: Some documentation fixes and fixes in Jax point cost functions. --- doc/source/contributors_guide/index.rst | 1 + examples/README.txt | 8 +++----- pydda/cost_functions/_cost_functions_jax.py | 2 +- pydda/cost_functions/cost_functions.py | 4 ---- pydda/io/__init__.py | 1 + pydda/io/read_grid.py | 1 - pydda/retrieval/nesting.py | 21 +++++++++++++++------ pydda/retrieval/wind_retrieve.py | 12 ++++++------ 8 files changed, 27 insertions(+), 23 deletions(-) diff --git a/doc/source/contributors_guide/index.rst b/doc/source/contributors_guide/index.rst index 91dbe4c..4581e83 100644 --- a/doc/source/contributors_guide/index.rst +++ b/doc/source/contributors_guide/index.rst @@ -48,6 +48,7 @@ Examples of unacceptable behavior by participants include: advances Trolling, insulting/derogatory comments, and personal or political attacks + Public or private harassment Publishing others' private information, such as a physical or electronic diff --git a/examples/README.txt b/examples/README.txt index ecfe837..dc4bc49 100644 --- a/examples/README.txt +++ b/examples/README.txt @@ -1,8 +1,6 @@ PyDDA Example Gallery ==================== -Different examples are given on how to retrieve winds using HRRR and radar data. - -Example grid data files for Hurricane Florence are available at: - -https://drive.google.com/drive/folders/1pcQxWRJV78xuJePTZnlXPPpMe1qut0ie +In this section, we show different examples on: + * How to use HRRR to initalize your wind retrieval + * How to adjust the variational retrieval parameters diff --git a/pydda/cost_functions/_cost_functions_jax.py b/pydda/cost_functions/_cost_functions_jax.py index 5dde4c9..a009501 100644 --- a/pydda/cost_functions/_cost_functions_jax.py +++ b/pydda/cost_functions/_cost_functions_jax.py @@ -390,7 +390,7 @@ def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): jnp.abs(z - the_point["z"]) < roi, ) J += jnp.sum( - ((u[the_box] - the_point["u"]) ** 2 + (v[the_box] - the_point["v"]) ** 2) + ((u * the_box - the_point["u"]) ** 2 + (v * the_box - the_point["v"]) ** 2) ) return J * Cp diff --git a/pydda/cost_functions/cost_functions.py b/pydda/cost_functions/cost_functions.py index a74230b..b537bbd 100644 --- a/pydda/cost_functions/cost_functions.py +++ b/pydda/cost_functions/cost_functions.py @@ -9,9 +9,6 @@ TENSORFLOW_AVAILABLE = False try: - from jax.config import config - - config.update("jax_enable_x64", True) import jax.numpy as jnp JAX_AVAILABLE = True @@ -858,7 +855,6 @@ def grad_jax(winds, parameters): parameters.point_list, Cp=parameters.Cpoint, roi=parameters.roi, - upper_bc=parameters.upper_bc, ) return grad diff --git a/pydda/io/__init__.py b/pydda/io/__init__.py index fb87d12..1746769 100644 --- a/pydda/io/__init__.py +++ b/pydda/io/__init__.py @@ -12,6 +12,7 @@ read_grid read_from_pyart_grid + read_hpl """ from .read_grid import read_grid, read_from_pyart_grid diff --git a/pydda/io/read_grid.py b/pydda/io/read_grid.py index ed2b37a..d4378c8 100644 --- a/pydda/io/read_grid.py +++ b/pydda/io/read_grid.py @@ -1,5 +1,4 @@ import xarray as xr -import xradar as xd import numpy as np from glob import glob diff --git a/pydda/retrieval/nesting.py b/pydda/retrieval/nesting.py index 23874e3..13a24e0 100644 --- a/pydda/retrieval/nesting.py +++ b/pydda/retrieval/nesting.py @@ -11,14 +11,23 @@ def get_dd_wind_field_nested(grid_tree: DataTree, **kwargs): """ Does a wind retrieval over nested grids. The nested grids are created using PyART's :func:`pyart.map.grid_from_radars` function and then placed into a tree structure using - dictionaries. Each node of the tree has three parameters: - 'input_grids': The list of PyART grids for the given level of the grid - 'kwargs': The list of key word arguments for input to the get_dd_wind_field function for the set of grids. - If this is None, then the default keyword arguments are carried from the keyword arguments of this function. - 'children': The list of trees that are the children of this node. + :func:`dataTree`s. Each node of the tree has three parameters: + .. list-table:: Title + :widths: 25 100 + :header-rows: 1 + + * - Dictionary key + - Description + * - input_grids + - The list of PyART grids for the given level of the grid + * - kwargs + - The list of key word arguments for input to the :py:func:`pydda.retrieval.get_dd_wind_field` function for the set of grids. + * - children + - The list of trees that are the children of this node. The function will output the same tree, with the list of output grids of each level output to the 'output_grids' - member of the tree structure. + member of the tree structure. If *kwargs* is set to None, then the input keyword arguments will be + used throughout the retrieval. """ # Look for radars in current level diff --git a/pydda/retrieval/wind_retrieve.py b/pydda/retrieval/wind_retrieve.py index f88cf10..3a75520 100644 --- a/pydda/retrieval/wind_retrieve.py +++ b/pydda/retrieval/wind_retrieve.py @@ -1326,7 +1326,7 @@ def get_dd_wind_field( Using Tensorflow or Jax expands PyDDA's capabiability to take advantage of GPU-based systems. In addition, these two implementations use automatic differentation to calculate the gradient of the cost function in order to optimize the gradient calculation. - TensorFlow 2.6 and tensorflow-probability are required for the TensorFlow-basedengine. + TensorFlow 2.6 and tensorflow-probability are required for the TensorFlow-based engine. The latest version of Jax is required for the Jax-based engine. points: None or list of dicts Point observations as returned by :func:`pydda.constraints.get_iem_obs`. Set @@ -1413,9 +1413,9 @@ def get_dd_wind_field( The list of fields in the first grid in Grids that contain the custom data interpolated to the Grid's grid specification. Helper functions to create such gridded fields for HRRR and NetCDF WRF data exist - in ::pydda.constraints::. PyDDA will look for fields named U_(model - field name), V_(model field name), and W_(model field name). For - example, if you have U_hrrr, V_hrrr, and W_hrrr, then specify ["hrrr"] + in :py:func:`pydda.constraints`. PyDDA will look for fields named *U_(model + field name)*, *V_(model field name)*, and *W_(model field name)*. For + example, if you have *U_hrrr*, *V_hrrr*, and *W_hrrr*, then specify *["hrrr"]* into model_fields. output_cost_functions: bool Set to True to output the value of each cost function every @@ -1429,9 +1429,9 @@ def get_dd_wind_field( wind_tol: float Stop iterations after maximum change in winds is less than this value. tolerance: float - Tolerance for L2 norm of gradient before stopping. + Tolerance for :math:`L_{2}` norm of gradient before stopping. max_wind_magnitude: float - Constrain the optimization to have :math:`|u|, :math:`|w|`, and :math:`|w| < x` m/s. + Constrain the optimization to have :math:`|u|`, :math:`|w|`, and :math:`|w| < x` m/s. Returns ======= From 72934aaaf9ee7d23183718860135581eaf883147 Mon Sep 17 00:00:00 2001 From: Bobby Jackson Date: Mon, 22 Apr 2024 16:48:18 -0500 Subject: [PATCH 2/6] FIX: Yaml file to add datatree. --- doc/environment_docs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/environment_docs.yml b/doc/environment_docs.yml index 11b6309..11b5135 100644 --- a/doc/environment_docs.yml +++ b/doc/environment_docs.yml @@ -27,3 +27,4 @@ dependencies: - sphinx-gallery - sphinx-copybutton - sphinx-design + - datatree From 3c979365aa822a7780745e6b89c221de370b5b59 Mon Sep 17 00:00:00 2001 From: Bobby Jackson Date: Mon, 22 Apr 2024 16:56:33 -0500 Subject: [PATCH 3/6] FIX: Datatree in REQUIREMENTS.txt --- REQUIREMENTS.txt | 2 +- doc/environment_docs.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/REQUIREMENTS.txt b/REQUIREMENTS.txt index cb8a06c..5902dc2 100644 --- a/REQUIREMENTS.txt +++ b/REQUIREMENTS.txt @@ -9,4 +9,4 @@ pooch cmweather cdsapi xarray -datatree +xarray-datatree diff --git a/doc/environment_docs.yml b/doc/environment_docs.yml index 11b5135..f0df2aa 100644 --- a/doc/environment_docs.yml +++ b/doc/environment_docs.yml @@ -27,4 +27,4 @@ dependencies: - sphinx-gallery - sphinx-copybutton - sphinx-design - - datatree + - xarray-datatree From 31e4f6652e724ca4bb48e3a281a4022b6d1596e0 Mon Sep 17 00:00:00 2001 From: Bobby Jackson Date: Mon, 22 Apr 2024 18:42:20 -0500 Subject: [PATCH 4/6] ADD: Datatree to test environment --- continuous_integration/environment-actions.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/continuous_integration/environment-actions.yml b/continuous_integration/environment-actions.yml index dd28266..f8e6c5d 100644 --- a/continuous_integration/environment-actions.yml +++ b/continuous_integration/environment-actions.yml @@ -23,3 +23,4 @@ dependencies: - jaxopt - tensorflow>=2.6 - tensorflow-probability + - xarray-datatree From 165f72348bb1b135e78ea45c61df785fd92e02bc Mon Sep 17 00:00:00 2001 From: Bobby Jackson Date: Wed, 24 Apr 2024 13:52:55 -0500 Subject: [PATCH 5/6] FIX: Jax cost function --- pydda/cost_functions/_cost_functions_jax.py | 5 ++--- pydda/tests/test_cost_functions.py | 4 ---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/pydda/cost_functions/_cost_functions_jax.py b/pydda/cost_functions/_cost_functions_jax.py index a009501..aa54a03 100644 --- a/pydda/cost_functions/_cost_functions_jax.py +++ b/pydda/cost_functions/_cost_functions_jax.py @@ -389,9 +389,8 @@ def calculate_point_cost(u, v, x, y, z, point_list, Cp=1e-3, roi=500.0): ), jnp.abs(z - the_point["z"]) < roi, ) - J += jnp.sum( - ((u * the_box - the_point["u"]) ** 2 + (v * the_box - the_point["v"]) ** 2) - ) + the_box = jnp.where(the_box, 1.0, 0.0) + J += jnp.sum(((u - the_point["u"]) ** 2 + (v - the_point["v"]) ** 2) * the_box) return J * Cp diff --git a/pydda/tests/test_cost_functions.py b/pydda/tests/test_cost_functions.py index 2b96537..ec49098 100644 --- a/pydda/tests/test_cost_functions.py +++ b/pydda/tests/test_cost_functions.py @@ -501,8 +501,6 @@ def test_vert_vorticity_tf(): def test_point_cost(): u = 1 * np.ones((10, 10, 10)) v = 1 * np.ones((10, 10, 10)) - 0 * np.ones((10, 10, 10)) - x = np.linspace(-10, 10, 10) y = np.linspace(-10, 10, 10) z = np.linspace(-10, 10, 10) @@ -556,8 +554,6 @@ def test_point_cost(): def test_point_cost_jax(): u = 1 * np.ones((10, 10, 10)) v = 1 * np.ones((10, 10, 10)) - 0 * np.ones((10, 10, 10)) - x = np.linspace(-10, 10, 10) y = np.linspace(-10, 10, 10) z = np.linspace(-10, 10, 10) From 88c6b3185f02517c273d371d470ab982b6b7a2f6 Mon Sep 17 00:00:00 2001 From: Bobby Jackson Date: Wed, 24 Apr 2024 14:22:21 -0500 Subject: [PATCH 6/6] VER: Drop Python 3.9 support. Currently no support for Py3.12. --- .github/workflows/python-package-conda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml index dca9abe..c3757da 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/python-package-conda.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11"] os: [macOS, ubuntu] inlcude: - os: macos-latest