pytorch中的重塑和视图有什么区别?


83

在numpy中,我们ndarray.reshape()用于重塑数组。

我注意到在pytorch中,人们torch.view(...)出于相同的目的而使用,但同时也torch.reshape(...)存在着。

所以我想知道它们之间有什么区别,什么时候应该使用它们中的任何一个?

Answers:


88

torch.view已经存在了很长时间。它将返回具有新形状的张量。返回的张量将与原始张量共享基础数据。请参阅此处文档

另一方面,似乎torch.reshape 最近已在版本0.4中引入。根据文档,此方法将

返回具有与输入相同的数据和元素数量,但具有指定形状的张量。如果可能,返回的张量将是输入视图。否则,它将是副本。连续输入和具有兼容步幅的输入可以在不复制的情况下进行重塑,但是您不应该依赖复制与查看行为。

这意味着torch.reshape可能会返回原始张量的副本或视图。您不能指望它返回视图或副本。根据开发商的说法:

如果需要复制,请使用clone();如果需要相同的存储,请使用view()。reshape()的语义是它可能共享或可能不共享存储,并且您事先不知道。

另一个区别是reshape()可以同时在连续和非连续张量上view()运行,而只能在连续张量上运行。另请参阅此处有关的含义contiguous


28
也许强调torch.view只能在连续的张量上运行,而torch.reshape可以同时在两个张量上运行可能也会有所帮助。
p13rr0m

6
@pierrom连续是指存储在连续内存或其他内容中的张量?
gokul_uf

3
@gokul_uf是的,您可以看一下这里写的答案:stackoverflow.com/questions/48915810/pytorch-contiguous
MBT

pytorch中的“张量视图”是什么意思?
查理·帕克

42

尽管两者torch.viewtorch.reshape都用于重整张量,但这是它们之间的区别。

  1. 顾名思义,torch.view仅创建原始张量的视图。新的张量将始终与原始张量共享其数据。这意味着,如果您更改原始张量,则重塑后的张量将会更改,反之亦然。
>>> z = torch.zeros(3, 2)
>>> x = z.view(2, 3)
>>> z.fill_(1)
>>> x
tensor([[1., 1., 1.],
        [1., 1., 1.]])
  1. 为了确保新的张量始终共享其数据与原始,torch.view规定了两个张量[的形状有些连续性约束文档。通常这不是问题,但torch.view即使两个张量的形状兼容,有时也会引发错误。这是一个著名的反例。
>>> z = torch.zeros(3, 2)
>>> y = z.t()
>>> y.size()
torch.Size([2, 3])
>>> y.view(6)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's
size and stride (at least one dimension spans across two contiguous subspaces).
Call .contiguous() before .view().
  1. torch.reshape不强加任何连续性约束,但也不保证数据共享。新张量可以是原始张量的视图,也可以是全新的张量。
>>> z = torch.zeros(3, 2)
>>> y = z.reshape(6)
>>> x = z.t().reshape(6)
>>> z.fill_(1)
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
>>> y
tensor([1., 1., 1., 1., 1., 1.])
>>> x
tensor([0., 0., 0., 0., 0., 0.])

TL; DR:
如果只想重塑张量,请使用torch.reshape。如果您还担心内存使用情况,并想要确保两个张量共享相同的数据,请使用torch.view


也许只有我一个人,但是我困惑于认为连续性是重塑何时共享数据和不共享数据之间的决定因素。根据我自己的实验,似乎并非如此。(您xy以上都是连续的)。也许可以澄清一下?也许重塑何时可以复制发表评论会有所帮助?
RMurphy 16:09

6

Tensor.reshape()更强大。它适用于任何张量,而Tensor.view()仅适用于twhere的张量t.is_contiguous()==True

解释不连续和连续是另一回事,但是您始终可以在t调用时使张量连续t.contiguous(),然后可以view()无错误地调用。

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.