diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7c2e8b003b9..91c7fd16c51 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3265,6 +3265,7 @@ py_library(
         "@six_archive//:six",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/distribute:distribute_coordinator_context",
         "//tensorflow/python/eager:backprop",
         "//tensorflow/python/eager:context",
         # `layers` dependency only exists due to the use of a small utility.
@@ -4658,7 +4659,10 @@ py_test(
     size = "medium",
     srcs = ["training/monitored_session_test.py"],
     srcs_version = "PY2AND3",
-    tags = ["notsan"],  # b/67945581
+    tags = [
+        "no_pip",
+        "notsan",  # b/67945581
+    ],
     deps = [
         ":array_ops",
         ":checkpoint_management",
@@ -4676,6 +4680,7 @@ py_test(
         "//tensorflow/contrib/framework:framework_py",
         "//tensorflow/contrib/testing:testing_py",
         "//tensorflow/core:protos_all_py",
+        "//tensorflow/python/distribute:distribute_coordinator",
     ],
 )
 
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 68d8b8d13b1..16fbe3f4b55 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -41,3 +41,12 @@ py_test(
         "//tensorflow/python:variables",
     ],
 )
+
+py_library(
+    name = "distribute_coordinator_context",
+    srcs = [
+        "distribute_coordinator_context.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [],
+)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index fc9ca4ac4a3..eb081b65fc7 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""A unified and split coordinator for distributed TensorFlow."""
+"""A component for running distributed TensorFlow."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -24,6 +24,8 @@ import os
 import threading
 
 from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.training import monitored_session
 from tensorflow.python.training import server_lib
 
 
@@ -43,23 +45,12 @@ class CoordinatorMode(object):
   # client and connects to remote servers for training.  Each remote server can
   # use the distribute coordinator binary with task_type set correctly which
   # will then turn into standard servers.
-  SPLIT_CLIENT = 0
+  STANDALONE_CLIENT = "standalone_client"
 
   # The distribute coordinator runs on each worker. It will run a standard
   # server on each worker and optionally run the `worker_fn` that is configured
   # to talk to its standard server.
-  INDEPENDENT_WORKER = 1
-
-
-_worker_context = threading.local()
-
-
-def get_current_worker_context():
-  """Returns the current task context."""
-  try:
-    return _worker_context.current
-  except AttributeError:
-    return None
+  INDEPENDENT_WORKER = "independent_worker"
 
 
 class _Barrier(object):
@@ -113,14 +104,17 @@ class _WorkerContext(object):
   """
 
   def __init__(self,
+               strategy,
                cluster_spec,
                task_type,
                task_id,
+               session_config=None,
                rpc_layer="grpc",
                worker_barrier=None):
     """Initialize the worker context object.
 
     Args:
+      strategy: a `DistributionStrategy` object.
       cluster_spec: a ClusterSpec object. It can be empty or None in the local
         training case.
       task_type: a string indicating the role of the corresponding task, such as
@@ -128,14 +122,17 @@ class _WorkerContext(object):
         replicated training.
       task_id: an integer indicating id of the corresponding task. It can be
         None if it is local training or in-graph replicated training.
+      session_config: an optional @{tf.ConfigProto} object.
       rpc_layer: optional string specifying the RPC protocol for communication
         with worker masters. If None or empty, hosts in the `cluster_spec` will
         be used directly.
       worker_barrier: optional, the barrier object for worker synchronization.
     """
+    self._strategy = strategy
     self._cluster_spec = cluster_spec
     self._task_type = task_type
     self._task_id = task_id
+    self._session_config = session_config
     self._worker_barrier = worker_barrier
     self._rpc_layer = rpc_layer
     self._master_target = self._get_master_target()
@@ -143,26 +140,31 @@ class _WorkerContext(object):
     self._is_chief_node = self._is_chief()
 
   def _debug_message(self):
-    return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
-        self._cluster_spec, self.task_type, self.task_id)
+    if self._cluster_spec:
+      return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
+          self._cluster_spec, self.task_type, self.task_id)
+    else:
+      return "[local]"
 
   def __enter__(self):
-    old_context = get_current_worker_context()
+    old_context = distribute_coordinator_context.get_current_worker_context()
     if old_context:
       raise ValueError(
           "You cannot run distribute coordinator in a `worker_fn`.\t" +
           self._debug_message())
-    _worker_context.current = self
+    # pylint: disable=protected-access
+    distribute_coordinator_context._worker_context.current = self
 
   def __exit__(self, unused_exception_type, unused_exception_value,
                unused_traceback):
-    _worker_context.current = None
+    # pylint: disable=protected-access
+    distribute_coordinator_context._worker_context.current = None
 
   def _get_master_target(self):
     """Return the master target for a task."""
     # If cluster_spec is None or empty, we use local master.
     if not self._cluster_spec:
-      return "local"
+      return ""
 
     # If task_type is None, then it is in-graph replicated training. In this
     # case we use the chief or first worker's master target.
@@ -207,6 +209,47 @@ class _WorkerContext(object):
                        self._debug_message())
     self._worker_barrier.wait()
 
+  def session_creator(self,
+                      scaffold=None,
+                      config=None,
+                      checkpoint_dir=None,
+                      checkpoint_filename_with_path=None,
+                      max_wait_secs=7200):
+    """Returns a session creator.
+
+    The returned session creator will be configured with the correct master
+    target and session configs. It will also run either init ops or ready ops
+    by querying the `strategy` object when `create_session` is called on it.
+
+    Args:
+      scaffold: A `Scaffold` used for gathering or building supportive ops. If
+        not specified a default one is created. It's used to finalize the graph.
+      config: `ConfigProto` proto used to configure the session.
+      checkpoint_dir: A string. Optional path to a directory where to restore
+        variables.
+      checkpoint_filename_with_path: Full file name path to the checkpoint file.
+        Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
+        specified.
+      max_wait_secs: Maximum time to wait for the session to become available.
+
+    Returns:
+      a descendant of SessionCreator.
+    """
+    # TODO(yuefengz): merge session config.
+    if self._strategy.should_init:
+      return monitored_session.ChiefSessionCreator(
+          scaffold,
+          master=self.master_target,
+          config=config or self._session_config,
+          checkpoint_dir=checkpoint_dir,
+          checkpoint_filename_with_path=checkpoint_filename_with_path)
+    else:
+      return monitored_session.WorkerSessionCreator(
+          scaffold,
+          master=self.master_target,
+          config=config or self._session_config,
+          max_wait_secs=max_wait_secs)
+
   @property
   def has_barrier(self):
     """Whether the barrier is set or not."""
@@ -247,21 +290,38 @@ class _WorkerContext(object):
     """Returns number of workers in the cluster, including chief."""
     return self._num_workers
 
+  @property
+  def should_checkpoint(self):
+    """Whether to save checkpoint."""
+    return self._strategy.should_checkpoint
+
+  @property
+  def should_save_summary(self):
+    """Whether to save summaries."""
+    return self._strategy.should_save_summary
+
 
 def _run_single_worker(worker_fn,
+                       strategy,
                        cluster_spec,
                        task_type,
                        task_id,
-                       rpc_layer,
+                       session_config,
+                       rpc_layer="",
                        worker_barrier=None):
   """Runs a single worker by calling `worker_fn` under context."""
-  with _WorkerContext(
+  strategy = copy.deepcopy(strategy)
+  strategy.configure(session_config, cluster_spec, task_type, task_id)
+  context = _WorkerContext(
+      strategy,
       cluster_spec,
       task_type,
       task_id,
+      session_config=session_config,
       rpc_layer=rpc_layer,
-      worker_barrier=worker_barrier):
-    worker_fn()
+      worker_barrier=worker_barrier)
+  with context:
+    worker_fn(strategy)
 
 
 def _run_std_server(cluster_spec=None,
@@ -280,13 +340,15 @@ def _run_std_server(cluster_spec=None,
   return server
 
 
-def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
+def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
+                              rpc_layer):
   """Runs a standalone client for between-graph replication."""
   eval_thread = None
   if _TaskType.EVALUATOR in cluster_spec.jobs:
     eval_thread = threading.Thread(
         target=_run_single_worker,
-        args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+        args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+              session_config),
         kwargs={
             "rpc_layer": rpc_layer,
         })
@@ -298,7 +360,8 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
     for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
       t = threading.Thread(
           target=_run_single_worker,
-          args=(worker_fn, cluster_spec, task_type, task_id),
+          args=(worker_fn, strategy, cluster_spec, task_type, task_id,
+                session_config),
           kwargs={
               "rpc_layer": rpc_layer,
               "worker_barrier": worker_barrier
@@ -315,43 +378,53 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
     eval_thread.join()
 
 
-def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer):
+def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
+                         rpc_layer):
   """Runs a standalone client for in-graph replication."""
   eval_thread = None
   if _TaskType.EVALUATOR in cluster_spec.jobs:
     eval_thread = threading.Thread(
         target=_run_single_worker,
-        args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+        args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+              session_config),
         kwargs={
             "rpc_layer": rpc_layer,
         })
     eval_thread.start()
 
-  _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+  _run_single_worker(
+      worker_fn,
+      strategy,
+      cluster_spec,
+      None,
+      None,
+      session_config,
+      rpc_layer=rpc_layer)
   if eval_thread:
     eval_thread.join()
 
-
-# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode.
+# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
 # TODO(yuefengz): we may need a smart way to figure out whether the current task
 # is the special task when we support cluster_spec propagation.
 def run_distribute_coordinator(worker_fn,
-                               mode=CoordinatorMode.SPLIT_CLIENT,
+                               strategy,
+                               mode=CoordinatorMode.STANDALONE_CLIENT,
                                cluster_spec=None,
                                task_type=None,
                                task_id=None,
-                               between_graph=False,
+                               session_config=None,
                                rpc_layer="grpc"):
   """Runs the coordinator for distributed TensorFlow.
 
   This function runs a split coordinator for distributed TensorFlow in its
-  default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying
-  server addresses and their roles in a cluster, this coordinator will figure
-  out how to set them up, give the underlying function the right targets for
-  master sessions via a scope object and coordinate their training. The cluster
-  consisting of standard servers needs to be brought up either with the standard
-  server binary or with a binary running distribute coordinator with `task_type`
-  set to non-client type which will then turn into standard servers.
+  default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
+  specifying server addresses and their roles in a cluster, this coordinator
+  will figure out how to set them up, give the underlying function the right
+  targets for master sessions via a scope object and coordinate their training.
+  The cluster consisting of standard servers needs to be brought up either with
+  the standard server binary or with a binary running distribute coordinator
+  with `task_type` set to non-client type which will then turn into standard
+  servers.
 
   In addition to be the distribute coordinator, this is also the source of
   configurations for each job in the distributed training. As there are multiple
@@ -370,6 +443,14 @@ def run_distribute_coordinator(worker_fn,
   `worker_fn` depending whether it is between-graph training or in-graph
   replicated training.
 
+  The `strategy` object is expected to be a DistributionStrategy object which
+  has implemented methods needed by distributed coordinator such as
+  `configure(session_config, cluster_spec, task_type, task_id)` which configures
+  the strategy object for a specific task and `should_init` property which
+  instructs the distribute coordinator whether to run init ops for a task. The
+  distribute coordinator will make a copy of the `strategy` object, call its
+  `configure` method and pass it to `worker_fn` as an argument.
+
   The `worker_fn` defines the training logic and is called under a its own
   worker context which can be accessed to via `get_current_worker_context`. A
   worker context provides access to configurations for each task, e.g. the
@@ -413,16 +494,20 @@ def run_distribute_coordinator(worker_fn,
   evaluation.
 
   Args:
-    worker_fn: the function to be called and given the access to a coordinator
-      context object.
+    worker_fn: the function to be called. The function should accept a
+      `strategy` object and will be given access to a context object via a
+      context manager scope.
+    strategy: a DistributionStrategy object which specifying whether it should
+      run between-graph replicated training or not, whether to run init ops,
+      etc. This object will also be configured given `session_config`,
+      `cluster_spc`, `task_type` and `task_id`.
     mode: in which mode this distribute coordinator runs.
     cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
       in a cluster. If not set or empty, fall back to local training.
     task_type: the current task type, optional if this is a client.
     task_id: the current task id, optional if this is a client.
-    between_graph: a boolean. It is only useful when `cluster_spec` is set and
-      not empty. If true, it will use between-graph replicated training;
-      otherwise it will use in-graph replicated training.
+    session_config: an optional @{tf.ConfigProto} object which will be passed
+      to `strategy`'s `configure` method and used to create a session.
     rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
 
   Raises:
@@ -448,15 +533,18 @@ def run_distribute_coordinator(worker_fn,
 
   if not cluster_spec:
     # `mode` is ignored in the local case.
-    _run_single_worker(worker_fn, None, None, None, rpc_layer)
-  elif mode == CoordinatorMode.SPLIT_CLIENT:
+    _run_single_worker(worker_fn, strategy, None, None, None, session_config,
+                       rpc_layer)
+  elif mode == CoordinatorMode.STANDALONE_CLIENT:
     # The client must know the cluster but servers in the cluster don't have to
     # know the client.
     if task_type in [_TaskType.CLIENT, None]:
-      if between_graph:
-        _run_between_graph_client(worker_fn, cluster_spec, rpc_layer)
+      if strategy.between_graph:
+        _run_between_graph_client(worker_fn, strategy, cluster_spec,
+                                  session_config, rpc_layer)
       else:
-        _run_in_graph_client(worker_fn, cluster_spec, rpc_layer)
+        _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
+                             rpc_layer)
     else:
       # If not a client job, run the standard server.
       server = _run_std_server(
@@ -471,19 +559,21 @@ def run_distribute_coordinator(worker_fn,
         cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
 
     if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
-      if between_graph:
+      if strategy.between_graph:
         # All jobs run `worker_fn` if between-graph.
-        _run_single_worker(worker_fn, cluster_spec, task_type, task_id,
-                           rpc_layer)
+        _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
+                           task_id, session_config, rpc_layer)
       else:
         # Only one node runs `worker_fn` if in-graph.
-        context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer)
+        context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
         if context.is_chief:
-          _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+          _run_single_worker(worker_fn, strategy, cluster_spec, None, None,
+                             session_config, rpc_layer)
         else:
           server.join()
     elif task_type == _TaskType.EVALUATOR:
-      _run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer)
+      _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
+                         session_config, rpc_layer)
     else:
       if task_type != _TaskType.PS:
         raise ValueError("Unexpected task_type: %r" % task_type)
diff --git a/tensorflow/python/distribute/distribute_coordinator_context.py b/tensorflow/python/distribute/distribute_coordinator_context.py
new file mode 100644
index 00000000000..dee65ce8839
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_coordinator_context.py
@@ -0,0 +1,31 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The context retrieval method for distribute coordinator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+_worker_context = threading.local()
+
+
+def get_current_worker_context():
+  """Returns the current task context."""
+  try:
+    return _worker_context.current
+  except AttributeError:
+    return None
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 319c29ba2fa..97c6bdd15a5 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for distribute coordinator."""
+"""Tests for Distribute Coordinator."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -37,6 +37,7 @@ except ImportError as _error:
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session
 from tensorflow.python.distribute import distribute_coordinator
+from tensorflow.python.distribute import distribute_coordinator_context
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import control_flow_ops
@@ -44,17 +45,17 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+
 
 CHIEF = distribute_coordinator._TaskType.CHIEF
 WORKER = distribute_coordinator._TaskType.WORKER
 PS = distribute_coordinator._TaskType.PS
 EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
 
-SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT
+STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT
 INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
 
-RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server"
-
 NUM_WORKERS = 3
 NUM_PS = 2
 
@@ -74,6 +75,57 @@ def _strip_protocol(target):
     return target
 
 
+class MockStrategy(object):
+
+  def __init__(self,
+               between_graph=False,
+               should_init=None,
+               should_checkpoint=None,
+               should_save_summary=None):
+    self._between_graph = between_graph
+    self._should_init = should_init
+    self._should_checkpoint = should_checkpoint
+    self._should_save_summary = should_save_summary
+
+  @property
+  def between_graph(self):
+    return self._between_graph
+
+  def configure(self,
+                session_options=None,
+                cluster_spec=None,
+                task_type=None,
+                task_id=None):
+    del session_options, cluster_spec, task_type
+    if self._should_init is None:
+      if task_id == 0:
+        self._should_init = True
+      else:
+        self._should_init = False
+    if self._should_checkpoint is None:
+      if task_id == 0:
+        self._should_checkpoint = True
+      else:
+        self._should_checkpoint = False
+    if self._should_save_summary is None:
+      if task_id == 0:
+        self._should_save_summary = True
+      else:
+        self._should_save_summary = False
+
+  @property
+  def should_init(self):
+    return self._should_init
+
+  @property
+  def should_checkpoint(self):
+    return self._should_checkpoint
+
+  @property
+  def should_save_summary(self):
+    return self._should_save_summary
+
+
 class MockServer(object):
 
   def __init__(self):
@@ -108,6 +160,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
     self._result_correct = 0
     self._lock = threading.Lock()
     self._worker_context = {}
+    self._strategy_property = {}
     self._std_servers = {}
     self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
 
@@ -142,8 +195,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
       cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
     return cluster_spec
 
-  def _in_graph_worker_fn(self):
-    context = distribute_coordinator.get_current_worker_context()
+  def _in_graph_worker_fn(self, strategy):
+    context = distribute_coordinator_context.get_current_worker_context()
     self.assertTrue(context is not None)
     with self._test_session(target=context.master_target) as sess:
       xs = []
@@ -164,22 +217,23 @@ class DistributeCoordinatorTestBase(test.TestCase):
     if result_value == expected:
       self._result_correct += 1
 
-  def _run_coordinator_in_thread(self, worker_fn, **kwargs):
+  def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs):
     t = threading.Thread(
         target=distribute_coordinator.run_distribute_coordinator,
-        args=(worker_fn,),
+        args=(worker_fn, strategy),
         kwargs=kwargs)
     t.start()
     return t
 
-  def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec,
-                                           **kwargs):
+  def _run_multiple_coordinator_in_threads(self, worker_fn, strategy,
+                                           cluster_spec, **kwargs):
     threads = {}
     for task_type in cluster_spec.keys():
       threads[task_type] = []
       for task_id in range(len(cluster_spec[task_type])):
         t = self._run_coordinator_in_thread(
             worker_fn,
+            strategy,
             cluster_spec=cluster_spec,
             task_type=task_type,
             task_id=task_id,
@@ -187,8 +241,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
         threads[task_type].append(t)
     return threads
 
-  def _between_graph_worker_fn(self):
-    context = distribute_coordinator.get_current_worker_context()
+  def _between_graph_worker_fn(self, strategy):
+    context = distribute_coordinator_context.get_current_worker_context()
     self.assertTrue(context is not None)
     with self._test_session(target=context.master_target) as sess:
       with ops.device("/job:ps/task:0"):
@@ -234,14 +288,50 @@ class DistributeCoordinatorTestBase(test.TestCase):
         with self._lock:
           self._result_correct += 1
 
-  def _dump_worker_context(self):
+  def _between_graph_with_monitored_session(self, strategy):
+    context = distribute_coordinator_context.get_current_worker_context()
+    self.assertTrue(context is not None)
+    with ops.device("/job:ps/task:0"):
+      # TODO(yuefengz): investigate why not using resource variable will make
+      # the test flaky.
+      x = variable_scope.get_variable("x", initializer=10.0, use_resource=True)
+    with ops.device("/job:ps/task:1"):
+      y = variable_scope.get_variable("y", initializer=20.0, use_resource=True)
+
+    x_add = x.assign_add(2.0)
+    y_sub = y.assign_sub(2.0)
+    train_op = control_flow_ops.group([x_add, y_sub])
+
+    # The monitored session will run init or ready ops.
+    with monitored_session.MonitoredSession() as sess:
+      sess.run(train_op)
+
+      # Synchronize workers after one step to make sure they all have finished
+      # training.
+      if context.has_barrier:
+        context.wait_for_other_workers()
+      else:
+        self._barrier.wait()
+
+      x_val, y_val = sess.run([x, y])
+
+    self.assertEqual(x_val, 16.0)
+    self.assertEqual(y_val, 14.0)
+    if x_val == 16.0 and y_val == 14.0:
+      with self._lock:
+        self._result_correct += 1
+
+  def _dump_worker_context(self, strategy):
     """Dumps the propoerties of each worker context.
 
     It dumps the context properties to a dict mapping from task_type to a list
     of tuples of master_target, num_workers, is_chief and distribute_mode, where
     the list is indexed by the task_id.
+
+    Args:
+      strategy: a `DistributionStrategy` object.
     """
-    context = distribute_coordinator.get_current_worker_context()
+    context = distribute_coordinator_context.get_current_worker_context()
     self.assertTrue(context is not None)
     task_type = str(context.task_type)
     task_id = context.task_id or 0
@@ -255,6 +345,25 @@ class DistributeCoordinatorTestBase(test.TestCase):
                                                   context.is_chief,
                                                   context.distributed_mode)
 
+  def _dump_strategy_property(self, strategy):
+    context = distribute_coordinator_context.get_current_worker_context()
+    self.assertTrue(context is not None)
+
+    self.assertEqual(context._strategy.should_init, strategy.should_init)
+    self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
+    self.assertEqual(context.should_save_summary, strategy.should_save_summary)
+
+    task_type = str(context.task_type)
+    task_id = context.task_id or 0
+    with self._lock:
+      if task_type not in self._strategy_property:
+        self._strategy_property[task_type] = []
+      while len(self._strategy_property[task_type]) <= task_id:
+        self._strategy_property[task_type].append(None)
+      self._strategy_property[task_type][task_id] = (
+          context._strategy.should_init, context.should_checkpoint,
+          context.should_save_summary)
+
   def _run_mock_std_server(self,
                            session_config=None,
                            cluster_spec=None,
@@ -274,22 +383,32 @@ class DistributeCoordinatorTestBase(test.TestCase):
     return server
 
 
-class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
+class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
 
-  def testInGraphSplitMode(self):
-    """Test it runs in-graph replication in split client mode."""
+  def testInGraphStandaloneMode(self):
+    """Test it runs in-graph replication in standalone client mode."""
     distribute_coordinator.run_distribute_coordinator(
         self._in_graph_worker_fn,
-        cluster_spec=self._cluster_spec,
-        between_graph=False)
+        MockStrategy(between_graph=False),
+        cluster_spec=self._cluster_spec)
     self.assertEqual(self._result_correct, 1)
 
   def testBetweenGraph(self):
-    """Test it runs between-graph replication in split client mode."""
+    """Test it runs between-graph replication in standalone client mode."""
     distribute_coordinator.run_distribute_coordinator(
         self._between_graph_worker_fn,
-        cluster_spec=self._cluster_spec,
-        between_graph=True)
+        MockStrategy(between_graph=True),
+        cluster_spec=self._cluster_spec)
+
+    # Each finished worker will increment self._result_correct.
+    self.assertEqual(self._result_correct, NUM_WORKERS)
+
+  def testBetweenGraphWithMonitoredSession(self):
+    """Test monitored session in standalone client mode."""
+    distribute_coordinator.run_distribute_coordinator(
+        self._between_graph_with_monitored_session,
+        MockStrategy(between_graph=True),
+        cluster_spec=self._cluster_spec)
 
     # Each finished worker will increment self._result_correct.
     self.assertEqual(self._result_correct, NUM_WORKERS)
@@ -298,8 +417,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
     # Dumps the task contexts to the self._worker_context dict.
     distribute_coordinator.run_distribute_coordinator(
         self._dump_worker_context,
-        cluster_spec=self._cluster_spec,
-        between_graph=True)
+        MockStrategy(between_graph=True),
+        cluster_spec=self._cluster_spec)
 
     # There is only one type of task and there three such tasks.
     self.assertEqual(len(self._worker_context), 1)
@@ -318,12 +437,30 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
         self._worker_context[WORKER][2],
         (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
 
+  def testBetweenGraphStrategyProperties(self):
+    # Dumps properties of the strategy objects.
+    distribute_coordinator.run_distribute_coordinator(
+        self._dump_strategy_property,
+        MockStrategy(between_graph=True, should_init=True),
+        cluster_spec=self._cluster_spec)
+
+    # There is only one type of task and there three such tasks.
+    self.assertEqual(len(self._strategy_property), 1)
+    self.assertTrue(WORKER in self._strategy_property)
+    self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
+
+    # Check whether each task has the right properties of should_init,
+    # should_checkpoint and should_save_summary.
+    self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
+    self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
+    self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
+
   def testInGraphContext(self):
     # Dumps the task contexts to the self._worker_context dict.
     distribute_coordinator.run_distribute_coordinator(
         self._dump_worker_context,
-        cluster_spec=self._cluster_spec,
-        between_graph=False)
+        MockStrategy(between_graph=False),
+        cluster_spec=self._cluster_spec)
 
     # There is only a "None" task in the dumped task context.
     self.assertEqual(len(self._worker_context), 1)
@@ -339,7 +476,9 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
   def testLocalContext(self):
     # Dumps the task contexts to the self._worker_context dict.
     distribute_coordinator.run_distribute_coordinator(
-        self._dump_worker_context, cluster_spec=None, between_graph=True)
+        self._dump_worker_context,
+        MockStrategy(between_graph=False),
+        cluster_spec=None)
 
     # There is only a "None" task.
     self.assertEqual(len(self._worker_context), 1)
@@ -348,7 +487,7 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
 
     # Check whether each task has the right master_target, num_workers, is_chief
     # and distributed_mode.
-    self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False))
+    self.assertEqual(self._worker_context["None"][0], ("", 0, True, False))
 
   def testBetweenGraphContextWithChief(self):
     # Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
@@ -358,8 +497,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
     # Dumps the task contexts to the self._worker_context dict.
     distribute_coordinator.run_distribute_coordinator(
         self._dump_worker_context,
+        MockStrategy(between_graph=True),
         cluster_spec=cluster_spec,
-        between_graph=True,
         rpc_layer="grpc")
 
     # There are one CHIEF and three workers.
@@ -391,8 +530,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
     # Dumps the task contexts to the self._worker_context dict.
     distribute_coordinator.run_distribute_coordinator(
         self._dump_worker_context,
+        MockStrategy(between_graph=False),
         cluster_spec=cluster_spec,
-        between_graph=False,
         rpc_layer=None)
 
     # There are one "None" task and one EVALUATOR task.
@@ -417,8 +556,8 @@ class DistributeCoordinatorTestInpendentWorkerMode(
     cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
     threads = self._run_multiple_coordinator_in_threads(
         self._in_graph_worker_fn,
+        MockStrategy(between_graph=False),
         cluster_spec,
-        between_graph=False,
         mode=INDEPENDENT_WORKER)
     threads[WORKER][0].join()
     self.assertEqual(self._result_correct, 1)
@@ -428,8 +567,22 @@ class DistributeCoordinatorTestInpendentWorkerMode(
         num_workers=NUM_WORKERS, num_ps=NUM_PS)
     threads = self._run_multiple_coordinator_in_threads(
         self._between_graph_worker_fn,
+        MockStrategy(between_graph=True),
+        cluster_spec,
+        mode=INDEPENDENT_WORKER)
+    for task_id in range(NUM_WORKERS):
+      threads[WORKER][task_id].join()
+
+    # Each finished worker will increment self._result_correct.
+    self.assertEqual(self._result_correct, NUM_WORKERS)
+
+  def testBetweenGraphWithMonitoredSession(self):
+    cluster_spec = self._create_cluster_spec(
+        num_workers=NUM_WORKERS, num_ps=NUM_PS)
+    threads = self._run_multiple_coordinator_in_threads(
+        self._between_graph_with_monitored_session,
+        MockStrategy(between_graph=True),
         cluster_spec,
-        between_graph=True,
         mode=INDEPENDENT_WORKER)
     for task_id in range(NUM_WORKERS):
       threads[WORKER][task_id].join()
@@ -444,9 +597,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
                                 self._run_mock_std_server):
       threads = self._run_multiple_coordinator_in_threads(
           self._dump_worker_context,
+          MockStrategy(between_graph=True),
           cluster_spec,
           mode=INDEPENDENT_WORKER,
-          between_graph=True,
           rpc_layer=None)
       for task_id in range(NUM_WORKERS):
         threads[WORKER][task_id].join()
@@ -476,6 +629,31 @@ class DistributeCoordinatorTestInpendentWorkerMode(
     self.assertFalse(self._std_servers[WORKER][1].joined)
     self.assertFalse(self._std_servers[WORKER][2].joined)
 
+  def testBetweenGraphStrategyProperties(self):
+    cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
+    # Dumps properties of the strategy objects.
+    with test.mock.patch.object(distribute_coordinator, "_run_std_server",
+                                self._run_mock_std_server):
+      threads = self._run_multiple_coordinator_in_threads(
+          self._dump_strategy_property,
+          MockStrategy(between_graph=True, should_init=True),
+          cluster_spec,
+          mode=INDEPENDENT_WORKER,
+          rpc_layer=None)
+      for task_id in range(NUM_WORKERS):
+        threads[WORKER][task_id].join()
+
+    # There is only one type of task and there three such tasks.
+    self.assertEqual(len(self._strategy_property), 1)
+    self.assertTrue(WORKER in self._strategy_property)
+    self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
+
+    # Check whether each task has the right properties of should_init,
+    # should_checkpoint and should_save_summary.
+    self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
+    self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
+    self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
+
   def testInGraphContext(self):
     cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
     # Dumps the task contexts and std server arguments.
@@ -483,9 +661,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
                                 self._run_mock_std_server):
       threads = self._run_multiple_coordinator_in_threads(
           self._dump_worker_context,
+          MockStrategy(between_graph=False),
           cluster_spec,
           mode=INDEPENDENT_WORKER,
-          between_graph=False,
           rpc_layer=None)
       for task_id in range(NUM_WORKERS):
         threads[WORKER][task_id].join()
@@ -519,9 +697,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
                                 self._run_mock_std_server):
       threads = self._run_multiple_coordinator_in_threads(
           self._dump_worker_context,
+          MockStrategy(between_graph=False),
           cluster_spec,
           mode=INDEPENDENT_WORKER,
-          between_graph=False,
           rpc_layer=None)
       for task_id in range(NUM_WORKERS):
         threads[WORKER][task_id].join()
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 7b06bffa4b2..c077630de2b 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -25,6 +25,7 @@ import sys
 import six
 
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.distribute import distribute_coordinator_context
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -284,6 +285,63 @@ class Scaffold(object):
         resources.initialize_resources(resources.local_resources()))
 
 
+def _create_monitored_session_with_worker_context(worker_context,  # pylint: disable=missing-docstring
+                                                  scaffold,
+                                                  checkpoint_dir=None,
+                                                  hooks=None,
+                                                  chief_only_hooks=None,
+                                                  save_checkpoint_secs=None,
+                                                  save_summaries_steps=None,
+                                                  save_summaries_secs=None,
+                                                  config=None,
+                                                  stop_grace_period_secs=120,
+                                                  log_step_count_steps=100,
+                                                  max_wait_secs=7200,
+                                                  save_checkpoint_steps=None,
+                                                  summary_dir=None):
+  all_hooks = []
+  if hooks:
+    all_hooks.extend(hooks)
+  if chief_only_hooks and worker_context.is_chief:
+    all_hooks.extend(chief_only_hooks)
+
+  summary_dir = summary_dir or checkpoint_dir
+  if summary_dir and worker_context.should_save_summary:
+    if log_step_count_steps and log_step_count_steps > 0:
+      all_hooks.append(
+          basic_session_run_hooks.StepCounterHook(
+              output_dir=summary_dir, every_n_steps=log_step_count_steps))
+
+    if (save_summaries_steps and save_summaries_steps > 0) or (
+        save_summaries_secs and save_summaries_secs > 0):
+      all_hooks.append(
+          basic_session_run_hooks.SummarySaverHook(
+              scaffold=scaffold,
+              save_steps=save_summaries_steps,
+              save_secs=save_summaries_secs,
+              output_dir=summary_dir))
+
+  if checkpoint_dir and worker_context.should_checkpoint:
+    if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
+        save_checkpoint_steps and save_checkpoint_steps > 0):
+      all_hooks.append(
+          basic_session_run_hooks.CheckpointSaverHook(
+              checkpoint_dir,
+              save_steps=save_checkpoint_steps,
+              save_secs=save_checkpoint_secs,
+              scaffold=scaffold))
+
+  session_creator = worker_context.session_creator(
+      scaffold,
+      config=config,
+      checkpoint_dir=checkpoint_dir,
+      max_wait_secs=max_wait_secs)
+  return MonitoredSession(
+      session_creator=session_creator,
+      hooks=all_hooks,
+      stop_grace_period_secs=stop_grace_period_secs)
+
+
 @tf_export('train.MonitoredTrainingSession')
 def MonitoredTrainingSession(master='',  # pylint: disable=invalid-name
                              is_chief=True,
@@ -373,14 +431,35 @@ def MonitoredTrainingSession(master='',  # pylint: disable=invalid-name
     save_checkpoint_steps = None
 
   scaffold = scaffold or Scaffold()
+  worker_context = distribute_coordinator_context.get_current_worker_context()
+
+  if worker_context:
+    return _create_monitored_session_with_worker_context(
+        worker_context,
+        scaffold,
+        checkpoint_dir=checkpoint_dir,
+        hooks=hooks,
+        chief_only_hooks=chief_only_hooks,
+        save_checkpoint_secs=save_checkpoint_secs,
+        save_summaries_steps=save_summaries_steps,
+        save_summaries_secs=save_summaries_secs,
+        config=config,
+        stop_grace_period_secs=stop_grace_period_secs,
+        log_step_count_steps=log_step_count_steps,
+        max_wait_secs=max_wait_secs,
+        save_checkpoint_steps=save_checkpoint_steps,
+        summary_dir=summary_dir)
+
   if not is_chief:
     session_creator = WorkerSessionCreator(
         scaffold=scaffold,
         master=master,
         config=config,
         max_wait_secs=max_wait_secs)
-    return MonitoredSession(session_creator=session_creator, hooks=hooks or [],
-                            stop_grace_period_secs=stop_grace_period_secs)
+    return MonitoredSession(
+        session_creator=session_creator,
+        hooks=hooks or [],
+        stop_grace_period_secs=stop_grace_period_secs)
 
   all_hooks = []
   if chief_only_hooks:
@@ -400,25 +479,29 @@ def MonitoredTrainingSession(master='',  # pylint: disable=invalid-name
 
     if (save_summaries_steps and save_summaries_steps > 0) or (
         save_summaries_secs and save_summaries_secs > 0):
-      all_hooks.append(basic_session_run_hooks.SummarySaverHook(
-          scaffold=scaffold,
-          save_steps=save_summaries_steps,
-          save_secs=save_summaries_secs,
-          output_dir=summary_dir))
+      all_hooks.append(
+          basic_session_run_hooks.SummarySaverHook(
+              scaffold=scaffold,
+              save_steps=save_summaries_steps,
+              save_secs=save_summaries_secs,
+              output_dir=summary_dir))
 
   if checkpoint_dir:
     if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
         save_checkpoint_steps and save_checkpoint_steps > 0):
-      all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(
-          checkpoint_dir,
-          save_steps=save_checkpoint_steps,
-          save_secs=save_checkpoint_secs,
-          scaffold=scaffold))
+      all_hooks.append(
+          basic_session_run_hooks.CheckpointSaverHook(
+              checkpoint_dir,
+              save_steps=save_checkpoint_steps,
+              save_secs=save_checkpoint_secs,
+              scaffold=scaffold))
 
   if hooks:
     all_hooks.extend(hooks)
-  return MonitoredSession(session_creator=session_creator, hooks=all_hooks,
-                          stop_grace_period_secs=stop_grace_period_secs)
+  return MonitoredSession(
+      session_creator=session_creator,
+      hooks=all_hooks,
+      stop_grace_period_secs=stop_grace_period_secs)
 
 
 @tf_export('train.SessionCreator')
@@ -546,6 +629,11 @@ class _MonitoredSession(object):
     self._hooks = hooks or []
     for h in self._hooks:
       h.begin()
+
+    worker_context = distribute_coordinator_context.get_current_worker_context()
+    if not session_creator and worker_context:
+      session_creator = worker_context.session_creator()
+
     # Create the session.
     self._coordinated_creator = self._CoordinatedSessionCreator(
         session_creator=session_creator or ChiefSessionCreator(),
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 92533ca4f3b..ff586b6c03f 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -32,6 +32,7 @@ from tensorflow.contrib.testing.python.framework import util_test
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import debug_pb2
 from tensorflow.python.client import session as session_lib
+from tensorflow.python.distribute import distribute_coordinator
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
@@ -381,6 +382,119 @@ class MonitoredTrainingSessionTest(test.TestCase):
         self.assertEqual(0, session.run(gstep))
 
 
+class MockStrategy(object):
+
+  def __init__(self,
+               between_graph=False,
+               should_init=True,
+               should_checkpoint=None,
+               should_save_summary=None):
+    self._between_graph = between_graph
+    self._should_init = should_init
+    self._should_checkpoint = should_checkpoint
+    self._should_save_summary = should_save_summary
+
+  @property
+  def between_graph(self):
+    return self._between_graph
+
+  @property
+  def should_init(self):
+    return self._should_init
+
+  @property
+  def should_checkpoint(self):
+    return self._should_checkpoint
+
+  @property
+  def should_save_summary(self):
+    return self._should_save_summary
+
+
+class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
+  """Test distribute coordinator controls summary saving and checkpointing."""
+
+  def test_summary_hook_enabled(self):
+    context = distribute_coordinator._WorkerContext(
+        MockStrategy(should_save_summary=True), None, None, None)
+
+    logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
+    with ops.Graph().as_default():
+      gstep = variables_lib.get_or_create_global_step()
+      new_gstep = state_ops.assign_add(gstep, 1)
+      summary.scalar('my_summary_tag', new_gstep * 2)
+      with context, monitored_session.MonitoredTrainingSession(
+          checkpoint_dir=logdir,
+          save_summaries_steps=100,
+          log_step_count_steps=10) as session:
+        for _ in range(101):
+          session.run(new_gstep)
+
+    summaries = util_test.latest_summaries(logdir)
+    tags = [s.summary.value[0].tag for s in summaries]
+    self.assertIn('my_summary_tag', tags)
+    self.assertIn('global_step/sec', tags)
+
+  def test_summary_hook_disabled(self):
+    context = distribute_coordinator._WorkerContext(
+        MockStrategy(should_save_summary=False), None, None, None)
+
+    logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
+    with ops.Graph().as_default():
+      gstep = variables_lib.get_or_create_global_step()
+      new_gstep = state_ops.assign_add(gstep, 1)
+      summary.scalar('my_summary_tag', new_gstep * 2)
+      with context, monitored_session.MonitoredTrainingSession(
+          checkpoint_dir=logdir,
+          save_summaries_steps=100,
+          log_step_count_steps=10) as session:
+        for _ in range(101):
+          session.run(new_gstep)
+
+    # No summary is saved.
+    summaries = util_test.latest_summaries(logdir)
+    self.assertEqual(len(summaries), 0)
+
+  def test_checkpoint_hook_enabled(self):
+    context = distribute_coordinator._WorkerContext(
+        MockStrategy(should_checkpoint=True), None, None, None)
+
+    logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
+    with ops.Graph().as_default():
+      gstep = variables_lib.get_or_create_global_step()
+      new_gstep = state_ops.assign_add(gstep, 1)
+      with context, monitored_session.MonitoredTrainingSession(
+          checkpoint_dir=logdir,
+          save_checkpoint_steps=100,
+          log_step_count_steps=10) as session:
+        for _ in range(100):
+          session.run(new_gstep)
+
+      # A restart will find the checkpoint and recover automatically.
+      with monitored_session.MonitoredTrainingSession(
+          is_chief=True, checkpoint_dir=logdir) as session:
+        self.assertEqual(100, session.run(gstep))
+
+  def test_checkpoint_hook_disabled(self):
+    context = distribute_coordinator._WorkerContext(
+        MockStrategy(should_checkpoint=False), None, None, None)
+
+    logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
+    with ops.Graph().as_default():
+      gstep = variables_lib.get_or_create_global_step()
+      new_gstep = state_ops.assign_add(gstep, 1)
+      with context, monitored_session.MonitoredTrainingSession(
+          checkpoint_dir=logdir,
+          save_checkpoint_steps=100,
+          log_step_count_steps=10) as session:
+        for _ in range(100):
+          session.run(new_gstep)
+
+    # No checkpoint is saved.
+    checkpoint = checkpoint_management.latest_checkpoint(logdir)
+    self.assertIsNone(checkpoint)
+
+
 class StopAtNSession(monitored_session._WrappedSession):
   """A wrapped session that stops at the N-th call to _check_stop."""
 
@@ -1365,8 +1479,8 @@ class MonitoredSessionTest(test.TestCase):
       with monitored_session.MonitoredSession(
           session_creator=monitored_session.ChiefSessionCreator(
               scaffold,
-              checkpoint_filename_with_path=
-              checkpoint_management.latest_checkpoint(logdir))) as session:
+              checkpoint_filename_with_path=checkpoint_management.
+              latest_checkpoint(logdir))) as session:
         self.assertEqual(2, session.run(gstep))
 
   def test_retry_initialization_on_aborted_error(self):