From 778973620a53a2426127b21dbe9167493a86c5fd Mon Sep 17 00:00:00 2001 From: Tomer Kaftan Date: Wed, 27 May 2020 14:26:11 -0700 Subject: [PATCH] Don't grab `learning_phase` to pass to functional model construction when in v2 & the global learning phase has not been set, because it isn't needed. PiperOrigin-RevId: 313463994 Change-Id: Ice070617d0a82451aaa9f1ea164404263a213c8e --- tensorflow/python/keras/engine/base_layer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 4a43b0526f6..0eb6954d2a8 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -873,11 +873,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Priority 3a: `learning_phase()` has been set. elif backend.global_learning_phase_is_set(): training_value = backend.learning_phase() - # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph. - elif build_graph: - with backend.get_graph().as_default(): - if base_layer_utils.is_in_keras_graph(): - training_value = backend.learning_phase() if self._expects_training_arg and training_value is not None: # Force the training_value to be bool type which matches to the contract @@ -1491,7 +1486,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._metrics.append(metric_obj) else: from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top - metric_obj = metrics_mod.Mean(name=name, dtype=value.dtype) + # Build the metric object with the value's dtype if it defines one + metric_obj = metrics_mod.Mean( + name=name, dtype=getattr(value, 'dtype', None)) self._metrics.append(metric_obj) if should_update_state: