diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 94b696d842b..0f4bec92e39 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2588,7 +2588,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Keep track of metric instance created in subclassed layer. from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top for val in nest.flatten(value): - if isinstance(val, metrics_module.Metric): + if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'): self._metrics.append(val) # TODO(scottzhu): Need to track Module object as well for weight tracking. diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 4a277ec3a3e..80e0b4be2f1 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -2226,7 +2226,7 @@ class Layer(base_layer.Layer): # Keep track of metric instance created in subclassed layer. from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top for val in nest.flatten(value): - if isinstance(val, metrics_module.Metric): + if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'): self._metrics.append(val) # TODO(scottzhu): Need to track Module object as well for weight tracking.