在numpy中,我们ndarray.reshape()
用于重塑数组。
我注意到在pytorch中,人们torch.view(...)
出于相同的目的而使用,但同时也torch.reshape(...)
存在着。
所以我想知道它们之间有什么区别,什么时候应该使用它们中的任何一个?
Answers:
torch.view
已经存在了很长时间。它将返回具有新形状的张量。返回的张量将与原始张量共享基础数据。请参阅此处的文档。
另一方面,似乎torch.reshape
最近已在版本0.4中引入。根据文档,此方法将
返回具有与输入相同的数据和元素数量,但具有指定形状的张量。如果可能,返回的张量将是输入视图。否则,它将是副本。连续输入和具有兼容步幅的输入可以在不复制的情况下进行重塑,但是您不应该依赖复制与查看行为。
这意味着torch.reshape
可能会返回原始张量的副本或视图。您不能指望它返回视图或副本。根据开发商的说法:
如果需要复制,请使用clone();如果需要相同的存储,请使用view()。reshape()的语义是它可能共享或可能不共享存储,并且您事先不知道。
另一个区别是reshape()
可以同时在连续和非连续张量上view()
运行,而只能在连续张量上运行。另请参阅此处有关的含义contiguous
。
尽管两者torch.view
和torch.reshape
都用于重整张量,但这是它们之间的区别。
torch.view
仅创建原始张量的视图。新的张量将始终与原始张量共享其数据。这意味着,如果您更改原始张量,则重塑后的张量将会更改,反之亦然。>>> z = torch.zeros(3, 2)
>>> x = z.view(2, 3)
>>> z.fill_(1)
>>> x
tensor([[1., 1., 1.],
[1., 1., 1.]])
torch.view
规定了两个张量[的形状有些连续性约束文档。通常这不是问题,但torch.view
即使两个张量的形状兼容,有时也会引发错误。这是一个著名的反例。>>> z = torch.zeros(3, 2)
>>> y = z.t()
>>> y.size()
torch.Size([2, 3])
>>> y.view(6)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's
size and stride (at least one dimension spans across two contiguous subspaces).
Call .contiguous() before .view().
torch.reshape
不强加任何连续性约束,但也不保证数据共享。新张量可以是原始张量的视图,也可以是全新的张量。>>> z = torch.zeros(3, 2)
>>> y = z.reshape(6)
>>> x = z.t().reshape(6)
>>> z.fill_(1)
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
>>> y
tensor([1., 1., 1., 1., 1., 1.])
>>> x
tensor([0., 0., 0., 0., 0., 0.])
TL; DR:
如果只想重塑张量,请使用torch.reshape
。如果您还担心内存使用情况,并想要确保两个张量共享相同的数据,请使用torch.view
。
x
和y
以上都是连续的)。也许可以澄清一下?也许在重塑何时可以复制时发表评论会有所帮助?