Fix issue where metric instances created in subclassed layer are not tracked on creation but only on call.

PiperOrigin-RevId: 311407078
Change-Id: I2cd7ecb675699a56a4b90e5a29ba80ce6ca59cac
This commit is contained in:
Pavithra Vijay 2020-05-13 14:39:12 -07:00 committed by TensorFlower Gardener
parent eb7af8c54d
commit 2560d6fd31
4 changed files with 49 additions and 7 deletions

View File

@ -2585,6 +2585,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
except AttributeError:
pass
# Keep track of metric instance created in subclassed layer.
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
for val in nest.flatten(value):
if isinstance(val, metrics_module.Metric):
self._metrics.append(val)
# TODO(scottzhu): Need to track Module object as well for weight tracking.
# Be careful about metric if it becomes a Module in future.
# Append value to self._layers if relevant

View File

@ -2223,6 +2223,12 @@ class Layer(base_layer.Layer):
except AttributeError:
pass
# Keep track of metric instance created in subclassed layer.
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
for val in nest.flatten(value):
if isinstance(val, metrics_module.Metric):
self._metrics.append(val)
# TODO(scottzhu): Need to track Module object as well for weight tracking.
# Be careful about metric if it becomes a Module in future.
# Append value to self._layers if relevant

View File

@ -334,13 +334,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
super(Model, self).__setattr__(name, value)
# Keep track of metric instance created in subclassed model/layer.
# We do this so that we can maintain the correct order of metrics by adding
# the instance to the `metrics` list as soon as it is created.
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
if isinstance(value, metrics_module.Metric):
self._metrics.append(value)
@generic_utils.default
def build(self, input_shape):
"""Builds the model based on input shapes received.

View File

@ -2979,6 +2979,8 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
return self.dense1(x)
model = TestModel()
self.assertListEqual([m.name for m in model.metrics],
['metric_1', 'metric_2'])
model.compile(
loss='mse',
optimizer=RMSPropOptimizer(0.01),
@ -2998,6 +3000,41 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
model.train_on_batch(x, y)
model.test_on_batch(x, y)
@keras_parameterized.run_all_keras_modes
def test_multiple_add_metric_calls_layer(self):
class TestLayer(layers_module.Layer):
def __init__(self):
super(TestLayer, self).__init__(name='test_layer')
self.dense1 = layers_module.Dense(2, kernel_initializer='ones')
self.m1 = metrics_module.Mean(name='m_1')
self.m2 = [
metrics_module.Mean(name='m_2'),
metrics_module.Mean(name='m_3')
]
self.m3 = {
'mean4': metrics_module.Mean(name='m_4'),
'mean5': metrics_module.Mean(name='m_5')
}
def call(self, x):
self.add_metric(self.m2[0](x))
self.add_metric(self.m2[1](x))
self.add_metric(self.m1(x))
self.add_metric(self.m3['mean4'](x))
self.add_metric(self.m3['mean5'](x))
self.add_metric(math_ops.reduce_sum(x), name='m_6', aggregation='mean')
return self.dense1(x)
layer = TestLayer()
self.assertListEqual([m.name for m in layer.metrics],
['m_1', 'm_2', 'm_3', 'm_4', 'm_5'])
layer(np.ones((10, 10)))
self.assertListEqual([m.name for m in layer.metrics],
['m_1', 'm_2', 'm_3', 'm_4', 'm_5', 'm_6'])
@keras_parameterized.run_all_keras_modes
def test_duplicate_metric_name_in_add_metric(self):