Support steps_per_execution in Model.predict
PiperOrigin-RevId: 308715987 Change-Id: I2bfe838c890ccb5ab3fe5d980255dc6eab20cc7e
This commit is contained in:
parent
a59f9bac82
commit
690afb0e68
@ -334,6 +334,8 @@ class BatchCountingCB(keras.callbacks.Callback):
|
||||
self.train_end_batches = []
|
||||
self.test_begin_batches = []
|
||||
self.test_end_batches = []
|
||||
self.predict_begin_batches = []
|
||||
self.predict_end_batches = []
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
self.train_begin_batches.append(batch)
|
||||
@ -347,6 +349,12 @@ class BatchCountingCB(keras.callbacks.Callback):
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
self.test_end_batches.append(batch)
|
||||
|
||||
def on_predict_batch_begin(self, batch, logs=None):
|
||||
self.predict_begin_batches.append(batch)
|
||||
|
||||
def on_predict_batch_end(self, batch, logs=None):
|
||||
self.predict_end_batches.append(batch)
|
||||
|
||||
|
||||
class TestDistributionStrategyWithNumpyArrays(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
@ -1763,6 +1771,10 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
self.assertEqual(bc.test_begin_batches, [0, 10, 20, 30, 40])
|
||||
self.assertEqual(bc.test_end_batches, [9, 19, 29, 39, 49])
|
||||
|
||||
model.predict(x, batch_size=2, callbacks=[bc])
|
||||
self.assertEqual(bc.predict_begin_batches, [0, 10, 20, 30, 40])
|
||||
self.assertEqual(bc.predict_end_batches, [9, 19, 29, 39, 49])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(distribution=all_strategies, mode=['eager']))
|
||||
def test_host_training_loop_last_partial_execution(self, distribution):
|
||||
@ -1783,6 +1795,10 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
self.assertEqual(bc.test_begin_batches, [0, 20, 40])
|
||||
self.assertEqual(bc.test_end_batches, [19, 39, 49])
|
||||
|
||||
model.predict(x, batch_size=2, callbacks=[bc])
|
||||
self.assertEqual(bc.predict_begin_batches, [0, 20, 40])
|
||||
self.assertEqual(bc.predict_end_batches, [19, 39, 49])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(distribution=all_strategies, mode=['eager']))
|
||||
def test_host_training_loop_dataset_unknown_size(self, distribution):
|
||||
@ -1814,6 +1830,11 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
self.assertEqual(bc.test_begin_batches, [0, 20, 40])
|
||||
self.assertEqual(bc.test_end_batches, [19, 39, 49])
|
||||
|
||||
predict_ds = ds.repeat(2)
|
||||
model.predict(predict_ds, steps=50, callbacks=[bc])
|
||||
self.assertEqual(bc.predict_begin_batches, [0, 20, 40])
|
||||
self.assertEqual(bc.predict_end_batches, [19, 39, 49])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(distribution=all_strategies, mode=['eager']))
|
||||
def test_host_training_loop_truncate_to_epoch(self, distribution):
|
||||
@ -1835,6 +1856,11 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
self.assertEqual(bc.test_begin_batches, [0])
|
||||
self.assertEqual(bc.test_end_batches, [24])
|
||||
|
||||
x = np.ones((50, 10))
|
||||
model.predict(x, batch_size=2, callbacks=[bc])
|
||||
self.assertEqual(bc.predict_begin_batches, [0])
|
||||
self.assertEqual(bc.predict_end_batches, [24])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
all_strategy_combinations_minus_default()))
|
||||
|
@ -1101,20 +1101,21 @@ class DataHandler(object):
|
||||
workers=1,
|
||||
use_multiprocessing=False,
|
||||
model=None,
|
||||
steps_per_execution=1):
|
||||
steps_per_execution=None):
|
||||
|
||||
self._initial_epoch = initial_epoch
|
||||
self._epochs = epochs
|
||||
self._insufficient_data = False
|
||||
self._model = model
|
||||
self._steps_per_execution = steps_per_execution
|
||||
|
||||
# This `Variable` is assigned to by `DataHandler` to allow partial
|
||||
# executions. Save its original value here to reset after a partial
|
||||
# execution.
|
||||
if isinstance(steps_per_execution, int):
|
||||
self._steps_per_execution_value = steps_per_execution
|
||||
# `steps_per_execution_value` is the cached initial value.
|
||||
# `steps_per_execution` is mutable and may be changed by the DataAdapter
|
||||
# to handle partial executions.
|
||||
if steps_per_execution is None:
|
||||
self._steps_per_execution = 1
|
||||
self._steps_per_execution_value = 1
|
||||
else:
|
||||
self._steps_per_execution = steps_per_execution
|
||||
self._steps_per_execution_value = steps_per_execution.numpy().item()
|
||||
|
||||
adapter_cls = select_data_adapter(x, y)
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
@ -199,6 +200,8 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
self.compiled_loss = None
|
||||
self.compiled_metrics = None
|
||||
|
||||
self._steps_per_execution = None
|
||||
|
||||
self._init_batch_counters()
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
@ -1231,22 +1234,44 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
if self.predict_function is not None:
|
||||
return self.predict_function
|
||||
|
||||
def predict_function(iterator):
|
||||
"""Runs one call to `self.predict_function`."""
|
||||
def step_function(model, iterator):
|
||||
"""Runs a single evaluation step."""
|
||||
|
||||
def run_step(data):
|
||||
outputs = self.predict_step(data)
|
||||
# Ensure counter is updated only if `predict_step` succeeds.
|
||||
outputs = model.predict_step(data)
|
||||
# Ensure counter is updated only if `test_step` succeeds.
|
||||
with ops.control_dependencies(_minimum_control_deps(outputs)):
|
||||
self._predict_counter.assign_add(1)
|
||||
model._predict_counter.assign_add(1) # pylint: disable=protected-access
|
||||
return outputs
|
||||
|
||||
data = next(iterator)
|
||||
outputs = self.distribute_strategy.run(run_step, args=(data,))
|
||||
outputs = model.distribute_strategy.run(run_step, args=(data,))
|
||||
outputs = reduce_per_replica(
|
||||
outputs, self.distribute_strategy, reduction='concat')
|
||||
return outputs
|
||||
|
||||
if (self._steps_per_execution is None or
|
||||
self._steps_per_execution.numpy().item() == 1):
|
||||
|
||||
def predict_function(iterator):
|
||||
"""Runs an evaluation execution with one step."""
|
||||
return step_function(self, iterator)
|
||||
|
||||
else:
|
||||
|
||||
def predict_function(iterator):
|
||||
"""Runs an evaluation execution with multiple steps."""
|
||||
outputs = step_function(self, iterator)
|
||||
for _ in math_ops.range(self._steps_per_execution - 1):
|
||||
directives.set_loop_options(
|
||||
shape_invariants=[(
|
||||
t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape)
|
||||
for t in nest.flatten(outputs)])
|
||||
step_outputs = step_function(self, iterator)
|
||||
outputs = nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs,
|
||||
step_outputs)
|
||||
return outputs
|
||||
|
||||
if not self.run_eagerly:
|
||||
predict_function = def_function.function(
|
||||
predict_function, experimental_relax_shapes=True)
|
||||
@ -1345,7 +1370,8 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
max_queue_size=max_queue_size,
|
||||
workers=workers,
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
model=self)
|
||||
model=self,
|
||||
steps_per_execution=self._steps_per_execution)
|
||||
|
||||
# Container that configures and calls `tf.keras.Callback`s.
|
||||
if not isinstance(callbacks, callbacks_module.CallbackList):
|
||||
@ -1377,7 +1403,8 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
batch_outputs,
|
||||
lambda output, batch_output: output.append(batch_output),
|
||||
outputs, batch_outputs)
|
||||
callbacks.on_predict_batch_end(step, {'outputs': batch_outputs})
|
||||
end_step = step + data_handler.step_increment
|
||||
callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
|
||||
callbacks.on_predict_end()
|
||||
all_outputs = nest.map_structure_up_to(batch_outputs, concat, outputs)
|
||||
return tf_utils.to_numpy_or_python_type(all_outputs)
|
||||
|
Loading…
Reference in New Issue
Block a user