Add MetricsContainer.weighted_metrics and MetricsContainer.unweighted_metrics
properties to help distinguish between metrics that should and shouldn't be passed sample_weight argument. Note these properties are set to None before Model.fit is called, since metrics are potentially broadcast to match the structure of data seen in Model.fit. PiperOrigin-RevId: 339892649 Change-Id: I0abffae08efde2b8adc58014ef205d318d66a9ab
This commit is contained in:
parent
6f3ef0032a
commit
cb043911fe
tensorflow/python/keras/engine
@ -292,11 +292,25 @@ class MetricsContainer(Container):
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""Metrics created by this container."""
|
||||
"""All metrics in this container."""
|
||||
if not self._built:
|
||||
return []
|
||||
return self._metrics_in_order
|
||||
|
||||
@property
|
||||
def unweighted_metrics(self):
|
||||
"""Metrics in this container that should not be passed `sample_weight`."""
|
||||
if not self._built:
|
||||
return None
|
||||
return nest.flatten(self._metrics)
|
||||
|
||||
@property
|
||||
def weighted_metrics(self):
|
||||
"""Metrics in this container that should be passed `sample_weight`."""
|
||||
if not self._built:
|
||||
return None
|
||||
return nest.flatten(self._weighted_metrics)
|
||||
|
||||
def build(self, y_pred, y_true):
|
||||
"""One-time setup of metric objects."""
|
||||
super(MetricsContainer, self).build(y_pred)
|
||||
|
@ -420,6 +420,18 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(acc_metric_2.result().numpy(), 0.)
|
||||
self.assertEqual(acc_metric_2._fn, metrics_mod.binary_accuracy)
|
||||
|
||||
weighted_metrics = metric_container.weighted_metrics
|
||||
self.assertLen(weighted_metrics, 2)
|
||||
self.assertEqual(weighted_metrics[0].name, 'output_1_accuracy')
|
||||
self.assertEqual(weighted_metrics[1].name, 'output_2_accuracy')
|
||||
|
||||
unweighted_metrics = metric_container.unweighted_metrics
|
||||
self.assertLen(unweighted_metrics, 4)
|
||||
self.assertEqual(unweighted_metrics[0].name, 'output_1_mse')
|
||||
self.assertEqual(unweighted_metrics[1].name, 'output_1_mae')
|
||||
self.assertEqual(unweighted_metrics[2].name, 'output_2_mse')
|
||||
self.assertEqual(unweighted_metrics[3].name, 'output_2_mae')
|
||||
|
||||
def test_metric_dict(self):
|
||||
metric_container = compile_utils.MetricsContainer(
|
||||
metrics={
|
||||
|
Loading…
Reference in New Issue
Block a user