diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 4a43b0526f6..0eb6954d2a8 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -873,11 +873,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Priority 3a: `learning_phase()` has been set. elif backend.global_learning_phase_is_set(): training_value = backend.learning_phase() - # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph. - elif build_graph: - with backend.get_graph().as_default(): - if base_layer_utils.is_in_keras_graph(): - training_value = backend.learning_phase() if self._expects_training_arg and training_value is not None: # Force the training_value to be bool type which matches to the contract @@ -1491,7 +1486,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._metrics.append(metric_obj) else: from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top - metric_obj = metrics_mod.Mean(name=name, dtype=value.dtype) + # Build the metric object with the value's dtype if it defines one + metric_obj = metrics_mod.Mean( + name=name, dtype=getattr(value, 'dtype', None)) self._metrics.append(metric_obj) if should_update_state: