numpy dot()和Python 3.5+矩阵乘法@之间的区别


119

我最近使用Python 3.5,注意到新的矩阵乘法运算符(@)有时与numpy点运算符的行为有所不同。例如,对于3d阵列:

import numpy as np

a = np.random.rand(8,13,13)
b = np.random.rand(8,13,13)
c = a @ b  # Python 3.5+
d = np.dot(a, b)

@运算符返回形状的阵列:

c.shape
(8, 13, 13)

np.dot()函数返回时:

d.shape
(8, 13, 8, 13)

如何用numpy点重现相同的结果?还有其他重大区别吗?


5
您无法从点中得到该结果。我认为人们普遍认为dot对高维输入的处理是错误的设计决定。
user2357112支持Monica 2015年

他们为什么不在matmul几年前实现功能? @作为infix运算符,它是新的,但是没有它,该函数也可以正常工作。
hpaulj 2015年

Answers:


140

@运营商称阵列的__matmul__方法,而不是dot。此方法在API中也作为函数存在np.matmul

>>> a = np.random.rand(8,13,13)
>>> b = np.random.rand(8,13,13)
>>> np.matmul(a, b).shape
(8, 13, 13)

从文档中:

matmul区别在于dot两个重要方面。

  • 标量不能相乘。
  • 将矩阵堆栈一起广播,就好像矩阵是元素一样。

最后一点很清楚,当传递3D(或更高维)数组时,dotmatmul方法的行为会有所不同。从文档中引用更多内容:

对于matmul

如果任何一个参数为ND,N> 2,则将其视为驻留在最后两个索引中的一组矩阵,并进行相应广播。

对于np.dot

对于2-D数组,它等效于矩阵乘法,对于1-D数组,其等效于向量的内积(无复共轭)。对于N维,它是a的最后一个轴和b的倒数第二个轴的和积


13
这里的困惑可能是由于发行说明所引起的,该发行说明将示例代码中的“ @”符号直接等同于numpy的dot()函数。
亚历克斯K

12

@ajcr的答案说明了dotand matmul(由@符号调用)之间的区别。通过看一个简单的例子,可以清楚地看到两者在“矩阵堆栈”或张量上进行操作时的行为有何不同。

为了弄清差异,采用4x4数组,然后将dot乘积和matmul乘积返回3x4x2的“矩阵堆栈”或张量。

import numpy as np
fourbyfour = np.array([
                       [1,2,3,4],
                       [3,2,1,4],
                       [5,4,6,7],
                       [11,12,13,14]
                      ])


threebyfourbytwo = np.array([
                             [[2,3],[11,9],[32,21],[28,17]],
                             [[2,3],[1,9],[3,21],[28,7]],
                             [[2,3],[1,9],[3,21],[28,7]],
                            ])

print('4x4*3x4x2 dot:\n {}\n'.format(np.dot(fourbyfour,twobyfourbythree)))
print('4x4*3x4x2 matmul:\n {}\n'.format(np.matmul(fourbyfour,twobyfourbythree)))

每个操作的结果如下所示。注意点积如何

... a的最后一个轴与b的倒数第二个和的乘积

以及如何通过一起广播矩阵来形成矩阵乘积。

4x4*3x4x2 dot:
 [[[232 152]
  [125 112]
  [125 112]]

 [[172 116]
  [123  76]
  [123  76]]

 [[442 296]
  [228 226]
  [228 226]]

 [[962 652]
  [465 512]
  [465 512]]]

4x4*3x4x2 matmul:
 [[[232 152]
  [172 116]
  [442 296]
  [962 652]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]]

2
dot(a,b)[i,j,k,m] = sum(a [i,j ,:] * b [k,:,m])-------像文档中所说:这是一个a的最后一个轴和b的倒数第二个轴上的总积:
Ronak Agrawal,

不错的收获,它的3x4x2。生成矩阵的另一种方法是a = np.arange(24).reshape(3, 4, 2)创建尺寸为3x4x2的数组。
内森

8

仅供参考,@其numpy的等价物dot,并matmul都大致一样快。(用我的一个项目perfplot创建的图。)

在此处输入图片说明

复制剧情的代码:

import perfplot
import numpy


def setup(n):
    A = numpy.random.rand(n, n)
    x = numpy.random.rand(n)
    return A, x


def at(data):
    A, x = data
    return A @ x


def numpy_dot(data):
    A, x = data
    return numpy.dot(A, x)


def numpy_matmul(data):
    A, x = data
    return numpy.matmul(A, x)


perfplot.show(
    setup=setup,
    kernels=[at, numpy_dot, numpy_matmul],
    n_range=[2 ** k for k in range(12)],
    logx=True,
    logy=True,
)

7

在数学上,我认为numpy中的更有意义

(a,b)_ {i,j,k,a,b,c} =式

因为当a和b是向量时它给出点积,或者当a和b是矩阵时给出矩阵乘积


对于numpy中的matmul操作,它由结果的一部分组成,可以定义为

> matmul(a,b)_ {i,j,k,c} =式

因此,您可以看到matmul(a,b)返回的数组形状较小,从而减少了内存消耗,并在应用程序中更有意义。特别是结合广播,您可以获得

matmul(a,b)_ {i,j,k,l} =式

例如。


从以上两个定义中,您可以看到使用这两个操作的要求。假设a.shape =(s1,s2,s3,s4)b.shape =(t1,t2,t3,t4)

  • 要使用点(a,b),您需要

    1. t3 = s4 ;
  • 要使用matmul(a,b),您需要

    1. t3 = s4
    2. t2 = s2或t2和s2之一为1
    3. t1 = s1或t1和s1之一为1

使用以下代码说服自己。

代码样例

import numpy as np
for it in xrange(10000):
    a = np.random.rand(5,6,2,4)
    b = np.random.rand(6,4,3)
    c = np.matmul(a,b)
    d = np.dot(a,b)
    #print 'c shape: ', c.shape,'d shape:', d.shape

    for i in range(5):
        for j in range(6):
            for k in range(2):
                for l in range(3):
                    if not c[i,j,k,l] == d[i,j,k,j,l]:
                        print it,i,j,k,l,c[i,j,k,l]==d[i,j,k,j,l] #you will not see them

np.matmul还给出了矢量的点积和矩阵的矩阵积。
Subhaneil Lahiri

2

这是与的比较,np.einsum以显示索引的投影方式

np.allclose(np.einsum('ijk,ijk->ijk', a,b), a*b)        # True 
np.allclose(np.einsum('ijk,ikl->ijl', a,b), a@b)        # True
np.allclose(np.einsum('ijk,lkm->ijlm',a,b), a.dot(b))   # True

0

我对MATMUL和DOT的经验

尝试使用MATMUL时,我经常收到“ ValueError:传递的值的形状为(200,1),索引表示(200,3)”。我想要一个快速的解决方法,并发现DOT可以提供相同的功能。使用DOT我没有任何错误。我得到正确的答案

与MATMUL

X.shape
>>>(200, 3)

type(X)

>>>pandas.core.frame.DataFrame

w

>>>array([0.37454012, 0.95071431, 0.73199394])

YY = np.matmul(X,w)

>>>  ValueError: Shape of passed values is (200, 1), indices imply (200, 3)"

与DOT

YY = np.dot(X,w)
# no error message
YY
>>>array([ 2.59206877,  1.06842193,  2.18533396,  2.11366346,  0.28505879, 

YY.shape

>>> (200, )
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.