Support steps_per_execution in Model.predict

PiperOrigin-RevId: 308715987
Change-Id: I2bfe838c890ccb5ab3fe5d980255dc6eab20cc7e
This commit is contained in:
Thomas O'Malley 2020-04-27 15:58:17 -07:00 committed by TensorFlower Gardener
parent a59f9bac82
commit 690afb0e68
3 changed files with 69 additions and 15 deletions

View File

@ -334,6 +334,8 @@ class BatchCountingCB(keras.callbacks.Callback):
self.train_end_batches = []
self.test_begin_batches = []
self.test_end_batches = []
self.predict_begin_batches = []
self.predict_end_batches = []
def on_train_batch_begin(self, batch, logs=None):
self.train_begin_batches.append(batch)
@ -347,6 +349,12 @@ class BatchCountingCB(keras.callbacks.Callback):
def on_test_batch_end(self, batch, logs=None):
self.test_end_batches.append(batch)
def on_predict_batch_begin(self, batch, logs=None):
self.predict_begin_batches.append(batch)
def on_predict_batch_end(self, batch, logs=None):
self.predict_end_batches.append(batch)
class TestDistributionStrategyWithNumpyArrays(test.TestCase,
parameterized.TestCase):
@ -1763,6 +1771,10 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertEqual(bc.test_begin_batches, [0, 10, 20, 30, 40])
self.assertEqual(bc.test_end_batches, [9, 19, 29, 39, 49])
model.predict(x, batch_size=2, callbacks=[bc])
self.assertEqual(bc.predict_begin_batches, [0, 10, 20, 30, 40])
self.assertEqual(bc.predict_end_batches, [9, 19, 29, 39, 49])
@combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager']))
def test_host_training_loop_last_partial_execution(self, distribution):
@ -1783,6 +1795,10 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertEqual(bc.test_begin_batches, [0, 20, 40])
self.assertEqual(bc.test_end_batches, [19, 39, 49])
model.predict(x, batch_size=2, callbacks=[bc])
self.assertEqual(bc.predict_begin_batches, [0, 20, 40])
self.assertEqual(bc.predict_end_batches, [19, 39, 49])
@combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager']))
def test_host_training_loop_dataset_unknown_size(self, distribution):
@ -1814,6 +1830,11 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertEqual(bc.test_begin_batches, [0, 20, 40])
self.assertEqual(bc.test_end_batches, [19, 39, 49])
predict_ds = ds.repeat(2)
model.predict(predict_ds, steps=50, callbacks=[bc])
self.assertEqual(bc.predict_begin_batches, [0, 20, 40])
self.assertEqual(bc.predict_end_batches, [19, 39, 49])
@combinations.generate(
combinations.combine(distribution=all_strategies, mode=['eager']))
def test_host_training_loop_truncate_to_epoch(self, distribution):
@ -1835,6 +1856,11 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
self.assertEqual(bc.test_begin_batches, [0])
self.assertEqual(bc.test_end_batches, [24])
x = np.ones((50, 10))
model.predict(x, batch_size=2, callbacks=[bc])
self.assertEqual(bc.predict_begin_batches, [0])
self.assertEqual(bc.predict_end_batches, [24])
@combinations.generate(
combinations.times(
all_strategy_combinations_minus_default()))

View File

@ -1101,20 +1101,21 @@ class DataHandler(object):
workers=1,
use_multiprocessing=False,
model=None,
steps_per_execution=1):
steps_per_execution=None):
self._initial_epoch = initial_epoch
self._epochs = epochs
self._insufficient_data = False
self._model = model
self._steps_per_execution = steps_per_execution
# This `Variable` is assigned to by `DataHandler` to allow partial
# executions. Save its original value here to reset after a partial
# execution.
if isinstance(steps_per_execution, int):
self._steps_per_execution_value = steps_per_execution
# `steps_per_execution_value` is the cached initial value.
# `steps_per_execution` is mutable and may be changed by the DataAdapter
# to handle partial executions.
if steps_per_execution is None:
self._steps_per_execution = 1
self._steps_per_execution_value = 1
else:
self._steps_per_execution = steps_per_execution
self._steps_per_execution_value = steps_per_execution.numpy().item()
adapter_cls = select_data_adapter(x, y)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import copy
import itertools
from tensorflow.python.autograph.lang import directives
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
from tensorflow.python.distribute import distribution_strategy_context as ds_context
@ -199,6 +200,8 @@ class Model(network.Network, version_utils.ModelVersionSelector):
self.compiled_loss = None
self.compiled_metrics = None
self._steps_per_execution = None
self._init_batch_counters()
@trackable.no_automatic_dependency_tracking
@ -1231,22 +1234,44 @@ class Model(network.Network, version_utils.ModelVersionSelector):
if self.predict_function is not None:
return self.predict_function
def predict_function(iterator):
"""Runs one call to `self.predict_function`."""
def step_function(model, iterator):
"""Runs a single evaluation step."""
def run_step(data):
outputs = self.predict_step(data)
# Ensure counter is updated only if `predict_step` succeeds.
outputs = model.predict_step(data)
# Ensure counter is updated only if `test_step` succeeds.
with ops.control_dependencies(_minimum_control_deps(outputs)):
self._predict_counter.assign_add(1)
model._predict_counter.assign_add(1) # pylint: disable=protected-access
return outputs
data = next(iterator)
outputs = self.distribute_strategy.run(run_step, args=(data,))
outputs = model.distribute_strategy.run(run_step, args=(data,))
outputs = reduce_per_replica(
outputs, self.distribute_strategy, reduction='concat')
return outputs
if (self._steps_per_execution is None or
self._steps_per_execution.numpy().item() == 1):
def predict_function(iterator):
"""Runs an evaluation execution with one step."""
return step_function(self, iterator)
else:
def predict_function(iterator):
"""Runs an evaluation execution with multiple steps."""
outputs = step_function(self, iterator)
for _ in math_ops.range(self._steps_per_execution - 1):
directives.set_loop_options(
shape_invariants=[(
t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape)
for t in nest.flatten(outputs)])
step_outputs = step_function(self, iterator)
outputs = nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs,
step_outputs)
return outputs
if not self.run_eagerly:
predict_function = def_function.function(
predict_function, experimental_relax_shapes=True)
@ -1345,7 +1370,8 @@ class Model(network.Network, version_utils.ModelVersionSelector):
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
model=self)
model=self,
steps_per_execution=self._steps_per_execution)
# Container that configures and calls `tf.keras.Callback`s.
if not isinstance(callbacks, callbacks_module.CallbackList):
@ -1377,7 +1403,8 @@ class Model(network.Network, version_utils.ModelVersionSelector):
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs, batch_outputs)
callbacks.on_predict_batch_end(step, {'outputs': batch_outputs})
end_step = step + data_handler.step_increment
callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
callbacks.on_predict_end()
all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
return tf_utils.to_numpy_or_python_type(all_outputs)