diff --git a/tensorflow/python/keras/engine/compile_utils.py b/tensorflow/python/keras/engine/compile_utils.py
index cc81c91dfb6..48e6312dbb4 100644
--- a/tensorflow/python/keras/engine/compile_utils.py
+++ b/tensorflow/python/keras/engine/compile_utils.py
@@ -250,6 +250,15 @@ class LossesContainer(Container):
       # Ok for a model to have no compiled loss.
       return array_ops.zeros(shape=())
 
+  def reset_states(self):
+    """Resets the state of loss metrics."""
+    if not self._built:
+      return
+    metrics = [self._loss_metric] + nest.flatten(self._per_output_metrics)
+    for metric_obj in metrics:
+      if metric_obj is not None:
+        metric_obj.reset_states()
+
   def _get_loss_object(self, loss):
     """Returns a `Loss` object.
 
@@ -434,6 +443,21 @@ class MetricsContainer(Container):
           continue
         weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
 
+  def reset_states(self):
+    """Resets the state of all `Metric`s in this container."""
+    if self._built:
+      metrics = self._metrics_in_order
+    else:
+      # If the user supplied `Metric` objects directly, we should
+      # reset those. This could also contain `str`s or `function`s
+      # though.
+      metrics = nest.flatten(self._user_metrics) + nest.flatten(
+          self._user_weighted_metrics)
+
+    for metric_obj in metrics:
+      if isinstance(metric_obj, metrics_mod.Metric):
+        metric_obj.reset_states()
+
   def _get_metric_objects(self, metrics, y_t, y_p):
     """Convert user-supplied metrics to `Metric` objects."""
     metrics = nest.flatten(metrics)
diff --git a/tensorflow/python/keras/engine/compile_utils_test.py b/tensorflow/python/keras/engine/compile_utils_test.py
index fdf81ee35ad..21ff55f8b03 100644
--- a/tensorflow/python/keras/engine/compile_utils_test.py
+++ b/tensorflow/python/keras/engine/compile_utils_test.py
@@ -50,6 +50,9 @@ class LossesContainerTest(keras_parameterized.TestCase):
     self.assertEqual(loss_metric.name, 'loss')
     self.assertEqual(loss_metric.result().numpy(), 1.)
 
+    loss_container.reset_states()
+    self.assertEqual(loss_metric.result().numpy(), 0.)
+
   def test_loss_list(self):
     loss_container = compile_utils.LossesContainer(['mse', 'mae'], [1, 0.5])
 
@@ -76,6 +79,11 @@ class LossesContainerTest(keras_parameterized.TestCase):
     self.assertEqual(output_2_metric.name, 'output_2_loss')
     self.assertEqual(output_2_metric.result().numpy(), 0.5)
 
+    loss_container.reset_states()
+    self.assertEqual(loss_metric.result().numpy(), 0)
+    self.assertEqual(output_1_metric.result().numpy(), 0)
+    self.assertEqual(output_2_metric.result().numpy(), 0)
+
   def test_loss_dict(self):
     loss_container = compile_utils.LossesContainer(
         {
@@ -108,6 +116,11 @@ class LossesContainerTest(keras_parameterized.TestCase):
     self.assertEqual(out2_metric.name, 'out2_loss')
     self.assertEqual(out2_metric.result().numpy(), 0.5)
 
+    loss_container.reset_states()
+    self.assertEqual(loss_metric.result().numpy(), 0)
+    self.assertEqual(out1_metric.result().numpy(), 0)
+    self.assertEqual(out2_metric.result().numpy(), 0)
+
   def test_loss_partial_dict_with_output_names(self):
     loss_container = compile_utils.LossesContainer(
         {'out2': 'mae'}, {'out2': 1.}, output_names=['out1', 'out2'])
@@ -400,6 +413,9 @@ class MetricsContainerTest(keras_parameterized.TestCase):
     self.assertEqual(metric.name, 'mse')
     self.assertEqual(metric.result().numpy(), 1.)
 
+    metric_container.reset_states()
+    self.assertEqual(metric.result().numpy(), 0.)
+
   def test_list_of_metrics_one_output(self):
     metric_container = compile_utils.MetricsContainer(['mse', 'mae'])
     y_t, y_p = 2 * array_ops.ones((10, 5)), array_ops.zeros((10, 5))
@@ -414,6 +430,10 @@ class MetricsContainerTest(keras_parameterized.TestCase):
     self.assertEqual(mae_metric.name, 'mae')
     self.assertEqual(mae_metric.result().numpy(), 2.)
 
+    metric_container.reset_states()
+    self.assertEqual(mse_metric.result().numpy(), 0.)
+    self.assertEqual(mae_metric.result().numpy(), 0.)
+
   def test_list_of_metrics_list_of_outputs(self):
     metric_container = compile_utils.MetricsContainer(
         metrics=['mse', 'mae'],  # Should broadcast to both outputs.
@@ -495,6 +515,12 @@ class MetricsContainerTest(keras_parameterized.TestCase):
     self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae')
     self.assertEqual(weighted_mae_metric.result().numpy(), 2.)
 
+    metric_container.reset_states()
+    self.assertEqual(mse_metric.result().numpy(), 0.)
+    self.assertEqual(weighted_mse_metric.result().numpy(), 0.)
+    self.assertEqual(mae_metric.result().numpy(), 0.)
+    self.assertEqual(weighted_mae_metric.result().numpy(), 0.)
+
   def test_metric_partial_dict_with_output_names(self):
     metric_container = compile_utils.MetricsContainer(
         {'out2': 'mae'}, output_names=['out1', 'out2'])
@@ -764,6 +790,15 @@ class MetricsContainerTest(keras_parameterized.TestCase):
     self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn')
     self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class')
 
+  def test_reset_states_existing_metric_before_built(self):
+    metric = metrics_mod.Mean()
+    metric.update_state([2.0, 4.0])
+    self.assertEqual(metric.result().numpy(), 3.0)
+
+    metric_container = compile_utils.MetricsContainer(metric)
+    metric_container.reset_states()
+    self.assertEqual(metric.result().numpy(), 0.0)
+
 
 if __name__ == '__main__':
   ops.enable_eager_execution()