pytorch中torch.no_grad有什么用?


21

我是pytorch的新手,并从此 github代码开始。我不理解代码中60-61行中的注释"because weights have requires_grad=True, but we don't need to track this in autograd"。我了解我们提到requires_grad=True了使用autograd计算梯度所需的变量,但这意味着什么"tracked by autograd"呢?

Answers:


24

包装器“ with torch.no_grad()”将所有require_grad标志临时设置为false。官方PyTorch教程(https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients)中的示例:

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

出:

True
True
False

我建议您阅读上面网站上的所有教程。

在您的示例中:我想作者不希望PyTorch计算新定义的变量w1和w2的梯度,因为他只想更新它们的值。


6
with torch.no_grad()

将使块中的所有操作没有渐变。

在pytorch中,您无法更改w1和w2的位置,这是使用的两个变量require_grad = True。我认为避免w1和w2的位置变化是因为它将在反向传播计算中引起错误。由于位置变化将完全改变w1和w2。

但是,如果使用this no_grad(),则可以控制新的w1和新的w2没有渐变,因为它们是通过操作生成的,这意味着您仅更改w1和w2的值,而不是渐变部分,它们仍然具有先前定义的变量渐变信息反向传播可以继续。

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.