Fix issue of Metrics being tracked as Layers.
PiperOrigin-RevId: 335486677 Change-Id: Ib1724aaf1490ff60f3dcc20998e8a5ad4df3c367
This commit is contained in:
parent
0931bdf88f
commit
d6323766ba
@ -86,6 +86,12 @@ from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
# pylint: disable=g-inconsistent-quotes
|
||||
metrics_mod = generic_utils.LazyLoader(
|
||||
"metrics_mod", globals(),
|
||||
"tensorflow.python.keras.metrics")
|
||||
# pylint: enable=g-inconsistent-quotes
|
||||
|
||||
# Prefix that is added to the TF op layer names.
|
||||
_TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_'
|
||||
|
||||
@ -1607,7 +1613,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
['max', 'min']
|
||||
|
||||
Returns:
|
||||
A list of tensors.
|
||||
A list of `Metric` objects.
|
||||
"""
|
||||
collected_metrics = []
|
||||
for layer in self._flatten_layers():
|
||||
@ -1625,11 +1631,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
class MyMetricLayer(tf.keras.layers.Layer):
|
||||
def __init__(self):
|
||||
super(MyMetricLayer, self).__init__(name='my_metric_layer')
|
||||
self.mean = metrics_module.Mean(name='metric_1')
|
||||
self.mean = tf.keras.metrics.Mean(name='metric_1')
|
||||
|
||||
def call(self, inputs):
|
||||
self.add_metric(self.mean(x))
|
||||
self.add_metric(math_ops.reduce_sum(x), name='metric_2')
|
||||
self.add_metric(tf.reduce_sum(x), name='metric_2')
|
||||
return inputs
|
||||
```
|
||||
|
||||
@ -1721,7 +1727,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
elif metric_obj:
|
||||
self._metrics.append(metric_obj)
|
||||
else:
|
||||
from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top
|
||||
# Build the metric object with the value's dtype if it defines one
|
||||
metric_obj = metrics_mod.Mean(
|
||||
name=name, dtype=getattr(value, 'dtype', None))
|
||||
@ -2803,9 +2808,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
pass
|
||||
|
||||
# Keep track of metric instance created in subclassed layer.
|
||||
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
|
||||
for val in nest.flatten(value):
|
||||
if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'):
|
||||
if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'):
|
||||
self._metrics.append(val)
|
||||
|
||||
# TODO(scottzhu): Need to track Module object as well for weight tracking.
|
||||
@ -2882,7 +2886,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
continue
|
||||
seen_object_ids.add(layer_or_container_id)
|
||||
|
||||
if isinstance(layer_or_container, Layer):
|
||||
if (isinstance(layer_or_container, Layer) and
|
||||
not isinstance(layer_or_container, metrics_mod.Metric)):
|
||||
yield layer_or_container
|
||||
# Introspect recursively through sublayers.
|
||||
if recursive:
|
||||
|
@ -37,6 +37,8 @@ from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import metrics
|
||||
from tensorflow.python.keras import Model
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import training as training_mod
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -2059,6 +2061,47 @@ class CustomMetricsTest(test.TestCase):
|
||||
metric_result = tf_functioned_metric_fn(sum_metric, y_true, y_pred)
|
||||
self.assertAllClose(self.evaluate(metric_result), 10, 1e-2)
|
||||
|
||||
def test_metric_not_tracked_as_sublayer_in_layer(self):
|
||||
|
||||
class MyLayer(base_layer.Layer):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(MyLayer, self).__init__(**kwargs)
|
||||
self.mean_obj = metrics.Mean(name='my_mean_obj')
|
||||
|
||||
def call(self, x):
|
||||
self.add_metric(
|
||||
math_ops.reduce_sum(x), aggregation='mean', name='my_mean_tensor')
|
||||
self.add_metric(self.mean_obj(x))
|
||||
return x
|
||||
|
||||
layer = MyLayer()
|
||||
x = np.ones((1, 1))
|
||||
layer(x)
|
||||
self.assertLen(list(layer._flatten_layers(include_self=False)), 0)
|
||||
self.assertLen(layer.metrics, 2)
|
||||
|
||||
def test_metric_not_tracked_as_sublayer_in_model(self):
|
||||
|
||||
class MyModel(training_mod.Model):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(MyModel, self).__init__(**kwargs)
|
||||
self.mean_obj = metrics.Mean(name='my_mean_obj')
|
||||
|
||||
def call(self, x):
|
||||
self.add_metric(
|
||||
math_ops.reduce_sum(x), aggregation='mean', name='my_mean_tensor')
|
||||
self.add_metric(self.mean_obj(x))
|
||||
return x
|
||||
|
||||
model = MyModel()
|
||||
x = np.ones((1, 1))
|
||||
model(x)
|
||||
self.assertLen(list(model._flatten_layers(include_self=False)), 0)
|
||||
self.assertLen(model.layers, 0)
|
||||
self.assertLen(model.metrics, 2)
|
||||
|
||||
|
||||
def _get_model(compile_metrics):
|
||||
model_layers = [
|
||||
|
Loading…
Reference in New Issue
Block a user