Return eager tensors from the training_eager.* methods instead of numpy scalars. This also moves the conversion to numpy() to the end of the dist strat strategy execution function in the v2 loops.
PiperOrigin-RevId: 259393774
This commit is contained in:
parent
710d3113bf
commit
bf7368f7a0
@ -946,9 +946,14 @@ class Model(network.Network):
|
|||||||
ValueError: In case of invalid user-provided arguments.
|
ValueError: In case of invalid user-provided arguments.
|
||||||
"""
|
"""
|
||||||
if self._run_distributed:
|
if self._run_distributed:
|
||||||
return training_v2_utils.train_on_batch(
|
outputs = training_v2_utils.train_on_batch(
|
||||||
self, x, y=y, sample_weight=sample_weight,
|
self, x, y=y, sample_weight=sample_weight,
|
||||||
class_weight=class_weight, reset_metrics=reset_metrics)
|
class_weight=class_weight, reset_metrics=reset_metrics)
|
||||||
|
outputs = [
|
||||||
|
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||||
|
if len(outputs) == 1:
|
||||||
|
outputs = outputs[0]
|
||||||
|
return outputs
|
||||||
|
|
||||||
self._assert_compile_was_called()
|
self._assert_compile_was_called()
|
||||||
# If at this point we are in the replica context, then it is okay to execute
|
# If at this point we are in the replica context, then it is okay to execute
|
||||||
@ -974,6 +979,8 @@ class Model(network.Network):
|
|||||||
y,
|
y,
|
||||||
sample_weights=sample_weights,
|
sample_weights=sample_weights,
|
||||||
output_loss_metrics=self._output_loss_metrics)
|
output_loss_metrics=self._output_loss_metrics)
|
||||||
|
outputs = [
|
||||||
|
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
x = training_utils.ModelInputs(x).as_list()
|
x = training_utils.ModelInputs(x).as_list()
|
||||||
ins = x + (y or []) + (sample_weights or [])
|
ins = x + (y or []) + (sample_weights or [])
|
||||||
@ -1031,9 +1038,14 @@ class Model(network.Network):
|
|||||||
ValueError: In case of invalid user-provided arguments.
|
ValueError: In case of invalid user-provided arguments.
|
||||||
"""
|
"""
|
||||||
if self._run_distributed:
|
if self._run_distributed:
|
||||||
return training_v2_utils.test_on_batch(
|
outputs = training_v2_utils.test_on_batch(
|
||||||
self, x, y=y, sample_weight=sample_weight,
|
self, x, y=y, sample_weight=sample_weight,
|
||||||
reset_metrics=reset_metrics)
|
reset_metrics=reset_metrics)
|
||||||
|
outputs = [
|
||||||
|
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||||
|
if len(outputs) == 1:
|
||||||
|
outputs = outputs[0]
|
||||||
|
return outputs
|
||||||
|
|
||||||
self._assert_compile_was_called()
|
self._assert_compile_was_called()
|
||||||
if (self._distribution_strategy and
|
if (self._distribution_strategy and
|
||||||
@ -1053,6 +1065,8 @@ class Model(network.Network):
|
|||||||
y,
|
y,
|
||||||
sample_weights=sample_weights,
|
sample_weights=sample_weights,
|
||||||
output_loss_metrics=self._output_loss_metrics)
|
output_loss_metrics=self._output_loss_metrics)
|
||||||
|
outputs = [
|
||||||
|
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
x = training_utils.ModelInputs(x).as_list()
|
x = training_utils.ModelInputs(x).as_list()
|
||||||
inputs = x + (y or []) + (sample_weights or [])
|
inputs = x + (y or []) + (sample_weights or [])
|
||||||
|
@ -307,12 +307,7 @@ def train_on_batch(model,
|
|||||||
total_loss = nest.flatten(total_loss)
|
total_loss = nest.flatten(total_loss)
|
||||||
results = total_loss + output_losses + metrics_results
|
results = total_loss + output_losses + metrics_results
|
||||||
|
|
||||||
return [_non_none_constant_value(v) for v in results]
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _non_none_constant_value(v):
|
|
||||||
constant_value = tensor_util.constant_value(v)
|
|
||||||
return constant_value if constant_value is not None else v
|
|
||||||
|
|
||||||
|
|
||||||
def test_on_batch(model,
|
def test_on_batch(model,
|
||||||
@ -365,4 +360,4 @@ def test_on_batch(model,
|
|||||||
total_loss = nest.flatten(total_loss)
|
total_loss = nest.flatten(total_loss)
|
||||||
results = total_loss + output_losses + metrics_results
|
results = total_loss + output_losses + metrics_results
|
||||||
|
|
||||||
return [_non_none_constant_value(v) for v in results]
|
return results
|
||||||
|
@ -70,19 +70,22 @@ def _make_execution_function(model, mode):
|
|||||||
strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
|
strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
|
||||||
return all_outputs
|
return all_outputs
|
||||||
|
|
||||||
if model.run_eagerly:
|
if not model.run_eagerly:
|
||||||
execution_function = distributed_function
|
|
||||||
else:
|
|
||||||
distributed_function = def_function.function(
|
distributed_function = def_function.function(
|
||||||
distributed_function, autograph=False)
|
distributed_function, autograph=False)
|
||||||
|
|
||||||
def execution_function(input_fn):
|
def execution_function(input_fn):
|
||||||
# `numpy` translates Tensors to values in Eager mode.
|
# `numpy` translates Tensors to values in Eager mode.
|
||||||
return [out.numpy() for out in distributed_function(input_fn)]
|
return [out.numpy() for out in distributed_function(input_fn)]
|
||||||
|
|
||||||
return execution_function
|
return execution_function
|
||||||
|
|
||||||
|
|
||||||
|
def _non_none_constant_value(v):
|
||||||
|
constant_value = tensor_util.constant_value(v)
|
||||||
|
return constant_value if constant_value is not None else v
|
||||||
|
|
||||||
|
|
||||||
def _prepare_feed_values(model, inputs, mode):
|
def _prepare_feed_values(model, inputs, mode):
|
||||||
"""Prepare feed values to the model execution function.
|
"""Prepare feed values to the model execution function.
|
||||||
|
|
||||||
@ -232,8 +235,6 @@ def train_on_batch(
|
|||||||
if reset_metrics:
|
if reset_metrics:
|
||||||
model.reset_metrics()
|
model.reset_metrics()
|
||||||
|
|
||||||
if len(outputs) == 1:
|
|
||||||
return outputs[0]
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@ -295,8 +296,6 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
|
|||||||
if reset_metrics:
|
if reset_metrics:
|
||||||
model.reset_metrics()
|
model.reset_metrics()
|
||||||
|
|
||||||
if len(outputs) == 1:
|
|
||||||
return outputs[0]
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user