Fix issue of Metrics being tracked as Layers.

PiperOrigin-RevId: 335486677
Change-Id: Ib1724aaf1490ff60f3dcc20998e8a5ad4df3c367
This commit is contained in:
Thomas O'Malley 2020-10-05 13:26:25 -07:00 committed by TensorFlower Gardener
parent 0931bdf88f
commit d6323766ba
2 changed files with 55 additions and 7 deletions

View File

@ -86,6 +86,12 @@ from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
# pylint: disable=g-inconsistent-quotes
metrics_mod = generic_utils.LazyLoader(
"metrics_mod", globals(),
"tensorflow.python.keras.metrics")
# pylint: enable=g-inconsistent-quotes
# Prefix that is added to the TF op layer names.
_TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_'
@ -1607,7 +1613,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
['max', 'min']
Returns:
A list of tensors.
A list of `Metric` objects.
"""
collected_metrics = []
for layer in self._flatten_layers():
@ -1625,11 +1631,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
class MyMetricLayer(tf.keras.layers.Layer):
def __init__(self):
super(MyMetricLayer, self).__init__(name='my_metric_layer')
self.mean = metrics_module.Mean(name='metric_1')
self.mean = tf.keras.metrics.Mean(name='metric_1')
def call(self, inputs):
self.add_metric(self.mean(x))
self.add_metric(math_ops.reduce_sum(x), name='metric_2')
self.add_metric(tf.reduce_sum(x), name='metric_2')
return inputs
```
@ -1721,7 +1727,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
elif metric_obj:
self._metrics.append(metric_obj)
else:
from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top
# Build the metric object with the value's dtype if it defines one
metric_obj = metrics_mod.Mean(
name=name, dtype=getattr(value, 'dtype', None))
@ -2803,9 +2808,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
pass
# Keep track of metric instance created in subclassed layer.
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
for val in nest.flatten(value):
if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'):
if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'):
self._metrics.append(val)
# TODO(scottzhu): Need to track Module object as well for weight tracking.
@ -2882,7 +2886,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
continue
seen_object_ids.add(layer_or_container_id)
if isinstance(layer_or_container, Layer):
if (isinstance(layer_or_container, Layer) and
not isinstance(layer_or_container, metrics_mod.Metric)):
yield layer_or_container
# Introspect recursively through sublayers.
if recursive:

View File

@ -37,6 +37,8 @@ from tensorflow.python.keras import layers
from tensorflow.python.keras import metrics
from tensorflow.python.keras import Model
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import training as training_mod
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
@ -2059,6 +2061,47 @@ class CustomMetricsTest(test.TestCase):
metric_result = tf_functioned_metric_fn(sum_metric, y_true, y_pred)
self.assertAllClose(self.evaluate(metric_result), 10, 1e-2)
def test_metric_not_tracked_as_sublayer_in_layer(self):
class MyLayer(base_layer.Layer):
def __init__(self, **kwargs):
super(MyLayer, self).__init__(**kwargs)
self.mean_obj = metrics.Mean(name='my_mean_obj')
def call(self, x):
self.add_metric(
math_ops.reduce_sum(x), aggregation='mean', name='my_mean_tensor')
self.add_metric(self.mean_obj(x))
return x
layer = MyLayer()
x = np.ones((1, 1))
layer(x)
self.assertLen(list(layer._flatten_layers(include_self=False)), 0)
self.assertLen(layer.metrics, 2)
def test_metric_not_tracked_as_sublayer_in_model(self):
class MyModel(training_mod.Model):
def __init__(self, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.mean_obj = metrics.Mean(name='my_mean_obj')
def call(self, x):
self.add_metric(
math_ops.reduce_sum(x), aggregation='mean', name='my_mean_tensor')
self.add_metric(self.mean_obj(x))
return x
model = MyModel()
x = np.ones((1, 1))
model(x)
self.assertLen(list(model._flatten_layers(include_self=False)), 0)
self.assertLen(model.layers, 0)
self.assertLen(model.metrics, 2)
def _get_model(compile_metrics):
model_layers = [