如果您使用的是自定义训练循环,则可以使用collections.deque
,它是可以附加的“滚动”列表,并且列表长于时,左侧项目会弹出maxlen
。这是一行:
loss_history = deque(maxlen=early_stopping + 1)
for epoch in range(epochs):
fit(epoch)
loss_history.append(test_loss.result().numpy())
if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
break
这是一个完整的示例:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque
data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)
data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))
train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)
model = tf.keras.models.Sequential([
Dense(8, activation='relu'),
Dense(16, activation='relu'),
Dense(info.features['label'].num_classes)])
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
logits = model(inputs, training=True)
loss = loss_object(labels, logits)
gradients = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_acc(labels, logits)
@tf.function
def test_step(inputs, labels):
logits = model(inputs, training=False)
loss = loss_object(labels, logits)
test_loss(loss)
test_acc(labels, logits)
def fit(epoch):
template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
'Train Acc {:.2f} Test Acc {:.2f}'
train_loss.reset_states()
test_loss.reset_states()
train_acc.reset_states()
test_acc.reset_states()
for X_train, y_train in train_dataset:
train_step(X_train, y_train)
for X_test, y_test in test_dataset:
test_step(X_test, y_test)
print(template.format(
epoch + 1,
train_loss.result(),
test_loss.result(),
train_acc.result(),
test_acc.result()
))
def main(epochs=50, early_stopping=10):
loss_history = deque(maxlen=early_stopping + 1)
for epoch in range(epochs):
fit(epoch)
loss_history.append(test_loss.result().numpy())
if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
print(f'\nEarly stopping. No validation loss '
f'improvement in {early_stopping} epochs.')
break
if __name__ == '__main__':
main(epochs=250, early_stopping=10)
Epoch 1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch 2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch 3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch 4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch 5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87
Early stopping. No validation loss improvement in 10 epochs.