diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index da6099aa186..13bf1fc121d 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -86,6 +86,12 @@ from tensorflow.python.util import object_identity from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls +# pylint: disable=g-inconsistent-quotes +metrics_mod = generic_utils.LazyLoader( + "metrics_mod", globals(), + "tensorflow.python.keras.metrics") +# pylint: enable=g-inconsistent-quotes + # Prefix that is added to the TF op layer names. _TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_' @@ -1607,7 +1613,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ['max', 'min'] Returns: - A list of tensors. + A list of `Metric` objects. """ collected_metrics = [] for layer in self._flatten_layers(): @@ -1625,11 +1631,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): class MyMetricLayer(tf.keras.layers.Layer): def __init__(self): super(MyMetricLayer, self).__init__(name='my_metric_layer') - self.mean = metrics_module.Mean(name='metric_1') + self.mean = tf.keras.metrics.Mean(name='metric_1') def call(self, inputs): self.add_metric(self.mean(x)) - self.add_metric(math_ops.reduce_sum(x), name='metric_2') + self.add_metric(tf.reduce_sum(x), name='metric_2') return inputs ``` @@ -1721,7 +1727,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): elif metric_obj: self._metrics.append(metric_obj) else: - from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top # Build the metric object with the value's dtype if it defines one metric_obj = metrics_mod.Mean( name=name, dtype=getattr(value, 'dtype', None)) @@ -2803,9 +2808,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): pass # Keep track of metric instance created in subclassed layer. - from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top for val in nest.flatten(value): - if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'): + if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'): self._metrics.append(val) # TODO(scottzhu): Need to track Module object as well for weight tracking. @@ -2882,7 +2886,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): continue seen_object_ids.add(layer_or_container_id) - if isinstance(layer_or_container, Layer): + if (isinstance(layer_or_container, Layer) and + not isinstance(layer_or_container, metrics_mod.Metric)): yield layer_or_container # Introspect recursively through sublayers. if recursive: diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index b297063e0d3..1ce86e0f355 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -37,6 +37,8 @@ from tensorflow.python.keras import layers from tensorflow.python.keras import metrics from tensorflow.python.keras import Model from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import training as training_mod from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -2059,6 +2061,47 @@ class CustomMetricsTest(test.TestCase): metric_result = tf_functioned_metric_fn(sum_metric, y_true, y_pred) self.assertAllClose(self.evaluate(metric_result), 10, 1e-2) + def test_metric_not_tracked_as_sublayer_in_layer(self): + + class MyLayer(base_layer.Layer): + + def __init__(self, **kwargs): + super(MyLayer, self).__init__(**kwargs) + self.mean_obj = metrics.Mean(name='my_mean_obj') + + def call(self, x): + self.add_metric( + math_ops.reduce_sum(x), aggregation='mean', name='my_mean_tensor') + self.add_metric(self.mean_obj(x)) + return x + + layer = MyLayer() + x = np.ones((1, 1)) + layer(x) + self.assertLen(list(layer._flatten_layers(include_self=False)), 0) + self.assertLen(layer.metrics, 2) + + def test_metric_not_tracked_as_sublayer_in_model(self): + + class MyModel(training_mod.Model): + + def __init__(self, **kwargs): + super(MyModel, self).__init__(**kwargs) + self.mean_obj = metrics.Mean(name='my_mean_obj') + + def call(self, x): + self.add_metric( + math_ops.reduce_sum(x), aggregation='mean', name='my_mean_tensor') + self.add_metric(self.mean_obj(x)) + return x + + model = MyModel() + x = np.ones((1, 1)) + model(x) + self.assertLen(list(model._flatten_layers(include_self=False)), 0) + self.assertLen(model.layers, 0) + self.assertLen(model.metrics, 2) + def _get_model(compile_metrics): model_layers = [