Don't grab learning_phase to pass to functional model construction when in v2 & the global learning phase has not been set, because it isn't needed.

PiperOrigin-RevId: 313463994
Change-Id: Ice070617d0a82451aaa9f1ea164404263a213c8e
This commit is contained in:
Tomer Kaftan 2020-05-27 14:26:11 -07:00 committed by TensorFlower Gardener
parent 59cebfb711
commit 778973620a

View File

@ -873,11 +873,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Priority 3a: `learning_phase()` has been set.
elif backend.global_learning_phase_is_set():
training_value = backend.learning_phase()
# Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph.
elif build_graph:
with backend.get_graph().as_default():
if base_layer_utils.is_in_keras_graph():
training_value = backend.learning_phase()
if self._expects_training_arg and training_value is not None:
# Force the training_value to be bool type which matches to the contract
@ -1491,7 +1486,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._metrics.append(metric_obj)
else:
from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top
metric_obj = metrics_mod.Mean(name=name, dtype=value.dtype)
# Build the metric object with the value's dtype if it defines one
metric_obj = metrics_mod.Mean(
name=name, dtype=getattr(value, 'dtype', None))
self._metrics.append(metric_obj)
if should_update_state: