测试numpy数组是否仅包含零


92

我们初始化一个具有零的numpy数组,如下所示:

np.zeros((N,N+1))

但是,如何检查给定n * n numpy数组矩阵中的所有元素是否为零。
如果所有值确实为零,则该方法仅需要返回True。

Answers:



161

此处发布的其他答案也可以,但是要使用的最清晰,最有效的功能是numpy.any()

>>> all_zeros = not np.any(a)

要么

>>> all_zeros = not a.any()
  • 这是首选,numpy.all(a==0)因为它使用较少的RAM。(它不需要该a==0术语创建的临时数组。)
  • 而且,它比numpy.count_nonzero(a)因为找到第一个非零元素后可以立即返回而更快。
    • 编辑:正如@Rachel在评论中指出的那样,np.any()不再使用“短路”逻辑,因此您不会看到小型阵列的速度优势。

2
由于在一分钟前,numpy的的anyall做的短路。我相信它们是logical_or.reduce和的糖logical_and.reduce。相互比较和我的短路is_inall_false = np.zeros(10**8) all_true = np.ones(10**8) %timeit np.any(all_false) 91.5 ms ± 1.82 ms per loop %timeit np.any(all_true) 93.7 ms ± 6.16 ms per loop %timeit is_in(1, all_true) 293 ns ± 1.65 ns per loop
Rachel

2
很好,谢谢。看起来曾经是短路的行为,但是在某些时候消失了。这个问题的答案中有一些有趣的讨论。
Stuart Berg '18


9

另一个答案是,如果您知道真实/虚假评估0是数组中唯一的虚假元素,则可以利用它。数组中的所有元素都是虚假的,前提是其中没有任何真实的元素。*

>>> a = np.zeros(10)
>>> not np.any(a)
True

但是,答案声称,any由于短路,它比其他选择要快。截至2018年,Numpy allany 都不短路

如果您经常做这种事情,使用numba以下方法制作自己的短路版本非常容易:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

即使没有短路,它们也往往比Numpy的版本更快。count_nonzero是最慢的。

一些输入来检查性能:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

检查:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

*有用allany等效的内容:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))

-9

如果您正在测试所有零,以避免在另一个numpy函数上发出警告,那么请将该行换成尝试,除非block可以省去在您感兴趣的操作之前对零进行测试的麻烦,即

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.