Return eager tensors from the training_eager.* methods instead of numpy scalars. This also moves the conversion to numpy() to the end of the dist strat strategy execution function in the v2 loops.
PiperOrigin-RevId: 259393774
This commit is contained in:
parent
710d3113bf
commit
bf7368f7a0
@ -946,9 +946,14 @@ class Model(network.Network):
|
||||
ValueError: In case of invalid user-provided arguments.
|
||||
"""
|
||||
if self._run_distributed:
|
||||
return training_v2_utils.train_on_batch(
|
||||
outputs = training_v2_utils.train_on_batch(
|
||||
self, x, y=y, sample_weight=sample_weight,
|
||||
class_weight=class_weight, reset_metrics=reset_metrics)
|
||||
outputs = [
|
||||
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||
if len(outputs) == 1:
|
||||
outputs = outputs[0]
|
||||
return outputs
|
||||
|
||||
self._assert_compile_was_called()
|
||||
# If at this point we are in the replica context, then it is okay to execute
|
||||
@ -974,6 +979,8 @@ class Model(network.Network):
|
||||
y,
|
||||
sample_weights=sample_weights,
|
||||
output_loss_metrics=self._output_loss_metrics)
|
||||
outputs = [
|
||||
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||
else:
|
||||
x = training_utils.ModelInputs(x).as_list()
|
||||
ins = x + (y or []) + (sample_weights or [])
|
||||
@ -1031,9 +1038,14 @@ class Model(network.Network):
|
||||
ValueError: In case of invalid user-provided arguments.
|
||||
"""
|
||||
if self._run_distributed:
|
||||
return training_v2_utils.test_on_batch(
|
||||
outputs = training_v2_utils.test_on_batch(
|
||||
self, x, y=y, sample_weight=sample_weight,
|
||||
reset_metrics=reset_metrics)
|
||||
outputs = [
|
||||
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||
if len(outputs) == 1:
|
||||
outputs = outputs[0]
|
||||
return outputs
|
||||
|
||||
self._assert_compile_was_called()
|
||||
if (self._distribution_strategy and
|
||||
@ -1053,6 +1065,8 @@ class Model(network.Network):
|
||||
y,
|
||||
sample_weights=sample_weights,
|
||||
output_loss_metrics=self._output_loss_metrics)
|
||||
outputs = [
|
||||
training_v2_utils._non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access
|
||||
else:
|
||||
x = training_utils.ModelInputs(x).as_list()
|
||||
inputs = x + (y or []) + (sample_weights or [])
|
||||
|
@ -307,12 +307,7 @@ def train_on_batch(model,
|
||||
total_loss = nest.flatten(total_loss)
|
||||
results = total_loss + output_losses + metrics_results
|
||||
|
||||
return [_non_none_constant_value(v) for v in results]
|
||||
|
||||
|
||||
def _non_none_constant_value(v):
|
||||
constant_value = tensor_util.constant_value(v)
|
||||
return constant_value if constant_value is not None else v
|
||||
return results
|
||||
|
||||
|
||||
def test_on_batch(model,
|
||||
@ -365,4 +360,4 @@ def test_on_batch(model,
|
||||
total_loss = nest.flatten(total_loss)
|
||||
results = total_loss + output_losses + metrics_results
|
||||
|
||||
return [_non_none_constant_value(v) for v in results]
|
||||
return results
|
||||
|
@ -70,9 +70,7 @@ def _make_execution_function(model, mode):
|
||||
strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
|
||||
return all_outputs
|
||||
|
||||
if model.run_eagerly:
|
||||
execution_function = distributed_function
|
||||
else:
|
||||
if not model.run_eagerly:
|
||||
distributed_function = def_function.function(
|
||||
distributed_function, autograph=False)
|
||||
|
||||
@ -83,6 +81,11 @@ def _make_execution_function(model, mode):
|
||||
return execution_function
|
||||
|
||||
|
||||
def _non_none_constant_value(v):
|
||||
constant_value = tensor_util.constant_value(v)
|
||||
return constant_value if constant_value is not None else v
|
||||
|
||||
|
||||
def _prepare_feed_values(model, inputs, mode):
|
||||
"""Prepare feed values to the model execution function.
|
||||
|
||||
@ -232,8 +235,6 @@ def train_on_batch(
|
||||
if reset_metrics:
|
||||
model.reset_metrics()
|
||||
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
|
||||
@ -295,8 +296,6 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
|
||||
if reset_metrics:
|
||||
model.reset_metrics()
|
||||
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user