Add reset_states method to LossesContainer and MetricsContainer.
PiperOrigin-RevId: 358259524 Change-Id: I875ad8434cccd6ca47fc41f9e3a7328cd2fcecb6
This commit is contained in:
parent
c9240bd8ed
commit
aed7a7b5e8
@ -250,6 +250,15 @@ class LossesContainer(Container):
|
||||
# Ok for a model to have no compiled loss.
|
||||
return array_ops.zeros(shape=())
|
||||
|
||||
def reset_states(self):
|
||||
"""Resets the state of loss metrics."""
|
||||
if not self._built:
|
||||
return
|
||||
metrics = [self._loss_metric] + nest.flatten(self._per_output_metrics)
|
||||
for metric_obj in metrics:
|
||||
if metric_obj is not None:
|
||||
metric_obj.reset_states()
|
||||
|
||||
def _get_loss_object(self, loss):
|
||||
"""Returns a `Loss` object.
|
||||
|
||||
@ -434,6 +443,21 @@ class MetricsContainer(Container):
|
||||
continue
|
||||
weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
|
||||
|
||||
def reset_states(self):
|
||||
"""Resets the state of all `Metric`s in this container."""
|
||||
if self._built:
|
||||
metrics = self._metrics_in_order
|
||||
else:
|
||||
# If the user supplied `Metric` objects directly, we should
|
||||
# reset those. This could also contain `str`s or `function`s
|
||||
# though.
|
||||
metrics = nest.flatten(self._user_metrics) + nest.flatten(
|
||||
self._user_weighted_metrics)
|
||||
|
||||
for metric_obj in metrics:
|
||||
if isinstance(metric_obj, metrics_mod.Metric):
|
||||
metric_obj.reset_states()
|
||||
|
||||
def _get_metric_objects(self, metrics, y_t, y_p):
|
||||
"""Convert user-supplied metrics to `Metric` objects."""
|
||||
metrics = nest.flatten(metrics)
|
||||
|
@ -50,6 +50,9 @@ class LossesContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(loss_metric.name, 'loss')
|
||||
self.assertEqual(loss_metric.result().numpy(), 1.)
|
||||
|
||||
loss_container.reset_states()
|
||||
self.assertEqual(loss_metric.result().numpy(), 0.)
|
||||
|
||||
def test_loss_list(self):
|
||||
loss_container = compile_utils.LossesContainer(['mse', 'mae'], [1, 0.5])
|
||||
|
||||
@ -76,6 +79,11 @@ class LossesContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(output_2_metric.name, 'output_2_loss')
|
||||
self.assertEqual(output_2_metric.result().numpy(), 0.5)
|
||||
|
||||
loss_container.reset_states()
|
||||
self.assertEqual(loss_metric.result().numpy(), 0)
|
||||
self.assertEqual(output_1_metric.result().numpy(), 0)
|
||||
self.assertEqual(output_2_metric.result().numpy(), 0)
|
||||
|
||||
def test_loss_dict(self):
|
||||
loss_container = compile_utils.LossesContainer(
|
||||
{
|
||||
@ -108,6 +116,11 @@ class LossesContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(out2_metric.name, 'out2_loss')
|
||||
self.assertEqual(out2_metric.result().numpy(), 0.5)
|
||||
|
||||
loss_container.reset_states()
|
||||
self.assertEqual(loss_metric.result().numpy(), 0)
|
||||
self.assertEqual(out1_metric.result().numpy(), 0)
|
||||
self.assertEqual(out2_metric.result().numpy(), 0)
|
||||
|
||||
def test_loss_partial_dict_with_output_names(self):
|
||||
loss_container = compile_utils.LossesContainer(
|
||||
{'out2': 'mae'}, {'out2': 1.}, output_names=['out1', 'out2'])
|
||||
@ -400,6 +413,9 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(metric.name, 'mse')
|
||||
self.assertEqual(metric.result().numpy(), 1.)
|
||||
|
||||
metric_container.reset_states()
|
||||
self.assertEqual(metric.result().numpy(), 0.)
|
||||
|
||||
def test_list_of_metrics_one_output(self):
|
||||
metric_container = compile_utils.MetricsContainer(['mse', 'mae'])
|
||||
y_t, y_p = 2 * array_ops.ones((10, 5)), array_ops.zeros((10, 5))
|
||||
@ -414,6 +430,10 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(mae_metric.name, 'mae')
|
||||
self.assertEqual(mae_metric.result().numpy(), 2.)
|
||||
|
||||
metric_container.reset_states()
|
||||
self.assertEqual(mse_metric.result().numpy(), 0.)
|
||||
self.assertEqual(mae_metric.result().numpy(), 0.)
|
||||
|
||||
def test_list_of_metrics_list_of_outputs(self):
|
||||
metric_container = compile_utils.MetricsContainer(
|
||||
metrics=['mse', 'mae'], # Should broadcast to both outputs.
|
||||
@ -495,6 +515,12 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae')
|
||||
self.assertEqual(weighted_mae_metric.result().numpy(), 2.)
|
||||
|
||||
metric_container.reset_states()
|
||||
self.assertEqual(mse_metric.result().numpy(), 0.)
|
||||
self.assertEqual(weighted_mse_metric.result().numpy(), 0.)
|
||||
self.assertEqual(mae_metric.result().numpy(), 0.)
|
||||
self.assertEqual(weighted_mae_metric.result().numpy(), 0.)
|
||||
|
||||
def test_metric_partial_dict_with_output_names(self):
|
||||
metric_container = compile_utils.MetricsContainer(
|
||||
{'out2': 'mae'}, output_names=['out1', 'out2'])
|
||||
@ -764,6 +790,15 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn')
|
||||
self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class')
|
||||
|
||||
def test_reset_states_existing_metric_before_built(self):
|
||||
metric = metrics_mod.Mean()
|
||||
metric.update_state([2.0, 4.0])
|
||||
self.assertEqual(metric.result().numpy(), 3.0)
|
||||
|
||||
metric_container = compile_utils.MetricsContainer(metric)
|
||||
metric_container.reset_states()
|
||||
self.assertEqual(metric.result().numpy(), 0.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user