Fix to add_metric to not create duplicate metrics when a Model is called on new inputs.

PiperOrigin-RevId: 243094192
This commit is contained in:
Thomas O'Malley 2019-04-11 10:23:54 -07:00 committed by TensorFlower Gardener
parent 72194852a2
commit 86a6d44352
2 changed files with 12 additions and 10 deletions

View File

@ -1465,6 +1465,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
validation_data=(x, y),
validation_steps=2,
epochs=2)
self.assertLen(ds_model.metrics, 1)
self.assertAllClose(history.history, ds_history.history)
@ -1507,6 +1508,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
validation_data=(x, y),
validation_steps=2,
epochs=2)
self.assertLen(ds_model.metrics, 1)
self.assertAllClose(history.history, ds_history.history)

View File

@ -1533,24 +1533,20 @@ class Layer(module.Module):
self._metrics.append(metric_obj)
def _symbolic_add_metric(self, value, aggregation=None, name=None):
match = self._get_existing_metric(name)
if aggregation is None:
# Iterate over the metrics and check if the given metric exists already.
# This can happen when a metric instance is created in subclassed model
# layer `__init__` and we have tracked that instance already in
# model.__setattr__.
match = self._get_existing_metric(name)
if match:
result_tensor = value
if match.name not in self._metrics_tensors:
self._metrics_tensors[match.name] = result_tensor
return
else:
raise ValueError(
'We currently do not support reusing a metric instance.')
metric_obj = match
elif hasattr(value, '_metric_obj'):
# We track the instance using the metadata on the result tensor.
result_tensor = value
metric_obj = result_tensor._metric_obj
self._metrics.append(metric_obj)
else:
raise ValueError(
'We do not support adding an aggregated metric result tensor that '
@ -1564,6 +1560,10 @@ class Layer(module.Module):
else:
# If a non-aggregated tensor is given as input (ie. `aggregation` is
# explicitly set to `mean`), we wrap the tensor in `Mean` metric.
if match:
result_tensor = match(value)
metric_obj = match
else:
metric_obj, result_tensor = base_layer_utils.create_mean_metric(
value, name)
self._metrics.append(metric_obj)