训练期间难治的常见原因


85

我注意到在培训期间经常发生这种情况NAN

通常,它似乎是通过权重引入内部产品/完全连接或卷积层中的。

这是因为梯度计算正在爆炸吗?还是因为权重初始化(如果是这样,为什么权重初始化会产生这种效果)?还是可能是由于输入数据的性质引起的?

这里的首要问题很简单:在训练过程中发生NAN的最常见原因是什么?其次,有什么方法可以解决这个问题(为什么它们起作用)?


您在调用特定的MATLAB函数吗?这都是您自己的代码吗?
马修·冈恩

2
@MatthewGunn我不认为这个问题是特定于Matlab的,而是caffe相关的。
Shai 2015年

Answers:


134

好问题。
我多次遇到这种现象。这是我的观察结果:


渐变爆炸

原因:较大的梯度会使学习过程偏离轨道。

您应该期望的是:查看运行时日志,应查看每个迭代的损失值。您会注意到,损失在每次迭代之间都开始显着增长,最终损失将太大而无法用浮点变量表示,它将变为nan

您可以做什么:base_lr(在solver.prototxt中)减少一个数量级(至少)。如果您有多个损耗层,则应检查日志以查看是哪个层造成了梯度爆炸,并减少了loss_weight该特定层(而不是一般层)的(在train_val.prototxt中)base_lr


不良的学习率政策和参数

原因: caffe无法计算出有效的学习率,而得到'inf''nan'取而代之的是,该无效率乘以所有更新,从而使所有参数无效。

您应该期望的是:查看运行时日志,您应该看到学习率本身变为'nan',例如:

... sgd_solver.cpp:106] Iteration 0, lr = -nan

您可以做什么:修复所有影响'solver.prototxt'文件学习率的参数。
例如,如果使用却lr_policy: "poly"忘记定义max_iter参数,最终将得到 lr = nan...
有关caffe学习率的更多信息,请参见此线程


故障损失功能

原因:有时,损耗层中损耗的计算会导致nans出现。例如,InfogainLoss使用非标准化值的Feed,使用带有bug的自定义损失层等。

您应该期望的是:查看运行时日志,您可能不会注意到任何异常情况:损耗在逐渐减少,并且突然nan出现。

您可以做什么:查看是否可以重现错误,将打印输出添加到损失层并调试错误。

例如:一旦我使用了损失,就可以按批次中标签出现的频率归一化惩罚。碰巧的是,如果其中一个培训标签根本没有出现在批次中,那么所计算的损失将产生nans。在那种情况下,处理足够大的批次(相对于标签中的标签数量)足以避免此错误。


输入错误

原因:您有输入nan

您应该期望的是:一旦学习过程“达到”,该错误的输入输出就会变成nan。查看运行时日志,您可能不会注意到任何异常情况:损耗在逐渐减少,并且突然nan出现。

您可以做什么:重新构建输入数据集(lmdb / leveldn / hdf5 ...),以确保训练/验证集中没有不良的图像文件。对于调试,您可以构建一个简单的网络,该网络读取输入层,在其上具有虚拟损耗并遍历所有输入:如果其中一个输入有故障,则该虚拟网也应产生nan


步幅比内核大小较大"Pooling"

由于某些原因,选择stride>kernel_size进行池化可能会导致nans。例如:

layer {
  name: "faulty_pooling"
  type: "Pooling"
  bottom: "x"
  top: "y"
  pooling_param {
    pool: AVE
    stride: 5
    kernel: 3
  }
}

nan在中有s个结果y


不稳定 "BatchNorm"

据报道,由于数值的不稳定性,在某些设置"BatchNorm"层下可能会输出nan。bvlc / caffe中提出了
问题PR#5136正在尝试修复它。


最近,我意识到了debug_info标志:设置debug_info: true该值'solver.prototxt'将使caffe打印在训练过程中记录更多调试信息(包括梯度幅度和激活值):此信息可帮助发现训练过程中的梯度爆炸和其他问题


谢谢,如何解释这些数字?这些数字是多少?pastebin.com/DLYgXK5v为什么每层输出只有一个数字!这些数字应该如何显示,以便有人知道有问题或没有问题!
里卡

@Hossein这正是这篇文章的全部内容。
Shai 2016年

感谢您的回答。我在经过DICE损失训练的图像分割应用程序中遇到了NAN损失(即使在添加了一个小epsilon /平滑度常数之后)。我的数据集包含一些图像,这些图像的对应地面真相不包含任何前景标签,当我从训练中删除这些图像时,损失得以稳定。我不确定为什么会这样?
萨姆拉·艾尔沙德

@samrairshad您是否尝试过增加DICE损失中的epsilon?
Shai

是的,我做到了。我在堆栈溢出处打开了帖子,并在某些时期粘贴了损失的演变过程。这里的参考:stackoverflow.com/questions/62259112/...
萨姆拉伊尔沙德

5

就我而言,未在卷积/解卷积层中设置偏差是原因。

解决方案:将以下内容添加到卷积层参数中。

bias_filler {类型:“常量”值:0}


在matconvnet中看起来如何?我有类似'biases'.init_bias * ones(1,4,single)
h612

4

这个答案与nans的原因无关,而是提出了一种有助于对其进行调试的方法。您可以拥有以下python层:

class checkFiniteLayer(caffe.Layer):
  def setup(self, bottom, top):
    self.prefix = self.param_str
  def reshape(self, bottom, top):
    pass
  def forward(self, bottom, top):
    for i in xrange(len(bottom)):
      isbad = np.sum(1-np.isfinite(bottom[i].data[...]))
      if isbad>0:
        raise Exception("checkFiniteLayer: %s forward pass bottom %d has %.2f%% non-finite elements" %
                        (self.prefix,i,100*float(isbad)/bottom[i].count))
  def backward(self, top, propagate_down, bottom):
    for i in xrange(len(top)):
      if not propagate_down[i]:
        continue
      isf = np.sum(1-np.isfinite(top[i].diff[...]))
        if isf>0:
          raise Exception("checkFiniteLayer: %s backward pass top %d has %.2f%% non-finite elements" %
                          (self.prefix,i,100*float(isf)/top[i].count))

train_val.prototxt在您怀疑会在某些时候在您的某些点添加此层可能会导致麻烦:

layer {
  type: "Python"
  name: "check_loss"
  bottom: "fc2"
  top: "fc2"  # "in-place" layer
  python_param {
    module: "/path/to/python/file/check_finite_layer.py" # must be in $PYTHONPATH
    layer: "checkFiniteLayer"
    param_str: "prefix-check_loss" # string for printouts
  }
}


-1

我试图构建一个稀疏的自动编码器,并在其中添加了几层来引起稀疏。在运行网络时,我遇到了NaN。在删除某些图层(在我的情况下,实际上我必须删除1)后,我发现NaN消失了。因此,我想太多的稀疏性也可能导致NaN(可能已调用了一些0/0计算!!)


您能更具体一点吗?您能否提供有关具有nans的配置和固定配置的详细信息?什么类型的图层?什么参数?

1
@shai我使用了几层InnerProduct(lr_mult 1,delay_mult 1,lr_mult 2,衰减_mult 0,xavier,std:0.01),每层都紧随ReLU(最后一层除外)。我在与MNIST一起工作,如果我没记错的话,架构是784-> 1000-> 500-> 250-> 100-> 30(和对称解码器阶段)。去除30层及其ReLU使NaN消失。
LKB
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.