为什么我们需要在PyTorch中调用zero_grad()?


Answers:


143

在中PyTorch,我们需要在开始进行反向传播之前将梯度设置为零,因为PyTorch在随后的反向传递中累积梯度。在训练RNN时这很方便。因此,默认操作是在每次调用时累积(即求和)梯度loss.backward()

因此,理想情况下,当您开始训练循环时,应该zero out the gradients正确进行参数更新。否则,梯度将指向预期方向以外的其他方向,即朝向最小值(或最大化,如果达到最大化目标)。

这是一个简单的示例:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

或者,如果您要进行香草梯度下降,则:

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

注意:当张量上调用时,会发生梯度的累积(即求和)。.backward()loss


3
非常感谢,这真的很有帮助!您是否知道张量流是否具有该行为?
散播机

只是要确保..如果您不这样做,那么您将遇到爆炸性的梯度问题,对吗?
zwep

2
@zwep如果我们累积渐变,并不意味着渐变的幅度会增加:例如,渐变的符号会不断翻转。因此,这不能保证您会遇到爆炸性梯度问题。此外,即使您正确调零,也存在爆炸梯度。
汤姆·罗斯

当您运行香草梯度下降时,尝试更新权重时是否出现“需要在现场操作中使用grad的叶子变量”错误?
MUAS

1

如果您使用渐变方法来减少错误(或损失),zero_grad()将重新启动循环而不会损失上一步

如果您不使用zero_grad(),那么损失将减少而不是按需增加

例如,如果您使用zero_grad(),则会发现以下输出:

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

如果不使用zero_grad(),则会发现以下输出:

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.