在Tensorflow中分批训练


11

我目前正在尝试在大型csv文件(> 70GB,超过6000万行)上训练模型。为此,我正在使用tf.contrib.learn.read_batch_examples。我正在努力了解此函数实际上是如何读取数据的。如果我使用的批大小为例如50.000,它会读取文件的前50.000行吗?如果我想遍历整个文件(1个纪元),是否必须对estimator.fit方法使用num_rows / batch_size = 1.200的步数?

这是我当前正在使用的输入函数:

def input_fn(file_names, batch_size):
    # Read csv files and create examples dict
    examples_dict = read_csv_examples(file_names, batch_size)

    # Continuous features
    feature_cols = {k: tf.string_to_number(examples_dict[k],
                                           out_type=tf.float32) for k in CONTINUOUS_COLUMNS}

    # Categorical features
    feature_cols.update({
                            k: tf.SparseTensor(
                                indices=[[i, 0] for i in range(examples_dict[k].get_shape()[0])],
                                values=examples_dict[k],
                                shape=[int(examples_dict[k].get_shape()[0]), 1])
                            for k in CATEGORICAL_COLUMNS})

    label = tf.string_to_number(examples_dict[LABEL_COLUMN], out_type=tf.int32)

    return feature_cols, label


def read_csv_examples(file_names, batch_size):
    def parse_fn(record):
        record_defaults = [tf.constant([''], dtype=tf.string)] * len(COLUMNS)

        return tf.decode_csv(record, record_defaults)

    examples_op = tf.contrib.learn.read_batch_examples(
        file_names,
        batch_size=batch_size,
        queue_capacity=batch_size*2.5,
        reader=tf.TextLineReader,
        parse_fn=parse_fn,
        #read_batch_size= batch_size,
        #randomize_input=True,
        num_threads=8
    )

    # Important: convert examples to dict for ease of use in `input_fn`
    # Map each header to its respective column (COLUMNS order
    # matters!
    examples_dict_op = {}
    for i, header in enumerate(COLUMNS):
        examples_dict_op[header] = examples_op[:, i]

    return examples_dict_op

这是im用于训练模型的代码:

def train_and_eval():
"""Train and evaluate the model."""

m = build_estimator(model_dir)
m.fit(input_fn=lambda: input_fn(train_file_name, batch_size), steps=steps)

如果我使用相同的input_fn再次调用fit函数会发生什么。它是否再次从文件开头开始,还是会记住上次停止的行?


我发现medium.com/@ilblackdragon / ...对在tensorflow input_fn中进行批处理很有帮助
fistynuts

Answers:


1

由于尚无答案,因此我想尝试至少给出一个有用的答案。包括常量定义将有助于理解所提供的代码。

一般而言,一批使用n次记录或项。如何定义项目取决于您的问题。在张量流中,批次在张量的第一维中编码。就您的csv文件而言,它可能是逐行(reader=tf.TextLineReader)。它可以按列学习,但我认为这在您的代码中没有发生。如果您想训练整个数据集(= 一个纪元),可以使用numBatches=numItems/batchSize

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.