Add reset_states method to LossesContainer and MetricsContainer.

PiperOrigin-RevId: 358259524
Change-Id: I875ad8434cccd6ca47fc41f9e3a7328cd2fcecb6
This commit is contained in:
Thomas O'Malley 2021-02-18 14:01:00 -08:00 committed by TensorFlower Gardener
parent c9240bd8ed
commit aed7a7b5e8
2 changed files with 59 additions and 0 deletions

View File

@ -250,6 +250,15 @@ class LossesContainer(Container):
# Ok for a model to have no compiled loss.
return array_ops.zeros(shape=())
def reset_states(self):
"""Resets the state of loss metrics."""
if not self._built:
return
metrics = [self._loss_metric] + nest.flatten(self._per_output_metrics)
for metric_obj in metrics:
if metric_obj is not None:
metric_obj.reset_states()
def _get_loss_object(self, loss):
"""Returns a `Loss` object.
@ -434,6 +443,21 @@ class MetricsContainer(Container):
continue
weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
def reset_states(self):
"""Resets the state of all `Metric`s in this container."""
if self._built:
metrics = self._metrics_in_order
else:
# If the user supplied `Metric` objects directly, we should
# reset those. This could also contain `str`s or `function`s
# though.
metrics = nest.flatten(self._user_metrics) + nest.flatten(
self._user_weighted_metrics)
for metric_obj in metrics:
if isinstance(metric_obj, metrics_mod.Metric):
metric_obj.reset_states()
def _get_metric_objects(self, metrics, y_t, y_p):
"""Convert user-supplied metrics to `Metric` objects."""
metrics = nest.flatten(metrics)

View File

@ -50,6 +50,9 @@ class LossesContainerTest(keras_parameterized.TestCase):
self.assertEqual(loss_metric.name, 'loss')
self.assertEqual(loss_metric.result().numpy(), 1.)
loss_container.reset_states()
self.assertEqual(loss_metric.result().numpy(), 0.)
def test_loss_list(self):
loss_container = compile_utils.LossesContainer(['mse', 'mae'], [1, 0.5])
@ -76,6 +79,11 @@ class LossesContainerTest(keras_parameterized.TestCase):
self.assertEqual(output_2_metric.name, 'output_2_loss')
self.assertEqual(output_2_metric.result().numpy(), 0.5)
loss_container.reset_states()
self.assertEqual(loss_metric.result().numpy(), 0)
self.assertEqual(output_1_metric.result().numpy(), 0)
self.assertEqual(output_2_metric.result().numpy(), 0)
def test_loss_dict(self):
loss_container = compile_utils.LossesContainer(
{
@ -108,6 +116,11 @@ class LossesContainerTest(keras_parameterized.TestCase):
self.assertEqual(out2_metric.name, 'out2_loss')
self.assertEqual(out2_metric.result().numpy(), 0.5)
loss_container.reset_states()
self.assertEqual(loss_metric.result().numpy(), 0)
self.assertEqual(out1_metric.result().numpy(), 0)
self.assertEqual(out2_metric.result().numpy(), 0)
def test_loss_partial_dict_with_output_names(self):
loss_container = compile_utils.LossesContainer(
{'out2': 'mae'}, {'out2': 1.}, output_names=['out1', 'out2'])
@ -400,6 +413,9 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(metric.name, 'mse')
self.assertEqual(metric.result().numpy(), 1.)
metric_container.reset_states()
self.assertEqual(metric.result().numpy(), 0.)
def test_list_of_metrics_one_output(self):
metric_container = compile_utils.MetricsContainer(['mse', 'mae'])
y_t, y_p = 2 * array_ops.ones((10, 5)), array_ops.zeros((10, 5))
@ -414,6 +430,10 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(mae_metric.name, 'mae')
self.assertEqual(mae_metric.result().numpy(), 2.)
metric_container.reset_states()
self.assertEqual(mse_metric.result().numpy(), 0.)
self.assertEqual(mae_metric.result().numpy(), 0.)
def test_list_of_metrics_list_of_outputs(self):
metric_container = compile_utils.MetricsContainer(
metrics=['mse', 'mae'], # Should broadcast to both outputs.
@ -495,6 +515,12 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae')
self.assertEqual(weighted_mae_metric.result().numpy(), 2.)
metric_container.reset_states()
self.assertEqual(mse_metric.result().numpy(), 0.)
self.assertEqual(weighted_mse_metric.result().numpy(), 0.)
self.assertEqual(mae_metric.result().numpy(), 0.)
self.assertEqual(weighted_mae_metric.result().numpy(), 0.)
def test_metric_partial_dict_with_output_names(self):
metric_container = compile_utils.MetricsContainer(
{'out2': 'mae'}, output_names=['out1', 'out2'])
@ -764,6 +790,15 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn')
self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class')
def test_reset_states_existing_metric_before_built(self):
metric = metrics_mod.Mean()
metric.update_state([2.0, 4.0])
self.assertEqual(metric.result().numpy(), 3.0)
metric_container = compile_utils.MetricsContainer(metric)
metric_container.reset_states()
self.assertEqual(metric.result().numpy(), 0.0)
if __name__ == '__main__':
ops.enable_eager_execution()