Fix issue where metric instances created in subclassed layer are not tracked on creation but only on call.
PiperOrigin-RevId: 311407078 Change-Id: I2cd7ecb675699a56a4b90e5a29ba80ce6ca59cac
This commit is contained in:
parent
eb7af8c54d
commit
2560d6fd31
@ -2585,6 +2585,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Keep track of metric instance created in subclassed layer.
|
||||
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
|
||||
for val in nest.flatten(value):
|
||||
if isinstance(val, metrics_module.Metric):
|
||||
self._metrics.append(val)
|
||||
|
||||
# TODO(scottzhu): Need to track Module object as well for weight tracking.
|
||||
# Be careful about metric if it becomes a Module in future.
|
||||
# Append value to self._layers if relevant
|
||||
|
@ -2223,6 +2223,12 @@ class Layer(base_layer.Layer):
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Keep track of metric instance created in subclassed layer.
|
||||
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
|
||||
for val in nest.flatten(value):
|
||||
if isinstance(val, metrics_module.Metric):
|
||||
self._metrics.append(val)
|
||||
|
||||
# TODO(scottzhu): Need to track Module object as well for weight tracking.
|
||||
# Be careful about metric if it becomes a Module in future.
|
||||
# Append value to self._layers if relevant
|
||||
|
@ -334,13 +334,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
|
||||
super(Model, self).__setattr__(name, value)
|
||||
|
||||
# Keep track of metric instance created in subclassed model/layer.
|
||||
# We do this so that we can maintain the correct order of metrics by adding
|
||||
# the instance to the `metrics` list as soon as it is created.
|
||||
from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
|
||||
if isinstance(value, metrics_module.Metric):
|
||||
self._metrics.append(value)
|
||||
|
||||
@generic_utils.default
|
||||
def build(self, input_shape):
|
||||
"""Builds the model based on input shapes received.
|
||||
|
@ -2979,6 +2979,8 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
|
||||
return self.dense1(x)
|
||||
|
||||
model = TestModel()
|
||||
self.assertListEqual([m.name for m in model.metrics],
|
||||
['metric_1', 'metric_2'])
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer=RMSPropOptimizer(0.01),
|
||||
@ -2998,6 +3000,41 @@ class TestTrainingWithMetrics(keras_parameterized.TestCase):
|
||||
model.train_on_batch(x, y)
|
||||
model.test_on_batch(x, y)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_multiple_add_metric_calls_layer(self):
|
||||
|
||||
class TestLayer(layers_module.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(TestLayer, self).__init__(name='test_layer')
|
||||
self.dense1 = layers_module.Dense(2, kernel_initializer='ones')
|
||||
self.m1 = metrics_module.Mean(name='m_1')
|
||||
self.m2 = [
|
||||
metrics_module.Mean(name='m_2'),
|
||||
metrics_module.Mean(name='m_3')
|
||||
]
|
||||
self.m3 = {
|
||||
'mean4': metrics_module.Mean(name='m_4'),
|
||||
'mean5': metrics_module.Mean(name='m_5')
|
||||
}
|
||||
|
||||
def call(self, x):
|
||||
self.add_metric(self.m2[0](x))
|
||||
self.add_metric(self.m2[1](x))
|
||||
self.add_metric(self.m1(x))
|
||||
self.add_metric(self.m3['mean4'](x))
|
||||
self.add_metric(self.m3['mean5'](x))
|
||||
self.add_metric(math_ops.reduce_sum(x), name='m_6', aggregation='mean')
|
||||
return self.dense1(x)
|
||||
|
||||
layer = TestLayer()
|
||||
self.assertListEqual([m.name for m in layer.metrics],
|
||||
['m_1', 'm_2', 'm_3', 'm_4', 'm_5'])
|
||||
|
||||
layer(np.ones((10, 10)))
|
||||
self.assertListEqual([m.name for m in layer.metrics],
|
||||
['m_1', 'm_2', 'm_3', 'm_4', 'm_5', 'm_6'])
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_duplicate_metric_name_in_add_metric(self):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user