Truncate steps_per_execution to run at most one epoch.
PiperOrigin-RevId: 303187360 Change-Id: I3ab3eb46f0a8e60c37b0bf351aef6dff1a868265
This commit is contained in:
parent
e460cd4973
commit
6d947c837a
@ -1811,6 +1811,27 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
self.assertEqual(bc.test_begin_batches, [0, 20, 40])
|
||||
self.assertEqual(bc.test_end_batches, [19, 39, 49])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(distribution=all_strategies, mode=['eager']))
|
||||
def test_host_training_loop_truncate_to_epoch(self, distribution):
|
||||
with distribution.scope():
|
||||
inputs = keras.Input(10)
|
||||
outputs = keras.layers.Dense(1)(inputs)
|
||||
model = keras.Model(inputs, outputs)
|
||||
|
||||
model.compile('sgd', 'mse', experimental_steps_per_execution=500)
|
||||
|
||||
x, y = np.ones((100, 10)), np.ones((100, 1))
|
||||
bc = BatchCountingCB()
|
||||
model.fit(x, y, batch_size=2, epochs=2, callbacks=[bc])
|
||||
self.assertEqual(bc.train_begin_batches, [0, 0])
|
||||
self.assertEqual(bc.train_end_batches, [49, 49])
|
||||
|
||||
x, y = np.ones((50, 10)), np.ones((50, 1))
|
||||
model.evaluate(x, y, batch_size=2, callbacks=[bc])
|
||||
self.assertEqual(bc.test_begin_batches, [0])
|
||||
self.assertEqual(bc.test_end_batches, [24])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
all_strategy_combinations_minus_default()))
|
||||
|
||||
@ -1155,19 +1155,37 @@ class DataHandler(object):
|
||||
|
||||
def enumerate_epochs(self):
|
||||
"""Yields `(epoch, tf.data.Iterator)`."""
|
||||
data_iterator = iter(self._dataset)
|
||||
for epoch in range(self._initial_epoch, self._epochs):
|
||||
if self._insufficient_data: # Set by `catch_stop_iteration`.
|
||||
break
|
||||
if self._adapter.should_recreate_iterator():
|
||||
if ds_context.has_strategy():
|
||||
# TODO(b/138326910): remove this when MultiDeviceIterator is a
|
||||
# CompositeTensor (unless this is more efficient)
|
||||
data_iterator._initializer # pylint: disable=pointless-statement, protected-access
|
||||
else:
|
||||
data_iterator = iter(self._dataset)
|
||||
yield epoch, data_iterator
|
||||
self._adapter.on_epoch_end()
|
||||
with self._truncate_execution_to_epoch():
|
||||
data_iterator = iter(self._dataset)
|
||||
for epoch in range(self._initial_epoch, self._epochs):
|
||||
if self._insufficient_data: # Set by `catch_stop_iteration`.
|
||||
break
|
||||
if self._adapter.should_recreate_iterator():
|
||||
if ds_context.has_strategy():
|
||||
# TODO(b/138326910): remove this when MultiDeviceIterator is a
|
||||
# CompositeTensor (unless this is more efficient)
|
||||
data_iterator._initializer # pylint: disable=pointless-statement, protected-access
|
||||
else:
|
||||
data_iterator = iter(self._dataset)
|
||||
yield epoch, data_iterator
|
||||
self._adapter.on_epoch_end()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _truncate_execution_to_epoch(self):
|
||||
"""Truncates steps per execution to at most one epoch."""
|
||||
should_truncate = (
|
||||
self._inferred_steps is not None and
|
||||
self._steps_per_execution_value > self._inferred_steps)
|
||||
original_value = self._steps_per_execution_value
|
||||
try:
|
||||
if should_truncate:
|
||||
self._steps_per_execution.assign(self._inferred_steps)
|
||||
self._steps_per_execution_value = self._inferred_steps
|
||||
yield
|
||||
finally:
|
||||
if should_truncate:
|
||||
self._steps_per_execution.assign(original_value)
|
||||
self._steps_per_execution_value = original_value
|
||||
|
||||
@contextlib.contextmanager
|
||||
def catch_stop_iteration(self):
|
||||
|
||||
@ -341,7 +341,9 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
on TPUs or small models with a large Python overhead. Note that if
|
||||
this value is set to `N`, `Callback.on_batch` methods will only be
|
||||
called every `N` batches. This currently defaults to `1`. At most,
|
||||
one full epoch can be run each execution.
|
||||
one full epoch will be run each execution. If a number larger than
|
||||
the size of the epoch is passed, the execution will be truncated
|
||||
to the size of the epoch.
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid arguments for
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user