我在TensorFlow领域相对较新,对您如何将CSV数据实际读入TensorFlow中可用的示例/标签张量的方式感到困惑。TensorFlow教程中有关读取CSV数据的示例相当分散,仅使您成为能够训练CSV数据的一部分。
这是我根据CSV教程整理而成的代码:
from __future__ import print_function
import tensorflow as tf
def file_len(fname):
with open(fname) as f:
for i, l in enumerate(f):
pass
return i + 1
filename = "csv_test_data.csv"
# setup text reader
file_length = file_len(filename)
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TextLineReader(skip_header_lines=1)
_, csv_row = reader.read(filename_queue)
# setup CSV decoding
record_defaults = [[0],[0],[0],[0],[0]]
col1,col2,col3,col4,col5 = tf.decode_csv(csv_row, record_defaults=record_defaults)
# turn features back into a tensor
features = tf.stack([col1,col2,col3,col4])
print("loading, " + str(file_length) + " line(s)\n")
with tf.Session() as sess:
tf.initialize_all_variables().run()
# start populating filename queue
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(file_length):
# retrieve a single instance
example, label = sess.run([features, col5])
print(example, label)
coord.request_stop()
coord.join(threads)
print("\ndone loading")
这是我正在加载的CSV文件中的一个简短示例-基本数据-4个功能列和1个标签列:
0,0,0,0,0
0,15,0,0,0
0,30,0,0,0
0,45,0,0,0
上面的所有代码都是从CSV文件中逐个打印每个示例,虽然不错,但对于培训来说却毫无用处。
我在这里遇到的困难是如何将这些单独的示例(一个接一个地加载)变成训练数据集。例如,这是我在Udacity深度学习课程中正在使用的笔记本。我基本上想获取正在加载的CSV数据,并将其放入诸如train_dataset和train_labels之类的内容中:
def reformat(dataset, labels):
dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)
# Map 2 to [0.0, 1.0, 0.0 ...], 3 to [0.0, 0.0, 1.0 ...]
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
return dataset, labels
train_dataset, train_labels = reformat(train_dataset, train_labels)
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)
test_dataset, test_labels = reformat(test_dataset, test_labels)
print('Training set', train_dataset.shape, train_labels.shape)
print('Validation set', valid_dataset.shape, valid_labels.shape)
print('Test set', test_dataset.shape, test_labels.shape)
我已经尝试过使用tf.train.shuffle_batch
,就像这样,但它莫名其妙地挂起了:
for i in range(file_length):
# retrieve a single instance
example, label = sess.run([features, colRelevant])
example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=file_length, capacity=file_length, min_after_dequeue=10000)
print(example, label)
综上所述,这是我的问题:
- 我对这个过程缺少什么?
- 感觉到缺少一些关于如何正确构建输入管道的关键直觉。
- 有没有一种方法可以避免必须知道CSV文件的长度?
- 必须知道要处理的
for i in range(file_length)
行数(上面的代码行)感觉很不好
- 必须知道要处理的
编辑: 雅罗斯拉夫(Yaroslav)指出我很可能在这里混合了命令性和图形构造部分后,它变得越来越清晰。我能够整理以下代码,我认为这与从CSV训练模型时通常会执行的代码更接近(不包括任何模型训练代码):
from __future__ import print_function
import numpy as np
import tensorflow as tf
import math as math
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('dataset')
args = parser.parse_args()
def file_len(fname):
with open(fname) as f:
for i, l in enumerate(f):
pass
return i + 1
def read_from_csv(filename_queue):
reader = tf.TextLineReader(skip_header_lines=1)
_, csv_row = reader.read(filename_queue)
record_defaults = [[0],[0],[0],[0],[0]]
colHour,colQuarter,colAction,colUser,colLabel = tf.decode_csv(csv_row, record_defaults=record_defaults)
features = tf.stack([colHour,colQuarter,colAction,colUser])
label = tf.stack([colLabel])
return features, label
def input_pipeline(batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer([args.dataset], num_epochs=num_epochs, shuffle=True)
example, label = read_from_csv(filename_queue)
min_after_dequeue = 10000
capacity = min_after_dequeue + 3 * batch_size
example_batch, label_batch = tf.train.shuffle_batch(
[example, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return example_batch, label_batch
file_length = file_len(args.dataset) - 1
examples, labels = input_pipeline(file_length, 1)
with tf.Session() as sess:
tf.initialize_all_variables().run()
# start populating filename queue
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop():
example_batch, label_batch = sess.run([examples, labels])
print(example_batch)
except tf.errors.OutOfRangeError:
print('Done training, epoch reached')
finally:
coord.request_stop()
coord.join(threads)