通常做法是将批次的平均损失而不是总和减到最小?


17

Tensorflow有一个有关对CIFAR-10进行分类的示例教程。在本教程中,批次中的平均交叉熵损失最小。

def loss(logits, labels):
  """Add L2Loss to all the trainable variables.
  Add summary for for "Loss" and "Loss/avg".
  Args:
    logits: Logits from inference().
    labels: Labels from distorted_inputs or inputs(). 1-D tensor
            of shape [batch_size]
  Returns:
    Loss tensor of type float.
  """
  # Calculate the average cross entropy loss across the batch.
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits, labels, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)

  # The total loss is defined as the cross entropy loss plus all of the weight
  # decay terms (L2 loss).
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

参见cifar10.py,第267行。

为什么不将批次中的总和最小化?这有什么不同吗?我不明白这将如何影响反向传播计算。


没有确切的总和/平均值相关,但是损耗选择是一种应用程序设计选择。例如,如果您擅长于平均水平,则可以优化平均值。如果您的应用程序对最坏的情况敏感(例如,汽车碰撞),则应优化最大值。
Alex Kreimer

Answers:


17

如pkubik所述,通常有一个不依赖输入的参数正则化项,例如在tensorflow中,就像

# Loss function using L2 Regularization
regularizer = tf.nn.l2_loss(weights)
loss = tf.reduce_mean(loss + beta * regularizer)

在这种情况下,对微型批次进行平均有助于在更改批次大小时在cross_entropy损失与regularizer损失之间保持固定的比率。

此外,学习率也对损失的大小(梯度)敏感,因此为了标准化不同批次大小的结果,取平均值似乎是一个更好的选择。


更新资料

Facebook的这篇论文(准确,大型Minibatch SGD:在1小时内训练ImageNet)显示,实际上根据批次大小来缩放学习率非常有效:

线性缩放规则:当最小批量大小乘以k时,将学习率乘以k。

这与将梯度乘以k并保持学习率不变基本相同,因此我认为没有必要取平均值。


8

我将专注于这一部分:

我不明白这将如何影响反向传播计算。

首先,您可能已经注意到,所得损失值之间的唯一区别是平均损失相对于总和按比例缩小了倍。1个大号小号ü中号=大号一种VGd大号小号ü中号dX=d大号一种VGdX

d大号dX=Δ0大号X+Δ-大号XΔ
dC大号dX=Δ0C大号X+Δ-C大号XΔ
dC大号dX=CΔ0大号X+Δ-大号XΔ=Cd大号dX

在SGD中,我们将使用权重的梯度乘以学习率来更新权重,并且可以清楚地看到我们可以选择此参数,以使最终权重更新相等。第一个更新规则: 和第二个更新规则(假设): λ

w ^:=w ^+λ1个d大号小号ü中号dw ^
λ1个=λ2
w ^:=w ^+λ1个d大号一种VGdw ^=w ^+λ2d大号小号ü中号dw ^


dontloo的出色发现可能表明,使用总和可能是更合适的方法。为了证明似乎更受欢迎的平均值,我补充道,使用和可能会导致权重正则化问题。针对不同批次大小的正则器调整比例因子可能与调整学习率一样烦人。

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.