在Keras中拟合模型时,批处理大小和纪元数应为多少?


73

我正在训练970个样本,并验证243个样本。

在Keras中拟合模型以优化val_acc时,批处理大小和纪元数应为多少?是否有基于数据输入大小的经验法则?


7
我会说这在很大程度上取决于您的数据。如果您只是在处理一些简单的任务,例如XOR分类器,那么几百个批处理大小为1的纪元就足以达到99.9%的准确性。对于MNIST,我通常会获得合理的结果,批量大小大约在10到100之间,不到100个纪元。如果没有问题的详细信息,您的体系结构,学习规则/成本函数,数据等,就无法准确回答这一问题。
daniel451 '16

有没有办法在每个训练时期都包含所有数据?
kRazzy R

1
@kRazzyR。实际上,对于每次培训,所有数据都会分批考虑。如果要一次包含所有数据,请使用数据长度的batchsize。
Vivek Ananthan '19

Answers:


66

由于您的数据集非常小(约有1000个样本),因此使用32的批量大小(这很标准)可能会很安全。除非您正在接受数十万或数百万个观测值的训练,否则不会对您的问题产生太大的影响。

要回答有关批量大小和纪元的问题:

一般而言:批量越大,训练进度越快,但不一定总是收敛得那么快。小批量训练速度较慢,但收敛速度更快。这绝对取决于问题。

总的来说,模型会随着更多的训练时间而改善。随着它们的融合,它们的准确性将开始达到平稳状态。尝试使用类似50的值,并绘制历元数(x轴)与精度(y轴)的关系。您会看到它变平的地方。

数据的类型和/或形状是什么?这些是图像还是表格数据?这是一个重要的细节。


4
批处理大小应尽可能大,而不会超出内存。限制批处理大小的唯一另一个原因是,如果您同时获取下一个批处理并在当前批处理上训练模型,则可能会浪费时间来获取下一个批处理(因为它太大了,并且内存分配可能占用大量内存)。模型完成对当前批次的拟合),在这种情况下,最好更快地获取批次以减少模型的停机时间。
BallpointBenpoint,

4
我经常看到批量大小的值是8的倍数。这种选择是否有正式的理由?
彼得

更大的时期会导致过度拟合吗?是否有更多的数据和更少的纪元会导致拟合不足?
Dom045 '20

18

上面的好答案。每个人都提供了很好的意见。

理想情况下,这是应该使用的批量大小的顺序:

{1, 2, 4, 8, 16} - slow 

{ [32, 64],[ 128, 256] }- Good starters

[32, 64] - CPU

[128, 256] - GPU for more boost

1
对我而言,这些价值观非常糟糕。我最终在模型中使用了3000个批处理大小,这比您在此处提出的要多得多。
伊恩·瑞温克尔

6
嗯,有什么资料可以解释为事实吗?
Markus

这是在CNN模型上使用这些批次大小的引用来源。希望这对您有好处。〜干杯arxiv.org/pdf/1606.02228.pdf#page=3&zoom=150,0,125
Beltino Goncalves,

1
这似乎是过于简单化。批处理大小通常取决于输入集的每个项目的复杂性以及正在使用的内存量。根据我的经验,通过逐步缩放批处理大小可以获得最佳结果。对我来说,我最幸运的是,从1开始,每n训练一个小时,我的批处理量就增加一倍,n具体取决于数据集的复杂性或大小,直到达到机器的内存限制,然后继续进行最大的训练批次大小尽可能长。
IanCZane

10

我使用Keras对语音数据执行非线性回归。我的每个语音文件都为我提供了文本文件中25000行的功能,每行包含257个实数值。我使用批处理大小为100,纪元为50的方法Sequential在Keras中使用1个隐藏层来训练 模型。经过50个星期的训练,它收敛得很好val_loss


4

我使用Keras对市场组合建模进行了非线性回归。在Keras中使用3个隐藏层训练序列模型时,批处理大小为32,历时= 100,我得到了最佳结果。通常,批处理大小为32或25是好的,除非您有大型数据集,否则时期= 100。如果是大型数据集,则可以批量处理10个,历时b / w为50到100。上述数据对我来说也很好。


5
对于批量大小值应(优选)在2的幂stackoverflow.com/questions/44483233/...
阿努拉格古普塔

“对于大型数据集,批大小为10 ...”,这不是正确的理解,批大小越大越好,因为渐变是对一个批进行平均的
Prasanjit Rath

-8

Epochs完全可以满足您的要求,具体取决于验证损失何时停止进一步改善。这个应该是批量大小:


# To define function to find batch size for training the model
# use this function to find out the batch size

    def FindBatchSize(model):
        """#model: model architecture, that is yet to be trained"""
        import os, sys, psutil, gc, tensorflow, keras
        import numpy as np
        from keras import backend as K
        BatchFound= 16

        try:
            total_params= int(model.count_params());    GCPU= "CPU"
            #find whether gpu is available
            try:
                if K.tensorflow_backend._get_available_gpus()== []:
                    GCPU= "CPU";    #CPU and Cuda9GPU
                else:
                    GCPU= "GPU"
            except:
                from tensorflow.python.client import device_lib;    #Cuda8GPU
                def get_available_gpus():
                    local_device_protos= device_lib.list_local_devices()
                    return [x.name for x in local_device_protos if x.device_type == 'GPU']
                if "gpu" not in str(get_available_gpus()).lower():
                    GCPU= "CPU"
                else:
                    GCPU= "GPU"

            #decide batch size on the basis of GPU availability and model complexity
            if (GCPU== "GPU") and (os.cpu_count() >15) and (total_params <1000000):
                BatchFound= 64    
            if (os.cpu_count() <16) and (total_params <500000):
                BatchFound= 64  
            if (GCPU== "GPU") and (os.cpu_count() >15) and (total_params <2000000) and (total_params >=1000000):
                BatchFound= 32      
            if (GCPU== "GPU") and (os.cpu_count() >15) and (total_params >=2000000) and (total_params <10000000):
                BatchFound= 16  
            if (GCPU== "GPU") and (os.cpu_count() >15) and (total_params >=10000000):
                BatchFound= 8       
            if (os.cpu_count() <16) and (total_params >5000000):
                BatchFound= 8    
            if total_params >100000000:
                BatchFound= 1

        except:
            pass
        try:

            #find percentage of memory used
            memoryused= psutil.virtual_memory()
            memoryused= float(str(memoryused).replace(" ", "").split("percent=")[1].split(",")[0])
            if memoryused >75.0:
                BatchFound= 8
            if memoryused >85.0:
                BatchFound= 4
            if memoryused >90.0:
                BatchFound= 2
            if total_params >100000000:
                BatchFound= 1
            print("Batch Size:  "+ str(BatchFound));    gc.collect()
        except:
            pass

        memoryused= [];    total_params= [];    GCPU= "";
        del memoryused, total_params, GCPU;    gc.collect()
        return BatchFound
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.