Fix example by deleting bad steps_per_epoch value and associated TODO.
PiperOrigin-RevId: 238426188
This commit is contained in:
parent
869788b0e6
commit
66df637441
@ -103,16 +103,13 @@ def main(argv):
|
|||||||
fake_tiny_data=FLAGS.fast_test_mode)
|
fake_tiny_data=FLAGS.fast_test_mode)
|
||||||
model.compile(loss=tf.keras.losses.categorical_crossentropy,
|
model.compile(loss=tf.keras.losses.categorical_crossentropy,
|
||||||
optimizer=tf.keras.optimizers.SGD(),
|
optimizer=tf.keras.optimizers.SGD(),
|
||||||
metrics=['accuracy'],
|
metrics=['accuracy'])
|
||||||
# TODO(arnoegw): Remove after investigating huge allocs.
|
|
||||||
run_eagerly=True)
|
|
||||||
print('Training on %s with %d trainable and %d untrainable variables.' %
|
print('Training on %s with %d trainable and %d untrainable variables.' %
|
||||||
('Fashion MNIST' if FLAGS.use_fashion_mnist else 'MNIST',
|
('Fashion MNIST' if FLAGS.use_fashion_mnist else 'MNIST',
|
||||||
len(model.trainable_variables), len(model.non_trainable_variables)))
|
len(model.trainable_variables), len(model.non_trainable_variables)))
|
||||||
model.fit(x_train, y_train,
|
model.fit(x_train, y_train,
|
||||||
batch_size=128,
|
batch_size=128,
|
||||||
epochs=FLAGS.epochs,
|
epochs=FLAGS.epochs,
|
||||||
steps_per_epoch=3,
|
|
||||||
verbose=1,
|
verbose=1,
|
||||||
validation_data=(x_test, y_test))
|
validation_data=(x_test, y_test))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user