From 86a6d44352a7919d4b2417ca042a016c9fc6c3ca Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Thu, 11 Apr 2019 10:23:54 -0700 Subject: [PATCH] Fix to add_metric to not create duplicate metrics when a Model is called on new inputs. PiperOrigin-RevId: 243094192 --- .../distribute/distribute_strategy_test.py | 2 ++ tensorflow/python/keras/engine/base_layer.py | 20 +++++++++---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 55f310d071f..079adb05dee 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -1465,6 +1465,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, validation_data=(x, y), validation_steps=2, epochs=2) + self.assertLen(ds_model.metrics, 1) self.assertAllClose(history.history, ds_history.history) @@ -1507,6 +1508,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, validation_data=(x, y), validation_steps=2, epochs=2) + self.assertLen(ds_model.metrics, 1) self.assertAllClose(history.history, ds_history.history) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 2e999822696..5eb4927e156 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -1533,24 +1533,20 @@ class Layer(module.Module): self._metrics.append(metric_obj) def _symbolic_add_metric(self, value, aggregation=None, name=None): + match = self._get_existing_metric(name) if aggregation is None: # Iterate over the metrics and check if the given metric exists already. # This can happen when a metric instance is created in subclassed model # layer `__init__` and we have tracked that instance already in # model.__setattr__. - match = self._get_existing_metric(name) if match: result_tensor = value - if match.name not in self._metrics_tensors: - self._metrics_tensors[match.name] = result_tensor - return - else: - raise ValueError( - 'We currently do not support reusing a metric instance.') + metric_obj = match elif hasattr(value, '_metric_obj'): # We track the instance using the metadata on the result tensor. result_tensor = value metric_obj = result_tensor._metric_obj + self._metrics.append(metric_obj) else: raise ValueError( 'We do not support adding an aggregated metric result tensor that ' @@ -1564,9 +1560,13 @@ class Layer(module.Module): else: # If a non-aggregated tensor is given as input (ie. `aggregation` is # explicitly set to `mean`), we wrap the tensor in `Mean` metric. - metric_obj, result_tensor = base_layer_utils.create_mean_metric( - value, name) - self._metrics.append(metric_obj) + if match: + result_tensor = match(value) + metric_obj = match + else: + metric_obj, result_tensor = base_layer_utils.create_mean_metric( + value, name) + self._metrics.append(metric_obj) self._metrics_tensors[metric_obj.name] = result_tensor def _handle_weight_regularization(self, name, variable, regularizer):