diff --git a/tensorflow/python/keras/callbacks_v1.py b/tensorflow/python/keras/callbacks_v1.py
index 5c0c3ff6e96..251fb3476dc 100644
--- a/tensorflow/python/keras/callbacks_v1.py
+++ b/tensorflow/python/keras/callbacks_v1.py
@@ -245,8 +245,8 @@ class TensorBoard(callbacks.TensorBoard):
     # visualize embeddings.
     if self.embeddings_freq and self.embeddings_data is not None:
       # Avoid circular dependency.
-      from tensorflow.python.keras.engine import training_utils_v1  # pylint: disable=g-import-not-at-top
-      self.embeddings_data = training_utils_v1.standardize_input_data(
+      from tensorflow.python.keras.engine import training_utils  # pylint: disable=g-import-not-at-top
+      self.embeddings_data = training_utils.standardize_input_data(
           self.embeddings_data, model.input_names)
 
       # If embedding_layer_names are not provided, get all of the embedding
diff --git a/tensorflow/python/keras/distribute/distributed_training_utils_v1.py b/tensorflow/python/keras/distribute/distributed_training_utils_v1.py
index 2e7a8299e43..83426016412 100644
--- a/tensorflow/python/keras/distribute/distributed_training_utils_v1.py
+++ b/tensorflow/python/keras/distribute/distributed_training_utils_v1.py
@@ -36,9 +36,9 @@ from tensorflow.python.framework import tensor_util
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import callbacks
 from tensorflow.python.keras import metrics as metrics_module
-from tensorflow.python.keras import optimizers
+from tensorflow.python.keras import optimizer_v1
 from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
-from tensorflow.python.keras.engine import training_utils_v1
+from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.keras.utils import tf_contextlib
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
@@ -639,9 +639,9 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
     # TODO(b/124535720): Remove once this standarize data logic is shared with
     # main flow.
     inputs, targets = nest.map_structure(
-        training_utils_v1.standardize_single_array, (inputs, targets))
+        training_utils.standardize_single_array, (inputs, targets))
   else:
-    inputs = training_utils_v1.ModelInputs(inputs).as_list()
+    inputs = training_utils.ModelInputs(inputs).as_list()
 
   if mode == ModeKeys.PREDICT:
     sample_weights = []
@@ -779,7 +779,7 @@ def _clone_and_build_model(model, mode, inputs=None, targets=None):
   cloned_model = models.clone_model(model, input_tensors=inputs)
 
   # Compile and build model.
-  if isinstance(model.optimizer, optimizers.TFOptimizer):
+  if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
     optimizer = model.optimizer
   else:
     optimizer_config = model.optimizer.get_config()
diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD
index eeeb55cd623..4eef969bedf 100644
--- a/tensorflow/python/keras/engine/BUILD
+++ b/tensorflow/python/keras/engine/BUILD
@@ -38,7 +38,6 @@ py_library(
         "training_eager_v1.py",
         "training_generator_v1.py",
         "training_utils.py",
-        "training_utils_v1.py",
         "training_v1.py",
     ],
     srcs_version = "PY2AND3",
@@ -542,9 +541,9 @@ tf_py_test(
 )
 
 tf_py_test(
-    name = "training_utils_v1_test",
+    name = "training_utils_test",
     size = "medium",
-    srcs = ["training_utils_v1_test.py"],
+    srcs = ["training_utils_test.py"],
     python_version = "PY3",
     tags = [
         "no_oss",  # TODO(b/135021748) reenable
diff --git a/tensorflow/python/keras/engine/training_arrays_v1.py b/tensorflow/python/keras/engine/training_arrays_v1.py
index df4b8cd6930..4cdbcc9a02f 100644
--- a/tensorflow/python/keras/engine/training_arrays_v1.py
+++ b/tensorflow/python/keras/engine/training_arrays_v1.py
@@ -30,7 +30,7 @@ from tensorflow.python.framework import errors
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import callbacks as cbks
 from tensorflow.python.keras.distribute import distributed_training_utils_v1
-from tensorflow.python.keras.engine import training_utils_v1
+from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.utils.generic_utils import make_batches
 from tensorflow.python.keras.utils.generic_utils import slice_arrays
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
@@ -139,7 +139,7 @@ def model_iteration(model,
   if is_dataset:
     if steps_per_epoch is None:
       reset_dataset_after_each_epoch = True
-      steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
+      steps_per_epoch = training_utils.infer_steps_for_dataset(
           model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name)
     input_iterator = _get_iterator(inputs, model._distribution_strategy)
 
@@ -154,7 +154,7 @@ def model_iteration(model,
   do_validation = val_inputs is not None
 
   # Convert Eager Tensors to NumPy arrays to support batching/shuffling.
-  inputs, targets, sample_weights = training_utils_v1. \
+  inputs, targets, sample_weights = training_utils. \
       convert_eager_tensors_to_numpy((inputs, targets, sample_weights))
 
   # Prepare input data.
@@ -197,7 +197,7 @@ def model_iteration(model,
       # model_iteration() call, it will not trigger the dataset-input path
       # that determines the number of steps required. To avoid this issue,
       # set validation_steps here if validation_steps is None.
-      validation_steps = training_utils_v1.infer_steps_for_dataset(
+      validation_steps = training_utils.infer_steps_for_dataset(
           model,
           val_inputs,
           validation_steps,
@@ -240,12 +240,12 @@ def model_iteration(model,
 
   # Select aggregation method.
   if mode == ModeKeys.PREDICT:
-    aggregator = training_utils_v1.OutputsAggregator(
+    aggregator = training_utils.OutputsAggregator(
         use_steps,
         num_samples=None if steps_per_epoch else num_samples_or_steps,
         steps=steps_per_epoch)
   else:
-    aggregator = training_utils_v1.MetricsAggregator(
+    aggregator = training_utils.MetricsAggregator(
         use_steps,
         num_samples=None if steps_per_epoch else num_samples_or_steps,
         steps=steps_per_epoch)
@@ -350,7 +350,7 @@ def model_iteration(model,
       # Sample-wise loop.
       index_array = np.arange(num_samples_or_steps)
       if shuffle == 'batch':
-        index_array = training_utils_v1.batch_shuffle(index_array, batch_size)
+        index_array = training_utils.batch_shuffle(index_array, batch_size)
       elif shuffle:
         np.random.shuffle(index_array)
       batches = make_batches(num_samples_or_steps, batch_size)
@@ -409,7 +409,7 @@ def model_iteration(model,
 
     # Run the test loop every `validation_freq` epochs during training.
     if (do_validation and
-        training_utils_v1.should_run_validation(validation_freq, epoch) and
+        training_utils.should_run_validation(validation_freq, epoch) and
         not callbacks.model.stop_training):
 
       if model._compile_distribution:
@@ -483,8 +483,8 @@ def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch):
   """Returns total number of samples (when training in batch mode) or steps."""
   if steps_per_epoch:
     return steps_per_epoch
-  return training_utils_v1.check_num_samples(ins, batch_size, steps_per_epoch,
-                                             'steps_per_epoch')
+  return training_utils.check_num_samples(ins, batch_size, steps_per_epoch,
+                                          'steps_per_epoch')
 
 
 def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
@@ -527,7 +527,7 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
         inputs,
         extract_tensors_from_dataset=True)
 
-  inputs = training_utils_v1.ModelInputs(inputs).as_list()
+  inputs = training_utils.ModelInputs(inputs).as_list()
   targets = list(targets or [])
   sample_weights = list(sample_weights or [])
   ins = inputs + targets + sample_weights
@@ -541,7 +541,7 @@ def _get_iterator(inputs, distribution_strategy=None):
   if distribution_strategy:
     return distributed_training_utils_v1.get_iterator(
         inputs, distribution_strategy)
-  return training_utils_v1.get_iterator(inputs)
+  return training_utils.get_iterator(inputs)
 
 
 def _reinitialize_iterator(iterator, distribution_strategy=None):
@@ -549,7 +549,7 @@ def _reinitialize_iterator(iterator, distribution_strategy=None):
     distributed_training_utils_v1.initialize_iterator(
         iterator, distribution_strategy)
   else:
-    training_utils_v1.initialize_iterator(iterator)
+    training_utils.initialize_iterator(iterator)
 
 
 def _make_execution_function(model, mode):
@@ -593,7 +593,7 @@ predict_loop = functools.partial(
     model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
 
 
-class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop):
+class ArrayLikeTrainingLoop(training_utils.TrainingLoop):
   """TrainingLoop that handle inputs like array.
 
   This is the default handler for most of the input data types, includes
@@ -639,9 +639,9 @@ class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop):
       val_x, val_y, val_sample_weights = model._prepare_validation_data(
           validation_data, batch_size, validation_steps)
     elif validation_split and 0. < validation_split < 1.:
-      (x, y, sample_weights, val_x, val_y, val_sample_weights
-      ) = training_utils_v1.split_training_and_validation_data(
-          x, y, sample_weights, validation_split)
+      (x, y, sample_weights, val_x, val_y,
+       val_sample_weights) = training_utils.split_training_and_validation_data(
+           x, y, sample_weights, validation_split)
     else:
       if validation_steps:
         raise ValueError('`validation_steps` should not be specified if '
diff --git a/tensorflow/python/keras/engine/training_distributed_v1.py b/tensorflow/python/keras/engine/training_distributed_v1.py
index 6bfcb40c935..22fc19c4f62 100644
--- a/tensorflow/python/keras/engine/training_distributed_v1.py
+++ b/tensorflow/python/keras/engine/training_distributed_v1.py
@@ -35,7 +35,7 @@ from tensorflow.python.keras.distribute import distributed_training_utils as dis
 from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils
 from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
 from tensorflow.python.keras.engine import training_arrays_v1
-from tensorflow.python.keras.engine import training_utils_v1
+from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.utils.generic_utils import Progbar
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
 from tensorflow.python.ops import array_ops
@@ -258,7 +258,7 @@ def experimental_tpu_fit_loop(model,
         break
 
     if (do_validation and
-        training_utils_v1.should_run_validation(validation_freq, epoch)):
+        training_utils.should_run_validation(validation_freq, epoch)):
       logging.info('Running validation at fit epoch: %s', epoch)
 
       if model._compile_distribution:
@@ -575,7 +575,7 @@ def experimental_tpu_predict_loop(model,
   return prediction_result
 
 
-class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
+class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
   """Training loop for distribution strategy with single worker."""
 
   def fit(self,
@@ -630,8 +630,8 @@ class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
 
     val_dataset = None
     if validation_data:
-      val_x, val_y, val_sample_weights = (
-          training_utils_v1.unpack_validation_data(validation_data))
+      val_x, val_y, val_sample_weights = training_utils.unpack_validation_data(
+          validation_data)
       dist_utils.validate_inputs(val_x, val_y)
       _, validation_steps = dist_utils.process_batch_and_step_size(
           model._distribution_strategy, val_x, batch_size, validation_steps,
@@ -650,7 +650,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
                        'distribution strategies.')
 
     if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
-      steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
+      steps_per_epoch = training_utils.infer_steps_for_dataset(
           model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
       if steps_per_epoch is None:
         raise ValueError('Number of steps could not be inferred from the data, '
@@ -707,7 +707,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
         allow_partial_batch=True)
 
     if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
-      steps = training_utils_v1.infer_steps_for_dataset(
+      steps = training_utils.infer_steps_for_dataset(
           model, dataset, steps, steps_name='steps')
       if steps is None:
         raise ValueError('Number of steps could not be inferred from the data, '
@@ -744,7 +744,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
         batch_size=batch_size,
         allow_partial_batch=True)
     if dist_utils_v2.is_tpu_strategy(model._distribution_strategy):
-      steps = training_utils_v1.infer_steps_for_dataset(
+      steps = training_utils.infer_steps_for_dataset(
           model, dataset, steps, steps_name='steps')
       if steps is None:
         raise ValueError('Number of steps could not be inferred from the data, '
@@ -780,7 +780,7 @@ def _train_with_multi_worker(method):
   return wrapper
 
 
-class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop):
+class DistributionMultiWorkerTrainingLoop(training_utils.TrainingLoop):
   """Training loop for distribution strategy with multiple worker."""
 
   def __init__(self, single_worker_loop):
diff --git a/tensorflow/python/keras/engine/training_eager_v1.py b/tensorflow/python/keras/engine/training_eager_v1.py
index 2acd7493cb0..09e6f0d1edd 100644
--- a/tensorflow/python/keras/engine/training_eager_v1.py
+++ b/tensorflow/python/keras/engine/training_eager_v1.py
@@ -25,7 +25,6 @@ from tensorflow.python.eager.backprop import GradientTape
 from tensorflow.python.framework import ops
 from tensorflow.python.keras import backend
 from tensorflow.python.keras.engine import training_utils
-from tensorflow.python.keras.engine import training_utils_v1
 from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
 from tensorflow.python.keras.utils import losses_utils
 from tensorflow.python.ops import math_ops
@@ -128,12 +127,11 @@ def _model_loss(model,
   outs = nest.flatten(outs)
 
   if targets:
-    targets = training_utils_v1.cast_if_floating_dtype_and_mismatch(
-        targets, outs)
+    targets = training_utils.cast_if_floating_dtype_and_mismatch(targets, outs)
   # TODO(sallymatson/psv): check if we should do same mismatch fix for weights
   if sample_weights:
     sample_weights = [
-        training_utils_v1.cast_if_floating_dtype(
+        training_utils.cast_if_floating_dtype(
             ops.convert_to_tensor_v2_with_dispatch(val))
         if val is not None else None for val in sample_weights
     ]
@@ -306,7 +304,7 @@ def train_on_batch(model,
           model output. Could be a empty list when model has only one output.
         'metrics': list of tensors for metric specified.
   """
-  inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)
+  inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
   outs, total_loss, output_losses, masks = (
       _process_single_batch(
           model,
@@ -347,7 +345,7 @@ def test_on_batch(model,
           model output. Could be a empty list when model has only one output.
         'metrics': list of tensors for metric specified.
   """
-  inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)
+  inputs = training_utils.cast_to_model_input_dtypes(inputs, model)
 
   with backend.eager_learning_phase_scope(0):
     outs, total_loss, output_losses, masks = (
diff --git a/tensorflow/python/keras/engine/training_generator_v1.py b/tensorflow/python/keras/engine/training_generator_v1.py
index 9b6fc1577bb..1fcf3ef25e4 100644
--- a/tensorflow/python/keras/engine/training_generator_v1.py
+++ b/tensorflow/python/keras/engine/training_generator_v1.py
@@ -31,7 +31,6 @@ from tensorflow.python.framework import errors
 from tensorflow.python.keras import backend
 from tensorflow.python.keras import callbacks as cbks
 from tensorflow.python.keras.engine import training_utils
-from tensorflow.python.keras.engine import training_utils_v1
 from tensorflow.python.keras.utils import data_utils
 from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.keras.utils.mode_keys import ModeKeys
@@ -133,7 +132,7 @@ def model_iteration(model,
     original_dataset = data
     if steps_per_epoch is None:
       reset_dataset_after_each_epoch = True
-      steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
+      steps_per_epoch = training_utils.infer_steps_for_dataset(
           model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name)
 
   # Convert to a format that supports `next(generator)`.
@@ -180,11 +179,9 @@ def model_iteration(model,
       mode=mode)
 
   if mode == ModeKeys.PREDICT:
-    aggregator = training_utils_v1.OutputsAggregator(
-        True, steps=steps_per_epoch)
+    aggregator = training_utils.OutputsAggregator(True, steps=steps_per_epoch)
   else:
-    aggregator = training_utils_v1.MetricsAggregator(
-        True, steps=steps_per_epoch)
+    aggregator = training_utils.MetricsAggregator(True, steps=steps_per_epoch)
 
   should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
   if should_set_learning_phase:
@@ -296,7 +293,7 @@ def model_iteration(model,
 
     # Run the test loop every epoch during training.
     if (do_validation and
-        training_utils_v1.should_run_validation(validation_freq, epoch) and
+        training_utils.should_run_validation(validation_freq, epoch) and
         not callbacks.model.stop_training):
       val_results = model_iteration(
           model,
@@ -541,7 +538,7 @@ def _get_num_samples_or_steps(data, steps_per_epoch):
   return steps_per_epoch, True
 
 
-class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
+class GeneratorOrSequenceTrainingLoop(training_utils.TrainingLoop):
   """Generator-like.
 
   Input is Python generator, or Sequence object.
@@ -572,7 +569,7 @@ class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
           workers=1,
           use_multiprocessing=False):
     model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
-    training_utils_v1.check_generator_arguments(
+    training_utils.check_generator_arguments(
         y, sample_weight, validation_split=validation_split)
     return fit_generator(
         model,
@@ -605,7 +602,7 @@ class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
                workers=1,
                use_multiprocessing=False):
     model._validate_or_infer_batch_size(batch_size, steps, x)
-    training_utils_v1.check_generator_arguments(y, sample_weight)
+    training_utils.check_generator_arguments(y, sample_weight)
     return evaluate_generator(
         model,
         x,
@@ -638,7 +635,7 @@ class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
         use_multiprocessing=use_multiprocessing)
 
 
-class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
+class EagerDatasetOrIteratorTrainingLoop(training_utils.TrainingLoop):
   """A non-distributed Dataset or iterator in eager execution."""
 
   def fit(self,
@@ -661,11 +658,10 @@ class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
           **kwargs):
     model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
     # Make sure that y, sample_weights, validation_split are not passed.
-    training_utils_v1.validate_dataset_input(x, y, sample_weight,
-                                             validation_split)
+    training_utils.validate_dataset_input(x, y, sample_weight, validation_split)
     if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) and
         shuffle):
-      training_utils_v1.verify_dataset_shuffled(x)
+      training_utils.verify_dataset_shuffled(x)
 
     return fit_generator(
         model,
@@ -695,7 +691,7 @@ class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
                **kwargs):
     model._validate_or_infer_batch_size(batch_size, steps, x)
     # Make sure that y, sample_weights, validation_split are not passed.
-    training_utils_v1.validate_dataset_input(x, y, sample_weight)
+    training_utils.validate_dataset_input(x, y, sample_weight)
     return evaluate_generator(
         model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
 
@@ -712,7 +708,7 @@ class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
         model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks)
 
 
-class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop):
+class GeneratorLikeTrainingLoop(training_utils.TrainingLoop):
   """TrainingLoop that handle inputs like python generator.
 
   This is the default handler for most of the input data types, includes
@@ -759,9 +755,8 @@ class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop):
                                                        validation_steps)
     elif validation_split and 0. < validation_split < 1.:
       (x, y, sample_weights, val_x, val_y,
-       val_sample_weights) = (
-           training_utils_v1.split_training_and_validation_data(
-               x, y, sample_weights, validation_split))
+       val_sample_weights) = training_utils.split_training_and_validation_data(
+           x, y, sample_weights, validation_split)
       validation_data = (val_x, val_y, val_sample_weights)
     else:
       if validation_steps:
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 2d7d57559a6..1f8f8cb1b52 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -44,7 +44,7 @@ from tensorflow.python.keras.callbacks import Callback
 from tensorflow.python.keras.engine import input_layer
 from tensorflow.python.keras.engine import sequential
 from tensorflow.python.keras.engine import training as training_module
-from tensorflow.python.keras.engine import training_utils_v1
+from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.utils import data_utils
 from tensorflow.python.keras.utils import np_utils
 from tensorflow.python.ops import array_ops
@@ -2019,7 +2019,7 @@ class LossWeightingTest(keras_parameterized.TestCase):
           [[0, .4, 1, 1], [2, .4, .3, 1]])
       dataset = dataset_ops.Dataset.from_tensor_slices(sample_weights)
       sample_weights = dataset_ops.make_one_shot_iterator(dataset).get_next()
-      sample_weights = training_utils_v1.standardize_sample_weights(
+      sample_weights = training_utils.standardize_sample_weights(
           sample_weights, model.output_names)
 
       # Update model loss with sample weight tensor.
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 4180c0b7e1d..1df48401f33 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -17,13 +17,350 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import numpy as np
+import abc
+import atexit
+from collections import OrderedDict
+import functools
+import multiprocessing.pool
+import threading
+import time
 
+import numpy as np
+import six
+from six.moves import zip  # pylint: disable=redefined-builtin
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python import tf2
+from tensorflow.python.data.experimental.ops import cardinality
+from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import composite_tensor_utils
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import callbacks as cbks
+from tensorflow.python.keras import losses
+from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.keras.utils import data_utils
 from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.keras.utils import losses_utils
+from tensorflow.python.keras.utils import tf_inspect
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
+from tensorflow.python.util.compat import collections_abc
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Aggregator(object):
+  """Abstract base class used to aggregate batch-level outputs of a loop.
+
+  Attributes:
+    use_steps: Whether the loop is using `step` or `batch_size`.
+    num_samples: Total number of samples: `batch_size * num_batches`.
+    steps: Total number of steps.
+    batch_size: Batch size. It is used for validation checks between inputs and
+      outputs.
+    results: What to return at the end of the aggregation loop.
+  """
+
+  def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None):
+    self.use_steps = use_steps
+    self.num_samples = num_samples
+    self.steps = steps
+    self.batch_size = batch_size
+    self.results = []
+
+  @abc.abstractmethod
+  def create(self, batch_outs):
+    """Creates the initial results from the first batch outputs.
+
+    Arguments:
+      batch_outs: A list of batch-level outputs.
+    """
+    raise NotImplementedError('Must be implemented in subclasses.')
+
+  @abc.abstractmethod
+  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
+    """Aggregates batch-level results into total results.
+
+    Arguments:
+      batch_outs: A list of batch-level outputs.
+      batch_start: The start index of this batch. Always `None` if `use_steps`
+        is `True`.
+      batch_end: The end index of this batch. Always `None` if `use_steps` is
+        `True`.
+    """
+    raise NotImplementedError('Must be implemented in subclasses.')
+
+  @abc.abstractmethod
+  def finalize(self):
+    """Prepares the total results to be returned."""
+    raise NotImplementedError('Must be implemented in subclasses.')
+
+
+class MetricsAggregator(Aggregator):
+  """Aggregator that calculates loss and metrics info.
+
+  Attributes:
+    use_steps: Whether the loop is using `step` or `batch_size`.
+    num_samples: Total number of samples: `batch_size*num_batches`.
+    steps: Total number of steps, ie number of times to iterate over a dataset
+      to cover all samples.
+  """
+
+  def __init__(self, use_steps, num_samples=None, steps=None):
+    super(MetricsAggregator, self).__init__(
+        use_steps=use_steps,
+        num_samples=num_samples,
+        steps=steps,
+        batch_size=None)
+
+  def create(self, batch_outs):
+    self.results = [0.] * len(batch_outs)
+
+  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
+    # Loss.
+    if self.use_steps:
+      self.results[0] += batch_outs[0]
+    else:
+      self.results[0] += batch_outs[0] * (batch_end - batch_start)
+    # Metrics (always stateful, just grab current values.)
+    self.results[1:] = batch_outs[1:]
+
+  def finalize(self):
+    if not self.results:
+      raise ValueError('Empty training data.')
+    self.results[0] /= (self.num_samples or self.steps)
+
+
+class ConcatAggregator(Aggregator):
+  """Combine tensor-likes which cannot be merged on the fly.
+
+  This class expects to aggregate a single tensor-like rather than a nested
+  structure of tensor-likes.
+  """
+
+  def __init__(self, batch_size):
+    self.composite = None
+    super(ConcatAggregator, self).__init__(
+        use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
+
+  def create(self, batch_element):
+    self.composite = composite_tensor_utils.is_composite_or_composite_value(
+        batch_element)
+
+  def aggregate(self, batch_element, batch_start=None, batch_end=None):
+
+    # TODO(psv): Add num_samples check here to detect when output batch
+    # #samples is < batch size and != input batch #samples.
+    if self.batch_size and self.batch_size < batch_element.shape[0]:
+      raise ValueError(
+          'Mismatch between expected batch size and model output batch size. '
+          'Output shape = {}, expected output shape = shape {}'.format(
+              batch_element.shape,
+              (self.batch_size,) + batch_element.shape[1:]))
+    self.results.append(batch_element)
+
+  def finalize(self):
+    # Special case of single batch inference which skips a copy.
+    if len(self.results) == 1:
+      self.results = self.results[0]
+
+    elif self.composite:
+      # TODO(taylorrobie): efficiently concatenate.
+      results = self.results[0]
+      for r in self.results[1:]:
+        results = composite_tensor_utils.append_composite_tensor(results, r)
+      self.results = results
+
+    else:
+      self.results = np.concatenate(self.results, axis=0)
+
+    if isinstance(self.results, ops.EagerTensor):
+      self.results = self.results._numpy()  # pylint: disable=protected-access
+
+
+_COPY_THREADS = 4
+_COPY_POOL = None
+
+
+def get_copy_pool():
+  """Shared threadpool for copying arrays.
+
+  Pool instantiation takes ~ 2ms, so a singleton pool is used rather than
+  creating a pool per SliceAggregator.
+
+  Returns:
+    The global copy threadpool.
+  """
+  global _COPY_POOL
+  if _COPY_POOL is None:
+    _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS)
+    atexit.register(_COPY_POOL.close)
+  return _COPY_POOL
+
+
+class SliceAggregator(Aggregator):
+  """Combine arrays where the final size is known.
+
+  This class expects to aggregate a single tensor-like rather than a nested
+  structure of tensor-likes.
+
+  NumPy copies are an operation that threads handle quite well because all of
+  the heavy lifting is in c and does not need the GIL. Moreover, we can perform
+  lock-free writes to the same buffer in multiple threads because the nature of
+  result aggregation guarantees that either the indices are disjoint or the
+  aggregator will throw an exception in finalize. Moreover, because aggregation
+  is performed on the slowest varying dimension, assignments for a given batch
+  will write to contiguous blocks of memory, further minimizing contention.
+
+  There is, however, some scheduling and context switching overhead which will
+  offset the gains from pipelining the slice assignment. Below a given threshold
+  it is faster to simply assign in the main thread rather than enqueue the
+  assignment in a side thread. The exact threshold will vary from system to
+  system, but the time is not very sensitive to the exact transition so a value
+  of 2 ** 14 was chosen which should be reasonable on most systems.
+  """
+
+  _BINARY_SIZE_THRESHOLD = 2 ** 14
+  _MAX_COPY_SECONDS = 300
+
+  def __init__(self, num_samples, batch_size):
+    self._async_copies = []
+    self._pool = get_copy_pool()
+    self._errors = []
+    super(SliceAggregator, self).__init__(
+        use_steps=False,
+        num_samples=num_samples,
+        steps=None,
+        batch_size=batch_size)
+
+  def create(self, batch_element):
+    # This step does not need to be pipelined because NumPy empty array
+    # initialization is effectively instantaneous.
+    shape = (self.num_samples,) + batch_element.shape[1:]
+    dtype = batch_element.dtype
+    if isinstance(batch_element, ops.EagerTensor):
+      dtype = dtype.as_numpy_dtype
+
+    self.results = np.empty(shape=shape, dtype=dtype)
+
+  def aggregate(self, batch_element, batch_start, batch_end):
+    # Fail early.
+    if self._errors:
+      six.reraise(type(self._errors[0]), self._errors[0])
+
+    # In the special case of single batch inference, no copy is needed.
+    if batch_end - batch_start == self.num_samples:
+      if self.num_samples != batch_element.shape[0]:
+        raise ValueError(
+            'Mismatch between expected batch size and model output batch size. '
+            'Output shape = {}, expected output shape = shape {}'.format(
+                batch_element.shape, self.results.shape))
+
+      self.results = batch_element
+      return
+
+    # This is an approximate threshold, so we don't need to consider the number
+    # of bytes per element.
+    num_elements = np.prod(batch_element.shape)
+    if num_elements < self._BINARY_SIZE_THRESHOLD:
+      self.results[batch_start:batch_end] = batch_element
+    else:
+      is_finished = threading.Event()
+      self._pool.apply_async(
+          self._slice_assign,
+          args=(batch_element, batch_start, batch_end, is_finished))
+      self._async_copies.append(is_finished)
+
+  def _slice_assign(self, batch_element, batch_start, batch_end, is_finished):
+    try:
+      self.results[batch_start:batch_end] = batch_element
+
+    except Exception as e:  # pylint: disable=broad-except
+      # `_slice_assign` should only be called in threads and exceptions raised
+      # in threads do not carry over to the main thread. So instead we perform a
+      # a broad catch in the thread and then store the exception to be re-raised
+      # in the main thread.
+      self._errors.append(e)
+
+    finally:
+      is_finished.set()
+
+  def finalize(self):
+    start_time = time.time()
+    for is_finished in self._async_copies:
+      timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)])
+      if not is_finished.wait(timeout):
+        raise ValueError('Timed out waiting for copy to complete.')
+
+    if self._errors:
+      six.reraise(self._errors[0].__class__, self._errors[0])
+
+
+class OutputsAggregator(Aggregator):
+  """Aggregator that concatenates outputs."""
+
+  _structure = None
+
+  def create(self, batch_outs):
+    # SparseTensorValue is a named tuple which nest will flatten, so we need
+    # to guard it to properly handle the structure.
+    self._structure = nest.get_traverse_shallow_structure(
+        lambda x: not composite_tensor_utils.is_composite_or_composite_value(x),
+        batch_outs)
+    batch_outs = nest.flatten_up_to(self._structure, batch_outs)
+
+    for batch_element in batch_outs:
+      if composite_tensor_utils.is_composite_or_composite_value(batch_element):
+        # If the output is not a ndarray, it will be either a composite tensor
+        # or a composite tensor's Value object. In either case, we can't
+        # allocate an array to hold the object - we'll handle it later.
+        self.results.append(ConcatAggregator(self.batch_size))
+      elif isinstance(batch_element, (np.ndarray, ops.EagerTensor)):
+        self.results.append(
+            (ConcatAggregator(self.batch_size) if self.use_steps else
+             SliceAggregator(self.num_samples, self.batch_size)))
+      else:
+        # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue.
+        # Fail fast rather than trying to concatenate it.
+        raise RuntimeError('Attempted to aggregate unsupported object {}.'
+                           .format(batch_element))
+
+      self.results[-1].create(batch_element)
+
+  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
+    batch_outs = nest.flatten_up_to(self._structure, batch_outs)
+    for batch_element, result in zip(batch_outs, self.results):
+      result.aggregate(batch_element, batch_start, batch_end)
+
+  def finalize(self):
+    for result in self.results:
+      result.finalize()
+    self.results = [i.results for i in self.results]
+    self.results = nest.pack_sequence_as(self._structure, self.results)
+
+
+def get_progbar(model, count_mode, include_metrics=True):
+  """Get Progbar."""
+  if include_metrics:
+    stateful_metric_names = getattr(model, 'metrics_names', None)
+    if stateful_metric_names:
+      stateful_metric_names = stateful_metric_names[1:]  # Exclude `loss`
+  else:
+    stateful_metric_names = None
+  return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names)
 
 
 def slice_arrays(arrays, indices, contiguous=True):
@@ -62,6 +399,245 @@ def slice_arrays(arrays, indices, contiguous=True):
   return slices
 
 
+def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'):
+  """Determine the number of samples provided for training and evaluation.
+
+  The number of samples is not defined when running with `steps`,
+  in which case the number of samples is set to `None`.
+
+  Arguments:
+      ins: List of tensors to be fed to the Keras function.
+      batch_size: Integer batch size or `None` if not defined.
+      steps: Total number of steps (batches of samples) before declaring
+        `_predict_loop` finished. Ignored with the default value of `None`.
+      steps_name: The public API's parameter name for `steps`.
+
+  Raises:
+      ValueError: when `steps` is `None` and the attribute `ins.shape`
+      does not exist. Also raises ValueError when `steps` is not `None`
+      and `batch_size` is not `None` because they are mutually
+      exclusive.
+
+  Returns:
+      When steps is `None`, returns the number of samples to be
+      processed based on the size of the first dimension of the
+      first input numpy array. When steps is not `None` and
+      `batch_size` is `None`, returns `None`.
+  """
+  if steps is not None and batch_size is not None:
+    raise ValueError('If ' + steps_name +
+                     ' is set, the `batch_size` must be None.')
+  if check_steps_argument(ins, steps, steps_name):
+    return None
+
+  if hasattr(ins[0], 'shape'):
+    return int(ins[0].shape[0])
+  return None  # Edge case where ins == [static_learning_phase]
+
+
+def standardize_single_array(x, expected_shape=None):
+  """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1."""
+  if x is None:
+    return None
+
+  if composite_tensor_utils.is_composite_or_composite_value(x):
+    return x
+
+  if isinstance(x, int):
+    raise ValueError(
+        'Expected an array data type but received an integer: {}'.format(x))
+
+  if (x.shape is not None and len(x.shape) == 1 and
+      (expected_shape is None or len(expected_shape) != 1)):
+    if tensor_util.is_tensor(x):
+      x = array_ops.expand_dims(x, axis=1)
+    else:
+      x = np.expand_dims(x, 1)
+  return x
+
+
+def standardize_input_data(data,
+                           names,
+                           shapes=None,
+                           check_batch_axis=True,
+                           exception_prefix=''):
+  """Normalizes inputs and targets provided by users.
+
+  Users may pass data as a list of arrays, dictionary of arrays,
+  or as a single array. We normalize this to an ordered list of
+  arrays (same order as `names`), while checking that the provided
+  arrays have shapes that match the network's expectations.
+
+  Arguments:
+      data: User-provided input data (polymorphic).
+      names: List of expected array names.
+      shapes: Optional list of expected array shapes.
+      check_batch_axis: Boolean; whether to check that the batch axis of the
+        arrays matches the expected value found in `shapes`.
+      exception_prefix: String prefix used for exception formatting.
+
+  Returns:
+      List of standardized input arrays (one array per model input).
+
+  Raises:
+      ValueError: in case of improperly formatted user-provided data.
+  """
+  try:
+    data_len = len(data)
+  except TypeError:
+    # For instance if data is `None` or a symbolic Tensor.
+    data_len = None
+
+  if not names:
+    if data_len and not isinstance(data, dict):
+      raise ValueError(
+          'Error when checking model ' + exception_prefix + ': '
+          'expected no data, but got:', data)
+    return []
+  if data is None:
+    return [None for _ in range(len(names))]
+
+  if isinstance(data, dict):
+    try:
+      data = [
+          data[x].values
+          if data[x].__class__.__name__ == 'DataFrame' else data[x]
+          for x in names
+      ]
+    except KeyError as e:
+      raise ValueError('No data provided for "' + e.args[0] + '". Need data '
+                       'for each key in: ' + str(names))
+  elif isinstance(data, (list, tuple)):
+    if isinstance(data[0], (list, tuple)):
+      data = [np.asarray(d) for d in data]
+    elif len(names) == 1 and isinstance(data[0], (float, int)):
+      data = [np.asarray(data)]
+    else:
+      data = [
+          x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
+      ]
+  else:
+    data = data.values if data.__class__.__name__ == 'DataFrame' else data
+    data = [data]
+
+  if shapes is not None:
+    data = [
+        standardize_single_array(x, shape) for (x, shape) in zip(data, shapes)
+    ]
+  else:
+    data = [standardize_single_array(x) for x in data]
+
+  if len(data) != len(names):
+    if data and hasattr(data[0], 'shape'):
+      raise ValueError('Error when checking model ' + exception_prefix +
+                       ': the list of Numpy arrays that you are passing to '
+                       'your model is not the size the model expected. '
+                       'Expected to see ' + str(len(names)) + ' array(s), ' +
+                       'for inputs ' + str(names) + ' but instead got the '
+                       'following list of ' + str(len(data)) + ' arrays: ' +
+                       str(data)[:200] + '...')
+    elif len(names) > 1:
+      raise ValueError('Error when checking model ' + exception_prefix +
+                       ': you are passing a list as input to your model, '
+                       'but the model expects a list of ' + str(len(names)) +
+                       ' Numpy arrays instead. The list you passed was: ' +
+                       str(data)[:200])
+    elif len(data) == 1 and not hasattr(data[0], 'shape'):
+      raise TypeError('Error when checking model ' + exception_prefix +
+                      ': data should be a Numpy array, or list/dict of '
+                      'Numpy arrays. Found: ' + str(data)[:200] + '...')
+    elif len(names) == 1:
+      data = [np.asarray(data)]
+
+  # Check shapes compatibility.
+  if shapes:
+    for i in range(len(names)):
+      if shapes[i] is not None:
+        if tensor_util.is_tensor(data[i]):
+          tensorshape = data[i].shape
+          if not tensorshape:
+            continue
+          data_shape = tuple(tensorshape.as_list())
+        elif composite_tensor_utils.is_composite_or_composite_value(data[i]):
+          tensorshape = composite_tensor_utils.get_shape(data[i])
+          data_shape = tuple(tensorshape.as_list())
+        else:
+          data_shape = data[i].shape
+
+        shape = shapes[i]
+        if len(data_shape) != len(shape):
+          raise ValueError('Error when checking ' + exception_prefix +
+                           ': expected ' + names[i] + ' to have ' +
+                           str(len(shape)) + ' dimensions, but got array '
+                           'with shape ' + str(data_shape))
+        if not check_batch_axis:
+          data_shape = data_shape[1:]
+          shape = shape[1:]
+        for dim, ref_dim in zip(data_shape, shape):
+          if ref_dim != dim and ref_dim is not None and dim is not None:
+            raise ValueError('Error when checking ' + exception_prefix +
+                             ': expected ' + names[i] + ' to have shape ' +
+                             str(shape) + ' but got array with shape ' +
+                             str(data_shape))
+  return data
+
+
+def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
+  """Maps `sample_weight` or `class_weight` to model outputs.
+
+  Arguments:
+      x_weight: User-provided `sample_weight` or `class_weight` argument.
+      output_names: List of output names (strings) in the model.
+      weight_type: A string used purely for exception printing.
+
+  Returns:
+      A list of `sample_weight` or `class_weight` where there are exactly
+          one element per model output.
+
+  Raises:
+      ValueError: In case of invalid user-provided argument.
+  """
+  if x_weight is None or (isinstance(x_weight, (list, tuple)) and
+                          len(x_weight) == 0):  # pylint: disable=g-explicit-length-test
+    return [None for _ in output_names]
+  if len(output_names) == 1:
+    if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
+      return x_weight
+    if isinstance(x_weight, dict) and output_names[0] in x_weight:
+      return [x_weight[output_names[0]]]
+    else:
+      return [x_weight]
+  if isinstance(x_weight, (list, tuple)):
+    if len(x_weight) != len(output_names):
+      raise ValueError('Provided `' + weight_type + '` was a list of ' +
+                       str(len(x_weight)) + ' elements, but the model has ' +
+                       str(len(output_names)) + ' outputs. '
+                       'You should provide one `' + weight_type + '`'
+                       'array per model output.')
+    return x_weight
+  if isinstance(x_weight, collections_abc.Mapping):
+    generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
+    x_weights = []
+    for name in output_names:
+      x_weights.append(x_weight.get(name))
+    return x_weights
+  else:
+    raise TypeError('The model has multiple outputs, so `' + weight_type + '` '
+                    'should be either a list or a dict. '
+                    'Provided `' + weight_type + '` type not understood: ' +
+                    str(x_weight))
+
+
+def standardize_class_weights(class_weight, output_names):
+  return standardize_sample_or_class_weights(class_weight, output_names,
+                                             'class_weight')
+
+
+def standardize_sample_weights(sample_weight, output_names):
+  return standardize_sample_or_class_weights(sample_weight, output_names,
+                                             'sample_weight')
+
+
 def handle_partial_sample_weights(outputs, sample_weights, sample_weight_modes,
                                   check_all_flat=False):
   """Adds 1.0 as sample weights for the outputs for which there is no weight.
@@ -121,6 +697,506 @@ def handle_partial_sample_weights(outputs, sample_weights, sample_weight_modes,
           any_sample_weight, partial_sample_weight)
 
 
+def check_array_lengths(inputs, targets, weights=None):
+  """Does user input validation for numpy arrays.
+
+  Arguments:
+      inputs: list of Numpy arrays of inputs.
+      targets: list of Numpy arrays of targets.
+      weights: list of Numpy arrays of sample weights.
+
+  Raises:
+      ValueError: in case of incorrectly formatted data.
+  """
+
+  def is_tensor_or_composite_tensor(x):
+    return tensor_util.is_tensor(
+        x) or composite_tensor_utils.is_composite_or_composite_value(x)
+
+  def set_of_lengths(x):
+    # Returns a set with the variation between
+    # different shapes, with None => 0
+    if x is None:
+      return {}
+    else:
+      return set([
+          y.shape[0]
+          for y in x
+          if y is not None and not is_tensor_or_composite_tensor(y)
+      ])
+
+  set_x = set_of_lengths(inputs)
+  set_y = set_of_lengths(targets)
+  set_w = set_of_lengths(weights)
+  if len(set_x) > 1:
+    raise ValueError('All input arrays (x) should have '
+                     'the same number of samples. Got array shapes: ' +
+                     str([x.shape for x in inputs]))
+  if len(set_y) > 1:
+    raise ValueError('All target arrays (y) should have '
+                     'the same number of samples. Got array shapes: ' +
+                     str([y.shape for y in targets]))
+  if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
+    raise ValueError('Input arrays should have '
+                     'the same number of samples as target arrays. '
+                     'Found ' + str(list(set_x)[0]) + ' input samples '
+                     'and ' + str(list(set_y)[0]) + ' target samples.')
+  if len(set_w) > 1:
+    raise ValueError('All sample_weight arrays should have '
+                     'the same number of samples. Got array shapes: ' +
+                     str([w.shape for w in weights]))
+  if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
+    raise ValueError('Sample_weight arrays should have '
+                     'the same number of samples as target arrays. Got ' +
+                     str(list(set_y)[0]) + ' input samples and ' +
+                     str(list(set_w)[0]) + ' target samples.')
+
+
+def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
+  """Does validation on the compatibility of targets and loss functions.
+
+  This helps prevent users from using loss functions incorrectly. This check
+  is purely for UX purposes.
+
+  Arguments:
+      targets: list of Numpy arrays of targets.
+      loss_fns: list of loss functions.
+      output_shapes: list of shapes of model outputs.
+
+  Raises:
+      ValueError: if a loss function or target array
+          is incompatible with an output.
+  """
+  key_loss_fns = {
+      losses.mean_squared_error, losses.binary_crossentropy,
+      losses.categorical_crossentropy
+  }
+  key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy,
+                      losses.CategoricalCrossentropy)
+  for y, loss, shape in zip(targets, loss_fns, output_shapes):
+    if y is None or loss is None or tensor_util.is_tensor(y):
+      continue
+    if losses.is_categorical_crossentropy(loss):
+      if y.shape[-1] == 1:
+        raise ValueError('You are passing a target array of shape ' +
+                         str(y.shape) +
+                         ' while using as loss `categorical_crossentropy`. '
+                         '`categorical_crossentropy` expects '
+                         'targets to be binary matrices (1s and 0s) '
+                         'of shape (samples, classes). '
+                         'If your targets are integer classes, '
+                         'you can convert them to the expected format via:\n'
+                         '```\n'
+                         'from keras.utils import to_categorical\n'
+                         'y_binary = to_categorical(y_int)\n'
+                         '```\n'
+                         '\n'
+                         'Alternatively, you can use the loss function '
+                         '`sparse_categorical_crossentropy` instead, '
+                         'which does expect integer targets.')
+
+    is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
+    if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and
+                                               (loss.fn in key_loss_fns))):
+      for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
+        if out_dim is not None and target_dim != out_dim:
+          loss_name = loss.name
+          if loss_name is None:
+            loss_type = loss.fn if is_loss_wrapper else type(loss)
+            loss_name = loss_type.__name__
+          raise ValueError('A target array with shape ' + str(y.shape) +
+                           ' was passed for an output of shape ' + str(shape) +
+                           ' while using as loss `' + loss_name + '`. '
+                           'This loss expects targets to have the same shape '
+                           'as the output.')
+
+
+def collect_per_output_metric_info(metrics,
+                                   output_names,
+                                   output_shapes,
+                                   loss_fns,
+                                   is_weighted=False):
+  """Maps metric names and functions to model outputs.
+
+  Arguments:
+      metrics: a list or a list of lists or a dict of metric functions.
+      output_names: a list of the names (strings) of model outputs.
+      output_shapes: a list of the shapes (strings) of model outputs.
+      loss_fns: a list of the loss functions corresponding to the model outputs.
+      is_weighted: Boolean indicating whether the given metrics are weighted.
+
+  Returns:
+      A list (one entry per model output) of dicts.
+      For instance, if the model has 2 outputs, and for the first output
+      we want to compute "binary_accuracy" and "binary_crossentropy",
+      and just "binary_accuracy" for the second output,
+      the list would look like: `[{
+          'acc': binary_accuracy(),
+          'ce': binary_crossentropy(),
+        }, {
+          'acc': binary_accuracy(),
+        }]`
+
+  Raises:
+      TypeError: if an incorrect type is passed for the `metrics` argument.
+  """
+  if not metrics:
+    return [{} for _ in output_names]
+
+  if isinstance(metrics, list):
+    any_sub_list = any(isinstance(m, list) for m in metrics)
+    if any_sub_list:
+      if len(metrics) != len(output_names):
+        raise ValueError('When passing a list of lists as `metrics`, '
+                         'it should have one entry per model output. '
+                         'The model has ' + str(len(output_names)) +
+                         ' outputs, but you passed metrics=' + str(metrics))
+      # User has provided a list of len = len(outputs).
+      nested_metrics = [generic_utils.to_list(m) for m in metrics]
+    else:
+      # If it is a single list we then apply all metrics to all outputs.
+      if len(output_names) > 1:
+        nested_metrics = []
+        for _ in output_names:
+          nested_metrics.append(
+              [metrics_module.clone_metric(m) for m in metrics])
+      else:
+        nested_metrics = [metrics]
+  elif isinstance(metrics, collections_abc.Mapping):
+    generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
+    nested_metrics = []
+    for name in output_names:
+      output_metrics = generic_utils.to_list(metrics.get(name, []))
+      nested_metrics.append(output_metrics)
+  else:
+    raise TypeError('Type of `metrics` argument not understood. '
+                    'Expected a list or dictionary, found: ' + str(metrics))
+
+  per_output_metrics = []
+  for i, metrics in enumerate(nested_metrics):
+    metrics_dict = OrderedDict()
+    for metric in metrics:
+      metric_name = get_metric_name(metric, is_weighted)
+      metric_fn = get_metric_function(
+          metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
+
+      # If the metric function is not stateful, we create a stateful version.
+      if not isinstance(metric_fn, metrics_module.Metric):
+        metric_fn = metrics_module.MeanMetricWrapper(
+            metric_fn, name=metric_name)
+      metrics_dict[metric_name] = metric_fn
+    per_output_metrics.append(metrics_dict)
+
+  return per_output_metrics
+
+
+def batch_shuffle(index_array, batch_size):
+  """Shuffles an array in a batch-wise fashion.
+
+  Useful for shuffling HDF5 arrays
+  (where one cannot access arbitrary indices).
+
+  Arguments:
+      index_array: array of indices to be shuffled.
+      batch_size: integer.
+
+  Returns:
+      The `index_array` array, shuffled in a batch-wise fashion.
+  """
+  batch_count = int(len(index_array) / batch_size)
+  # to reshape we need to be cleanly divisible by batch size
+  # we stash extra items and reappend them after shuffling
+  last_batch = index_array[batch_count * batch_size:]
+  index_array = index_array[:batch_count * batch_size]
+  index_array = index_array.reshape((batch_count, batch_size))
+  np.random.shuffle(index_array)
+  index_array = index_array.flatten()
+  return np.append(index_array, last_batch)
+
+
+def standardize_weights(y,
+                        sample_weight=None,
+                        class_weight=None,
+                        sample_weight_mode=None):
+  """Performs sample weight validation and standardization.
+
+  Everything gets normalized to a single sample-wise (or timestep-wise)
+  weight array. If both `sample_weight` and `class_weight` are provided,
+  the weights are multiplied.
+
+  Arguments:
+      y: Numpy array or Tensor of model targets to be weighted.
+      sample_weight: User-provided `sample_weight` argument.
+      class_weight: User-provided `class_weight` argument.
+      sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated
+        that we expect 2D weight data that will be applied to the last 2
+        dimensions of the targets (i.e. we are weighting timesteps, not
+        samples).
+
+  Returns:
+      A numpy array of target weights, one entry per sample to weight.
+
+  Raises:
+      ValueError: In case of invalid user-provided arguments.
+  """
+  # Iterator may return sample_weight as 1-tuple
+  if isinstance(sample_weight, tuple):
+    sample_weight = sample_weight[0]
+  if sample_weight_mode is not None and sample_weight_mode != 'samplewise':
+    if sample_weight_mode != 'temporal':
+      raise ValueError('"sample_weight_mode '
+                       'should be None or "temporal". '
+                       'Found: ' + str(sample_weight_mode))
+    if len(y.shape) < 3:
+      raise ValueError('Found a sample_weight array for '
+                       'an input with shape ' + str(y.shape) + '. '
+                       'Timestep-wise sample weighting (use of '
+                       'sample_weight_mode="temporal") is restricted to '
+                       'outputs that are at least 3D, i.e. that have '
+                       'a time dimension.')
+    if sample_weight is not None and len(sample_weight.shape) != 2:
+      raise ValueError('Found a sample_weight array with shape ' +
+                       str(sample_weight.shape) + '. '
+                       'In order to use timestep-wise sample weighting, '
+                       'you should pass a 2D sample_weight array.')
+  else:
+    if sample_weight is not None and len(sample_weight.shape) != 1:
+      raise ValueError('Found a sample_weight array with shape {}. In order to '
+                       'use timestep-wise sample weights, you should specify '
+                       'sample_weight_mode="temporal" in compile(); found "{}" '
+                       'instead. If you just mean to use sample-wise weights, '
+                       'make sure your sample_weight array is 1D.'
+                       .format(sample_weight.shape, sample_weight_mode))
+
+  if sample_weight is not None:
+    if len(sample_weight.shape) > len(y.shape):
+      raise ValueError('Found a sample_weight with shape' +
+                       str(sample_weight.shape) + '.'
+                       'Expected sample_weight with rank '
+                       'less than or equal to ' + str(len(y.shape)))
+
+    if (not tensor_util.is_tensor(sample_weight) and
+        y.shape[:sample_weight.ndim] != sample_weight.shape):
+      raise ValueError('Found a sample_weight array with shape ' +
+                       str(sample_weight.shape) + ' for an input with shape ' +
+                       str(y.shape) + '. '
+                       'sample_weight cannot be broadcast.')
+
+  # Class weights applied per-sample.
+  class_sample_weight = None
+  if isinstance(class_weight, dict):
+    if len(y.shape) > 2:
+      raise ValueError('`class_weight` not supported for '
+                       '3+ dimensional targets.')
+
+    if tensor_util.is_tensor(y):
+      # Few classes are expected, so densifying is reasonable.
+      keys = np.array(sorted(class_weight.keys()))
+      values = np.array([class_weight[i] for i in keys])
+      weight_vector = np.zeros(np.max(keys) + 1)
+      weight_vector[:] = np.nan
+      weight_vector[keys] = values
+
+      y_classes = smart_cond.smart_cond(
+          len(y.shape.as_list()) == 2 and K.shape(y)[1] > 1,
+          lambda: K.argmax(y, axis=1),
+          lambda: math_ops.cast(K.reshape(y, (-1,)), dtypes.int64))
+      class_sample_weight = array_ops.gather(weight_vector, y_classes)
+      gen_array_ops.check_numerics(
+          class_sample_weight,
+          'Invalid classes or class weights detected. NaN values indicate that '
+          'an appropriate class weight could not be determined.')
+      class_sample_weight = math_ops.cast(class_sample_weight, K.floatx())
+      if sample_weight is not None:
+        sample_weight = math_ops.cast(
+            ops.convert_to_tensor_v2_with_dispatch(sample_weight), K.floatx())
+    else:
+      y_classes = y
+      if len(y.shape) == 2:
+        if y.shape[1] > 1:
+          y_classes = np.argmax(y, axis=1)
+        elif y.shape[1] == 1:
+          y_classes = np.reshape(y, y.shape[0])
+
+      class_sample_weight = np.asarray(
+          [class_weight[cls] for cls in y_classes if cls in class_weight])
+
+      if len(class_sample_weight) != len(y_classes):
+        # subtract the sets to pick all missing classes
+        existing_classes = set(y_classes)
+        existing_class_weight = set(class_weight.keys())
+        raise ValueError(
+            '`class_weight` must contain all classes in the data.'
+            ' The classes %s exist in the data but not in '
+            '`class_weight`.' % (existing_classes - existing_class_weight))
+
+  if class_sample_weight is not None and sample_weight is not None:
+    # Multiply weights if both are provided.
+    return class_sample_weight * sample_weight
+  if sample_weight is not None:
+    return sample_weight
+  if class_sample_weight is not None:
+    return class_sample_weight
+  return None
+
+
+def has_symbolic_tensors(ls):
+  if context.executing_eagerly():
+    return False
+  return has_tensors(ls)
+
+
+def has_tensors(ls):
+  """Returns true if `ls` contains tensors."""
+  # Note: at some point in time ragged tensors didn't count as tensors, so this
+  # returned false for ragged tensors. Making this return true fails some tests
+  # which would then require a steps_per_epoch argument.
+  if isinstance(ls, (list, tuple)):
+    return any(
+        tensor_util.is_tensor(v) and
+        not isinstance(v, ragged_tensor.RaggedTensor) for v in ls)
+  if isinstance(ls, dict):
+    return any(
+        tensor_util.is_tensor(v) and
+        not isinstance(v, ragged_tensor.RaggedTensor)
+        for _, v in six.iteritems(ls))
+  return tensor_util.is_tensor(ls) and not isinstance(
+      ls, ragged_tensor.RaggedTensor)
+
+
+def get_metric_name(metric, weighted=False):
+  """Returns the name corresponding to the given metric input.
+
+  Arguments:
+    metric: Metric function name or reference.
+    weighted: Boolean indicating if the given metric is weighted.
+
+  Returns:
+      The metric name.
+  """
+  if tf2.enabled():
+    # We keep the string that the user has set in compile as the metric name.
+    if isinstance(metric, six.string_types):
+      return metric
+
+    metric = metrics_module.get(metric)
+    return metric.name if hasattr(metric, 'name') else metric.__name__
+  else:
+    metric_name_prefix = 'weighted_' if weighted else ''
+    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
+      if metric in ('accuracy', 'acc'):
+        suffix = 'acc'
+      elif metric in ('crossentropy', 'ce'):
+        suffix = 'ce'
+    else:
+      metric_fn = metrics_module.get(metric)
+      # Get metric name as string
+      if hasattr(metric_fn, 'name'):
+        suffix = metric_fn.name
+      else:
+        suffix = metric_fn.__name__
+    metric_name = metric_name_prefix + suffix
+    return metric_name
+
+
+def get_metric_function(metric, output_shape=None, loss_fn=None):
+  """Returns the metric function corresponding to the given metric input.
+
+  Arguments:
+      metric: Metric function name or reference.
+      output_shape: The shape of the output that this metric will be calculated
+        for.
+      loss_fn: The loss function used.
+
+  Returns:
+      The metric function.
+  """
+  if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
+    return metrics_module.get(metric)
+
+  is_sparse_categorical_crossentropy = (
+      isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or
+      (isinstance(loss_fn, losses.LossFunctionWrapper) and
+       loss_fn.fn == losses.sparse_categorical_crossentropy))
+
+  is_binary_crossentropy = (
+      isinstance(loss_fn, losses.BinaryCrossentropy) or
+      (isinstance(loss_fn, losses.LossFunctionWrapper) and
+       loss_fn.fn == losses.binary_crossentropy))
+
+  if metric in ['accuracy', 'acc']:
+    if output_shape[-1] == 1 or is_binary_crossentropy:
+      return metrics_module.binary_accuracy
+    elif is_sparse_categorical_crossentropy:
+      return metrics_module.sparse_categorical_accuracy
+    # If the output_shape[-1] is not 1, then we know output is `categorical`.
+    # We assume it is sparse categorical only if loss is explicitly given
+    # as sparse categorical crossentropy loss.
+    return metrics_module.categorical_accuracy
+  else:
+    if output_shape[-1] == 1 or is_binary_crossentropy:
+      return metrics_module.binary_crossentropy
+    elif is_sparse_categorical_crossentropy:
+      return metrics_module.sparse_categorical_crossentropy
+    return metrics_module.categorical_crossentropy
+
+
+def call_metric_function(metric_fn,
+                         y_true,
+                         y_pred=None,
+                         weights=None,
+                         mask=None):
+  """Invokes metric function and returns the metric result tensor."""
+  if mask is not None:
+    mask = math_ops.cast(mask, y_pred.dtype)
+    if weights is None:
+      # Use mask as sample weight.
+      weights = mask
+    else:
+      # Update dimensions of weights to match with mask.
+      weights = math_ops.cast(weights, dtype=y_pred.dtype)
+      mask, _, weights = losses_utils.squeeze_or_expand_dimensions(
+          mask, sample_weight=weights)
+      weights *= mask
+
+  if y_pred is not None:
+    return metric_fn(y_true, y_pred, sample_weight=weights)
+  # `Mean` metric only takes a single value.
+  return metric_fn(y_true, sample_weight=weights)
+
+
+def get_loss_function(loss):
+  """Returns the loss corresponding to the loss input in `compile` API."""
+  if loss is None or isinstance(loss, losses.Loss):
+    return loss
+
+  if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss):
+    # It is not safe to assume that the loss takes no constructor arguments.
+    raise ValueError(
+        'Received uninstantiated Loss class: {}\nPlease call loss ""classes '
+        'before passing them to Model.compile.'.format(loss))
+
+  # Deserialize loss configuration, if needed.
+  if isinstance(loss, collections_abc.Mapping):
+    loss = losses.get(loss)
+
+  # Custom callable class.
+  if callable(loss) and not hasattr(loss, '__name__'):
+    return loss
+
+  # Wrap loss function with signature `(y_true, y_pred, **kwargs)`
+  # in `LossFunctionWrapper` class.
+  loss_fn = losses.get(loss)
+
+  # For losses which are given as strings/functions in the compile API,
+  # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`
+  # (both in distribution strategy context and otherwise).
+  return losses.LossFunctionWrapper(
+      loss_fn,
+      name=loss_fn.__name__,
+      reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)
+
+
 class RespectCompiledTrainableState(object):
   """Set and restore trainable state if it has changed since compile.
 
@@ -167,6 +1243,568 @@ class RespectCompiledTrainableState(object):
     return False  # False values do not suppress exceptions
 
 
+def validate_dataset_input(x, y, sample_weight, validation_split=None):
+  """Validates user input arguments when a dataset iterator is passed.
+
+  Arguments:
+    x: Input data. A `tf.data` dataset or iterator.
+    y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
+      Expected to be `None` when `x` is a dataset iterator.
+    sample_weight: An optional sample-weight array passed by the user to weight
+      the importance of each sample in `x`. Expected to be `None` when `x` is a
+      dataset iterator
+    validation_split: Float between 0 and 1. Fraction of the training data to be
+      used as validation data. Expected to be `None` when `x` is a dataset
+      iterator.
+
+  Raises:
+    ValueError: if argument `y` or `sample_weight` or `validation_split` are
+        provided by user.
+  """
+  if y is not None:
+    raise ValueError('You passed a dataset or dataset iterator (%s) as '
+                     'input `x` to your model. In that case, you should '
+                     'not specify a target (`y`) argument, since the dataset '
+                     'or dataset iterator generates both input data and '
+                     'target data. '
+                     'Received: %s' % (x, y))
+  if sample_weight is not None:
+    raise ValueError('`sample_weight` argument is not supported when input '
+                     '`x` is a dataset or a dataset iterator. Instead, you'
+                     'can provide sample_weight as the third element  of your'
+                     'dataset, i.e. (inputs, targets, sample_weight). '
+                     'Received: x=%s, sample_weight=%s' % (x, sample_weight))
+  if validation_split is not None and validation_split != 0.0:
+    raise ValueError(
+        '`validation_split` argument is not supported when '
+        'input `x` is a dataset or a dataset iterator. '
+        'Received: x=%s, validation_split=%f' % (x, validation_split))
+
+
+def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'):
+  """Helper function to validate either inputs or targets."""
+  if isinstance(inp, (list, tuple)):
+    if not all(isinstance(v, np.ndarray) or
+               tensor_util.is_tensor(v) for v in inp):
+      raise ValueError(
+          'Please provide as model inputs either a single array or a list of '
+          'arrays. You passed: {}={}'.format(field_name, str(orig_inp)))
+  elif isinstance(inp, dict):
+    if not allow_dict:
+      raise ValueError(
+          'You cannot pass a dictionary as model {}.'.format(field_name))
+  elif not isinstance(inp, np.ndarray) and not tensor_util.is_tensor(inp):
+    raise ValueError(
+        'Please provide as model inputs either a single array or a list of '
+        'arrays. You passed: {}={}'.format(field_name, orig_inp))
+
+
+def check_generator_arguments(y=None, sample_weight=None,
+                              validation_split=None):
+  """Validates arguments passed when using a generator."""
+  if y is not None:
+    raise ValueError('`y` argument is not supported when data is'
+                     'a generator or Sequence instance. Instead pass targets'
+                     ' as the second element of the generator.')
+  if sample_weight is not None:
+    raise ValueError('`sample_weight` argument is not supported when data is'
+                     'a generator or Sequence instance. Instead pass sample'
+                     ' weights as the third element of the generator.')
+  if validation_split:
+    raise ValueError('If your data is in the form of a Python generator, '
+                     'you cannot use `validation_split`.')
+
+
+def check_steps_argument(input_data, steps, steps_name):
+  """Validates `steps` argument based on input data's type.
+
+  The cases when `steps` value must be provided are when
+    1. input data passed is an iterator.
+    2. model was built on top of symbolic tensors, input data is not
+       required and is `None`.
+    3. input data passed is a symbolic tensor.
+
+  Arguments:
+      input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
+        tf.data.Dataset iterator or `None`.
+      steps: Integer or `None`. Total number of steps (batches of samples) to
+        execute.
+      steps_name: The public API's parameter name for `steps`.
+
+  Returns:
+    boolean, True if `steps` argument is required, else False.
+
+  Raises:
+      ValueError: if `steps` argument is required for given input data type
+        but not provided.
+  """
+  is_x_iterator = isinstance(
+      input_data, (iterator_ops.Iterator, iterator_ops.OwnedIterator))
+  if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or
+      (isinstance(input_data, list) and not input_data)):
+    if steps is None:
+      input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors'
+      raise ValueError('When using {input_type} as input to a model, you should'
+                       ' specify the `{steps_name}` argument.'.format(
+                           input_type=input_type_str, steps_name=steps_name))
+    return True
+
+  if isinstance(input_data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
+    return True
+
+  if steps is not None:
+    list_types = (np.ndarray, list, tuple)
+    if (isinstance(input_data, list_types) or
+        (isinstance(input_data, dict) and
+         any(isinstance(v, list_types) for v in input_data.values()))):
+      logging.warning('When passing input data as arrays, do not specify '
+                      '`steps_per_epoch`/`steps` argument. '
+                      'Please use `batch_size` instead.')
+  return False
+
+
+def cast_single_tensor(x, dtype=None):
+  if isinstance(x, np.ndarray):
+    x = ops.convert_to_tensor_v2_with_dispatch(x)
+  dtype = dtype or K.floatx()
+  if x.dtype.is_floating:
+    return math_ops.cast(x, dtype=dtype)
+  return x
+
+
+def cast_if_floating_dtype_and_mismatch(targets, outputs):
+  """Returns target data tensors using correct datatype.
+
+  Checks that each target and output pair are the same datatype. If not, casts
+  the target to the output's datatype.
+
+  Args:
+    targets: tensor or list of targets.
+    outputs: tensor or list of outputs.
+
+  Returns:
+    Targets in appropriate datatype.
+  """
+  if tensor_util.is_tensor(targets):
+    # There is one target, so output[0] should be the only output.
+    return cast_single_tensor(targets, dtype=outputs[0].dtype)
+  new_targets = []
+  for target, out in zip(targets, outputs):
+    if isinstance(target, np.ndarray):
+      target = ops.convert_to_tensor_v2_with_dispatch(target)
+    if target.dtype != out.dtype:
+      new_targets.append(cast_single_tensor(target, dtype=out.dtype))
+    else:
+      new_targets.append(target)
+  return new_targets
+
+
+def cast_if_floating_dtype(x, dtype=None):
+  """Casts the given data tensors to the default floating point type.
+
+  Casts only if the input is already a floating point type.
+  Args:
+    x: tensor or list/tuple of tensors.
+    dtype: The dtype to which Tensors should be cast.
+
+  Returns:
+    Converted input.
+  """
+  return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype),
+                            x)
+
+
+def cast_to_model_input_dtypes(x, model):
+  """Casts the given data tensors to the dtypes of the model inputs.
+
+  Args:
+    x: tensor or list/tuple of tensors.
+    model: The model.
+
+  Returns:
+    Converted input. Each tensor is casted to the corresponding input in
+    `model.inputs`.
+  """
+  input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs)
+  return nest.map_structure(math_ops.cast, x, input_dtypes)
+
+
+def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
+  """Prepares sample weight modes for the model.
+
+  Args:
+    training_endpoints: List of model _TrainingEndpoints.
+    sample_weight_mode: sample weight mode user input passed from compile API.
+
+  Raises:
+    ValueError: In case of invalid `sample_weight_mode` input.
+  """
+
+  if isinstance(sample_weight_mode, collections_abc.Mapping):
+    generic_utils.check_for_unexpected_keys(
+        'sample_weight_mode', sample_weight_mode,
+        [e.output_name for e in training_endpoints])
+
+    for end_point in training_endpoints:
+      if not end_point.should_skip_target_weights():
+        if end_point.output_name not in sample_weight_mode:
+          raise ValueError('Output ' + end_point.output_name +
+                           'missing from `_sample_weight_modes` dictionary')
+        else:
+          end_point.sample_weight_mode = sample_weight_mode.get(
+              end_point.output_name)
+  elif isinstance(sample_weight_mode, (list, tuple)):
+    if len(sample_weight_mode) != len(training_endpoints):
+      raise ValueError('When passing a list as sample_weight_mode, '
+                       'it should have one entry per model output. '
+                       'The model has ' + str(len(training_endpoints)) +
+                       ' outputs, but you passed ' +
+                       str(len(sample_weight_mode)) + '_sample_weight_modes.')
+    for mode, endpoint in zip(sample_weight_mode, training_endpoints):
+      if not endpoint.should_skip_target_weights():
+        endpoint.sample_weight_mode = mode
+  else:
+    for endpoint in training_endpoints:
+      if not endpoint.should_skip_target_weights():
+        endpoint.sample_weight_mode = sample_weight_mode
+
+
+def prepare_loss_functions(loss, output_names):
+  """Converts loss to a list of loss functions.
+
+  Arguments:
+      loss: String (name of objective function), objective function or
+        `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple
+        outputs, you can use a different loss on each output by passing a
+        dictionary or a list of losses. The loss value that will be minimized by
+        the model will then be the sum of all individual losses.
+      output_names: List of model output names.
+
+  Returns:
+      A list of loss objective functions.
+
+  Raises:
+      ValueError: If loss is a dict with keys not in model output names,
+          or if loss is a list with len not equal to model outputs.
+  """
+  if isinstance(loss, collections_abc.Mapping):
+    generic_utils.check_for_unexpected_keys('loss', loss, output_names)
+    loss_functions = []
+    for name in output_names:
+      if name not in loss:
+        logging.warning(
+            'Output {0} missing from loss dictionary. We assume '
+            'this was done on purpose. The fit and evaluate APIs will not be '
+            'expecting any data to be passed to {0}.'.format(name))
+      loss_functions.append(get_loss_function(loss.get(name, None)))
+  elif isinstance(loss, six.string_types):
+    loss_functions = [get_loss_function(loss) for _ in output_names]
+  elif isinstance(loss, collections_abc.Sequence):
+    if len(loss) != len(output_names):
+      raise ValueError('When passing a list as loss, it should have one entry '
+                       'per model outputs. The model has {} outputs, but you '
+                       'passed loss={}'.format(len(output_names), loss))
+    loss_functions = nest.map_structure(get_loss_function, loss)
+  else:
+    loss_functions = [get_loss_function(loss) for _ in range(len(output_names))]
+
+  return loss_functions
+
+
+def prepare_loss_weights(training_endpoints, loss_weights=None):
+  """Converts loss weights to a list of loss weights.
+
+  The result loss weights will be populated on the training endpoint.
+
+  Arguments:
+      training_endpoints: List of model training endpoints.
+      loss_weights: Optional list or dictionary specifying scalar coefficients
+        (Python floats) to weight the loss contributions of different model
+        outputs. The loss value that will be minimized by the model will then be
+        the *weighted sum* of all individual losses, weighted by the
+          `loss_weights` coefficients. If a list, it is expected to have a 1:1
+            mapping to the model's outputs. If a dict, it is expected to map
+            output names (strings) to scalar coefficients.
+
+  Raises:
+      ValueError: If loss weight is a dict with key not in model output names,
+          or if loss is a list with len not equal to model outputs.
+  """
+  if loss_weights is None:
+    for e in training_endpoints:
+      e.loss_weight = 1.
+  elif isinstance(loss_weights, collections_abc.Mapping):
+    generic_utils.check_for_unexpected_keys(
+        'loss_weights', loss_weights,
+        [e.output_name for e in training_endpoints])
+    for e in training_endpoints:
+      e.loss_weight = loss_weights.get(e.output_name, 1.)
+  elif isinstance(loss_weights, list):
+    if len(loss_weights) != len(training_endpoints):
+      raise ValueError('When passing a list as loss_weights, '
+                       'it should have one entry per model output. '
+                       'The model has ' + str(len(training_endpoints)) +
+                       ' outputs, but you passed loss_weights=' +
+                       str(loss_weights))
+    for w, e in zip(loss_weights, training_endpoints):
+      e.loss_weight = w
+  else:
+    raise TypeError('Could not interpret loss_weights argument: ' +
+                    str(loss_weights) + ' - expected a list of dicts.')
+
+
+# TODO(rohanj): This is a hack to get around not depending on feature_column and
+# create a cyclical dependency. Figure out a cleaner solution
+def is_feature_layer(layer):
+  """Returns whether `layer` is a FeatureLayer or not."""
+  return getattr(layer, '_is_feature_layer', False)
+
+
+def is_eager_dataset_or_iterator(data):
+  return context.executing_eagerly() and isinstance(
+      data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
+             iterator_ops.OwnedIterator))
+
+
+# pylint: disable=protected-access
+def get_dataset_graph_def(dataset):
+  if context.executing_eagerly():
+    graph_def_str = dataset._as_serialized_graph().numpy()
+  else:
+    graph_def_str = K.get_value(dataset._as_serialized_graph())
+  return graph_pb2.GraphDef().FromString(graph_def_str)
+
+
+def verify_dataset_shuffled(x):
+  """Verifies that the dataset is shuffled.
+
+  Args:
+    x: Dataset passed as an input to the model.
+
+  Returns:
+    boolean, whether the input dataset is shuffled or not.
+  """
+  assert isinstance(x, dataset_ops.DatasetV2)
+  graph_def = get_dataset_graph_def(x)
+  for node in graph_def.node:
+    if node.op.startswith('ShuffleDataset'):
+      return True
+  # Also check graph_def.library.function for ds.interleave or ds.flat_map
+  for function in graph_def.library.function:
+    for node in function.node_def:
+      if node.op.startswith('ShuffleDataset'):
+        return True
+  logging.warning('Expected a shuffled dataset but input dataset `x` is '
+                  'not shuffled. Please invoke `shuffle()` on input dataset.')
+  return False
+
+
+def is_dataset_or_iterator(data):
+  return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
+                           iterator_ops.Iterator, iterator_ops.OwnedIterator))
+
+
+def get_iterator(dataset):
+  """Create and initialize an iterator from a dataset."""
+  if context.executing_eagerly():
+    iterator = dataset_ops.make_one_shot_iterator(dataset)
+  else:
+    iterator = dataset_ops.make_initializable_iterator(dataset)
+  initialize_iterator(iterator)
+  return iterator
+
+
+def initialize_iterator(iterator):
+  if not context.executing_eagerly():
+    init_op = iterator.initializer
+    K.get_session((init_op,)).run(init_op)
+
+
+def extract_tensors_from_dataset(dataset):
+  """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset.
+
+  Arguments:
+    dataset: Dataset instance.
+
+  Returns:
+    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
+  """
+  iterator = get_iterator(dataset)
+  inputs, targets, sample_weight = unpack_iterator_input(iterator)
+  return inputs, targets, sample_weight
+
+
+def unpack_iterator_input(iterator):
+  """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`.
+
+  Arguments:
+    iterator: Instance of a dataset iterator.
+
+  Returns:
+    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
+  """
+  try:
+    next_element = iterator.get_next()
+  except errors.OutOfRangeError:
+    raise RuntimeError('Your dataset iterator ran out of data; '
+                       'Make sure that your dataset can generate '
+                       'required number of samples.')
+
+  if isinstance(next_element, (list, tuple)):
+    if len(next_element) not in [2, 3]:
+      raise ValueError(
+          'Please provide model inputs as a list or tuple of 2 or 3 '
+          'elements: (input, target) or (input, target, sample_weights) '
+          'Received %s' % next_element)
+    if len(next_element) == 2:
+      x, y = next_element
+      weights = None
+    else:
+      x, y, weights = next_element
+  else:
+    x = next_element
+    y = None
+    weights = None
+  return x, y, weights
+
+
+def infer_steps_for_dataset(model,
+                            dataset,
+                            steps,
+                            epochs=1,
+                            steps_name='steps'):
+  """Infers steps_per_epoch needed to loop through a dataset.
+
+  Arguments:
+      model: Keras model instance.
+      dataset: Input data of type tf.data.Dataset.
+      steps: Number of steps to draw from the dataset (may be None if unknown).
+      epochs: Number of times to iterate over the dataset.
+      steps_name: The string name of the steps argument, either `steps`,
+        `validation_steps`, or `steps_per_epoch`. Only used for error message
+        formatting.
+
+  Returns:
+    Integer or `None`. Inferred number of steps to loop through the dataset.
+    `None` is returned if 1) the size of the dataset is unknown and `steps` was
+    not specified, or 2) this is multi-worker training and auto sharding is
+    enabled.
+
+  Raises:
+    ValueError: In case of invalid argument values.
+  """
+  assert isinstance(dataset, dataset_ops.DatasetV2)
+  if (model._in_multi_worker_mode() and
+      (dataset.options().experimental_distribute.auto_shard_policy !=
+       AutoShardPolicy.OFF)):
+    # If the dataset would be auto-sharded, we should not infer a local
+    # steps_per_epoch due to the possible inbalanced sharding between workers.
+    return None
+
+  size = K.get_value(cardinality.cardinality(dataset))
+  if size == cardinality.INFINITE and steps is None:
+    raise ValueError('When passing an infinitely repeating dataset, you '
+                     'must specify the `%s` argument.' % (steps_name,))
+  if size >= 0:
+    if steps is not None and steps * epochs > size:
+      if epochs > 1:
+        raise ValueError('The dataset you passed contains %s batches, but you '
+                         'passed `epochs=%s` and `%s=%s`, which is a total of '
+                         '%s steps. We cannot draw that many steps from this '
+                         'dataset. We suggest to set `%s=%s`.' %
+                         (size, epochs, steps_name, steps, steps * epochs,
+                          steps_name, size // epochs))
+      else:
+        raise ValueError('The dataset you passed contains %s batches, but you '
+                         'passed `%s=%s`. We cannot draw that many steps from '
+                         'this dataset. We suggest to set `%s=%s`.' %
+                         (size, steps_name, steps, steps_name, size))
+  if steps is None:
+    if size >= 0:
+      return size
+    return None
+  return steps
+
+
+class ModelInputs(object):
+  """Encapsulates model inputs.
+
+  Allows for transforming model inputs while keeping the same structure.
+  """
+
+  def __init__(self, inputs):
+    self._inputs = inputs
+    self._is_dict = isinstance(self._inputs, dict)
+    self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))
+
+    self._flattened_inputs = []
+    self._input_names = []
+
+    if self._is_dict:
+      for k in sorted(self._inputs.keys()):
+        self._flattened_inputs.append(self._inputs[k])
+        self._input_names.append(k)
+    else:
+      self._flattened_inputs = nest.flatten(self._inputs)
+      self._input_names = [
+          'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
+      ]
+
+  def get_input_names(self):
+    """Returns keys to name inputs by.
+
+    In case inputs provided were a list, tuple or single entry, we make up a
+    key 'input_%d'. For dictionary case, we return a sorted list of keys.
+    """
+    return self._input_names
+
+  def get_symbolic_inputs(self, return_single_as_list=False):
+    """Returns inputs to be set as self.inputs for a model."""
+    # TODO(karmel): There is a side-effect here where what you get
+    # with as_list and as_dict depends on whether you have called this
+    # method first, since it modifies in place.
+    for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)):
+      if isinstance(v, (list, float, int)):
+        v = np.asarray(v)
+        if v.ndim == 1:
+          v = np.expand_dims(v, 1)
+
+      if isinstance(v, (np.ndarray, ops.EagerTensor)):
+        # We fix the placeholder shape except the batch size.
+        # This is suboptimal, but it is the best we can do with the info
+        # we have. The user should call `model._set_inputs(placeholders)`
+        # to specify custom placeholders if the need arises.
+        shape = (None,) + tuple(v.shape[1:])
+        if shape == (None,):
+          shape = (None, 1)
+        dtype = dtypes.as_dtype(v.dtype)
+        if dtype.is_floating:
+          dtype = K.floatx()
+        v = K.placeholder(shape=shape, name=k, dtype=dtype)
+      elif isinstance(v, tensor_spec.TensorSpec):
+        shape = (None,) + tuple(v.shape.as_list()[1:])
+        if shape == (None,):
+          shape = (None, 1)
+        v = K.placeholder(shape=shape, name=k, dtype=v.dtype)
+
+      self._flattened_inputs[i] = v
+
+    if self._is_dict:
+      return dict(zip(self._input_names, self._flattened_inputs))
+    if self._is_single_input and not return_single_as_list:
+      return self._flattened_inputs[0]
+    return self._flattened_inputs
+
+  def as_dict(self):
+    """An iterable over a dictionary version of inputs."""
+    for k, v in zip(self._input_names, self._flattened_inputs):
+      yield k, v
+
+  def as_list(self):
+    """Returning the inputs as a list."""
+    return self._flattened_inputs
+
+
 # Allow use of methods not exposed to the user.
 # pylint: disable=protected-access
 def get_input_shape_and_dtype(layer):
@@ -218,8 +1856,187 @@ def get_static_batch_size(layer):
   return None
 
 
+def generic_output_names(outputs_list):
+  return ['output_%d' % (i + 1) for i in range(len(outputs_list))]
+
+
+def convert_eager_tensors_to_numpy(structure):
+  """Convert every EagerTensor in `structure` to NumPy.
+
+  Arguments:
+    structure: An arbitrary structure of elements to be converted to NumPy
+      arrays.
+
+  Returns:
+    An identical structure with EagerTensors converted to NumPy arrays.
+  """
+
+  def _convert(element):
+    if isinstance(element, ops.EagerTensor):
+      return element.numpy()
+    return element
+
+  return nest.map_structure(_convert, structure)
+
+
 def list_to_tuple(maybe_list):
   """Datasets will stack the list of tensor, so switch them to tuples."""
   if isinstance(maybe_list, list):
     return tuple(maybe_list)
   return maybe_list
+
+
+def should_run_validation(validation_freq, epoch):
+  """Checks if validation should be run this epoch.
+
+  Arguments:
+    validation_freq: Integer or list. If an integer, specifies how many training
+      epochs to run before a new validation run is performed. If a list,
+      specifies the epochs on which to run validation.
+    epoch: Integer, the number of the training epoch just completed.
+
+  Returns:
+    Bool, True if validation should be run.
+
+  Raises:
+    ValueError: if `validation_freq` is an Integer and less than 1, or if
+    it is neither an Integer nor a Sequence.
+  """
+  # `epoch` is 0-indexed internally but 1-indexed in the public API.
+  one_indexed_epoch = epoch + 1
+
+  if isinstance(validation_freq, int):
+    if validation_freq < 1:
+      raise ValueError('`validation_freq` can not be less than 1.')
+    return one_indexed_epoch % validation_freq == 0
+
+  if not isinstance(validation_freq, collections_abc.Container):
+    raise ValueError('`validation_freq` must be an Integer or '
+                     '`collections_abc.Container` (e.g. list, tuple, etc.)')
+  return one_indexed_epoch in validation_freq
+
+
+def split_training_and_validation_data(x, y, sample_weights, validation_split):
+  """Split input data into train/eval section based on validation_split."""
+  if has_symbolic_tensors(x):
+    raise ValueError('If your data is in the form of symbolic tensors, '
+                     'you cannot use `validation_split`.')
+  if hasattr(x[0], 'shape'):
+    split_at = int(x[0].shape[0] * (1. - validation_split))
+  else:
+    split_at = int(len(x[0]) * (1. - validation_split))
+  x, val_x = (generic_utils.slice_arrays(x, 0, split_at),
+              generic_utils.slice_arrays(x, split_at))
+  y, val_y = (generic_utils.slice_arrays(y, 0, split_at),
+              generic_utils.slice_arrays(y, split_at))
+  if sample_weights:
+    sample_weights, val_sample_weights = (
+        generic_utils.slice_arrays(sample_weights, 0, split_at),
+        generic_utils.slice_arrays(sample_weights, split_at),
+    )
+  else:
+    val_sample_weights = None
+  return x, y, sample_weights, val_x, val_y, val_sample_weights
+
+
+def unpack_validation_data(validation_data, raise_if_ambiguous=True):
+  """Unpack validation data based input type.
+
+  The validation data is not touched if its dataset or dataset iterator.
+  For other type of input (Numpy or tensor), it will be unpacked into tuple of
+  3 which is x, y and sample weights.
+
+  Args:
+    validation_data: dataset, dataset iterator, or numpy, tensor tuple.
+    raise_if_ambiguous: boolean on whether to fail if validation_data cannot be
+      parsed. Otherwise simply return validation_data, None, None and defer the
+      decision to the caller.
+
+  Returns:
+    tuple of 3, (x, y, sample_weights) for numpy and tensor input.
+  """
+  if (isinstance(validation_data, (iterator_ops.Iterator,
+                                   iterator_ops.OwnedIterator,
+                                   dataset_ops.DatasetV2,
+                                   data_utils.Sequence))
+      or not hasattr(validation_data, '__len__')):
+    val_x = validation_data
+    val_y = None
+    val_sample_weight = None
+  elif len(validation_data) == 2:
+    try:
+      val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
+      val_sample_weight = None
+    except ValueError:
+      val_x, val_y, val_sample_weight = validation_data, None, None
+  elif len(validation_data) == 3:
+    try:
+      val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
+    except ValueError:
+      val_x, val_y, val_sample_weight = validation_data, None, None
+  else:
+    if raise_if_ambiguous:
+      raise ValueError(
+          'When passing a `validation_data` argument, '
+          'it must contain either 2 items (x_val, y_val), '
+          'or 3 items (x_val, y_val, val_sample_weights), '
+          'or alternatively it could be a dataset or a '
+          'dataset or a dataset iterator. '
+          'However we received `validation_data=%s`' % validation_data)
+    val_x, val_y, val_sample_weight = validation_data, None, None
+  return val_x, val_y, val_sample_weight
+
+
+class TrainingLoop(object):
+  """TrainingLoop is a wrapper class around the training logic.
+
+  This class is trying to encapsulate the different logic of fit/eval/predict
+  with regard to different data input and model condition.
+
+  Note that TrainingLoop is stateless, which means it doesn't contain any
+  internal field and can be reused with different model and inputs.
+  """
+
+  def fit(self,
+          model,
+          x=None,
+          y=None,
+          batch_size=None,
+          epochs=1,
+          verbose=1,
+          callbacks=None,
+          validation_split=0.,
+          validation_data=None,
+          shuffle=True,
+          class_weight=None,
+          sample_weight=None,
+          initial_epoch=0,
+          steps_per_epoch=None,
+          validation_steps=None,
+          validation_freq=1,
+          **kwargs):
+    """Train the model with the inputs and targets."""
+    raise NotImplementedError()
+
+  def evaluate(self,
+               model,
+               x=None,
+               y=None,
+               batch_size=None,
+               verbose=1,
+               sample_weight=None,
+               steps=None,
+               callbacks=None,
+               **kwargs):
+    """Returns the loss value & metrics values for the model in test mode."""
+    raise NotImplementedError()
+
+  def predict(self,
+              model,
+              x,
+              batch_size=None,
+              verbose=0,
+              steps=None,
+              callbacks=None,
+              **kwargs):
+    raise NotImplementedError()
diff --git a/tensorflow/python/keras/engine/training_utils_v1_test.py b/tensorflow/python/keras/engine/training_utils_test.py
similarity index 83%
rename from tensorflow/python/keras/engine/training_utils_v1_test.py
rename to tensorflow/python/keras/engine/training_utils_test.py
index 64d44cb7955..8a3fd3926cf 100644
--- a/tensorflow/python/keras/engine/training_utils_v1_test.py
+++ b/tensorflow/python/keras/engine/training_utils_test.py
@@ -35,7 +35,7 @@ from tensorflow.python.keras import backend
 from tensorflow.python.keras import keras_parameterized
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.engine import keras_tensor
-from tensorflow.python.keras.engine import training_utils_v1
+from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging as logging
@@ -45,7 +45,7 @@ class ModelInputsTest(test.TestCase):
 
   def test_single_thing(self):
     a = np.ones(10)
-    model_inputs = training_utils_v1.ModelInputs(a)
+    model_inputs = training_utils.ModelInputs(a)
     self.assertEqual(['input_1'], model_inputs.get_input_names())
     vals = model_inputs.get_symbolic_inputs()
     self.assertTrue(tensor_util.is_tensor(vals))
@@ -59,7 +59,7 @@ class ModelInputsTest(test.TestCase):
       self.skipTest('Run in eager mode only.')
     with testing_utils.use_keras_tensors_scope(False):
       a = np.ones(10, dtype=np.int32)
-      model_inputs = training_utils_v1.ModelInputs(a)
+      model_inputs = training_utils.ModelInputs(a)
       self.assertEqual(['input_1'], model_inputs.get_input_names())
       val = model_inputs.get_symbolic_inputs()
       self.assertTrue(tf_utils.is_symbolic_tensor(val))
@@ -69,7 +69,7 @@ class ModelInputsTest(test.TestCase):
       self.assertEqual(dtypes.int32, vals[0].dtype)
     with testing_utils.use_keras_tensors_scope(True):
       a = np.ones(10, dtype=np.int32)
-      model_inputs = training_utils_v1.ModelInputs(a)
+      model_inputs = training_utils.ModelInputs(a)
       self.assertEqual(['input_1'], model_inputs.get_input_names())
       val = model_inputs.get_symbolic_inputs()
       self.assertIsInstance(val, keras_tensor.KerasTensor)
@@ -80,7 +80,7 @@ class ModelInputsTest(test.TestCase):
 
   def test_list(self):
     a = [np.ones(10), np.ones(20)]
-    model_inputs = training_utils_v1.ModelInputs(a)
+    model_inputs = training_utils.ModelInputs(a)
     self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
     vals = model_inputs.get_symbolic_inputs()
     self.assertTrue(tensor_util.is_tensor(vals[0]))
@@ -91,14 +91,14 @@ class ModelInputsTest(test.TestCase):
       self.skipTest('Run in eager mode only.')
     with testing_utils.use_keras_tensors_scope(False):
       a = [np.ones(10), np.ones(20)]
-      model_inputs = training_utils_v1.ModelInputs(a)
+      model_inputs = training_utils.ModelInputs(a)
       self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
       vals = model_inputs.get_symbolic_inputs()
       self.assertTrue(tf_utils.is_symbolic_tensor(vals[0]))
       self.assertTrue(tf_utils.is_symbolic_tensor(vals[1]))
     with testing_utils.use_keras_tensors_scope(True):
       a = [np.ones(10), np.ones(20)]
-      model_inputs = training_utils_v1.ModelInputs(a)
+      model_inputs = training_utils.ModelInputs(a)
       self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names())
       vals = model_inputs.get_symbolic_inputs()
       self.assertIsInstance(vals[0], keras_tensor.KerasTensor)
@@ -106,7 +106,7 @@ class ModelInputsTest(test.TestCase):
 
   def test_dict(self):
     a = {'b': np.ones(10), 'a': np.ones(20)}
-    model_inputs = training_utils_v1.ModelInputs(a)
+    model_inputs = training_utils.ModelInputs(a)
     self.assertEqual(['a', 'b'], model_inputs.get_input_names())
     vals = model_inputs.get_symbolic_inputs()
     self.assertTrue(tensor_util.is_tensor(vals['a']))
@@ -117,14 +117,14 @@ class ModelInputsTest(test.TestCase):
       self.skipTest('Run in eager mode only.')
     with testing_utils.use_keras_tensors_scope(False):
       a = {'b': np.ones(10), 'a': np.ones(20)}
-      model_inputs = training_utils_v1.ModelInputs(a)
+      model_inputs = training_utils.ModelInputs(a)
       self.assertEqual(['a', 'b'], model_inputs.get_input_names())
       vals = model_inputs.get_symbolic_inputs()
       self.assertTrue(tf_utils.is_symbolic_tensor(vals['a']))
       self.assertTrue(tf_utils.is_symbolic_tensor(vals['b']))
     with testing_utils.use_keras_tensors_scope(True):
       a = {'b': np.ones(10), 'a': np.ones(20)}
-      model_inputs = training_utils_v1.ModelInputs(a)
+      model_inputs = training_utils.ModelInputs(a)
       self.assertEqual(['a', 'b'], model_inputs.get_input_names())
       vals = model_inputs.get_symbolic_inputs()
       self.assertIsInstance(vals['a'], keras_tensor.KerasTensor)
@@ -182,12 +182,12 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase):
 
     if not expect_shuffled:
       with test.mock.patch.object(logging, 'warning') as mock_log:
-        shuffled = training_utils_v1.verify_dataset_shuffled(dataset)
+        shuffled = training_utils.verify_dataset_shuffled(dataset)
         self.assertRegex(
             str(mock_log.call_args), 'input dataset `x` is not shuffled.')
         self.assertFalse(shuffled)
     else:
-      self.assertTrue(training_utils_v1.verify_dataset_shuffled(dataset))
+      self.assertTrue(training_utils.verify_dataset_shuffled(dataset))
 
 
 class StandardizeWeightsTest(keras_parameterized.TestCase):
@@ -195,22 +195,21 @@ class StandardizeWeightsTest(keras_parameterized.TestCase):
   def test_sample_weights(self):
     y = np.array([0, 1, 0, 0, 2])
     sample_weights = np.array([0.5, 1., 1., 0., 2.])
-    weights = training_utils_v1.standardize_weights(y, sample_weights)
+    weights = training_utils.standardize_weights(y, sample_weights)
     self.assertAllClose(weights, sample_weights)
 
   def test_class_weights(self):
     y = np.array([0, 1, 0, 0, 2])
     class_weights = {0: 0.5, 1: 1., 2: 1.5}
-    weights = training_utils_v1.standardize_weights(
-        y, class_weight=class_weights)
+    weights = training_utils.standardize_weights(y, class_weight=class_weights)
     self.assertAllClose(weights, np.array([0.5, 1., 0.5, 0.5, 1.5]))
 
   def test_sample_weights_and_class_weights(self):
     y = np.array([0, 1, 0, 0, 2])
     sample_weights = np.array([0.5, 1., 1., 0., 2.])
     class_weights = {0: 0.5, 1: 1., 2: 1.5}
-    weights = training_utils_v1.standardize_weights(y, sample_weights,
-                                                    class_weights)
+    weights = training_utils.standardize_weights(y, sample_weights,
+                                                 class_weights)
     expected = sample_weights * np.array([0.5, 1., 0.5, 0.5, 1.5])
     self.assertAllClose(weights, expected)
 
@@ -277,35 +276,32 @@ class AggregationTest(keras_parameterized.TestCase):
 
   def setUp(self):
     super(AggregationTest, self).setUp()
-    self._old_pool = training_utils_v1._COPY_POOL
-    self._old_threshold = (
-        training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD)
-    self._old_timeout = training_utils_v1.SliceAggregator._MAX_COPY_SECONDS
-    training_utils_v1._COPY_POOL = MonitoredPool(
-        training_utils_v1._COPY_THREADS)
+    self._old_pool = training_utils._COPY_POOL
+    self._old_threshold = training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD
+    self._old_timeout = training_utils.SliceAggregator._MAX_COPY_SECONDS
+    training_utils._COPY_POOL = MonitoredPool(training_utils._COPY_THREADS)
 
   def tearDown(self):
     super(AggregationTest, self).tearDown()
-    training_utils_v1._COPY_POOL = self._old_pool
-    training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = (
-        self._old_threshold)
-    training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = self._old_timeout
+    training_utils._COPY_POOL = self._old_pool
+    training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = self._old_threshold
+    training_utils.SliceAggregator._MAX_COPY_SECONDS = self._old_timeout
 
   def _run_with_steps(self):
-    aggregator = training_utils_v1.OutputsAggregator(use_steps=True)
+    aggregator = training_utils.OutputsAggregator(use_steps=True)
     for i, batch in enumerate(np.array_split(_TEST_DATA, 4)):
       if i == 0:
         aggregator.create(batch)
       aggregator.aggregate(batch)
 
     assert len(aggregator.results) == 1
-    assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator)
+    assert isinstance(aggregator.results[0], training_utils.ConcatAggregator)
 
     aggregator.finalize()
     return aggregator.results
 
   def _run_without_steps(self):
-    aggregator = training_utils_v1.OutputsAggregator(
+    aggregator = training_utils.OutputsAggregator(
         use_steps=False, num_samples=6)
 
     batch_start = 0
@@ -318,7 +314,7 @@ class AggregationTest(keras_parameterized.TestCase):
       batch_start = batch_end
 
     assert len(aggregator.results) == 1
-    assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator)
+    assert isinstance(aggregator.results[0], training_utils.SliceAggregator)
 
     aggregator.finalize()
     return aggregator.results
@@ -330,7 +326,7 @@ class AggregationTest(keras_parameterized.TestCase):
     self.assertAllEqual(self._run_without_steps(), _TEST_DATA)
 
   def test_nested_aggregation(self):
-    aggregator = training_utils_v1.OutputsAggregator(
+    aggregator = training_utils.OutputsAggregator(
         use_steps=False, num_samples=6)
 
     batches = np.array_split(_TEST_DATA, 4)
@@ -348,46 +344,46 @@ class AggregationTest(keras_parameterized.TestCase):
     self.assertAllEqual(aggregator.results, (_TEST_DATA, _TEST_DATA))
 
   def test_concat_single_batch(self):
-    aggregator = training_utils_v1.OutputsAggregator(use_steps=True)
+    aggregator = training_utils.OutputsAggregator(use_steps=True)
     data = _TEST_DATA.copy()
     aggregator.create(data)
     assert len(aggregator.results) == 1
-    assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator)
+    assert isinstance(aggregator.results[0], training_utils.ConcatAggregator)
 
     aggregator.aggregate(data)
     aggregator.finalize()
     assert aggregator.results is data  # No copy.
 
   def test_slice_single_batch(self):
-    aggregator = training_utils_v1.OutputsAggregator(
+    aggregator = training_utils.OutputsAggregator(
         use_steps=False, num_samples=6)
     data = _TEST_DATA.copy()
     aggregator.create(data)
     assert len(aggregator.results) == 1
-    assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator)
+    assert isinstance(aggregator.results[0], training_utils.SliceAggregator)
 
     aggregator.aggregate(data, 0, 6)
     aggregator.finalize()
     assert aggregator.results is data  # No copy.
 
   def test_async_copy(self):
-    training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
+    training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
     self.assertAllEqual(self._run_without_steps(), _TEST_DATA)
 
     # Two of the four batches will have 20 elements and two will have 10.
-    self.assertEqual(training_utils_v1._COPY_POOL._apply_counter, 2)
+    self.assertEqual(training_utils._COPY_POOL._apply_counter, 2)
 
   def test_async_copy_timeout(self):
-    training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
-    training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 0.1
-    training_utils_v1._COPY_POOL._func_wrapper = add_sleep
+    training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
+    training_utils.SliceAggregator._MAX_COPY_SECONDS = 0.1
+    training_utils._COPY_POOL._func_wrapper = add_sleep
     with self.assertRaisesRegex(ValueError, 'Timed out waiting for copy'):
       self._run_without_steps()
 
   def test_async_copy_reraise(self):
-    training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
-    training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 1.
-    training_utils_v1._COPY_POOL._func_wrapper = cause_error
+    training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = 15
+    training_utils.SliceAggregator._MAX_COPY_SECONDS = 1.
+    training_utils._COPY_POOL._func_wrapper = cause_error
     with self.assertRaisesRegex(TypeError, 'NoneType'):
       self._run_without_steps()
 
diff --git a/tensorflow/python/keras/engine/training_utils_v1.py b/tensorflow/python/keras/engine/training_utils_v1.py
deleted file mode 100644
index 724882d6787..00000000000
--- a/tensorflow/python/keras/engine/training_utils_v1.py
+++ /dev/null
@@ -1,1849 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Training-related utilities."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import atexit
-from collections import OrderedDict
-import functools
-import multiprocessing.pool
-import threading
-import time
-
-import numpy as np
-import six
-from six.moves import zip  # pylint: disable=redefined-builtin
-
-from tensorflow.core.framework import graph_pb2
-from tensorflow.python import tf2
-from tensorflow.python.data.experimental.ops import cardinality
-from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.eager import context
-from tensorflow.python.framework import composite_tensor_utils
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import smart_cond
-from tensorflow.python.framework import tensor_spec
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.keras import backend as K
-from tensorflow.python.keras import callbacks as cbks
-from tensorflow.python.keras import losses
-from tensorflow.python.keras import metrics as metrics_module
-from tensorflow.python.keras.utils import data_utils
-from tensorflow.python.keras.utils import generic_utils
-from tensorflow.python.keras.utils import losses_utils
-from tensorflow.python.keras.utils import tf_inspect
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.ragged import ragged_tensor
-from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.util import nest
-from tensorflow.python.util.compat import collections_abc
-
-
-@six.add_metaclass(abc.ABCMeta)
-class Aggregator(object):
-  """Abstract base class used to aggregate batch-level outputs of a loop.
-
-  Attributes:
-    use_steps: Whether the loop is using `step` or `batch_size`.
-    num_samples: Total number of samples: `batch_size * num_batches`.
-    steps: Total number of steps.
-    batch_size: Batch size. It is used for validation checks between inputs and
-      outputs.
-    results: What to return at the end of the aggregation loop.
-  """
-
-  def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None):
-    self.use_steps = use_steps
-    self.num_samples = num_samples
-    self.steps = steps
-    self.batch_size = batch_size
-    self.results = []
-
-  @abc.abstractmethod
-  def create(self, batch_outs):
-    """Creates the initial results from the first batch outputs.
-
-    Arguments:
-      batch_outs: A list of batch-level outputs.
-    """
-    raise NotImplementedError('Must be implemented in subclasses.')
-
-  @abc.abstractmethod
-  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
-    """Aggregates batch-level results into total results.
-
-    Arguments:
-      batch_outs: A list of batch-level outputs.
-      batch_start: The start index of this batch. Always `None` if `use_steps`
-        is `True`.
-      batch_end: The end index of this batch. Always `None` if `use_steps` is
-        `True`.
-    """
-    raise NotImplementedError('Must be implemented in subclasses.')
-
-  @abc.abstractmethod
-  def finalize(self):
-    """Prepares the total results to be returned."""
-    raise NotImplementedError('Must be implemented in subclasses.')
-
-
-class MetricsAggregator(Aggregator):
-  """Aggregator that calculates loss and metrics info.
-
-  Attributes:
-    use_steps: Whether the loop is using `step` or `batch_size`.
-    num_samples: Total number of samples: `batch_size*num_batches`.
-    steps: Total number of steps, ie number of times to iterate over a dataset
-      to cover all samples.
-  """
-
-  def __init__(self, use_steps, num_samples=None, steps=None):
-    super(MetricsAggregator, self).__init__(
-        use_steps=use_steps,
-        num_samples=num_samples,
-        steps=steps,
-        batch_size=None)
-
-  def create(self, batch_outs):
-    self.results = [0.] * len(batch_outs)
-
-  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
-    # Loss.
-    if self.use_steps:
-      self.results[0] += batch_outs[0]
-    else:
-      self.results[0] += batch_outs[0] * (batch_end - batch_start)
-    # Metrics (always stateful, just grab current values.)
-    self.results[1:] = batch_outs[1:]
-
-  def finalize(self):
-    if not self.results:
-      raise ValueError('Empty training data.')
-    self.results[0] /= (self.num_samples or self.steps)
-
-
-class ConcatAggregator(Aggregator):
-  """Combine tensor-likes which cannot be merged on the fly.
-
-  This class expects to aggregate a single tensor-like rather than a nested
-  structure of tensor-likes.
-  """
-
-  def __init__(self, batch_size):
-    self.composite = None
-    super(ConcatAggregator, self).__init__(
-        use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
-
-  def create(self, batch_element):
-    self.composite = composite_tensor_utils.is_composite_or_composite_value(
-        batch_element)
-
-  def aggregate(self, batch_element, batch_start=None, batch_end=None):
-
-    # TODO(psv): Add num_samples check here to detect when output batch
-    # #samples is < batch size and != input batch #samples.
-    if self.batch_size and self.batch_size < batch_element.shape[0]:
-      raise ValueError(
-          'Mismatch between expected batch size and model output batch size. '
-          'Output shape = {}, expected output shape = shape {}'.format(
-              batch_element.shape,
-              (self.batch_size,) + batch_element.shape[1:]))
-    self.results.append(batch_element)
-
-  def finalize(self):
-    # Special case of single batch inference which skips a copy.
-    if len(self.results) == 1:
-      self.results = self.results[0]
-
-    elif self.composite:
-      # TODO(taylorrobie): efficiently concatenate.
-      results = self.results[0]
-      for r in self.results[1:]:
-        results = composite_tensor_utils.append_composite_tensor(results, r)
-      self.results = results
-
-    else:
-      self.results = np.concatenate(self.results, axis=0)
-
-    if isinstance(self.results, ops.EagerTensor):
-      self.results = self.results._numpy()  # pylint: disable=protected-access
-
-
-_COPY_THREADS = 4
-_COPY_POOL = None
-
-
-def get_copy_pool():
-  """Shared threadpool for copying arrays.
-
-  Pool instantiation takes ~ 2ms, so a singleton pool is used rather than
-  creating a pool per SliceAggregator.
-
-  Returns:
-    The global copy threadpool.
-  """
-  global _COPY_POOL
-  if _COPY_POOL is None:
-    _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS)
-    atexit.register(_COPY_POOL.close)
-  return _COPY_POOL
-
-
-class SliceAggregator(Aggregator):
-  """Combine arrays where the final size is known.
-
-  This class expects to aggregate a single tensor-like rather than a nested
-  structure of tensor-likes.
-
-  NumPy copies are an operation that threads handle quite well because all of
-  the heavy lifting is in c and does not need the GIL. Moreover, we can perform
-  lock-free writes to the same buffer in multiple threads because the nature of
-  result aggregation guarantees that either the indices are disjoint or the
-  aggregator will throw an exception in finalize. Moreover, because aggregation
-  is performed on the slowest varying dimension, assignments for a given batch
-  will write to contiguous blocks of memory, further minimizing contention.
-
-  There is, however, some scheduling and context switching overhead which will
-  offset the gains from pipelining the slice assignment. Below a given threshold
-  it is faster to simply assign in the main thread rather than enqueue the
-  assignment in a side thread. The exact threshold will vary from system to
-  system, but the time is not very sensitive to the exact transition so a value
-  of 2 ** 14 was chosen which should be reasonable on most systems.
-  """
-
-  _BINARY_SIZE_THRESHOLD = 2 ** 14
-  _MAX_COPY_SECONDS = 300
-
-  def __init__(self, num_samples, batch_size):
-    self._async_copies = []
-    self._pool = get_copy_pool()
-    self._errors = []
-    super(SliceAggregator, self).__init__(
-        use_steps=False,
-        num_samples=num_samples,
-        steps=None,
-        batch_size=batch_size)
-
-  def create(self, batch_element):
-    # This step does not need to be pipelined because NumPy empty array
-    # initialization is effectively instantaneous.
-    shape = (self.num_samples,) + batch_element.shape[1:]
-    dtype = batch_element.dtype
-    if isinstance(batch_element, ops.EagerTensor):
-      dtype = dtype.as_numpy_dtype
-
-    self.results = np.empty(shape=shape, dtype=dtype)
-
-  def aggregate(self, batch_element, batch_start, batch_end):
-    # Fail early.
-    if self._errors:
-      six.reraise(type(self._errors[0]), self._errors[0])
-
-    # In the special case of single batch inference, no copy is needed.
-    if batch_end - batch_start == self.num_samples:
-      if self.num_samples != batch_element.shape[0]:
-        raise ValueError(
-            'Mismatch between expected batch size and model output batch size. '
-            'Output shape = {}, expected output shape = shape {}'.format(
-                batch_element.shape, self.results.shape))
-
-      self.results = batch_element
-      return
-
-    # This is an approximate threshold, so we don't need to consider the number
-    # of bytes per element.
-    num_elements = np.prod(batch_element.shape)
-    if num_elements < self._BINARY_SIZE_THRESHOLD:
-      self.results[batch_start:batch_end] = batch_element
-    else:
-      is_finished = threading.Event()
-      self._pool.apply_async(
-          self._slice_assign,
-          args=(batch_element, batch_start, batch_end, is_finished))
-      self._async_copies.append(is_finished)
-
-  def _slice_assign(self, batch_element, batch_start, batch_end, is_finished):
-    try:
-      self.results[batch_start:batch_end] = batch_element
-
-    except Exception as e:  # pylint: disable=broad-except
-      # `_slice_assign` should only be called in threads and exceptions raised
-      # in threads do not carry over to the main thread. So instead we perform a
-      # a broad catch in the thread and then store the exception to be re-raised
-      # in the main thread.
-      self._errors.append(e)
-
-    finally:
-      is_finished.set()
-
-  def finalize(self):
-    start_time = time.time()
-    for is_finished in self._async_copies:
-      timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)])
-      if not is_finished.wait(timeout):
-        raise ValueError('Timed out waiting for copy to complete.')
-
-    if self._errors:
-      six.reraise(self._errors[0].__class__, self._errors[0])
-
-
-class OutputsAggregator(Aggregator):
-  """Aggregator that concatenates outputs."""
-
-  _structure = None
-
-  def create(self, batch_outs):
-    # SparseTensorValue is a named tuple which nest will flatten, so we need
-    # to guard it to properly handle the structure.
-    self._structure = nest.get_traverse_shallow_structure(
-        lambda x: not composite_tensor_utils.is_composite_or_composite_value(x),
-        batch_outs)
-    batch_outs = nest.flatten_up_to(self._structure, batch_outs)
-
-    for batch_element in batch_outs:
-      if composite_tensor_utils.is_composite_or_composite_value(batch_element):
-        # If the output is not a ndarray, it will be either a composite tensor
-        # or a composite tensor's Value object. In either case, we can't
-        # allocate an array to hold the object - we'll handle it later.
-        self.results.append(ConcatAggregator(self.batch_size))
-      elif isinstance(batch_element, (np.ndarray, ops.EagerTensor)):
-        self.results.append(
-            (ConcatAggregator(self.batch_size) if self.use_steps else
-             SliceAggregator(self.num_samples, self.batch_size)))
-      else:
-        # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue.
-        # Fail fast rather than trying to concatenate it.
-        raise RuntimeError('Attempted to aggregate unsupported object {}.'
-                           .format(batch_element))
-
-      self.results[-1].create(batch_element)
-
-  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
-    batch_outs = nest.flatten_up_to(self._structure, batch_outs)
-    for batch_element, result in zip(batch_outs, self.results):
-      result.aggregate(batch_element, batch_start, batch_end)
-
-  def finalize(self):
-    for result in self.results:
-      result.finalize()
-    self.results = [i.results for i in self.results]
-    self.results = nest.pack_sequence_as(self._structure, self.results)
-
-
-def get_progbar(model, count_mode, include_metrics=True):
-  """Get Progbar."""
-  if include_metrics:
-    stateful_metric_names = getattr(model, 'metrics_names', None)
-    if stateful_metric_names:
-      stateful_metric_names = stateful_metric_names[1:]  # Exclude `loss`
-  else:
-    stateful_metric_names = None
-  return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names)
-
-
-def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'):
-  """Determine the number of samples provided for training and evaluation.
-
-  The number of samples is not defined when running with `steps`,
-  in which case the number of samples is set to `None`.
-
-  Arguments:
-      ins: List of tensors to be fed to the Keras function.
-      batch_size: Integer batch size or `None` if not defined.
-      steps: Total number of steps (batches of samples) before declaring
-        `_predict_loop` finished. Ignored with the default value of `None`.
-      steps_name: The public API's parameter name for `steps`.
-
-  Raises:
-      ValueError: when `steps` is `None` and the attribute `ins.shape`
-      does not exist. Also raises ValueError when `steps` is not `None`
-      and `batch_size` is not `None` because they are mutually
-      exclusive.
-
-  Returns:
-      When steps is `None`, returns the number of samples to be
-      processed based on the size of the first dimension of the
-      first input numpy array. When steps is not `None` and
-      `batch_size` is `None`, returns `None`.
-  """
-  if steps is not None and batch_size is not None:
-    raise ValueError('If ' + steps_name +
-                     ' is set, the `batch_size` must be None.')
-  if check_steps_argument(ins, steps, steps_name):
-    return None
-
-  if hasattr(ins[0], 'shape'):
-    return int(ins[0].shape[0])
-  return None  # Edge case where ins == [static_learning_phase]
-
-
-def standardize_single_array(x, expected_shape=None):
-  """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1."""
-  if x is None:
-    return None
-
-  if composite_tensor_utils.is_composite_or_composite_value(x):
-    return x
-
-  if isinstance(x, int):
-    raise ValueError(
-        'Expected an array data type but received an integer: {}'.format(x))
-
-  if (x.shape is not None and len(x.shape) == 1 and
-      (expected_shape is None or len(expected_shape) != 1)):
-    if tensor_util.is_tensor(x):
-      x = array_ops.expand_dims(x, axis=1)
-    else:
-      x = np.expand_dims(x, 1)
-  return x
-
-
-def standardize_input_data(data,
-                           names,
-                           shapes=None,
-                           check_batch_axis=True,
-                           exception_prefix=''):
-  """Normalizes inputs and targets provided by users.
-
-  Users may pass data as a list of arrays, dictionary of arrays,
-  or as a single array. We normalize this to an ordered list of
-  arrays (same order as `names`), while checking that the provided
-  arrays have shapes that match the network's expectations.
-
-  Arguments:
-      data: User-provided input data (polymorphic).
-      names: List of expected array names.
-      shapes: Optional list of expected array shapes.
-      check_batch_axis: Boolean; whether to check that the batch axis of the
-        arrays matches the expected value found in `shapes`.
-      exception_prefix: String prefix used for exception formatting.
-
-  Returns:
-      List of standardized input arrays (one array per model input).
-
-  Raises:
-      ValueError: in case of improperly formatted user-provided data.
-  """
-  try:
-    data_len = len(data)
-  except TypeError:
-    # For instance if data is `None` or a symbolic Tensor.
-    data_len = None
-
-  if not names:
-    if data_len and not isinstance(data, dict):
-      raise ValueError(
-          'Error when checking model ' + exception_prefix + ': '
-          'expected no data, but got:', data)
-    return []
-  if data is None:
-    return [None for _ in range(len(names))]
-
-  if isinstance(data, dict):
-    try:
-      data = [
-          data[x].values
-          if data[x].__class__.__name__ == 'DataFrame' else data[x]
-          for x in names
-      ]
-    except KeyError as e:
-      raise ValueError('No data provided for "' + e.args[0] + '". Need data '
-                       'for each key in: ' + str(names))
-  elif isinstance(data, (list, tuple)):
-    if isinstance(data[0], (list, tuple)):
-      data = [np.asarray(d) for d in data]
-    elif len(names) == 1 and isinstance(data[0], (float, int)):
-      data = [np.asarray(data)]
-    else:
-      data = [
-          x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
-      ]
-  else:
-    data = data.values if data.__class__.__name__ == 'DataFrame' else data
-    data = [data]
-
-  if shapes is not None:
-    data = [
-        standardize_single_array(x, shape) for (x, shape) in zip(data, shapes)
-    ]
-  else:
-    data = [standardize_single_array(x) for x in data]
-
-  if len(data) != len(names):
-    if data and hasattr(data[0], 'shape'):
-      raise ValueError('Error when checking model ' + exception_prefix +
-                       ': the list of Numpy arrays that you are passing to '
-                       'your model is not the size the model expected. '
-                       'Expected to see ' + str(len(names)) + ' array(s), ' +
-                       'for inputs ' + str(names) + ' but instead got the '
-                       'following list of ' + str(len(data)) + ' arrays: ' +
-                       str(data)[:200] + '...')
-    elif len(names) > 1:
-      raise ValueError('Error when checking model ' + exception_prefix +
-                       ': you are passing a list as input to your model, '
-                       'but the model expects a list of ' + str(len(names)) +
-                       ' Numpy arrays instead. The list you passed was: ' +
-                       str(data)[:200])
-    elif len(data) == 1 and not hasattr(data[0], 'shape'):
-      raise TypeError('Error when checking model ' + exception_prefix +
-                      ': data should be a Numpy array, or list/dict of '
-                      'Numpy arrays. Found: ' + str(data)[:200] + '...')
-    elif len(names) == 1:
-      data = [np.asarray(data)]
-
-  # Check shapes compatibility.
-  if shapes:
-    for i in range(len(names)):
-      if shapes[i] is not None:
-        if tensor_util.is_tensor(data[i]):
-          tensorshape = data[i].shape
-          if not tensorshape:
-            continue
-          data_shape = tuple(tensorshape.as_list())
-        elif composite_tensor_utils.is_composite_or_composite_value(data[i]):
-          tensorshape = composite_tensor_utils.get_shape(data[i])
-          data_shape = tuple(tensorshape.as_list())
-        else:
-          data_shape = data[i].shape
-
-        shape = shapes[i]
-        if len(data_shape) != len(shape):
-          raise ValueError('Error when checking ' + exception_prefix +
-                           ': expected ' + names[i] + ' to have ' +
-                           str(len(shape)) + ' dimensions, but got array '
-                           'with shape ' + str(data_shape))
-        if not check_batch_axis:
-          data_shape = data_shape[1:]
-          shape = shape[1:]
-        for dim, ref_dim in zip(data_shape, shape):
-          if ref_dim != dim and ref_dim is not None and dim is not None:
-            raise ValueError('Error when checking ' + exception_prefix +
-                             ': expected ' + names[i] + ' to have shape ' +
-                             str(shape) + ' but got array with shape ' +
-                             str(data_shape))
-  return data
-
-
-def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
-  """Maps `sample_weight` or `class_weight` to model outputs.
-
-  Arguments:
-      x_weight: User-provided `sample_weight` or `class_weight` argument.
-      output_names: List of output names (strings) in the model.
-      weight_type: A string used purely for exception printing.
-
-  Returns:
-      A list of `sample_weight` or `class_weight` where there are exactly
-          one element per model output.
-
-  Raises:
-      ValueError: In case of invalid user-provided argument.
-  """
-  if x_weight is None or (isinstance(x_weight, (list, tuple)) and
-                          len(x_weight) == 0):  # pylint: disable=g-explicit-length-test
-    return [None for _ in output_names]
-  if len(output_names) == 1:
-    if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
-      return x_weight
-    if isinstance(x_weight, dict) and output_names[0] in x_weight:
-      return [x_weight[output_names[0]]]
-    else:
-      return [x_weight]
-  if isinstance(x_weight, (list, tuple)):
-    if len(x_weight) != len(output_names):
-      raise ValueError('Provided `' + weight_type + '` was a list of ' +
-                       str(len(x_weight)) + ' elements, but the model has ' +
-                       str(len(output_names)) + ' outputs. '
-                       'You should provide one `' + weight_type + '`'
-                       'array per model output.')
-    return x_weight
-  if isinstance(x_weight, collections_abc.Mapping):
-    generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
-    x_weights = []
-    for name in output_names:
-      x_weights.append(x_weight.get(name))
-    return x_weights
-  else:
-    raise TypeError('The model has multiple outputs, so `' + weight_type + '` '
-                    'should be either a list or a dict. '
-                    'Provided `' + weight_type + '` type not understood: ' +
-                    str(x_weight))
-
-
-def standardize_class_weights(class_weight, output_names):
-  return standardize_sample_or_class_weights(class_weight, output_names,
-                                             'class_weight')
-
-
-def standardize_sample_weights(sample_weight, output_names):
-  return standardize_sample_or_class_weights(sample_weight, output_names,
-                                             'sample_weight')
-
-
-def check_array_lengths(inputs, targets, weights=None):
-  """Does user input validation for numpy arrays.
-
-  Arguments:
-      inputs: list of Numpy arrays of inputs.
-      targets: list of Numpy arrays of targets.
-      weights: list of Numpy arrays of sample weights.
-
-  Raises:
-      ValueError: in case of incorrectly formatted data.
-  """
-
-  def is_tensor_or_composite_tensor(x):
-    return tensor_util.is_tensor(
-        x) or composite_tensor_utils.is_composite_or_composite_value(x)
-
-  def set_of_lengths(x):
-    # Returns a set with the variation between
-    # different shapes, with None => 0
-    if x is None:
-      return {}
-    else:
-      return set([
-          y.shape[0]
-          for y in x
-          if y is not None and not is_tensor_or_composite_tensor(y)
-      ])
-
-  set_x = set_of_lengths(inputs)
-  set_y = set_of_lengths(targets)
-  set_w = set_of_lengths(weights)
-  if len(set_x) > 1:
-    raise ValueError('All input arrays (x) should have '
-                     'the same number of samples. Got array shapes: ' +
-                     str([x.shape for x in inputs]))
-  if len(set_y) > 1:
-    raise ValueError('All target arrays (y) should have '
-                     'the same number of samples. Got array shapes: ' +
-                     str([y.shape for y in targets]))
-  if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
-    raise ValueError('Input arrays should have '
-                     'the same number of samples as target arrays. '
-                     'Found ' + str(list(set_x)[0]) + ' input samples '
-                     'and ' + str(list(set_y)[0]) + ' target samples.')
-  if len(set_w) > 1:
-    raise ValueError('All sample_weight arrays should have '
-                     'the same number of samples. Got array shapes: ' +
-                     str([w.shape for w in weights]))
-  if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
-    raise ValueError('Sample_weight arrays should have '
-                     'the same number of samples as target arrays. Got ' +
-                     str(list(set_y)[0]) + ' input samples and ' +
-                     str(list(set_w)[0]) + ' target samples.')
-
-
-def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
-  """Does validation on the compatibility of targets and loss functions.
-
-  This helps prevent users from using loss functions incorrectly. This check
-  is purely for UX purposes.
-
-  Arguments:
-      targets: list of Numpy arrays of targets.
-      loss_fns: list of loss functions.
-      output_shapes: list of shapes of model outputs.
-
-  Raises:
-      ValueError: if a loss function or target array
-          is incompatible with an output.
-  """
-  key_loss_fns = {
-      losses.mean_squared_error, losses.binary_crossentropy,
-      losses.categorical_crossentropy
-  }
-  key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy,
-                      losses.CategoricalCrossentropy)
-  for y, loss, shape in zip(targets, loss_fns, output_shapes):
-    if y is None or loss is None or tensor_util.is_tensor(y):
-      continue
-    if losses.is_categorical_crossentropy(loss):
-      if y.shape[-1] == 1:
-        raise ValueError('You are passing a target array of shape ' +
-                         str(y.shape) +
-                         ' while using as loss `categorical_crossentropy`. '
-                         '`categorical_crossentropy` expects '
-                         'targets to be binary matrices (1s and 0s) '
-                         'of shape (samples, classes). '
-                         'If your targets are integer classes, '
-                         'you can convert them to the expected format via:\n'
-                         '```\n'
-                         'from keras.utils import to_categorical\n'
-                         'y_binary = to_categorical(y_int)\n'
-                         '```\n'
-                         '\n'
-                         'Alternatively, you can use the loss function '
-                         '`sparse_categorical_crossentropy` instead, '
-                         'which does expect integer targets.')
-
-    is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
-    if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and
-                                               (loss.fn in key_loss_fns))):
-      for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
-        if out_dim is not None and target_dim != out_dim:
-          loss_name = loss.name
-          if loss_name is None:
-            loss_type = loss.fn if is_loss_wrapper else type(loss)
-            loss_name = loss_type.__name__
-          raise ValueError('A target array with shape ' + str(y.shape) +
-                           ' was passed for an output of shape ' + str(shape) +
-                           ' while using as loss `' + loss_name + '`. '
-                           'This loss expects targets to have the same shape '
-                           'as the output.')
-
-
-def collect_per_output_metric_info(metrics,
-                                   output_names,
-                                   output_shapes,
-                                   loss_fns,
-                                   is_weighted=False):
-  """Maps metric names and functions to model outputs.
-
-  Arguments:
-      metrics: a list or a list of lists or a dict of metric functions.
-      output_names: a list of the names (strings) of model outputs.
-      output_shapes: a list of the shapes (strings) of model outputs.
-      loss_fns: a list of the loss functions corresponding to the model outputs.
-      is_weighted: Boolean indicating whether the given metrics are weighted.
-
-  Returns:
-      A list (one entry per model output) of dicts.
-      For instance, if the model has 2 outputs, and for the first output
-      we want to compute "binary_accuracy" and "binary_crossentropy",
-      and just "binary_accuracy" for the second output,
-      the list would look like: `[{
-          'acc': binary_accuracy(),
-          'ce': binary_crossentropy(),
-        }, {
-          'acc': binary_accuracy(),
-        }]`
-
-  Raises:
-      TypeError: if an incorrect type is passed for the `metrics` argument.
-  """
-  if not metrics:
-    return [{} for _ in output_names]
-
-  if isinstance(metrics, list):
-    any_sub_list = any(isinstance(m, list) for m in metrics)
-    if any_sub_list:
-      if len(metrics) != len(output_names):
-        raise ValueError('When passing a list of lists as `metrics`, '
-                         'it should have one entry per model output. '
-                         'The model has ' + str(len(output_names)) +
-                         ' outputs, but you passed metrics=' + str(metrics))
-      # User has provided a list of len = len(outputs).
-      nested_metrics = [generic_utils.to_list(m) for m in metrics]
-    else:
-      # If it is a single list we then apply all metrics to all outputs.
-      if len(output_names) > 1:
-        nested_metrics = []
-        for _ in output_names:
-          nested_metrics.append(
-              [metrics_module.clone_metric(m) for m in metrics])
-      else:
-        nested_metrics = [metrics]
-  elif isinstance(metrics, collections_abc.Mapping):
-    generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
-    nested_metrics = []
-    for name in output_names:
-      output_metrics = generic_utils.to_list(metrics.get(name, []))
-      nested_metrics.append(output_metrics)
-  else:
-    raise TypeError('Type of `metrics` argument not understood. '
-                    'Expected a list or dictionary, found: ' + str(metrics))
-
-  per_output_metrics = []
-  for i, metrics in enumerate(nested_metrics):
-    metrics_dict = OrderedDict()
-    for metric in metrics:
-      metric_name = get_metric_name(metric, is_weighted)
-      metric_fn = get_metric_function(
-          metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
-
-      # If the metric function is not stateful, we create a stateful version.
-      if not isinstance(metric_fn, metrics_module.Metric):
-        metric_fn = metrics_module.MeanMetricWrapper(
-            metric_fn, name=metric_name)
-      metrics_dict[metric_name] = metric_fn
-    per_output_metrics.append(metrics_dict)
-
-  return per_output_metrics
-
-
-def batch_shuffle(index_array, batch_size):
-  """Shuffles an array in a batch-wise fashion.
-
-  Useful for shuffling HDF5 arrays
-  (where one cannot access arbitrary indices).
-
-  Arguments:
-      index_array: array of indices to be shuffled.
-      batch_size: integer.
-
-  Returns:
-      The `index_array` array, shuffled in a batch-wise fashion.
-  """
-  batch_count = int(len(index_array) / batch_size)
-  # to reshape we need to be cleanly divisible by batch size
-  # we stash extra items and reappend them after shuffling
-  last_batch = index_array[batch_count * batch_size:]
-  index_array = index_array[:batch_count * batch_size]
-  index_array = index_array.reshape((batch_count, batch_size))
-  np.random.shuffle(index_array)
-  index_array = index_array.flatten()
-  return np.append(index_array, last_batch)
-
-
-def standardize_weights(y,
-                        sample_weight=None,
-                        class_weight=None,
-                        sample_weight_mode=None):
-  """Performs sample weight validation and standardization.
-
-  Everything gets normalized to a single sample-wise (or timestep-wise)
-  weight array. If both `sample_weight` and `class_weight` are provided,
-  the weights are multiplied.
-
-  Arguments:
-      y: Numpy array or Tensor of model targets to be weighted.
-      sample_weight: User-provided `sample_weight` argument.
-      class_weight: User-provided `class_weight` argument.
-      sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated
-        that we expect 2D weight data that will be applied to the last 2
-        dimensions of the targets (i.e. we are weighting timesteps, not
-        samples).
-
-  Returns:
-      A numpy array of target weights, one entry per sample to weight.
-
-  Raises:
-      ValueError: In case of invalid user-provided arguments.
-  """
-  # Iterator may return sample_weight as 1-tuple
-  if isinstance(sample_weight, tuple):
-    sample_weight = sample_weight[0]
-  if sample_weight_mode is not None and sample_weight_mode != 'samplewise':
-    if sample_weight_mode != 'temporal':
-      raise ValueError('"sample_weight_mode '
-                       'should be None or "temporal". '
-                       'Found: ' + str(sample_weight_mode))
-    if len(y.shape) < 3:
-      raise ValueError('Found a sample_weight array for '
-                       'an input with shape ' + str(y.shape) + '. '
-                       'Timestep-wise sample weighting (use of '
-                       'sample_weight_mode="temporal") is restricted to '
-                       'outputs that are at least 3D, i.e. that have '
-                       'a time dimension.')
-    if sample_weight is not None and len(sample_weight.shape) != 2:
-      raise ValueError('Found a sample_weight array with shape ' +
-                       str(sample_weight.shape) + '. '
-                       'In order to use timestep-wise sample weighting, '
-                       'you should pass a 2D sample_weight array.')
-  else:
-    if sample_weight is not None and len(sample_weight.shape) != 1:
-      raise ValueError('Found a sample_weight array with shape {}. In order to '
-                       'use timestep-wise sample weights, you should specify '
-                       'sample_weight_mode="temporal" in compile(); found "{}" '
-                       'instead. If you just mean to use sample-wise weights, '
-                       'make sure your sample_weight array is 1D.'
-                       .format(sample_weight.shape, sample_weight_mode))
-
-  if sample_weight is not None:
-    if len(sample_weight.shape) > len(y.shape):
-      raise ValueError('Found a sample_weight with shape' +
-                       str(sample_weight.shape) + '.'
-                       'Expected sample_weight with rank '
-                       'less than or equal to ' + str(len(y.shape)))
-
-    if (not tensor_util.is_tensor(sample_weight) and
-        y.shape[:sample_weight.ndim] != sample_weight.shape):
-      raise ValueError('Found a sample_weight array with shape ' +
-                       str(sample_weight.shape) + ' for an input with shape ' +
-                       str(y.shape) + '. '
-                       'sample_weight cannot be broadcast.')
-
-  # Class weights applied per-sample.
-  class_sample_weight = None
-  if isinstance(class_weight, dict):
-    if len(y.shape) > 2:
-      raise ValueError('`class_weight` not supported for '
-                       '3+ dimensional targets.')
-
-    if tensor_util.is_tensor(y):
-      # Few classes are expected, so densifying is reasonable.
-      keys = np.array(sorted(class_weight.keys()))
-      values = np.array([class_weight[i] for i in keys])
-      weight_vector = np.zeros(np.max(keys) + 1)
-      weight_vector[:] = np.nan
-      weight_vector[keys] = values
-
-      y_classes = smart_cond.smart_cond(
-          len(y.shape.as_list()) == 2 and K.shape(y)[1] > 1,
-          lambda: K.argmax(y, axis=1),
-          lambda: math_ops.cast(K.reshape(y, (-1,)), dtypes.int64))
-      class_sample_weight = array_ops.gather(weight_vector, y_classes)
-      gen_array_ops.check_numerics(
-          class_sample_weight,
-          'Invalid classes or class weights detected. NaN values indicate that '
-          'an appropriate class weight could not be determined.')
-      class_sample_weight = math_ops.cast(class_sample_weight, K.floatx())
-      if sample_weight is not None:
-        sample_weight = math_ops.cast(
-            ops.convert_to_tensor_v2_with_dispatch(sample_weight), K.floatx())
-    else:
-      y_classes = y
-      if len(y.shape) == 2:
-        if y.shape[1] > 1:
-          y_classes = np.argmax(y, axis=1)
-        elif y.shape[1] == 1:
-          y_classes = np.reshape(y, y.shape[0])
-
-      class_sample_weight = np.asarray(
-          [class_weight[cls] for cls in y_classes if cls in class_weight])
-
-      if len(class_sample_weight) != len(y_classes):
-        # subtract the sets to pick all missing classes
-        existing_classes = set(y_classes)
-        existing_class_weight = set(class_weight.keys())
-        raise ValueError(
-            '`class_weight` must contain all classes in the data.'
-            ' The classes %s exist in the data but not in '
-            '`class_weight`.' % (existing_classes - existing_class_weight))
-
-  if class_sample_weight is not None and sample_weight is not None:
-    # Multiply weights if both are provided.
-    return class_sample_weight * sample_weight
-  if sample_weight is not None:
-    return sample_weight
-  if class_sample_weight is not None:
-    return class_sample_weight
-  return None
-
-
-def has_symbolic_tensors(ls):
-  if context.executing_eagerly():
-    return False
-  return has_tensors(ls)
-
-
-def has_tensors(ls):
-  """Returns true if `ls` contains tensors."""
-  # Note: at some point in time ragged tensors didn't count as tensors, so this
-  # returned false for ragged tensors. Making this return true fails some tests
-  # which would then require a steps_per_epoch argument.
-  if isinstance(ls, (list, tuple)):
-    return any(
-        tensor_util.is_tensor(v) and
-        not isinstance(v, ragged_tensor.RaggedTensor) for v in ls)
-  if isinstance(ls, dict):
-    return any(
-        tensor_util.is_tensor(v) and
-        not isinstance(v, ragged_tensor.RaggedTensor)
-        for _, v in six.iteritems(ls))
-  return tensor_util.is_tensor(ls) and not isinstance(
-      ls, ragged_tensor.RaggedTensor)
-
-
-def get_metric_name(metric, weighted=False):
-  """Returns the name corresponding to the given metric input.
-
-  Arguments:
-    metric: Metric function name or reference.
-    weighted: Boolean indicating if the given metric is weighted.
-
-  Returns:
-      The metric name.
-  """
-  if tf2.enabled():
-    # We keep the string that the user has set in compile as the metric name.
-    if isinstance(metric, six.string_types):
-      return metric
-
-    metric = metrics_module.get(metric)
-    return metric.name if hasattr(metric, 'name') else metric.__name__
-  else:
-    metric_name_prefix = 'weighted_' if weighted else ''
-    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
-      if metric in ('accuracy', 'acc'):
-        suffix = 'acc'
-      elif metric in ('crossentropy', 'ce'):
-        suffix = 'ce'
-    else:
-      metric_fn = metrics_module.get(metric)
-      # Get metric name as string
-      if hasattr(metric_fn, 'name'):
-        suffix = metric_fn.name
-      else:
-        suffix = metric_fn.__name__
-    metric_name = metric_name_prefix + suffix
-    return metric_name
-
-
-def get_metric_function(metric, output_shape=None, loss_fn=None):
-  """Returns the metric function corresponding to the given metric input.
-
-  Arguments:
-      metric: Metric function name or reference.
-      output_shape: The shape of the output that this metric will be calculated
-        for.
-      loss_fn: The loss function used.
-
-  Returns:
-      The metric function.
-  """
-  if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
-    return metrics_module.get(metric)
-
-  is_sparse_categorical_crossentropy = (
-      isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or
-      (isinstance(loss_fn, losses.LossFunctionWrapper) and
-       loss_fn.fn == losses.sparse_categorical_crossentropy))
-
-  is_binary_crossentropy = (
-      isinstance(loss_fn, losses.BinaryCrossentropy) or
-      (isinstance(loss_fn, losses.LossFunctionWrapper) and
-       loss_fn.fn == losses.binary_crossentropy))
-
-  if metric in ['accuracy', 'acc']:
-    if output_shape[-1] == 1 or is_binary_crossentropy:
-      return metrics_module.binary_accuracy
-    elif is_sparse_categorical_crossentropy:
-      return metrics_module.sparse_categorical_accuracy
-    # If the output_shape[-1] is not 1, then we know output is `categorical`.
-    # We assume it is sparse categorical only if loss is explicitly given
-    # as sparse categorical crossentropy loss.
-    return metrics_module.categorical_accuracy
-  else:
-    if output_shape[-1] == 1 or is_binary_crossentropy:
-      return metrics_module.binary_crossentropy
-    elif is_sparse_categorical_crossentropy:
-      return metrics_module.sparse_categorical_crossentropy
-    return metrics_module.categorical_crossentropy
-
-
-def call_metric_function(metric_fn,
-                         y_true,
-                         y_pred=None,
-                         weights=None,
-                         mask=None):
-  """Invokes metric function and returns the metric result tensor."""
-  if mask is not None:
-    mask = math_ops.cast(mask, y_pred.dtype)
-    if weights is None:
-      # Use mask as sample weight.
-      weights = mask
-    else:
-      # Update dimensions of weights to match with mask.
-      weights = math_ops.cast(weights, dtype=y_pred.dtype)
-      mask, _, weights = losses_utils.squeeze_or_expand_dimensions(
-          mask, sample_weight=weights)
-      weights *= mask
-
-  if y_pred is not None:
-    return metric_fn(y_true, y_pred, sample_weight=weights)
-  # `Mean` metric only takes a single value.
-  return metric_fn(y_true, sample_weight=weights)
-
-
-def get_loss_function(loss):
-  """Returns the loss corresponding to the loss input in `compile` API."""
-  if loss is None or isinstance(loss, losses.Loss):
-    return loss
-
-  if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss):
-    # It is not safe to assume that the loss takes no constructor arguments.
-    raise ValueError(
-        'Received uninstantiated Loss class: {}\nPlease call loss ""classes '
-        'before passing them to Model.compile.'.format(loss))
-
-  # Deserialize loss configuration, if needed.
-  if isinstance(loss, collections_abc.Mapping):
-    loss = losses.get(loss)
-
-  # Custom callable class.
-  if callable(loss) and not hasattr(loss, '__name__'):
-    return loss
-
-  # Wrap loss function with signature `(y_true, y_pred, **kwargs)`
-  # in `LossFunctionWrapper` class.
-  loss_fn = losses.get(loss)
-
-  # For losses which are given as strings/functions in the compile API,
-  # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`
-  # (both in distribution strategy context and otherwise).
-  return losses.LossFunctionWrapper(
-      loss_fn,
-      name=loss_fn.__name__,
-      reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)
-
-
-def validate_dataset_input(x, y, sample_weight, validation_split=None):
-  """Validates user input arguments when a dataset iterator is passed.
-
-  Arguments:
-    x: Input data. A `tf.data` dataset or iterator.
-    y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
-      Expected to be `None` when `x` is a dataset iterator.
-    sample_weight: An optional sample-weight array passed by the user to weight
-      the importance of each sample in `x`. Expected to be `None` when `x` is a
-      dataset iterator
-    validation_split: Float between 0 and 1. Fraction of the training data to be
-      used as validation data. Expected to be `None` when `x` is a dataset
-      iterator.
-
-  Raises:
-    ValueError: if argument `y` or `sample_weight` or `validation_split` are
-        provided by user.
-  """
-  if y is not None:
-    raise ValueError('You passed a dataset or dataset iterator (%s) as '
-                     'input `x` to your model. In that case, you should '
-                     'not specify a target (`y`) argument, since the dataset '
-                     'or dataset iterator generates both input data and '
-                     'target data. '
-                     'Received: %s' % (x, y))
-  if sample_weight is not None:
-    raise ValueError('`sample_weight` argument is not supported when input '
-                     '`x` is a dataset or a dataset iterator. Instead, you'
-                     'can provide sample_weight as the third element  of your'
-                     'dataset, i.e. (inputs, targets, sample_weight). '
-                     'Received: x=%s, sample_weight=%s' % (x, sample_weight))
-  if validation_split is not None and validation_split != 0.0:
-    raise ValueError(
-        '`validation_split` argument is not supported when '
-        'input `x` is a dataset or a dataset iterator. '
-        'Received: x=%s, validation_split=%f' % (x, validation_split))
-
-
-def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'):
-  """Helper function to validate either inputs or targets."""
-  if isinstance(inp, (list, tuple)):
-    if not all(isinstance(v, np.ndarray) or
-               tensor_util.is_tensor(v) for v in inp):
-      raise ValueError(
-          'Please provide as model inputs either a single array or a list of '
-          'arrays. You passed: {}={}'.format(field_name, str(orig_inp)))
-  elif isinstance(inp, dict):
-    if not allow_dict:
-      raise ValueError(
-          'You cannot pass a dictionary as model {}.'.format(field_name))
-  elif not isinstance(inp, np.ndarray) and not tensor_util.is_tensor(inp):
-    raise ValueError(
-        'Please provide as model inputs either a single array or a list of '
-        'arrays. You passed: {}={}'.format(field_name, orig_inp))
-
-
-def check_generator_arguments(y=None, sample_weight=None,
-                              validation_split=None):
-  """Validates arguments passed when using a generator."""
-  if y is not None:
-    raise ValueError('`y` argument is not supported when data is'
-                     'a generator or Sequence instance. Instead pass targets'
-                     ' as the second element of the generator.')
-  if sample_weight is not None:
-    raise ValueError('`sample_weight` argument is not supported when data is'
-                     'a generator or Sequence instance. Instead pass sample'
-                     ' weights as the third element of the generator.')
-  if validation_split:
-    raise ValueError('If your data is in the form of a Python generator, '
-                     'you cannot use `validation_split`.')
-
-
-def check_steps_argument(input_data, steps, steps_name):
-  """Validates `steps` argument based on input data's type.
-
-  The cases when `steps` value must be provided are when
-    1. input data passed is an iterator.
-    2. model was built on top of symbolic tensors, input data is not
-       required and is `None`.
-    3. input data passed is a symbolic tensor.
-
-  Arguments:
-      input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
-        tf.data.Dataset iterator or `None`.
-      steps: Integer or `None`. Total number of steps (batches of samples) to
-        execute.
-      steps_name: The public API's parameter name for `steps`.
-
-  Returns:
-    boolean, True if `steps` argument is required, else False.
-
-  Raises:
-      ValueError: if `steps` argument is required for given input data type
-        but not provided.
-  """
-  is_x_iterator = isinstance(
-      input_data, (iterator_ops.Iterator, iterator_ops.OwnedIterator))
-  if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or
-      (isinstance(input_data, list) and not input_data)):
-    if steps is None:
-      input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors'
-      raise ValueError('When using {input_type} as input to a model, you should'
-                       ' specify the `{steps_name}` argument.'.format(
-                           input_type=input_type_str, steps_name=steps_name))
-    return True
-
-  if isinstance(input_data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
-    return True
-
-  if steps is not None:
-    list_types = (np.ndarray, list, tuple)
-    if (isinstance(input_data, list_types) or
-        (isinstance(input_data, dict) and
-         any(isinstance(v, list_types) for v in input_data.values()))):
-      logging.warning('When passing input data as arrays, do not specify '
-                      '`steps_per_epoch`/`steps` argument. '
-                      'Please use `batch_size` instead.')
-  return False
-
-
-def cast_single_tensor(x, dtype=None):
-  if isinstance(x, np.ndarray):
-    x = ops.convert_to_tensor_v2_with_dispatch(x)
-  dtype = dtype or K.floatx()
-  if x.dtype.is_floating:
-    return math_ops.cast(x, dtype=dtype)
-  return x
-
-
-def cast_if_floating_dtype_and_mismatch(targets, outputs):
-  """Returns target data tensors using correct datatype.
-
-  Checks that each target and output pair are the same datatype. If not, casts
-  the target to the output's datatype.
-
-  Args:
-    targets: tensor or list of targets.
-    outputs: tensor or list of outputs.
-
-  Returns:
-    Targets in appropriate datatype.
-  """
-  if tensor_util.is_tensor(targets):
-    # There is one target, so output[0] should be the only output.
-    return cast_single_tensor(targets, dtype=outputs[0].dtype)
-  new_targets = []
-  for target, out in zip(targets, outputs):
-    if isinstance(target, np.ndarray):
-      target = ops.convert_to_tensor_v2_with_dispatch(target)
-    if target.dtype != out.dtype:
-      new_targets.append(cast_single_tensor(target, dtype=out.dtype))
-    else:
-      new_targets.append(target)
-  return new_targets
-
-
-def cast_if_floating_dtype(x, dtype=None):
-  """Casts the given data tensors to the default floating point type.
-
-  Casts only if the input is already a floating point type.
-  Args:
-    x: tensor or list/tuple of tensors.
-    dtype: The dtype to which Tensors should be cast.
-
-  Returns:
-    Converted input.
-  """
-  return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype),
-                            x)
-
-
-def cast_to_model_input_dtypes(x, model):
-  """Casts the given data tensors to the dtypes of the model inputs.
-
-  Args:
-    x: tensor or list/tuple of tensors.
-    model: The model.
-
-  Returns:
-    Converted input. Each tensor is casted to the corresponding input in
-    `model.inputs`.
-  """
-  input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs)
-  return nest.map_structure(math_ops.cast, x, input_dtypes)
-
-
-def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
-  """Prepares sample weight modes for the model.
-
-  Args:
-    training_endpoints: List of model _TrainingEndpoints.
-    sample_weight_mode: sample weight mode user input passed from compile API.
-
-  Raises:
-    ValueError: In case of invalid `sample_weight_mode` input.
-  """
-
-  if isinstance(sample_weight_mode, collections_abc.Mapping):
-    generic_utils.check_for_unexpected_keys(
-        'sample_weight_mode', sample_weight_mode,
-        [e.output_name for e in training_endpoints])
-
-    for end_point in training_endpoints:
-      if not end_point.should_skip_target_weights():
-        if end_point.output_name not in sample_weight_mode:
-          raise ValueError('Output ' + end_point.output_name +
-                           'missing from `_sample_weight_modes` dictionary')
-        else:
-          end_point.sample_weight_mode = sample_weight_mode.get(
-              end_point.output_name)
-  elif isinstance(sample_weight_mode, (list, tuple)):
-    if len(sample_weight_mode) != len(training_endpoints):
-      raise ValueError('When passing a list as sample_weight_mode, '
-                       'it should have one entry per model output. '
-                       'The model has ' + str(len(training_endpoints)) +
-                       ' outputs, but you passed ' +
-                       str(len(sample_weight_mode)) + '_sample_weight_modes.')
-    for mode, endpoint in zip(sample_weight_mode, training_endpoints):
-      if not endpoint.should_skip_target_weights():
-        endpoint.sample_weight_mode = mode
-  else:
-    for endpoint in training_endpoints:
-      if not endpoint.should_skip_target_weights():
-        endpoint.sample_weight_mode = sample_weight_mode
-
-
-def prepare_loss_functions(loss, output_names):
-  """Converts loss to a list of loss functions.
-
-  Arguments:
-      loss: String (name of objective function), objective function or
-        `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple
-        outputs, you can use a different loss on each output by passing a
-        dictionary or a list of losses. The loss value that will be minimized by
-        the model will then be the sum of all individual losses.
-      output_names: List of model output names.
-
-  Returns:
-      A list of loss objective functions.
-
-  Raises:
-      ValueError: If loss is a dict with keys not in model output names,
-          or if loss is a list with len not equal to model outputs.
-  """
-  if isinstance(loss, collections_abc.Mapping):
-    generic_utils.check_for_unexpected_keys('loss', loss, output_names)
-    loss_functions = []
-    for name in output_names:
-      if name not in loss:
-        logging.warning(
-            'Output {0} missing from loss dictionary. We assume '
-            'this was done on purpose. The fit and evaluate APIs will not be '
-            'expecting any data to be passed to {0}.'.format(name))
-      loss_functions.append(get_loss_function(loss.get(name, None)))
-  elif isinstance(loss, six.string_types):
-    loss_functions = [get_loss_function(loss) for _ in output_names]
-  elif isinstance(loss, collections_abc.Sequence):
-    if len(loss) != len(output_names):
-      raise ValueError('When passing a list as loss, it should have one entry '
-                       'per model outputs. The model has {} outputs, but you '
-                       'passed loss={}'.format(len(output_names), loss))
-    loss_functions = nest.map_structure(get_loss_function, loss)
-  else:
-    loss_functions = [get_loss_function(loss) for _ in range(len(output_names))]
-
-  return loss_functions
-
-
-def prepare_loss_weights(training_endpoints, loss_weights=None):
-  """Converts loss weights to a list of loss weights.
-
-  The result loss weights will be populated on the training endpoint.
-
-  Arguments:
-      training_endpoints: List of model training endpoints.
-      loss_weights: Optional list or dictionary specifying scalar coefficients
-        (Python floats) to weight the loss contributions of different model
-        outputs. The loss value that will be minimized by the model will then be
-        the *weighted sum* of all individual losses, weighted by the
-          `loss_weights` coefficients. If a list, it is expected to have a 1:1
-            mapping to the model's outputs. If a dict, it is expected to map
-            output names (strings) to scalar coefficients.
-
-  Raises:
-      ValueError: If loss weight is a dict with key not in model output names,
-          or if loss is a list with len not equal to model outputs.
-  """
-  if loss_weights is None:
-    for e in training_endpoints:
-      e.loss_weight = 1.
-  elif isinstance(loss_weights, collections_abc.Mapping):
-    generic_utils.check_for_unexpected_keys(
-        'loss_weights', loss_weights,
-        [e.output_name for e in training_endpoints])
-    for e in training_endpoints:
-      e.loss_weight = loss_weights.get(e.output_name, 1.)
-  elif isinstance(loss_weights, list):
-    if len(loss_weights) != len(training_endpoints):
-      raise ValueError('When passing a list as loss_weights, '
-                       'it should have one entry per model output. '
-                       'The model has ' + str(len(training_endpoints)) +
-                       ' outputs, but you passed loss_weights=' +
-                       str(loss_weights))
-    for w, e in zip(loss_weights, training_endpoints):
-      e.loss_weight = w
-  else:
-    raise TypeError('Could not interpret loss_weights argument: ' +
-                    str(loss_weights) + ' - expected a list of dicts.')
-
-
-# TODO(rohanj): This is a hack to get around not depending on feature_column and
-# create a cyclical dependency. Figure out a cleaner solution
-def is_feature_layer(layer):
-  """Returns whether `layer` is a FeatureLayer or not."""
-  return getattr(layer, '_is_feature_layer', False)
-
-
-def is_eager_dataset_or_iterator(data):
-  return context.executing_eagerly() and isinstance(
-      data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
-             iterator_ops.OwnedIterator))
-
-
-# pylint: disable=protected-access
-def get_dataset_graph_def(dataset):
-  if context.executing_eagerly():
-    graph_def_str = dataset._as_serialized_graph().numpy()
-  else:
-    graph_def_str = K.get_value(dataset._as_serialized_graph())
-  return graph_pb2.GraphDef().FromString(graph_def_str)
-
-
-def verify_dataset_shuffled(x):
-  """Verifies that the dataset is shuffled.
-
-  Args:
-    x: Dataset passed as an input to the model.
-
-  Returns:
-    boolean, whether the input dataset is shuffled or not.
-  """
-  assert isinstance(x, dataset_ops.DatasetV2)
-  graph_def = get_dataset_graph_def(x)
-  for node in graph_def.node:
-    if node.op.startswith('ShuffleDataset'):
-      return True
-  # Also check graph_def.library.function for ds.interleave or ds.flat_map
-  for function in graph_def.library.function:
-    for node in function.node_def:
-      if node.op.startswith('ShuffleDataset'):
-        return True
-  logging.warning('Expected a shuffled dataset but input dataset `x` is '
-                  'not shuffled. Please invoke `shuffle()` on input dataset.')
-  return False
-
-
-def is_dataset_or_iterator(data):
-  return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
-                           iterator_ops.Iterator, iterator_ops.OwnedIterator))
-
-
-def get_iterator(dataset):
-  """Create and initialize an iterator from a dataset."""
-  if context.executing_eagerly():
-    iterator = dataset_ops.make_one_shot_iterator(dataset)
-  else:
-    iterator = dataset_ops.make_initializable_iterator(dataset)
-  initialize_iterator(iterator)
-  return iterator
-
-
-def initialize_iterator(iterator):
-  if not context.executing_eagerly():
-    init_op = iterator.initializer
-    K.get_session((init_op,)).run(init_op)
-
-
-def extract_tensors_from_dataset(dataset):
-  """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset.
-
-  Arguments:
-    dataset: Dataset instance.
-
-  Returns:
-    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
-  """
-  iterator = get_iterator(dataset)
-  inputs, targets, sample_weight = unpack_iterator_input(iterator)
-  return inputs, targets, sample_weight
-
-
-def unpack_iterator_input(iterator):
-  """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`.
-
-  Arguments:
-    iterator: Instance of a dataset iterator.
-
-  Returns:
-    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
-  """
-  try:
-    next_element = iterator.get_next()
-  except errors.OutOfRangeError:
-    raise RuntimeError('Your dataset iterator ran out of data; '
-                       'Make sure that your dataset can generate '
-                       'required number of samples.')
-
-  if isinstance(next_element, (list, tuple)):
-    if len(next_element) not in [2, 3]:
-      raise ValueError(
-          'Please provide model inputs as a list or tuple of 2 or 3 '
-          'elements: (input, target) or (input, target, sample_weights) '
-          'Received %s' % next_element)
-    if len(next_element) == 2:
-      x, y = next_element
-      weights = None
-    else:
-      x, y, weights = next_element
-  else:
-    x = next_element
-    y = None
-    weights = None
-  return x, y, weights
-
-
-def infer_steps_for_dataset(model,
-                            dataset,
-                            steps,
-                            epochs=1,
-                            steps_name='steps'):
-  """Infers steps_per_epoch needed to loop through a dataset.
-
-  Arguments:
-      model: Keras model instance.
-      dataset: Input data of type tf.data.Dataset.
-      steps: Number of steps to draw from the dataset (may be None if unknown).
-      epochs: Number of times to iterate over the dataset.
-      steps_name: The string name of the steps argument, either `steps`,
-        `validation_steps`, or `steps_per_epoch`. Only used for error message
-        formatting.
-
-  Returns:
-    Integer or `None`. Inferred number of steps to loop through the dataset.
-    `None` is returned if 1) the size of the dataset is unknown and `steps` was
-    not specified, or 2) this is multi-worker training and auto sharding is
-    enabled.
-
-  Raises:
-    ValueError: In case of invalid argument values.
-  """
-  assert isinstance(dataset, dataset_ops.DatasetV2)
-  if (model._in_multi_worker_mode() and
-      (dataset.options().experimental_distribute.auto_shard_policy !=
-       AutoShardPolicy.OFF)):
-    # If the dataset would be auto-sharded, we should not infer a local
-    # steps_per_epoch due to the possible inbalanced sharding between workers.
-    return None
-
-  size = K.get_value(cardinality.cardinality(dataset))
-  if size == cardinality.INFINITE and steps is None:
-    raise ValueError('When passing an infinitely repeating dataset, you '
-                     'must specify the `%s` argument.' % (steps_name,))
-  if size >= 0:
-    if steps is not None and steps * epochs > size:
-      if epochs > 1:
-        raise ValueError('The dataset you passed contains %s batches, but you '
-                         'passed `epochs=%s` and `%s=%s`, which is a total of '
-                         '%s steps. We cannot draw that many steps from this '
-                         'dataset. We suggest to set `%s=%s`.' %
-                         (size, epochs, steps_name, steps, steps * epochs,
-                          steps_name, size // epochs))
-      else:
-        raise ValueError('The dataset you passed contains %s batches, but you '
-                         'passed `%s=%s`. We cannot draw that many steps from '
-                         'this dataset. We suggest to set `%s=%s`.' %
-                         (size, steps_name, steps, steps_name, size))
-  if steps is None:
-    if size >= 0:
-      return size
-    return None
-  return steps
-
-
-class ModelInputs(object):
-  """Encapsulates model inputs.
-
-  Allows for transforming model inputs while keeping the same structure.
-  """
-
-  def __init__(self, inputs):
-    self._inputs = inputs
-    self._is_dict = isinstance(self._inputs, dict)
-    self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))
-
-    self._flattened_inputs = []
-    self._input_names = []
-
-    if self._is_dict:
-      for k in sorted(self._inputs.keys()):
-        self._flattened_inputs.append(self._inputs[k])
-        self._input_names.append(k)
-    else:
-      self._flattened_inputs = nest.flatten(self._inputs)
-      self._input_names = [
-          'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
-      ]
-
-  def get_input_names(self):
-    """Returns keys to name inputs by.
-
-    In case inputs provided were a list, tuple or single entry, we make up a
-    key 'input_%d'. For dictionary case, we return a sorted list of keys.
-    """
-    return self._input_names
-
-  def get_symbolic_inputs(self, return_single_as_list=False):
-    """Returns inputs to be set as self.inputs for a model."""
-    # TODO(karmel): There is a side-effect here where what you get
-    # with as_list and as_dict depends on whether you have called this
-    # method first, since it modifies in place.
-    for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)):
-      if isinstance(v, (list, float, int)):
-        v = np.asarray(v)
-        if v.ndim == 1:
-          v = np.expand_dims(v, 1)
-
-      if isinstance(v, (np.ndarray, ops.EagerTensor)):
-        # We fix the placeholder shape except the batch size.
-        # This is suboptimal, but it is the best we can do with the info
-        # we have. The user should call `model._set_inputs(placeholders)`
-        # to specify custom placeholders if the need arises.
-        shape = (None,) + tuple(v.shape[1:])
-        if shape == (None,):
-          shape = (None, 1)
-        dtype = dtypes.as_dtype(v.dtype)
-        if dtype.is_floating:
-          dtype = K.floatx()
-        v = K.placeholder(shape=shape, name=k, dtype=dtype)
-      elif isinstance(v, tensor_spec.TensorSpec):
-        shape = (None,) + tuple(v.shape.as_list()[1:])
-        if shape == (None,):
-          shape = (None, 1)
-        v = K.placeholder(shape=shape, name=k, dtype=v.dtype)
-
-      self._flattened_inputs[i] = v
-
-    if self._is_dict:
-      return dict(zip(self._input_names, self._flattened_inputs))
-    if self._is_single_input and not return_single_as_list:
-      return self._flattened_inputs[0]
-    return self._flattened_inputs
-
-  def as_dict(self):
-    """An iterable over a dictionary version of inputs."""
-    for k, v in zip(self._input_names, self._flattened_inputs):
-      yield k, v
-
-  def as_list(self):
-    """Returning the inputs as a list."""
-    return self._flattened_inputs
-
-
-# Allow use of methods not exposed to the user.
-# pylint: disable=protected-access
-
-
-# pylint: enable=protected-access
-
-
-def generic_output_names(outputs_list):
-  return ['output_%d' % (i + 1) for i in range(len(outputs_list))]
-
-
-def convert_eager_tensors_to_numpy(structure):
-  """Convert every EagerTensor in `structure` to NumPy.
-
-  Arguments:
-    structure: An arbitrary structure of elements to be converted to NumPy
-      arrays.
-
-  Returns:
-    An identical structure with EagerTensors converted to NumPy arrays.
-  """
-
-  def _convert(element):
-    if isinstance(element, ops.EagerTensor):
-      return element.numpy()
-    return element
-
-  return nest.map_structure(_convert, structure)
-
-
-def should_run_validation(validation_freq, epoch):
-  """Checks if validation should be run this epoch.
-
-  Arguments:
-    validation_freq: Integer or list. If an integer, specifies how many training
-      epochs to run before a new validation run is performed. If a list,
-      specifies the epochs on which to run validation.
-    epoch: Integer, the number of the training epoch just completed.
-
-  Returns:
-    Bool, True if validation should be run.
-
-  Raises:
-    ValueError: if `validation_freq` is an Integer and less than 1, or if
-    it is neither an Integer nor a Sequence.
-  """
-  # `epoch` is 0-indexed internally but 1-indexed in the public API.
-  one_indexed_epoch = epoch + 1
-
-  if isinstance(validation_freq, int):
-    if validation_freq < 1:
-      raise ValueError('`validation_freq` can not be less than 1.')
-    return one_indexed_epoch % validation_freq == 0
-
-  if not isinstance(validation_freq, collections_abc.Container):
-    raise ValueError('`validation_freq` must be an Integer or '
-                     '`collections_abc.Container` (e.g. list, tuple, etc.)')
-  return one_indexed_epoch in validation_freq
-
-
-def split_training_and_validation_data(x, y, sample_weights, validation_split):
-  """Split input data into train/eval section based on validation_split."""
-  if has_symbolic_tensors(x):
-    raise ValueError('If your data is in the form of symbolic tensors, '
-                     'you cannot use `validation_split`.')
-  if hasattr(x[0], 'shape'):
-    split_at = int(x[0].shape[0] * (1. - validation_split))
-  else:
-    split_at = int(len(x[0]) * (1. - validation_split))
-  x, val_x = (generic_utils.slice_arrays(x, 0, split_at),
-              generic_utils.slice_arrays(x, split_at))
-  y, val_y = (generic_utils.slice_arrays(y, 0, split_at),
-              generic_utils.slice_arrays(y, split_at))
-  if sample_weights:
-    sample_weights, val_sample_weights = (
-        generic_utils.slice_arrays(sample_weights, 0, split_at),
-        generic_utils.slice_arrays(sample_weights, split_at),
-    )
-  else:
-    val_sample_weights = None
-  return x, y, sample_weights, val_x, val_y, val_sample_weights
-
-
-def unpack_validation_data(validation_data, raise_if_ambiguous=True):
-  """Unpack validation data based input type.
-
-  The validation data is not touched if its dataset or dataset iterator.
-  For other type of input (Numpy or tensor), it will be unpacked into tuple of
-  3 which is x, y and sample weights.
-
-  Args:
-    validation_data: dataset, dataset iterator, or numpy, tensor tuple.
-    raise_if_ambiguous: boolean on whether to fail if validation_data cannot be
-      parsed. Otherwise simply return validation_data, None, None and defer the
-      decision to the caller.
-
-  Returns:
-    tuple of 3, (x, y, sample_weights) for numpy and tensor input.
-  """
-  if (isinstance(validation_data, (iterator_ops.Iterator,
-                                   iterator_ops.OwnedIterator,
-                                   dataset_ops.DatasetV2,
-                                   data_utils.Sequence))
-      or not hasattr(validation_data, '__len__')):
-    val_x = validation_data
-    val_y = None
-    val_sample_weight = None
-  elif len(validation_data) == 2:
-    try:
-      val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
-      val_sample_weight = None
-    except ValueError:
-      val_x, val_y, val_sample_weight = validation_data, None, None
-  elif len(validation_data) == 3:
-    try:
-      val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
-    except ValueError:
-      val_x, val_y, val_sample_weight = validation_data, None, None
-  else:
-    if raise_if_ambiguous:
-      raise ValueError(
-          'When passing a `validation_data` argument, '
-          'it must contain either 2 items (x_val, y_val), '
-          'or 3 items (x_val, y_val, val_sample_weights), '
-          'or alternatively it could be a dataset or a '
-          'dataset or a dataset iterator. '
-          'However we received `validation_data=%s`' % validation_data)
-    val_x, val_y, val_sample_weight = validation_data, None, None
-  return val_x, val_y, val_sample_weight
-
-
-class TrainingLoop(object):
-  """TrainingLoop is a wrapper class around the training logic.
-
-  This class is trying to encapsulate the different logic of fit/eval/predict
-  with regard to different data input and model condition.
-
-  Note that TrainingLoop is stateless, which means it doesn't contain any
-  internal field and can be reused with different model and inputs.
-  """
-
-  def fit(self,
-          model,
-          x=None,
-          y=None,
-          batch_size=None,
-          epochs=1,
-          verbose=1,
-          callbacks=None,
-          validation_split=0.,
-          validation_data=None,
-          shuffle=True,
-          class_weight=None,
-          sample_weight=None,
-          initial_epoch=0,
-          steps_per_epoch=None,
-          validation_steps=None,
-          validation_freq=1,
-          **kwargs):
-    """Train the model with the inputs and targets."""
-    raise NotImplementedError()
-
-  def evaluate(self,
-               model,
-               x=None,
-               y=None,
-               batch_size=None,
-               verbose=1,
-               sample_weight=None,
-               steps=None,
-               callbacks=None,
-               **kwargs):
-    """Returns the loss value & metrics values for the model in test mode."""
-    raise NotImplementedError()
-
-  def predict(self,
-              model,
-              x,
-              batch_size=None,
-              verbose=0,
-              steps=None,
-              callbacks=None,
-              **kwargs):
-    raise NotImplementedError()
diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py
index 69a60e05531..7dca6ae3da7 100644
--- a/tensorflow/python/keras/engine/training_v1.py
+++ b/tensorflow/python/keras/engine/training_v1.py
@@ -51,7 +51,6 @@ from tensorflow.python.keras.engine import training_distributed_v1
 from tensorflow.python.keras.engine import training_eager_v1
 from tensorflow.python.keras.engine import training_generator_v1
 from tensorflow.python.keras.engine import training_utils
-from tensorflow.python.keras.engine import training_utils_v1
 from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
 from tensorflow.python.keras.saving.saved_model import model_serialization
@@ -414,7 +413,7 @@ class Model(training_lib.Model):
     base_layer.keras_api_gauge.get_cell('compile').set(True)
 
     # Prepare list of loss functions, same size of model outputs.
-    self.loss_functions = training_utils_v1.prepare_loss_functions(
+    self.loss_functions = training_utils.prepare_loss_functions(
         self.loss, self.output_names)
 
     target_tensors = self._process_target_tensor_for_compile(target_tensors)
@@ -426,8 +425,7 @@ class Model(training_lib.Model):
       self._training_endpoints.append(endpoint)
 
     # Prepare list loss weights, same size of model outputs.
-    training_utils_v1.prepare_loss_weights(self._training_endpoints,
-                                           loss_weights)
+    training_utils.prepare_loss_weights(self._training_endpoints, loss_weights)
 
     # Initialization for Eager mode execution.
     if self.run_eagerly:
@@ -449,7 +447,7 @@ class Model(training_lib.Model):
           masks=self._prepare_output_masks())
 
       # Prepare sample weight modes. List with the same length as model outputs.
-      training_utils_v1.prepare_sample_weight_modes(
+      training_utils.prepare_sample_weight_modes(
           self._training_endpoints, sample_weight_mode)
 
       # Creates the model loss and weighted metrics sub-graphs.
@@ -595,7 +593,7 @@ class Model(training_lib.Model):
     # or a non-distributed Dataset or iterator in eager execution.
     if data_utils.is_generator_or_sequence(inputs):
       return training_generator_v1.GeneratorOrSequenceTrainingLoop()
-    if training_utils_v1.is_eager_dataset_or_iterator(inputs):
+    if training_utils.is_eager_dataset_or_iterator(inputs):
       return training_generator_v1.EagerDatasetOrIteratorTrainingLoop()
 
     # Case 3: Symbolic tensors or Numpy array-like.
@@ -1076,7 +1074,7 @@ class Model(training_lib.Model):
                  + output_dict['metrics'])
       outputs = [_non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
     else:
-      x = training_utils_v1.ModelInputs(x).as_list()
+      x = training_utils.ModelInputs(x).as_list()
       ins = x + list(y or []) + list(sample_weights or [])
 
       if not isinstance(K.symbolic_learning_phase(), int):
@@ -1155,7 +1153,7 @@ class Model(training_lib.Model):
                  + output_dict['metrics'])
       outputs = [_non_none_constant_value(v) for v in outputs]  # pylint: disable=protected-access
     else:
-      x = training_utils_v1.ModelInputs(x).as_list()
+      x = training_utils.ModelInputs(x).as_list()
       inputs = x + list(y or []) + list(sample_weights or [])
 
       self._update_sample_weight_modes(sample_weights=sample_weights)
@@ -1200,7 +1198,7 @@ class Model(training_lib.Model):
     # If `self._distribution_strategy` is True, then we are in a replica context
     # at this point.
     if self.run_eagerly or self._distribution_strategy:
-      inputs = training_utils_v1.cast_if_floating_dtype(inputs)
+      inputs = training_utils.cast_if_floating_dtype(inputs)
       if isinstance(inputs, collections_abc.Sequence):
         # Unwrap lists with only one input, as we do when training on batch
         if len(inputs) == 1:
@@ -1372,7 +1370,7 @@ class Model(training_lib.Model):
   def _prepare_validation_data(self, validation_data, batch_size,
                                validation_steps):
     """Unpack and check the validation data."""
-    val_x, val_y, val_sample_weights = training_utils_v1.unpack_validation_data(
+    val_x, val_y, val_sample_weights = training_utils.unpack_validation_data(
         validation_data)
     return self._standardize_user_data(
         val_x,
@@ -1451,7 +1449,7 @@ class Model(training_lib.Model):
 
   def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode):
     # Prepare sample weight modes. List with the same length as model outputs.
-    training_utils_v1.prepare_sample_weight_modes(
+    training_utils.prepare_sample_weight_modes(
         self._training_endpoints, sample_weight_mode)
     # Prepare sample weights.
     self._prepare_sample_weights()
@@ -1790,10 +1788,10 @@ class Model(training_lib.Model):
         output_shapes.append(None)
       else:
         output_shapes.append(output.shape.as_list())
-    self._per_output_metrics = training_utils_v1.collect_per_output_metric_info(
+    self._per_output_metrics = training_utils.collect_per_output_metric_info(
         metrics, self.output_names, output_shapes, self.loss_functions)
     self._per_output_weighted_metrics = (
-        training_utils_v1.collect_per_output_metric_info(
+        training_utils.collect_per_output_metric_info(
             weighted_metrics,
             self.output_names,
             output_shapes,
@@ -1903,7 +1901,7 @@ class Model(training_lib.Model):
     metric_results = []
     for metric_name, metric_fn in metrics_dict.items():
       with K.name_scope(metric_name):
-        metric_result = training_utils_v1.call_metric_function(
+        metric_result = training_utils.call_metric_function(
             metric_fn, y_true, y_pred, weights=weights, mask=mask)
         metric_results.append(metric_result)
     return metric_results
@@ -2140,7 +2138,7 @@ class Model(training_lib.Model):
     # in the codebase.
     if isinstance(x, dataset_ops.DatasetV2):
       if shuffle:
-        training_utils_v1.verify_dataset_shuffled(x)
+        training_utils.verify_dataset_shuffled(x)
 
     strategy = self._distribution_strategy
     with strategy.scope():
@@ -2192,8 +2190,8 @@ class Model(training_lib.Model):
         x = ds.batch(batch_size, drop_remainder=drop_remainder)
       else:
         assert isinstance(x, dataset_ops.DatasetV2)
-        training_utils_v1.validate_dataset_input(x, y, sample_weight,
-                                                 validation_split)
+        training_utils.validate_dataset_input(x, y, sample_weight,
+                                              validation_split)
     return x
 
   def _standardize_user_data(self,
@@ -2270,28 +2268,28 @@ class Model(training_lib.Model):
       # Graph mode dataset. We'll pass the dataset as-is (unless
       # `extract_tensors_from_dataset` is True, in which case we extract
       # the tensors from the dataset and we output them.
-      training_utils_v1.validate_dataset_input(x, y, sample_weight,
-                                               validation_split)
+      training_utils.validate_dataset_input(x, y, sample_weight,
+                                            validation_split)
       if shuffle:
-        training_utils_v1.verify_dataset_shuffled(x)
+        training_utils.verify_dataset_shuffled(x)
 
       is_dataset = True
       if extract_tensors_from_dataset:
         # We do this for `train_on_batch`/etc.
-        x, y, sample_weight = training_utils_v1.extract_tensors_from_dataset(x)
+        x, y, sample_weight = training_utils.extract_tensors_from_dataset(x)
     elif isinstance(x, iterator_ops.Iterator):
       # Graph mode iterator. We extract the symbolic tensors.
-      training_utils_v1.validate_dataset_input(x, y, sample_weight,
-                                               validation_split)
+      training_utils.validate_dataset_input(x, y, sample_weight,
+                                            validation_split)
       iterator = x
-      x, y, sample_weight = training_utils_v1.unpack_iterator_input(iterator)
+      x, y, sample_weight = training_utils.unpack_iterator_input(iterator)
       is_dataset = True
     else:
       is_dataset = False
 
     # Validates `steps` argument based on x's type.
     if check_steps:
-      training_utils_v1.check_steps_argument(x, steps, steps_name)
+      training_utils.check_steps_argument(x, steps, steps_name)
 
     # First, we build the model on the fly if necessary.
     if not self.inputs:
@@ -2354,7 +2352,7 @@ class Model(training_lib.Model):
     # Standardize the inputs.
     if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
       # TODO(fchollet): run static checks with dataset output shape(s).
-      x = training_utils_v1.standardize_input_data(
+      x = training_utils.standardize_input_data(
           x,
           feed_input_names,
           feed_input_shapes,
@@ -2401,8 +2399,8 @@ class Model(training_lib.Model):
     if y is not None:
       # Prepare self._sample_weight_modes. List with the same length as
       # model outputs.
-      training_utils_v1.prepare_sample_weight_modes(self._training_endpoints,
-                                                    self.sample_weight_mode)
+      training_utils.prepare_sample_weight_modes(self._training_endpoints,
+                                                 self.sample_weight_mode)
       feed_output_names = self._feed_output_names
       feed_sample_weight_modes = self._sample_weight_modes
       if not self._is_graph_network:
@@ -2411,7 +2409,7 @@ class Model(training_lib.Model):
         feed_output_shapes = self._feed_output_shapes
 
       # Standardize the outputs.
-      y = training_utils_v1.standardize_input_data(
+      y = training_utils.standardize_input_data(
           y,
           feed_output_names,
           # Don't enforce target shapes to match output shapes.
@@ -2422,22 +2420,22 @@ class Model(training_lib.Model):
 
       # Generate sample-wise weight values given the `sample_weight` and
       # `class_weight` arguments.
-      sample_weights = training_utils_v1.standardize_sample_weights(
+      sample_weights = training_utils.standardize_sample_weights(
           sample_weight, feed_output_names)
-      class_weights = training_utils_v1.standardize_class_weights(
+      class_weights = training_utils.standardize_class_weights(
           class_weight, feed_output_names)
 
       sample_weights = [
-          training_utils_v1.standardize_weights(ref, sw, cw, mode)
+          training_utils.standardize_weights(ref, sw, cw, mode)
           for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
                                          feed_sample_weight_modes)
       ]
       # Check that all arrays have the same length.
       if not self._distribution_strategy:
-        training_utils_v1.check_array_lengths(x, y, sample_weights)
+        training_utils.check_array_lengths(x, y, sample_weights)
         if self._is_graph_network and not run_eagerly:
           # Additional checks to avoid users mistakenly using improper loss fns.
-          training_utils_v1.check_loss_and_target_compatibility(
+          training_utils.check_loss_and_target_compatibility(
               y, self._feed_loss_fns, feed_output_shapes)
 
       sample_weights, _, _ = training_utils.handle_partial_sample_weights(
@@ -2472,12 +2470,11 @@ class Model(training_lib.Model):
     # iterator and only one batch of samples is required, we fetch the data
     # tensors from the iterator and then standardize them.
     if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
-      inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset(
-          inputs)
+      inputs, targets, _ = training_utils.extract_tensors_from_dataset(inputs)
     # We type-check that `inputs` and `targets` are either single arrays
     # or lists of arrays, and extract a flat list of inputs from the passed
     # structure.
-    training_utils_v1.validate_input_types(inputs, orig_inputs)
+    training_utils.validate_input_types(inputs, orig_inputs)
 
     if isinstance(inputs, (list, tuple)):
       processed_inputs += list(inputs)
@@ -2512,14 +2509,14 @@ class Model(training_lib.Model):
       if not self.inputs:
         # For subclassed models, a robust input spec is not available so we
         # must cast to the model dtype.
-        inputs = training_utils_v1.cast_if_floating_dtype(inputs, self.dtype)
+        inputs = training_utils.cast_if_floating_dtype(inputs, self.dtype)
 
       def create_tensor_spec(t):
         return tensor_spec.TensorSpec(t.shape, t.dtype)
 
       cast_inputs = nest.map_structure(create_tensor_spec, inputs)
-    elif training_utils_v1.has_tensors(inputs):
-      cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs)
+    elif training_utils.has_tensors(inputs):
+      cast_inputs = training_utils.cast_if_floating_dtype(inputs)
     else:
       cast_inputs = inputs
     self._set_inputs(cast_inputs)
@@ -2528,11 +2525,11 @@ class Model(training_lib.Model):
   def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target):
     if target is not None:
       # We need to use `y` to set the model targets.
-      if training_utils_v1.has_tensors(target):
-        target = training_utils_v1.cast_if_floating_dtype_and_mismatch(
+      if training_utils.has_tensors(target):
+        target = training_utils.cast_if_floating_dtype_and_mismatch(
             target, self.outputs)
-      training_utils_v1.validate_input_types(
-          target, orig_target, allow_dict=False, field_name='target')
+      training_utils.validate_input_types(target, orig_target,
+                                          allow_dict=False, field_name='target')
       if isinstance(target, (list, tuple)):
         all_inputs += list(target)
       else:
@@ -2631,7 +2628,7 @@ class Model(training_lib.Model):
         input_shape = (None,) + tuple(inputs.as_list()[1:])
       elif isinstance(inputs, dict):
         # We assert that the first layer is a FeatureLayer.
-        if not training_utils_v1.is_feature_layer(self.layers[0]):
+        if not training_utils.is_feature_layer(self.layers[0]):
           raise ValueError('Passing a dictionary input to a Sequential Model '
                            'which doesn\'t have FeatureLayer as the first layer'
                            ' is an error.')
@@ -2646,7 +2643,7 @@ class Model(training_lib.Model):
 
     # On-the-fly setting of symbolic model inputs (either by using the tensor
     # provided, or by creating a placeholder if Numpy data was provided).
-    model_inputs = training_utils_v1.ModelInputs(inputs)
+    model_inputs = training_utils.ModelInputs(inputs)
     inputs = model_inputs.get_symbolic_inputs()
     self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
     self.input_names = model_inputs.get_input_names()
@@ -2670,7 +2667,7 @@ class Model(training_lib.Model):
     #                    data adapter since it assumes nest.flatten ordering.
     outputs = nest.flatten(outputs)
     self.outputs = outputs
-    self.output_names = training_utils_v1.generic_output_names(outputs)
+    self.output_names = training_utils.generic_output_names(outputs)
     # TODO(scottzhu): Should we cleanup the self._training_endpoints here?
     self.built = True