diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD
index d270e6f638c..4bb7d5358e5 100644
--- a/tensorflow/python/keras/tests/BUILD
+++ b/tensorflow/python/keras/tests/BUILD
@@ -335,6 +335,29 @@ tf_py_test(
     ],
 )
 
+cuda_py_test(
+    name = "saver_test",
+    size = "medium",
+    srcs = ["saver_test.py"],
+    python_version = "PY3",
+    deps = [
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:nn_grad",
+        "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python:saver",
+        "//tensorflow/python:training_lib",
+        "//tensorflow/python:training_util",
+        "//tensorflow/python/eager:context",
+        "//tensorflow/python/keras/engine",
+        "//tensorflow/python/keras/layers:core",
+        "//tensorflow/python/training/tracking",
+        "//tensorflow/python/training/tracking:util",
+    ],
+)
+
 tf_py_test(
     name = "temporal_sample_weights_correctness_test",
     srcs = ["temporal_sample_weights_correctness_test.py"],
diff --git a/tensorflow/python/keras/tests/saver_test.py b/tensorflow/python/keras/tests/saver_test.py
new file mode 100644
index 00000000000..f425414a932
--- /dev/null
+++ b/tensorflow/python/keras/tests/saver_test.py
@@ -0,0 +1,158 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for tensorflow.python.training.saver.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+from tensorflow.python.training import saver as saver_module
+from tensorflow.python.training import training_util
+from tensorflow.python.training.tracking import tracking as trackable_tracking
+from tensorflow.python.training.tracking import util as trackable_utils
+
+
+class NonLayerTrackable(trackable_tracking.AutoTrackable):
+
+  def __init__(self):
+    super(NonLayerTrackable, self).__init__()
+    self.a_variable = trackable_utils.add_variable(
+        self, name="a_variable", shape=[])
+
+
+class MyModel(training.Model):
+  """A concrete Model for testing."""
+
+  def __init__(self):
+    super(MyModel, self).__init__()
+    self._named_dense = core.Dense(1, use_bias=True)
+    self._second = core.Dense(1, use_bias=False)
+    # We can still track Trackables which aren't Layers.
+    self._non_layer = NonLayerTrackable()
+
+  def call(self, values):
+    ret = self._second(self._named_dense(values))
+    return ret
+
+
+class TrackableCompatibilityTests(test.TestCase):
+
+  def _initialized_model(self):
+    input_value = constant_op.constant([[3.]])
+    model = MyModel()
+    optimizer = adam.AdamOptimizer(0.001)
+    optimizer_step = training_util.get_or_create_global_step()
+    root_trackable = trackable_utils.Checkpoint(
+        optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+    train_op = optimizer.minimize(
+        functools.partial(model, input_value),
+        global_step=optimizer_step)
+    self.evaluate(trackable_utils.gather_initializers(
+        root_trackable))
+    self.evaluate(train_op)
+    # A regular variable, a slot variable, and a non-slot Optimizer variable
+    # with known values to check when loading.
+    self.evaluate(model._named_dense.bias.assign([1.]))
+    self.evaluate(optimizer.get_slot(
+        var=model._named_dense.bias, name="m").assign([2.]))
+    beta1_power, _ = optimizer._get_beta_accumulators()
+    self.evaluate(beta1_power.assign(3.))
+    return root_trackable
+
+  def _set_sentinels(self, root_trackable):
+    self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
+    self.evaluate(
+        root_trackable.optimizer.get_slot(
+            var=root_trackable.model._named_dense.bias, name="m")
+        .assign([102.]))
+    beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
+    self.evaluate(beta1_power.assign(103.))
+
+  def _check_sentinels(self, root_trackable):
+    self.assertAllEqual(
+        [1.], self.evaluate(root_trackable.model._named_dense.bias))
+    self.assertAllEqual([2.], self.evaluate(
+        root_trackable.optimizer.get_slot(
+            var=root_trackable.model._named_dense.bias, name="m")))
+    beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
+    self.assertAllEqual(3., self.evaluate(beta1_power))
+
+  def testLoadFromObjectBasedGraph(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+    save_graph = ops_lib.Graph()
+    with save_graph.as_default(), self.session(graph=save_graph) as sess:
+      root = self._initialized_model()
+      object_saver = trackable_utils.Checkpoint(root=root)
+      save_path = object_saver.save(file_prefix=checkpoint_prefix)
+
+      # An incompatible object-based checkpoint to check error messages
+      var = resource_variable_ops.ResourceVariable(1., name="a")
+      self.evaluate(var.initializer)
+      second_saver = trackable_utils.Checkpoint(v=var)
+      second_path = second_saver.save(file_prefix=os.path.join(
+          checkpoint_directory, "second"))
+
+    restore_graph = ops_lib.Graph()
+    with restore_graph.as_default(), self.session(
+        graph=restore_graph) as sess:
+      root = self._initialized_model()
+      self._set_sentinels(root)
+      saver = saver_module.Saver()
+      saver.restore(sess=sess, save_path=save_path)
+      self._check_sentinels(root)
+      before_second_restore_ops = restore_graph.get_operations()
+      # Test that multiple restores do not pollute the graph
+      saver.restore(sess=sess, save_path=save_path)
+      self.assertEqual(before_second_restore_ops,
+                       restore_graph.get_operations())
+      with self.assertRaisesRegexp(errors.NotFoundError,
+                                   "Could not find some variables"):
+        saver.restore(sess=sess, save_path=second_path)
+
+  def testLoadFromObjectBasedEager(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+    save_graph = ops_lib.Graph()
+    with save_graph.as_default(), self.session(graph=save_graph):
+      root = self._initialized_model()
+      object_saver = trackable_utils.Checkpoint(root=root)
+      save_path = object_saver.save(file_prefix=checkpoint_prefix)
+
+    with context.eager_mode():
+      root = self._initialized_model()
+      self._set_sentinels(root)
+      saver = saver_module.Saver(
+          root.model.variables + root.optimizer.variables())
+      saver.restore(sess=None, save_path=save_path)
+      self._check_sentinels(root)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 2c8bdadd5d7..5c87be37e4c 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import functools
 import glob
 import math
 import os
@@ -48,8 +47,6 @@ from tensorflow.python.framework import graph_io
 from tensorflow.python.framework import meta_graph
 from tensorflow.python.framework import ops as ops_lib
 from tensorflow.python.framework import test_util
-from tensorflow.python.keras.engine import training
-from tensorflow.python.keras.layers import core
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -74,10 +71,7 @@ from tensorflow.python.training import py_checkpoint_reader
 from tensorflow.python.training import queue_runner_impl
 from tensorflow.python.training import saver as saver_module
 from tensorflow.python.training import saver_test_utils
-from tensorflow.python.training import training_util
 from tensorflow.python.training.tracking import base as trackable_base
-from tensorflow.python.training.tracking import tracking as trackable_tracking
-from tensorflow.python.training.tracking import util as trackable_utils
 from tensorflow.python.util import compat
 
 
@@ -3024,29 +3018,6 @@ class _OwnsMirroredVariables(trackable_base.Trackable):
     return self.non_dep_variable.name
 
 
-class NonLayerTrackable(trackable_tracking.AutoTrackable):
-
-  def __init__(self):
-    super(NonLayerTrackable, self).__init__()
-    self.a_variable = trackable_utils.add_variable(
-        self, name="a_variable", shape=[])
-
-
-class MyModel(training.Model):
-  """A concrete Model for testing."""
-
-  def __init__(self):
-    super(MyModel, self).__init__()
-    self._named_dense = core.Dense(1, use_bias=True)
-    self._second = core.Dense(1, use_bias=False)
-    # We can still track Trackables which aren't Layers.
-    self._non_layer = NonLayerTrackable()
-
-  def call(self, values):
-    ret = self._second(self._named_dense(values))
-    return ret
-
-
 class TrackableCompatibilityTests(test.TestCase):
 
   # TODO(allenl): Track down python3 reference cycles in these tests.
@@ -3112,46 +3083,6 @@ class TrackableCompatibilityTests(test.TestCase):
         saver.restore(sess, save_path)
         self.assertEqual(1, v.eval_count)
 
-  def _initialized_model(self):
-    input_value = constant_op.constant([[3.]])
-    model = MyModel()
-    optimizer = adam.AdamOptimizer(0.001)
-    optimizer_step = training_util.get_or_create_global_step()
-    root_trackable = trackable_utils.Checkpoint(
-        optimizer=optimizer, model=model, optimizer_step=optimizer_step)
-    train_op = optimizer.minimize(
-        functools.partial(model, input_value),
-        global_step=optimizer_step)
-    self.evaluate(trackable_utils.gather_initializers(
-        root_trackable))
-    self.evaluate(train_op)
-    # A regular variable, a slot variable, and a non-slot Optimizer variable
-    # with known values to check when loading.
-    self.evaluate(model._named_dense.bias.assign([1.]))
-    self.evaluate(optimizer.get_slot(
-        var=model._named_dense.bias, name="m").assign([2.]))
-    beta1_power, _ = optimizer._get_beta_accumulators()
-    self.evaluate(beta1_power.assign(3.))
-    return root_trackable
-
-  def _set_sentinels(self, root_trackable):
-    self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
-    self.evaluate(
-        root_trackable.optimizer.get_slot(
-            var=root_trackable.model._named_dense.bias, name="m")
-        .assign([102.]))
-    beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
-    self.evaluate(beta1_power.assign(103.))
-
-  def _check_sentinels(self, root_trackable):
-    self.assertAllEqual(
-        [1.], self.evaluate(root_trackable.model._named_dense.bias))
-    self.assertAllEqual([2.], self.evaluate(
-        root_trackable.optimizer.get_slot(
-            var=root_trackable.model._named_dense.bias, name="m")))
-    beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
-    self.assertAllEqual(3., self.evaluate(beta1_power))
-
   def testVariableNotFoundErrorRaised(self):
     # Restore does some tricky exception handling to figure out if it should
     # load an object-based checkpoint. Tests that the exception handling isn't
@@ -3199,58 +3130,6 @@ class TrackableCompatibilityTests(test.TestCase):
             "a mismatch between the current graph and the graph"):
           a_saver.restore(sess=sess, save_path=save_path)
 
-  def testLoadFromObjectBasedGraph(self):
-    checkpoint_directory = self.get_temp_dir()
-    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
-
-    save_graph = ops_lib.Graph()
-    with save_graph.as_default(), self.session(graph=save_graph) as sess:
-      root = self._initialized_model()
-      object_saver = trackable_utils.Checkpoint(root=root)
-      save_path = object_saver.save(file_prefix=checkpoint_prefix)
-
-      # An incompatible object-based checkpoint to check error messages
-      var = resource_variable_ops.ResourceVariable(1., name="a")
-      self.evaluate(var.initializer)
-      second_saver = trackable_utils.Checkpoint(v=var)
-      second_path = second_saver.save(file_prefix=os.path.join(
-          checkpoint_directory, "second"))
-
-    restore_graph = ops_lib.Graph()
-    with restore_graph.as_default(), self.session(
-        graph=restore_graph) as sess:
-      root = self._initialized_model()
-      self._set_sentinels(root)
-      saver = saver_module.Saver()
-      saver.restore(sess=sess, save_path=save_path)
-      self._check_sentinels(root)
-      before_second_restore_ops = restore_graph.get_operations()
-      # Test that multiple restores do not pollute the graph
-      saver.restore(sess=sess, save_path=save_path)
-      self.assertEqual(before_second_restore_ops,
-                       restore_graph.get_operations())
-      with self.assertRaisesRegexp(errors.NotFoundError,
-                                   "Could not find some variables"):
-        saver.restore(sess=sess, save_path=second_path)
-
-  def testLoadFromObjectBasedEager(self):
-    checkpoint_directory = self.get_temp_dir()
-    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
-
-    save_graph = ops_lib.Graph()
-    with save_graph.as_default(), self.session(graph=save_graph):
-      root = self._initialized_model()
-      object_saver = trackable_utils.Checkpoint(root=root)
-      save_path = object_saver.save(file_prefix=checkpoint_prefix)
-
-    with context.eager_mode():
-      root = self._initialized_model()
-      self._set_sentinels(root)
-      saver = saver_module.Saver(
-          root.model.variables + root.optimizer.variables())
-      saver.restore(sess=None, save_path=save_path)
-      self._check_sentinels(root)
-
 
 if __name__ == "__main__":
   test.main()