diff --git a/tensorflow/python/keras/engine/compile_utils.py b/tensorflow/python/keras/engine/compile_utils.py index cc81c91dfb6..48e6312dbb4 100644 --- a/tensorflow/python/keras/engine/compile_utils.py +++ b/tensorflow/python/keras/engine/compile_utils.py @@ -250,6 +250,15 @@ class LossesContainer(Container): # Ok for a model to have no compiled loss. return array_ops.zeros(shape=()) + def reset_states(self): + """Resets the state of loss metrics.""" + if not self._built: + return + metrics = [self._loss_metric] + nest.flatten(self._per_output_metrics) + for metric_obj in metrics: + if metric_obj is not None: + metric_obj.reset_states() + def _get_loss_object(self, loss): """Returns a `Loss` object. @@ -434,6 +443,21 @@ class MetricsContainer(Container): continue weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw) + def reset_states(self): + """Resets the state of all `Metric`s in this container.""" + if self._built: + metrics = self._metrics_in_order + else: + # If the user supplied `Metric` objects directly, we should + # reset those. This could also contain `str`s or `function`s + # though. + metrics = nest.flatten(self._user_metrics) + nest.flatten( + self._user_weighted_metrics) + + for metric_obj in metrics: + if isinstance(metric_obj, metrics_mod.Metric): + metric_obj.reset_states() + def _get_metric_objects(self, metrics, y_t, y_p): """Convert user-supplied metrics to `Metric` objects.""" metrics = nest.flatten(metrics) diff --git a/tensorflow/python/keras/engine/compile_utils_test.py b/tensorflow/python/keras/engine/compile_utils_test.py index fdf81ee35ad..21ff55f8b03 100644 --- a/tensorflow/python/keras/engine/compile_utils_test.py +++ b/tensorflow/python/keras/engine/compile_utils_test.py @@ -50,6 +50,9 @@ class LossesContainerTest(keras_parameterized.TestCase): self.assertEqual(loss_metric.name, 'loss') self.assertEqual(loss_metric.result().numpy(), 1.) + loss_container.reset_states() + self.assertEqual(loss_metric.result().numpy(), 0.) + def test_loss_list(self): loss_container = compile_utils.LossesContainer(['mse', 'mae'], [1, 0.5]) @@ -76,6 +79,11 @@ class LossesContainerTest(keras_parameterized.TestCase): self.assertEqual(output_2_metric.name, 'output_2_loss') self.assertEqual(output_2_metric.result().numpy(), 0.5) + loss_container.reset_states() + self.assertEqual(loss_metric.result().numpy(), 0) + self.assertEqual(output_1_metric.result().numpy(), 0) + self.assertEqual(output_2_metric.result().numpy(), 0) + def test_loss_dict(self): loss_container = compile_utils.LossesContainer( { @@ -108,6 +116,11 @@ class LossesContainerTest(keras_parameterized.TestCase): self.assertEqual(out2_metric.name, 'out2_loss') self.assertEqual(out2_metric.result().numpy(), 0.5) + loss_container.reset_states() + self.assertEqual(loss_metric.result().numpy(), 0) + self.assertEqual(out1_metric.result().numpy(), 0) + self.assertEqual(out2_metric.result().numpy(), 0) + def test_loss_partial_dict_with_output_names(self): loss_container = compile_utils.LossesContainer( {'out2': 'mae'}, {'out2': 1.}, output_names=['out1', 'out2']) @@ -400,6 +413,9 @@ class MetricsContainerTest(keras_parameterized.TestCase): self.assertEqual(metric.name, 'mse') self.assertEqual(metric.result().numpy(), 1.) + metric_container.reset_states() + self.assertEqual(metric.result().numpy(), 0.) + def test_list_of_metrics_one_output(self): metric_container = compile_utils.MetricsContainer(['mse', 'mae']) y_t, y_p = 2 * array_ops.ones((10, 5)), array_ops.zeros((10, 5)) @@ -414,6 +430,10 @@ class MetricsContainerTest(keras_parameterized.TestCase): self.assertEqual(mae_metric.name, 'mae') self.assertEqual(mae_metric.result().numpy(), 2.) + metric_container.reset_states() + self.assertEqual(mse_metric.result().numpy(), 0.) + self.assertEqual(mae_metric.result().numpy(), 0.) + def test_list_of_metrics_list_of_outputs(self): metric_container = compile_utils.MetricsContainer( metrics=['mse', 'mae'], # Should broadcast to both outputs. @@ -495,6 +515,12 @@ class MetricsContainerTest(keras_parameterized.TestCase): self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae') self.assertEqual(weighted_mae_metric.result().numpy(), 2.) + metric_container.reset_states() + self.assertEqual(mse_metric.result().numpy(), 0.) + self.assertEqual(weighted_mse_metric.result().numpy(), 0.) + self.assertEqual(mae_metric.result().numpy(), 0.) + self.assertEqual(weighted_mae_metric.result().numpy(), 0.) + def test_metric_partial_dict_with_output_names(self): metric_container = compile_utils.MetricsContainer( {'out2': 'mae'}, output_names=['out1', 'out2']) @@ -764,6 +790,15 @@ class MetricsContainerTest(keras_parameterized.TestCase): self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn') self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class') + def test_reset_states_existing_metric_before_built(self): + metric = metrics_mod.Mean() + metric.update_state([2.0, 4.0]) + self.assertEqual(metric.result().numpy(), 3.0) + + metric_container = compile_utils.MetricsContainer(metric) + metric_container.reset_states() + self.assertEqual(metric.result().numpy(), 0.0) + if __name__ == '__main__': ops.enable_eager_execution()