Skip to content

Commit

Permalink
Merge pull request #442 from xylar/switch_from_pyflann_to_ckdtree
Browse files Browse the repository at this point in the history
 Switch from `pyflann` to scipy `KDTree`
  • Loading branch information
xylar authored Sep 28, 2021
2 parents ff01aa7 + 9b6d6de commit 9484daa
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 77 deletions.
4 changes: 3 additions & 1 deletion conda_package/ci/linux_python3.7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ cxx_compiler_version:
hdf5:
- 1.10.6
libnetcdf:
- 4.8.0
- 4.8.1
pin_run_as_build:
python:
min_pin: x.x
max_pin: x.x
python:
- 3.7.* *_cpython
target_platform:
- linux-64
4 changes: 3 additions & 1 deletion conda_package/ci/linux_python3.8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ cxx_compiler_version:
hdf5:
- 1.10.6
libnetcdf:
- 4.8.0
- 4.8.1
pin_run_as_build:
python:
min_pin: x.x
max_pin: x.x
python:
- 3.8.* *_cpython
target_platform:
- linux-64
4 changes: 3 additions & 1 deletion conda_package/ci/linux_python3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ cxx_compiler_version:
hdf5:
- 1.10.6
libnetcdf:
- 4.8.0
- 4.8.1
pin_run_as_build:
python:
min_pin: x.x
max_pin: x.x
python:
- 3.9.* *_cpython
target_platform:
- linux-64
4 changes: 3 additions & 1 deletion conda_package/ci/osx_python3.7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ cxx_compiler_version:
hdf5:
- 1.10.6
libnetcdf:
- 4.8.0
- 4.8.1
pin_run_as_build:
python:
min_pin: x.x
max_pin: x.x
python:
- 3.7.* *_cpython
target_platform:
- osx-64
4 changes: 3 additions & 1 deletion conda_package/ci/osx_python3.8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ cxx_compiler_version:
hdf5:
- 1.10.6
libnetcdf:
- 4.8.0
- 4.8.1
pin_run_as_build:
python:
min_pin: x.x
max_pin: x.x
python:
- 3.8.* *_cpython
target_platform:
- osx-64
4 changes: 3 additions & 1 deletion conda_package/ci/osx_python3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ cxx_compiler_version:
hdf5:
- 1.10.6
libnetcdf:
- 4.8.0
- 4.8.1
pin_run_as_build:
python:
min_pin: x.x
max_pin: x.x
python:
- 3.9.* *_cpython
target_platform:
- osx-64
1 change: 0 additions & 1 deletion conda_package/dev-spec.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ numpy
progressbar2
pyamg
pyevtk
pyflann
pyproj
python-igraph
scikit-image
Expand Down
11 changes: 5 additions & 6 deletions conda_package/docs/ocean/coastal_tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ a starting point for ``params``:
"plot_box": North_America,
# Options
"nn_search": "flann",
"nn_search": "kdtree",
"plot_option": True
}
Expand Down Expand Up @@ -116,8 +116,7 @@ Next, the distance to the coastal contours is computed using
:py:func:`mpas_tools.ocean.coastal_tools.distance_to_coast()` with the
following values from ``params``:

* ``'nn_search'`` - Whether to use the ``'flann'`` or ``'kdtree'`` algorithm,
with the ``'flann'`` strongly recommended.
* ``'nn_search'`` - currently, only the ``'kdtree'`` algorithm is supported
* ``'smooth_coastline'`` - The number of neighboring cells along the coastline
over which to average locations to smooth the coastline
* ``'plot_option'`` - Whether to plot the distance function.
Expand Down Expand Up @@ -254,9 +253,9 @@ A key ingredient in defining resolution in coastal meshes is a field containing
the distance from each location in the field to the nearest point on the
coastline. This distance field ``D`` is computed with
:py:func:`mpas_tools.ocean.coastal_tools.distance_to_coast()`
The user can optionally control the search algorithm used via
``params['nn_search']`` (though ``'flann'``, the default, is highly
recommended). They can also decide to smooth the coastline as long as there is
The user could previouly control the search algorithm used via
``params['nn_search']`` but ``'kdtree'`` is now the only option.
They can also decide to smooth the coastline as long as there is
a single coastline contour---with multiple contours, the current algorithm will
average the end of one contour with the start fo the next---by specifying an
integer number of neighbors as ``params['smooth_coastline']``. The default is
Expand Down
47 changes: 21 additions & 26 deletions conda_package/mpas_tools/mesh/creation/signed_distance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import pyflann
from scipy import spatial
from scipy.spatial import KDTree
import timeit
import shapely.geometry
import shapely.ops
Expand All @@ -13,7 +12,7 @@


def signed_distance_from_geojson(fc, lon_grd, lat_grd, earth_radius,
max_length=None):
max_length=None, workers=-1):
"""
Get the distance for each point on a lon/lat grid from the closest point
on the boundary of the geojson regions.
Expand All @@ -36,14 +35,19 @@ def signed_distance_from_geojson(fc, lon_grd, lat_grd, earth_radius,
The maximum distance (in degrees) between points on the boundary of the
geojson region. If the boundary is too coarse, it will be subdivided.
workers : int, optional
The number of threads used for finding nearest neighbors. The default
is all available threads (``workers=-1``)
Returns
-------
signed_distance : numpy.ndarray
A 2D field of distances (negative inside the region, positive outside)
to the shape boundary
"""
distance = distance_from_geojson(fc, lon_grd, lat_grd, earth_radius,
nn_search='flann', max_length=max_length)
nn_search='kdtree', max_length=max_length,
workers=workers)

mask = mask_from_geojson(fc, lon_grd, lat_grd)

Expand Down Expand Up @@ -102,7 +106,8 @@ def mask_from_geojson(fc, lon_grd, lat_grd):


def distance_from_geojson(fc, lon_grd, lat_grd, earth_radius,
nn_search='flann', max_length=None):
nn_search='kdtree', max_length=None,
workers=-1):
# {{{
"""
Get the distance for each point on a lon/lat grid from the closest point
Expand All @@ -122,18 +127,26 @@ def distance_from_geojson(fc, lon_grd, lat_grd, earth_radius,
earth_radius : float
Earth radius in meters
nn_search: {'kdtree', 'flann'}, optional
nn_search: {'kdtree'}, optional
The method used to find the nearest point on the shape boundary
max_length : float, optional
The maximum distance (in degrees) between points on the boundary of the
geojson region. If the boundary is too coarse, it will be subdivided.
workers : int, optional
The number of threads used for finding nearest neighbors. The default
is all available threads (``workers=-1``)
Returns
-------
distance : numpy.ndarray
A 2D field of distances to the shape boundary
"""

if nn_search != 'kdtree':
raise ValueError(f'nn_search method {nn_search} not available.')

print("Distance from geojson")
print("---------------------")

Expand Down Expand Up @@ -167,20 +180,7 @@ def distance_from_geojson(fc, lon_grd, lat_grd, earth_radius,
boundary_xyz = np.zeros((npoints, 3))
boundary_xyz[:, 0], boundary_xyz[:, 1], boundary_xyz[:, 2] = \
lonlat2xyz(boundary_lon, boundary_lat, earth_radius)
flann = None
tree = None
if nn_search == "kdtree":
tree = spatial.KDTree(boundary_xyz)
elif nn_search == "flann":
flann = pyflann.FLANN()
flann.build_index(
boundary_xyz,
algorithm='kdtree',
target_precision=1.0,
random_seed=0)
else:
raise ValueError('Bad nn_search: expected kdtree or flann, got '
'{}'.format(nn_search))
tree = KDTree(boundary_xyz)

# Convert background grid coordinates to x,y,z and put in a nx_grd x 3
# array for kd-tree query
Expand All @@ -191,12 +191,7 @@ def distance_from_geojson(fc, lon_grd, lat_grd, earth_radius,
# Find distances of background grid coordinates to the coast
print(" Finding distance")
start = timeit.default_timer()
distance = None
if nn_search == "kdtree":
distance, _ = tree.query(pts)
elif nn_search == "flann":
_, distance = flann.nn_index(pts, checks=2000, random_seed=0)
distance = np.sqrt(distance)
distance, _ = tree.query(pts, workers=workers)
end = timeit.default_timer()
print(" Done")
print(" {0:.0f} seconds".format(end-start))
Expand Down
20 changes: 12 additions & 8 deletions conda_package/mpas_tools/mesh/mask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import xarray as xr
import numpy
import pyflann
from scipy.spatial import KDTree
import shapely.geometry
import shapely.ops
from shapely.geometry import box, Polygon, MultiPolygon, GeometryCollection
Expand Down Expand Up @@ -387,7 +387,7 @@ def entry_point_compute_mpas_transect_masks():
engine=args.engine)


def compute_mpas_flood_fill_mask(dsMesh, fcSeed, logger=None):
def compute_mpas_flood_fill_mask(dsMesh, fcSeed, logger=None, workers=-1):
"""
Flood fill from the given set of seed points to create a contiguous mask.
The flood fill operates using cellsOnCell, starting from the cells
Expand All @@ -404,6 +404,10 @@ def compute_mpas_flood_fill_mask(dsMesh, fcSeed, logger=None):
logger : logging.Logger, optional
A logger for the output if not stdout
workers : int, optional
The number of threads used for finding nearest neighbors. The default
is all available threads (``workers=-1``)
Returns
-------
dsMask : xarray.Dataset
Expand All @@ -422,7 +426,7 @@ def compute_mpas_flood_fill_mask(dsMesh, fcSeed, logger=None):
if logger is not None:
logger.info(' Computing flood fill mask on cells:')

mask = _compute_seed_mask(fcSeed, lon, lat)
mask = _compute_seed_mask(fcSeed, lon, lat, workers)

cellsOnCell = dsMesh.cellsOnCell.values - 1

Expand Down Expand Up @@ -1158,23 +1162,23 @@ def _copy_dateline_lon_lat_vertices(lonVertex, latVertex, lonCenter):
return lonVertex, latVertex, duplicatePolygons


def _compute_seed_mask(fcSeed, lon, lat):
def _compute_seed_mask(fcSeed, lon, lat, workers):
"""
Find the cell centers (points) closes to the given seed points and set
the resulting mask to 1 there
"""
points = numpy.vstack((lon, lat)).T
flann = pyflann.FLANN()
flann.build_index(points, algorithm='kmeans', target_precision=1.0,
random_seed=0)

tree = KDTree(points)

mask = numpy.zeros(len(lon), dtype=int)

points = numpy.zeros((len(fcSeed.features), 2))
for index, feature in enumerate(fcSeed.features):
points[index, :] = feature['geometry']['coordinates']

indices, distances = flann.nn_index(points, checks=2000, random_seed=0)
_, indices = tree.query(points, workers=workers)

for index in indices:
mask[index] = 1

Expand Down
Loading

0 comments on commit 9484daa

Please sign in to comment.