PyAPI: add optional filter argument to KDTree.find

This commit is contained in:
Campbell Barton
2015-12-06 21:33:39 +11:00
parent 54b95c30ae
commit 9964eed9ac
2 changed files with 122 additions and 11 deletions

View File

@@ -240,17 +240,23 @@ class QuaternionTesting(unittest.TestCase):
class KDTreeTesting(unittest.TestCase):
@staticmethod
def kdtree_create_grid_3d(tot):
k = kdtree.KDTree(tot * tot * tot)
def kdtree_create_grid_3d_data(tot):
index = 0
mul = 1.0 / (tot - 1)
for x in range(tot):
for y in range(tot):
for z in range(tot):
k.insert((x * mul, y * mul, z * mul), index)
yield (x * mul, y * mul, z * mul), index
index += 1
@staticmethod
def kdtree_create_grid_3d(tot, *, filter_fn=None):
k = kdtree.KDTree(tot * tot * tot)
for co, index in KDTreeTesting.kdtree_create_grid_3d_data(tot):
if (filter_fn is not None) and (not filter_fn(co, index)):
continue
k.insert(co, index)
k.balance()
return k
@@ -327,6 +333,49 @@ class KDTreeTesting(unittest.TestCase):
ret = k.find_n((1.0,) * 3, tot)
self.assertEqual(len(ret), tot)
def test_kdtree_grid_filter_simple(self):
size = 10
k = self.kdtree_create_grid_3d(size)
# filter exact index
ret_regular = k.find((1.0,) * 3)
ret_filter = k.find((1.0,) * 3, filter=lambda i: i == ret_regular[1])
self.assertEqual(ret_regular, ret_filter)
ret_filter = k.find((-1.0,) * 3, filter=lambda i: i == ret_regular[1])
self.assertEqual(ret_regular[:2], ret_filter[:2]) # ignore distance
def test_kdtree_grid_filter_pairs(self):
size = 10
k_all = self.kdtree_create_grid_3d(size)
k_odd = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 1)
k_evn = self.kdtree_create_grid_3d(size, filter_fn=lambda co, i: (i % 2) == 0)
samples = 5
mul = 1 / (samples - 1)
for x in range(samples):
for y in range(samples):
for z in range(samples):
co = (x * mul, y * mul, z * mul)
ret_regular = k_odd.find(co)
self.assertEqual(ret_regular[1] % 2, 1)
ret_filter = k_all.find(co, lambda i: (i % 2) == 1)
self.assertEqual(ret_regular, ret_filter)
ret_regular = k_evn.find(co)
self.assertEqual(ret_regular[1] % 2, 0)
ret_filter = k_all.find(co, lambda i: (i % 2) == 0)
self.assertEqual(ret_regular, ret_filter)
# filter out all values (search odd tree for even values and the reverse)
co = (0,) * 3
ret_filter = k_odd.find(co, lambda i: (i % 2) == 0)
self.assertEqual(ret_filter[1], None)
ret_filter = k_evn.find(co, lambda i: (i % 2) == 1)
self.assertEqual(ret_filter[1], None)
def test_kdtree_invalid_size(self):
with self.assertRaises(ValueError):
kdtree.KDTree(-1)
@@ -342,6 +391,21 @@ class KDTreeTesting(unittest.TestCase):
with self.assertRaises(RuntimeError):
k.find(co)
def test_kdtree_invalid_filter(self):
k = kdtree.KDTree(1)
k.insert((0,) * 3, 0)
k.balance()
# not callable
with self.assertRaises(TypeError):
k.find((0,) * 3, filter=None)
# no args
with self.assertRaises(TypeError):
k.find((0,) * 3, filter=lambda: None)
# bad return value
with self.assertRaises(ValueError):
k.find((0,) * 3, filter=lambda i: None)
if __name__ == '__main__':
import sys
sys.argv = [__file__] + (sys.argv[sys.argv.index("--") + 1:] if "--" in sys.argv else [])