Skip to content

Commit

Permalink
added option for unique/non-unique returns (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
kip-hart authored Jun 1, 2021
1 parent b07ee9e commit dcadb9b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
18 changes: 12 additions & 6 deletions aabbtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def does_overlap(self, aabb, method='DFS', closed=False):

return len(_overlap_pairs(self, aabb, method, True, closed)) > 0

def overlap_aabbs(self, aabb, method='DFS', closed=False):
def overlap_aabbs(self, aabb, method='DFS', closed=False, unique=True):
"""Get overlapping AABBs
This function gets each overlapping AABB.
Expand All @@ -491,17 +491,19 @@ def overlap_aabbs(self, aabb, method='DFS', closed=False):
closed (bool): Option to specify closed or open box intersection.
If open, there must be a non-zero amount of overlap. If closed,
boxes can be touching.
unique (bool): Return only unique pairs. Defaults to True.
Returns:
list: AABB objects in AABBTree that overlap with the input.
"""
pairs = _overlap_pairs(self, aabb, method, closed=closed)
pairs = _overlap_pairs(self, aabb, method, closed=closed,
unique=unique)
if len(pairs) == 0:
return []
boxes, _ = zip(*pairs)
return list(boxes)

def overlap_values(self, aabb, method='DFS', closed=False):
def overlap_values(self, aabb, method='DFS', closed=False, unique=True):
"""Get values of overlapping AABBs
This function gets the value field of each overlapping AABB.
Expand All @@ -519,11 +521,13 @@ def overlap_values(self, aabb, method='DFS', closed=False):
closed (bool): Option to specify closed or open box intersection.
If open, there must be a non-zero amount of overlap. If closed,
boxes can be touching.
unique (bool): Return only unique pairs. Defaults to True.
Returns:
list: Value fields of each node that overlaps.
"""
pairs = _overlap_pairs(self, aabb, method, closed=closed)
pairs = _overlap_pairs(self, aabb, method, closed=closed,
unique=unique)
if len(pairs) == 0:
return []
_, values = zip(*pairs)
Expand All @@ -537,7 +541,8 @@ def _merge(lims1, lims2):
return (lower, upper)


def _overlap_pairs(in_tree, aabb, method='DFS', halt=False, closed=False):
def _overlap_pairs(in_tree, aabb, method='DFS', halt=False, closed=False,
unique=True):
"""Get overlapping AABBs and values in (AABB, value) pairs
*New in version 2.6.0*
Expand All @@ -553,6 +558,7 @@ def _overlap_pairs(in_tree, aabb, method='DFS', halt=False, closed=False):
halt (bool): Return the list immediately once a pair has been
added.
closed (bool): Check for closed box intersection. Defaults to False.
unique (bool): Return only unique pairs. Defaults to True.
Returns:
list: (AABB, value) pairs in AABBTree that overlap with the input.
Expand All @@ -571,7 +577,7 @@ def _overlap_pairs(in_tree, aabb, method='DFS', halt=False, closed=False):
e_str = "method should be 'DFS' or 'BFS', not " + str(method)
raise ValueError(e_str)

if len(pairs) < 2:
if len(pairs) < 2 or not unique:
return pairs
return _unique_pairs(pairs)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def read(fname):

setup(
name='aabbtree',
version='2.7.0',
version='2.8.0',
license='MIT',
description='Pure Python implementation of d-dimensional AABB tree.',
long_description=read('README.rst'),
Expand Down
16 changes: 16 additions & 0 deletions tests/test_aabbtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,22 @@ def test_depth():
assert standard_tree().depth == 2


def test_unique():
tree = AABBTree()
aabb1 = AABB([(0, 1)])
aabb2 = AABB([(0, 1)])
aabb3 = AABB([(0, 1)])
tree.add(aabb1, 'box 1')
tree.add(aabb2, 'box 2')
vals = tree.overlap_values(aabb3, unique=True)
assert len(vals) == 1

vals = tree.overlap_values(aabb3, unique=False)
assert len(vals) == 2
assert 'box 1' in vals
assert 'box 2' in vals


def standard_aabbs():
aabb1 = AABB([(0, 1), (0, 1)])
aabb2 = AABB([(3, 4), (0, 1)])
Expand Down

0 comments on commit dcadb9b

Please sign in to comment.