Add reset_states method to LossesContainer and MetricsContainer.
PiperOrigin-RevId: 358259524 Change-Id: I875ad8434cccd6ca47fc41f9e3a7328cd2fcecb6
This commit is contained in:
parent
c9240bd8ed
commit
aed7a7b5e8
@ -250,6 +250,15 @@ class LossesContainer(Container):
|
|||||||
# Ok for a model to have no compiled loss.
|
# Ok for a model to have no compiled loss.
|
||||||
return array_ops.zeros(shape=())
|
return array_ops.zeros(shape=())
|
||||||
|
|
||||||
|
def reset_states(self):
|
||||||
|
"""Resets the state of loss metrics."""
|
||||||
|
if not self._built:
|
||||||
|
return
|
||||||
|
metrics = [self._loss_metric] + nest.flatten(self._per_output_metrics)
|
||||||
|
for metric_obj in metrics:
|
||||||
|
if metric_obj is not None:
|
||||||
|
metric_obj.reset_states()
|
||||||
|
|
||||||
def _get_loss_object(self, loss):
|
def _get_loss_object(self, loss):
|
||||||
"""Returns a `Loss` object.
|
"""Returns a `Loss` object.
|
||||||
|
|
||||||
@ -434,6 +443,21 @@ class MetricsContainer(Container):
|
|||||||
continue
|
continue
|
||||||
weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
|
weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
|
||||||
|
|
||||||
|
def reset_states(self):
|
||||||
|
"""Resets the state of all `Metric`s in this container."""
|
||||||
|
if self._built:
|
||||||
|
metrics = self._metrics_in_order
|
||||||
|
else:
|
||||||
|
# If the user supplied `Metric` objects directly, we should
|
||||||
|
# reset those. This could also contain `str`s or `function`s
|
||||||
|
# though.
|
||||||
|
metrics = nest.flatten(self._user_metrics) + nest.flatten(
|
||||||
|
self._user_weighted_metrics)
|
||||||
|
|
||||||
|
for metric_obj in metrics:
|
||||||
|
if isinstance(metric_obj, metrics_mod.Metric):
|
||||||
|
metric_obj.reset_states()
|
||||||
|
|
||||||
def _get_metric_objects(self, metrics, y_t, y_p):
|
def _get_metric_objects(self, metrics, y_t, y_p):
|
||||||
"""Convert user-supplied metrics to `Metric` objects."""
|
"""Convert user-supplied metrics to `Metric` objects."""
|
||||||
metrics = nest.flatten(metrics)
|
metrics = nest.flatten(metrics)
|
||||||
|
@ -50,6 +50,9 @@ class LossesContainerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(loss_metric.name, 'loss')
|
self.assertEqual(loss_metric.name, 'loss')
|
||||||
self.assertEqual(loss_metric.result().numpy(), 1.)
|
self.assertEqual(loss_metric.result().numpy(), 1.)
|
||||||
|
|
||||||
|
loss_container.reset_states()
|
||||||
|
self.assertEqual(loss_metric.result().numpy(), 0.)
|
||||||
|
|
||||||
def test_loss_list(self):
|
def test_loss_list(self):
|
||||||
loss_container = compile_utils.LossesContainer(['mse', 'mae'], [1, 0.5])
|
loss_container = compile_utils.LossesContainer(['mse', 'mae'], [1, 0.5])
|
||||||
|
|
||||||
@ -76,6 +79,11 @@ class LossesContainerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(output_2_metric.name, 'output_2_loss')
|
self.assertEqual(output_2_metric.name, 'output_2_loss')
|
||||||
self.assertEqual(output_2_metric.result().numpy(), 0.5)
|
self.assertEqual(output_2_metric.result().numpy(), 0.5)
|
||||||
|
|
||||||
|
loss_container.reset_states()
|
||||||
|
self.assertEqual(loss_metric.result().numpy(), 0)
|
||||||
|
self.assertEqual(output_1_metric.result().numpy(), 0)
|
||||||
|
self.assertEqual(output_2_metric.result().numpy(), 0)
|
||||||
|
|
||||||
def test_loss_dict(self):
|
def test_loss_dict(self):
|
||||||
loss_container = compile_utils.LossesContainer(
|
loss_container = compile_utils.LossesContainer(
|
||||||
{
|
{
|
||||||
@ -108,6 +116,11 @@ class LossesContainerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(out2_metric.name, 'out2_loss')
|
self.assertEqual(out2_metric.name, 'out2_loss')
|
||||||
self.assertEqual(out2_metric.result().numpy(), 0.5)
|
self.assertEqual(out2_metric.result().numpy(), 0.5)
|
||||||
|
|
||||||
|
loss_container.reset_states()
|
||||||
|
self.assertEqual(loss_metric.result().numpy(), 0)
|
||||||
|
self.assertEqual(out1_metric.result().numpy(), 0)
|
||||||
|
self.assertEqual(out2_metric.result().numpy(), 0)
|
||||||
|
|
||||||
def test_loss_partial_dict_with_output_names(self):
|
def test_loss_partial_dict_with_output_names(self):
|
||||||
loss_container = compile_utils.LossesContainer(
|
loss_container = compile_utils.LossesContainer(
|
||||||
{'out2': 'mae'}, {'out2': 1.}, output_names=['out1', 'out2'])
|
{'out2': 'mae'}, {'out2': 1.}, output_names=['out1', 'out2'])
|
||||||
@ -400,6 +413,9 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(metric.name, 'mse')
|
self.assertEqual(metric.name, 'mse')
|
||||||
self.assertEqual(metric.result().numpy(), 1.)
|
self.assertEqual(metric.result().numpy(), 1.)
|
||||||
|
|
||||||
|
metric_container.reset_states()
|
||||||
|
self.assertEqual(metric.result().numpy(), 0.)
|
||||||
|
|
||||||
def test_list_of_metrics_one_output(self):
|
def test_list_of_metrics_one_output(self):
|
||||||
metric_container = compile_utils.MetricsContainer(['mse', 'mae'])
|
metric_container = compile_utils.MetricsContainer(['mse', 'mae'])
|
||||||
y_t, y_p = 2 * array_ops.ones((10, 5)), array_ops.zeros((10, 5))
|
y_t, y_p = 2 * array_ops.ones((10, 5)), array_ops.zeros((10, 5))
|
||||||
@ -414,6 +430,10 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(mae_metric.name, 'mae')
|
self.assertEqual(mae_metric.name, 'mae')
|
||||||
self.assertEqual(mae_metric.result().numpy(), 2.)
|
self.assertEqual(mae_metric.result().numpy(), 2.)
|
||||||
|
|
||||||
|
metric_container.reset_states()
|
||||||
|
self.assertEqual(mse_metric.result().numpy(), 0.)
|
||||||
|
self.assertEqual(mae_metric.result().numpy(), 0.)
|
||||||
|
|
||||||
def test_list_of_metrics_list_of_outputs(self):
|
def test_list_of_metrics_list_of_outputs(self):
|
||||||
metric_container = compile_utils.MetricsContainer(
|
metric_container = compile_utils.MetricsContainer(
|
||||||
metrics=['mse', 'mae'], # Should broadcast to both outputs.
|
metrics=['mse', 'mae'], # Should broadcast to both outputs.
|
||||||
@ -495,6 +515,12 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae')
|
self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae')
|
||||||
self.assertEqual(weighted_mae_metric.result().numpy(), 2.)
|
self.assertEqual(weighted_mae_metric.result().numpy(), 2.)
|
||||||
|
|
||||||
|
metric_container.reset_states()
|
||||||
|
self.assertEqual(mse_metric.result().numpy(), 0.)
|
||||||
|
self.assertEqual(weighted_mse_metric.result().numpy(), 0.)
|
||||||
|
self.assertEqual(mae_metric.result().numpy(), 0.)
|
||||||
|
self.assertEqual(weighted_mae_metric.result().numpy(), 0.)
|
||||||
|
|
||||||
def test_metric_partial_dict_with_output_names(self):
|
def test_metric_partial_dict_with_output_names(self):
|
||||||
metric_container = compile_utils.MetricsContainer(
|
metric_container = compile_utils.MetricsContainer(
|
||||||
{'out2': 'mae'}, output_names=['out1', 'out2'])
|
{'out2': 'mae'}, output_names=['out1', 'out2'])
|
||||||
@ -764,6 +790,15 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn')
|
self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn')
|
||||||
self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class')
|
self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class')
|
||||||
|
|
||||||
|
def test_reset_states_existing_metric_before_built(self):
|
||||||
|
metric = metrics_mod.Mean()
|
||||||
|
metric.update_state([2.0, 4.0])
|
||||||
|
self.assertEqual(metric.result().numpy(), 3.0)
|
||||||
|
|
||||||
|
metric_container = compile_utils.MetricsContainer(metric)
|
||||||
|
metric_container.reset_states()
|
||||||
|
self.assertEqual(metric.result().numpy(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
Reference in New Issue
Block a user