Return eager tensors from the training_eager.* methods instead of numpy scalars. This also moves the conversion to numpy() to the end of the dist strat strategy execution function in the v2 loops.

PiperOrigin-RevId: 259393774
This commit is contained in:
A. Unique TensorFlower 2019-07-22 13:25:39 -07:00 committed by TensorFlower Gardener
parent 710d3113bf
commit bf7368f7a0
3 changed files with 27 additions and 19 deletions

View File

@ -946,9 +946,14 @@ class Model(network.Network):
ValueError: In case of invalid user-provided arguments. ValueError: In case of invalid user-provided arguments.
""" """
if self._run_distributed: if self._run_distributed:
return training_v2_utils.train_on_batch( outputs = training_v2_utils.train_on_batch(
self, x, y=y, sample_weight=sample_weight, self, x, y=y, sample_weight=sample_weight,
class_weight=class_weight, reset_metrics=reset_metrics) class_weight=class_weight, reset_metrics=reset_metrics)
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
if len(outputs) == 1:
outputs = outputs[0]
return outputs
self._assert_compile_was_called() self._assert_compile_was_called()
# If at this point we are in the replica context, then it is okay to execute # If at this point we are in the replica context, then it is okay to execute
@ -974,6 +979,8 @@ class Model(network.Network):
y, y,
sample_weights=sample_weights, sample_weights=sample_weights,
output_loss_metrics=self._output_loss_metrics) output_loss_metrics=self._output_loss_metrics)
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
else: else:
x = training_utils.ModelInputs(x).as_list() x = training_utils.ModelInputs(x).as_list()
ins = x + (y or []) + (sample_weights or []) ins = x + (y or []) + (sample_weights or [])
@ -1031,9 +1038,14 @@ class Model(network.Network):
ValueError: In case of invalid user-provided arguments. ValueError: In case of invalid user-provided arguments.
""" """
if self._run_distributed: if self._run_distributed:
return training_v2_utils.test_on_batch( outputs = training_v2_utils.test_on_batch(
self, x, y=y, sample_weight=sample_weight, self, x, y=y, sample_weight=sample_weight,
reset_metrics=reset_metrics) reset_metrics=reset_metrics)
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
if len(outputs) == 1:
outputs = outputs[0]
return outputs
self._assert_compile_was_called() self._assert_compile_was_called()
if (self._distribution_strategy and if (self._distribution_strategy and
@ -1053,6 +1065,8 @@ class Model(network.Network):
y, y,
sample_weights=sample_weights, sample_weights=sample_weights,
output_loss_metrics=self._output_loss_metrics) output_loss_metrics=self._output_loss_metrics)
outputs = [
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
else: else:
x = training_utils.ModelInputs(x).as_list() x = training_utils.ModelInputs(x).as_list()
inputs = x + (y or []) + (sample_weights or []) inputs = x + (y or []) + (sample_weights or [])

View File

@ -307,12 +307,7 @@ def train_on_batch(model,
total_loss = nest.flatten(total_loss) total_loss = nest.flatten(total_loss)
results = total_loss + output_losses + metrics_results results = total_loss + output_losses + metrics_results
return [_non_none_constant_value(v) for v in results] return results
def _non_none_constant_value(v):
constant_value = tensor_util.constant_value(v)
return constant_value if constant_value is not None else v
def test_on_batch(model, def test_on_batch(model,
@ -365,4 +360,4 @@ def test_on_batch(model,
total_loss = nest.flatten(total_loss) total_loss = nest.flatten(total_loss)
results = total_loss + output_losses + metrics_results results = total_loss + output_losses + metrics_results
return [_non_none_constant_value(v) for v in results] return results

View File

@ -70,19 +70,22 @@ def _make_execution_function(model, mode):
strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT)) strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
return all_outputs return all_outputs
if model.run_eagerly: if not model.run_eagerly:
execution_function = distributed_function
else:
distributed_function = def_function.function( distributed_function = def_function.function(
distributed_function, autograph=False) distributed_function, autograph=False)
def execution_function(input_fn): def execution_function(input_fn):
# `numpy` translates Tensors to values in Eager mode. # `numpy` translates Tensors to values in Eager mode.
return [out.numpy() for out in distributed_function(input_fn)] return [out.numpy() for out in distributed_function(input_fn)]
return execution_function return execution_function
def _non_none_constant_value(v):
constant_value = tensor_util.constant_value(v)
return constant_value if constant_value is not None else v
def _prepare_feed_values(model, inputs, mode): def _prepare_feed_values(model, inputs, mode):
"""Prepare feed values to the model execution function. """Prepare feed values to the model execution function.
@ -232,8 +235,6 @@ def train_on_batch(
if reset_metrics: if reset_metrics:
model.reset_metrics() model.reset_metrics()
if len(outputs) == 1:
return outputs[0]
return outputs return outputs
@ -295,8 +296,6 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
if reset_metrics: if reset_metrics:
model.reset_metrics() model.reset_metrics()
if len(outputs) == 1:
return outputs[0]
return outputs return outputs