Fix to add_metric to not create duplicate metrics when a Model is called on new inputs.
PiperOrigin-RevId: 243094192
This commit is contained in:
parent
72194852a2
commit
86a6d44352
@ -1465,6 +1465,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
validation_data=(x, y),
|
||||
validation_steps=2,
|
||||
epochs=2)
|
||||
self.assertLen(ds_model.metrics, 1)
|
||||
|
||||
self.assertAllClose(history.history, ds_history.history)
|
||||
|
||||
@ -1507,6 +1508,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase,
|
||||
validation_data=(x, y),
|
||||
validation_steps=2,
|
||||
epochs=2)
|
||||
self.assertLen(ds_model.metrics, 1)
|
||||
|
||||
self.assertAllClose(history.history, ds_history.history)
|
||||
|
||||
|
@ -1533,24 +1533,20 @@ class Layer(module.Module):
|
||||
self._metrics.append(metric_obj)
|
||||
|
||||
def _symbolic_add_metric(self, value, aggregation=None, name=None):
|
||||
match = self._get_existing_metric(name)
|
||||
if aggregation is None:
|
||||
# Iterate over the metrics and check if the given metric exists already.
|
||||
# This can happen when a metric instance is created in subclassed model
|
||||
# layer `__init__` and we have tracked that instance already in
|
||||
# model.__setattr__.
|
||||
match = self._get_existing_metric(name)
|
||||
if match:
|
||||
result_tensor = value
|
||||
if match.name not in self._metrics_tensors:
|
||||
self._metrics_tensors[match.name] = result_tensor
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
'We currently do not support reusing a metric instance.')
|
||||
metric_obj = match
|
||||
elif hasattr(value, '_metric_obj'):
|
||||
# We track the instance using the metadata on the result tensor.
|
||||
result_tensor = value
|
||||
metric_obj = result_tensor._metric_obj
|
||||
self._metrics.append(metric_obj)
|
||||
else:
|
||||
raise ValueError(
|
||||
'We do not support adding an aggregated metric result tensor that '
|
||||
@ -1564,9 +1560,13 @@ class Layer(module.Module):
|
||||
else:
|
||||
# If a non-aggregated tensor is given as input (ie. `aggregation` is
|
||||
# explicitly set to `mean`), we wrap the tensor in `Mean` metric.
|
||||
metric_obj, result_tensor = base_layer_utils.create_mean_metric(
|
||||
value, name)
|
||||
self._metrics.append(metric_obj)
|
||||
if match:
|
||||
result_tensor = match(value)
|
||||
metric_obj = match
|
||||
else:
|
||||
metric_obj, result_tensor = base_layer_utils.create_mean_metric(
|
||||
value, name)
|
||||
self._metrics.append(metric_obj)
|
||||
self._metrics_tensors[metric_obj.name] = result_tensor
|
||||
|
||||
def _handle_weight_regularization(self, name, variable, regularizer):
|
||||
|
Loading…
Reference in New Issue
Block a user