我需要过滤一个数组以删除低于某个阈值的元素。我当前的代码是这样的:
threshold = 5
a = numpy.array(range(10)) # testing data
b = numpy.array(filter(lambda x: x >= threshold, a))
问题在于,这会使用带有lambda函数(慢速)的过滤器来创建一个临时列表。
由于这是一个非常简单的操作,因此也许有一个numpy函数可以高效地完成此操作,但是我一直找不到它。
我以为实现此目的的另一种方法可能是对数组进行排序,找到阈值的索引,然后从该索引开始返回一个切片,但是即使对于较小的输入这会更快(而且无论如何也不会引起注意) ),随着输入大小的增加,它的渐近渐近效率降低。
有任何想法吗?谢谢!
更新:我也进行了一些测量,当输入为100.000.000条目时,sorting + slicing仍比纯python过滤器快两倍。
In [321]: r = numpy.random.uniform(0, 1, 100000000)
In [322]: %timeit test1(r) # filter
1 loops, best of 3: 21.3 s per loop
In [323]: %timeit test2(r) # sort and slice
1 loops, best of 3: 11.1 s per loop
In [324]: %timeit test3(r) # boolean indexing
1 loops, best of 3: 1.26 s per loop