Truncate steps_per_execution to run at most one epoch.

PiperOrigin-RevId: 303187360
Change-Id: I3ab3eb46f0a8e60c37b0bf351aef6dff1a868265
This commit is contained in:
Thomas O'Malley 2020-03-26 13:54:12 -07:00 committed by TensorFlower Gardener
parent e460cd4973
commit 6d947c837a
3 changed files with 55 additions and 14 deletions

View File

@ -1811,6 +1811,27 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertEqual(bc.test_begin_batches, [0, 20, 40])
self.assertEqual(bc.test_end_batches, [19, 39, 49])
@combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager']))
def test_host_training_loop_truncate_to_epoch(self, distribution):
with distribution.scope():
inputs = keras.Input(10)
outputs = keras.layers.Dense(1)(inputs)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse', experimental_steps_per_execution=500)
x, y = np.ones((100, 10)), np.ones((100, 1))
bc = BatchCountingCB()
model.fit(x, y, batch_size=2, epochs=2, callbacks=[bc])
self.assertEqual(bc.train_begin_batches, [0, 0])
self.assertEqual(bc.train_end_batches, [49, 49])
x, y = np.ones((50, 10)), np.ones((50, 1))
model.evaluate(x, y, batch_size=2, callbacks=[bc])
self.assertEqual(bc.test_begin_batches, [0])
self.assertEqual(bc.test_end_batches, [24])
@combinations.generate(
combinations.times(
all_strategy_combinations_minus_default()))

View File

@ -1155,19 +1155,37 @@ class DataHandler(object):
def enumerate_epochs(self):
"""Yields `(epoch, tf.data.Iterator)`."""
data_iterator = iter(self._dataset)
for epoch in range(self._initial_epoch, self._epochs):
if self._insufficient_data: # Set by `catch_stop_iteration`.
break
if self._adapter.should_recreate_iterator():
if ds_context.has_strategy():
# TODO(b/138326910): remove this when MultiDeviceIterator is a
# CompositeTensor (unless this is more efficient)
data_iterator._initializer # pylint: disable=pointless-statement, protected-access
else:
data_iterator = iter(self._dataset)
yield epoch, data_iterator
self._adapter.on_epoch_end()
with self._truncate_execution_to_epoch():
data_iterator = iter(self._dataset)
for epoch in range(self._initial_epoch, self._epochs):
if self._insufficient_data: # Set by `catch_stop_iteration`.
break
if self._adapter.should_recreate_iterator():
if ds_context.has_strategy():
# TODO(b/138326910): remove this when MultiDeviceIterator is a
# CompositeTensor (unless this is more efficient)
data_iterator._initializer # pylint: disable=pointless-statement, protected-access
else:
data_iterator = iter(self._dataset)
yield epoch, data_iterator
self._adapter.on_epoch_end()
@contextlib.contextmanager
def _truncate_execution_to_epoch(self):
"""Truncates steps per execution to at most one epoch."""
should_truncate = (
self._inferred_steps is not None and
self._steps_per_execution_value > self._inferred_steps)
original_value = self._steps_per_execution_value
try:
if should_truncate:
self._steps_per_execution.assign(self._inferred_steps)
self._steps_per_execution_value = self._inferred_steps
yield
finally:
if should_truncate:
self._steps_per_execution.assign(original_value)
self._steps_per_execution_value = original_value
@contextlib.contextmanager
def catch_stop_iteration(self):

View File

@ -341,7 +341,9 @@ class Model(network.Network, version_utils.ModelVersionSelector):
on TPUs or small models with a large Python overhead. Note that if
this value is set to `N`, `Callback.on_batch` methods will only be
called every `N` batches. This currently defaults to `1`. At most,
one full epoch can be run each execution.
one full epoch will be run each execution. If a number larger than
the size of the epoch is passed, the execution will be truncated
to the size of the epoch.
Raises:
ValueError: In case of invalid arguments for