PyAPI: add optional filter argument to KDTree.find
This commit is contained in:
@@ -240,17 +240,23 @@ class QuaternionTesting(unittest.TestCase):
|
||||
|
||||
|
||||
class KDTreeTesting(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def kdtree_create_grid_3d(tot):
|
||||
k = kdtree.KDTree(tot * tot * tot)
|
||||
def kdtree_create_grid_3d_data(tot):
|
||||
index = 0
|
||||
mul = 1.0 / (tot - 1)
|
||||
for x in range(tot):
|
||||
for y in range(tot):
|
||||
for z in range(tot):
|
||||
k.insert((x * mul, y * mul, z * mul), index)
|
||||
yield (x * mul, y * mul, z * mul), index
|
||||
index += 1
|
||||
|
||||
@staticmethod
|
||||
def kdtree_create_grid_3d(tot, *, filter_fn=None):
|
||||
k = kdtree.KDTree(tot * tot * tot)
|
||||
for co, index in KDTreeTesting.kdtree_create_grid_3d_data(tot):
|
||||
if (filter_fn is not None) and (not filter_fn(co, index)):
|
||||
continue
|
||||
k.insert(co, index)
|
||||
k.balance()
|
||||
return k
|
||||
|
||||
@@ -327,6 +333,49 @@ class KDTreeTesting(unittest.TestCase):
|
||||
ret = k.find_n((1.0,) * 3, tot)
|
||||
self.assertEqual(len(ret), tot)
|
||||
|
||||
def test_kdtree_grid_filter_simple(self):
|
||||
size = 10
|
||||
k = self.kdtree_create_grid_3d(size)
|
||||
|
||||
# filter exact index
|
||||
ret_regular = k.find((1.0,) * 3)
|
||||
ret_filter = k.find((1.0,) * 3, filter=lambda i: i == ret_regular[1])
|
||||
self.assertEqual(ret_regular, ret_filter)
|
||||
ret_filter = k.find((-1.0,) * 3, filter=lambda i: i == ret_regular[1])
|
||||
self.assertEqual(ret_regular[:2], ret_filter[:2]) # ignore distance
|
||||
|
||||
def test_kdtree_grid_filter_pairs(self):
|
||||
size = 10
|
||||
k_all = self.kdtree_create_grid_3d(size)
|
||||
k_odd = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 1)
|
||||
k_evn = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 0)
|
||||
|
||||
samples = 5
|
||||
mul = 1 / (samples - 1)
|
||||
for x in range(samples):
|
||||
for y in range(samples):
|
||||
for z in range(samples):
|
||||
co = (x * mul, y * mul, z * mul)
|
||||
|
||||
ret_regular = k_odd.find(co)
|
||||
self.assertEqual(ret_regular[1] % 2, 1)
|
||||
ret_filter = k_all.find(co, lambda i: (i % 2) == 1)
|
||||
self.assertEqual(ret_regular, ret_filter)
|
||||
|
||||
ret_regular = k_evn.find(co)
|
||||
self.assertEqual(ret_regular[1] % 2, 0)
|
||||
ret_filter = k_all.find(co, lambda i: (i % 2) == 0)
|
||||
self.assertEqual(ret_regular, ret_filter)
|
||||
|
||||
|
||||
# filter out all values (search odd tree for even values and the reverse)
|
||||
co = (0,) * 3
|
||||
ret_filter = k_odd.find(co, lambda i: (i % 2) == 0)
|
||||
self.assertEqual(ret_filter[1], None)
|
||||
|
||||
ret_filter = k_evn.find(co, lambda i: (i % 2) == 1)
|
||||
self.assertEqual(ret_filter[1], None)
|
||||
|
||||
def test_kdtree_invalid_size(self):
|
||||
with self.assertRaises(ValueError):
|
||||
kdtree.KDTree(-1)
|
||||
@@ -342,6 +391,21 @@ class KDTreeTesting(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
k.find(co)
|
||||
|
||||
def test_kdtree_invalid_filter(self):
|
||||
k = kdtree.KDTree(1)
|
||||
k.insert((0,) * 3, 0)
|
||||
k.balance()
|
||||
# not callable
|
||||
with self.assertRaises(TypeError):
|
||||
k.find((0,) * 3, filter=None)
|
||||
# no args
|
||||
with self.assertRaises(TypeError):
|
||||
k.find((0,) * 3, filter=lambda: None)
|
||||
# bad return value
|
||||
with self.assertRaises(ValueError):
|
||||
k.find((0,) * 3, filter=lambda i: None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
sys.argv = [__file__] + (sys.argv[sys.argv.index("--") + 1:] if "--" in sys.argv else [])
|
||||
|
Reference in New Issue
Block a user