From bf7368f7a02db5055de09be13ac3ba0143749598 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jul 2019 13:25:39 -0700 Subject: [PATCH] Return eager tensors from the training_eager.* methods instead of numpy scalars. This also moves the conversion to numpy() to the end of the dist strat strategy execution function in the v2 loops. PiperOrigin-RevId: 259393774 --- tensorflow/python/keras/engine/training.py | 18 ++++++++++++++++-- .../python/keras/engine/training_eager.py | 9 ++------- .../python/keras/engine/training_v2_utils.py | 19 +++++++++---------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index cdc06daae6a..718f3a582cf 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -946,9 +946,14 @@ class Model(network.Network): ValueError: In case of invalid user-provided arguments. """ if self._run_distributed: - return training_v2_utils.train_on_batch( + outputs = training_v2_utils.train_on_batch( self, x, y=y, sample_weight=sample_weight, class_weight=class_weight, reset_metrics=reset_metrics) + outputs = [ + training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access + if len(outputs) == 1: + outputs = outputs[0] + return outputs self._assert_compile_was_called() # If at this point we are in the replica context, then it is okay to execute @@ -974,6 +979,8 @@ class Model(network.Network): y, sample_weights=sample_weights, output_loss_metrics=self._output_loss_metrics) + outputs = [ + training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access else: x = training_utils.ModelInputs(x).as_list() ins = x + (y or []) + (sample_weights or []) @@ -1031,9 +1038,14 @@ class Model(network.Network): ValueError: In case of invalid user-provided arguments. """ if self._run_distributed: - return training_v2_utils.test_on_batch( + outputs = training_v2_utils.test_on_batch( self, x, y=y, sample_weight=sample_weight, reset_metrics=reset_metrics) + outputs = [ + training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access + if len(outputs) == 1: + outputs = outputs[0] + return outputs self._assert_compile_was_called() if (self._distribution_strategy and @@ -1053,6 +1065,8 @@ class Model(network.Network): y, sample_weights=sample_weights, output_loss_metrics=self._output_loss_metrics) + outputs = [ + training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access else: x = training_utils.ModelInputs(x).as_list() inputs = x + (y or []) + (sample_weights or []) diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index 6cbc6851a8e..c019238f48e 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -307,12 +307,7 @@ def train_on_batch(model, total_loss = nest.flatten(total_loss) results = total_loss + output_losses + metrics_results - return [_non_none_constant_value(v) for v in results] - - -def _non_none_constant_value(v): - constant_value = tensor_util.constant_value(v) - return constant_value if constant_value is not None else v + return results def test_on_batch(model, @@ -365,4 +360,4 @@ def test_on_batch(model, total_loss = nest.flatten(total_loss) results = total_loss + output_losses + metrics_results - return [_non_none_constant_value(v) for v in results] + return results diff --git a/tensorflow/python/keras/engine/training_v2_utils.py b/tensorflow/python/keras/engine/training_v2_utils.py index 982ef2a71a1..e609559e5e8 100644 --- a/tensorflow/python/keras/engine/training_v2_utils.py +++ b/tensorflow/python/keras/engine/training_v2_utils.py @@ -70,19 +70,22 @@ def _make_execution_function(model, mode): strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT)) return all_outputs - if model.run_eagerly: - execution_function = distributed_function - else: + if not model.run_eagerly: distributed_function = def_function.function( distributed_function, autograph=False) - def execution_function(input_fn): - # `numpy` translates Tensors to values in Eager mode. - return [out.numpy() for out in distributed_function(input_fn)] + def execution_function(input_fn): + # `numpy` translates Tensors to values in Eager mode. + return [out.numpy() for out in distributed_function(input_fn)] return execution_function +def _non_none_constant_value(v): + constant_value = tensor_util.constant_value(v) + return constant_value if constant_value is not None else v + + def _prepare_feed_values(model, inputs, mode): """Prepare feed values to the model execution function. @@ -232,8 +235,6 @@ def train_on_batch( if reset_metrics: model.reset_metrics() - if len(outputs) == 1: - return outputs[0] return outputs @@ -295,8 +296,6 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True): if reset_metrics: model.reset_metrics() - if len(outputs) == 1: - return outputs[0] return outputs