Add MetricsContainer.weighted_metrics and MetricsContainer.unweighted_metrics

properties to help distinguish between metrics that should and shouldn't be
passed sample_weight argument.
Note these properties are set to None before Model.fit is called, since metrics
are potentially broadcast to match the structure of data seen in Model.fit.

PiperOrigin-RevId: 339892649
Change-Id: I0abffae08efde2b8adc58014ef205d318d66a9ab
This commit is contained in:
Thomas O'Malley 2020-10-30 10:12:08 -07:00 committed by TensorFlower Gardener
parent 6f3ef0032a
commit cb043911fe
2 changed files with 27 additions and 1 deletions
tensorflow/python/keras/engine

View File

@ -292,11 +292,25 @@ class MetricsContainer(Container):
@property
def metrics(self):
"""Metrics created by this container."""
"""All metrics in this container."""
if not self._built:
return []
return self._metrics_in_order
@property
def unweighted_metrics(self):
"""Metrics in this container that should not be passed `sample_weight`."""
if not self._built:
return None
return nest.flatten(self._metrics)
@property
def weighted_metrics(self):
"""Metrics in this container that should be passed `sample_weight`."""
if not self._built:
return None
return nest.flatten(self._weighted_metrics)
def build(self, y_pred, y_true):
"""One-time setup of metric objects."""
super(MetricsContainer, self).build(y_pred)

View File

@ -420,6 +420,18 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(acc_metric_2.result().numpy(), 0.)
self.assertEqual(acc_metric_2._fn, metrics_mod.binary_accuracy)
weighted_metrics = metric_container.weighted_metrics
self.assertLen(weighted_metrics, 2)
self.assertEqual(weighted_metrics[0].name, 'output_1_accuracy')
self.assertEqual(weighted_metrics[1].name, 'output_2_accuracy')
unweighted_metrics = metric_container.unweighted_metrics
self.assertLen(unweighted_metrics, 4)
self.assertEqual(unweighted_metrics[0].name, 'output_1_mse')
self.assertEqual(unweighted_metrics[1].name, 'output_1_mae')
self.assertEqual(unweighted_metrics[2].name, 'output_2_mse')
self.assertEqual(unweighted_metrics[3].name, 'output_2_mae')
def test_metric_dict(self):
metric_container = compile_utils.MetricsContainer(
metrics={