From 2560d6fd31b20e81a5a98a73f325fb1dcf0c68a7 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Wed, 13 May 2020 14:39:12 -0700 Subject: [PATCH] Fix issue where metric instances created in subclassed layer are not tracked on creation but only on call. PiperOrigin-RevId: 311407078 Change-Id: I2cd7ecb675699a56a4b90e5a29ba80ce6ca59cac --- tensorflow/python/keras/engine/base_layer.py | 6 +++ .../python/keras/engine/base_layer_v1.py | 6 +++ tensorflow/python/keras/engine/training.py | 7 ---- .../python/keras/engine/training_test.py | 37 +++++++++++++++++++ 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index f6fa17df5c2..94b696d842b 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2585,6 +2585,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector): except AttributeError: pass + # Keep track of metric instance created in subclassed layer. + from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top + for val in nest.flatten(value): + if isinstance(val, metrics_module.Metric): + self._metrics.append(val) + # TODO(scottzhu): Need to track Module object as well for weight tracking. # Be careful about metric if it becomes a Module in future. # Append value to self._layers if relevant diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 24d12ae4d59..4a277ec3a3e 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -2223,6 +2223,12 @@ class Layer(base_layer.Layer): except AttributeError: pass + # Keep track of metric instance created in subclassed layer. + from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top + for val in nest.flatten(value): + if isinstance(val, metrics_module.Metric): + self._metrics.append(val) + # TODO(scottzhu): Need to track Module object as well for weight tracking. # Be careful about metric if it becomes a Module in future. # Append value to self._layers if relevant diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 52bf42a099d..d8c95b2a972 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -334,13 +334,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): super(Model, self).__setattr__(name, value) - # Keep track of metric instance created in subclassed model/layer. - # We do this so that we can maintain the correct order of metrics by adding - # the instance to the `metrics` list as soon as it is created. - from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top - if isinstance(value, metrics_module.Metric): - self._metrics.append(value) - @generic_utils.default def build(self, input_shape): """Builds the model based on input shapes received. diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index e4c1ff6b1f8..c1c498b207b 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -2979,6 +2979,8 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): return self.dense1(x) model = TestModel() + self.assertListEqual([m.name for m in model.metrics], + ['metric_1', 'metric_2']) model.compile( loss='mse', optimizer=RMSPropOptimizer(0.01), @@ -2998,6 +3000,41 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase): model.train_on_batch(x, y) model.test_on_batch(x, y) + @keras_parameterized.run_all_keras_modes + def test_multiple_add_metric_calls_layer(self): + + class TestLayer(layers_module.Layer): + + def __init__(self): + super(TestLayer, self).__init__(name='test_layer') + self.dense1 = layers_module.Dense(2, kernel_initializer='ones') + self.m1 = metrics_module.Mean(name='m_1') + self.m2 = [ + metrics_module.Mean(name='m_2'), + metrics_module.Mean(name='m_3') + ] + self.m3 = { + 'mean4': metrics_module.Mean(name='m_4'), + 'mean5': metrics_module.Mean(name='m_5') + } + + def call(self, x): + self.add_metric(self.m2[0](x)) + self.add_metric(self.m2[1](x)) + self.add_metric(self.m1(x)) + self.add_metric(self.m3['mean4'](x)) + self.add_metric(self.m3['mean5'](x)) + self.add_metric(math_ops.reduce_sum(x), name='m_6', aggregation='mean') + return self.dense1(x) + + layer = TestLayer() + self.assertListEqual([m.name for m in layer.metrics], + ['m_1', 'm_2', 'm_3', 'm_4', 'm_5']) + + layer(np.ones((10, 10))) + self.assertListEqual([m.name for m in layer.metrics], + ['m_1', 'm_2', 'm_3', 'm_4', 'm_5', 'm_6']) + @keras_parameterized.run_all_keras_modes def test_duplicate_metric_name_in_add_metric(self):