Use distribution strategy to configure distribute coordinator.

Add session_creator and a couple properties to worker context which then are used to configure monitored sessions.

PiperOrigin-RevId: 209026599
This commit is contained in:
Yuefeng Zhou 2018-08-16 12:25:17 -07:00 committed by TensorFlower Gardener
parent 5360d73687
commit 1326f33515
7 changed files with 624 additions and 109 deletions

View File

@ -3265,6 +3265,7 @@ py_library(
"@six_archive//:six", "@six_archive//:six",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:distribute_coordinator_context",
"//tensorflow/python/eager:backprop", "//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
# `layers` dependency only exists due to the use of a small utility. # `layers` dependency only exists due to the use of a small utility.
@ -4658,7 +4659,10 @@ py_test(
size = "medium", size = "medium",
srcs = ["training/monitored_session_test.py"], srcs = ["training/monitored_session_test.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = ["notsan"], # b/67945581 tags = [
"no_pip",
"notsan", # b/67945581
],
deps = [ deps = [
":array_ops", ":array_ops",
":checkpoint_management", ":checkpoint_management",
@ -4676,6 +4680,7 @@ py_test(
"//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/testing:testing_py", "//tensorflow/contrib/testing:testing_py",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python/distribute:distribute_coordinator",
], ],
) )

View File

@ -41,3 +41,12 @@ py_test(
"//tensorflow/python:variables", "//tensorflow/python:variables",
], ],
) )
py_library(
name = "distribute_coordinator_context",
srcs = [
"distribute_coordinator_context.py",
],
srcs_version = "PY2AND3",
deps = [],
)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A unified and split coordinator for distributed TensorFlow.""" """A component for running distributed TensorFlow."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -24,6 +24,8 @@ import os
import threading import threading
from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import cluster_pb2
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
@ -43,23 +45,12 @@ class CoordinatorMode(object):
# client and connects to remote servers for training. Each remote server can # client and connects to remote servers for training. Each remote server can
# use the distribute coordinator binary with task_type set correctly which # use the distribute coordinator binary with task_type set correctly which
# will then turn into standard servers. # will then turn into standard servers.
SPLIT_CLIENT = 0 STANDALONE_CLIENT = "standalone_client"
# The distribute coordinator runs on each worker. It will run a standard # The distribute coordinator runs on each worker. It will run a standard
# server on each worker and optionally run the `worker_fn` that is configured # server on each worker and optionally run the `worker_fn` that is configured
# to talk to its standard server. # to talk to its standard server.
INDEPENDENT_WORKER = 1 INDEPENDENT_WORKER = "independent_worker"
_worker_context = threading.local()
def get_current_worker_context():
"""Returns the current task context."""
try:
return _worker_context.current
except AttributeError:
return None
class _Barrier(object): class _Barrier(object):
@ -113,14 +104,17 @@ class _WorkerContext(object):
""" """
def __init__(self, def __init__(self,
strategy,
cluster_spec, cluster_spec,
task_type, task_type,
task_id, task_id,
session_config=None,
rpc_layer="grpc", rpc_layer="grpc",
worker_barrier=None): worker_barrier=None):
"""Initialize the worker context object. """Initialize the worker context object.
Args: Args:
strategy: a `DistributionStrategy` object.
cluster_spec: a ClusterSpec object. It can be empty or None in the local cluster_spec: a ClusterSpec object. It can be empty or None in the local
training case. training case.
task_type: a string indicating the role of the corresponding task, such as task_type: a string indicating the role of the corresponding task, such as
@ -128,14 +122,17 @@ class _WorkerContext(object):
replicated training. replicated training.
task_id: an integer indicating id of the corresponding task. It can be task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training. None if it is local training or in-graph replicated training.
session_config: an optional @{tf.ConfigProto} object.
rpc_layer: optional string specifying the RPC protocol for communication rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly. be used directly.
worker_barrier: optional, the barrier object for worker synchronization. worker_barrier: optional, the barrier object for worker synchronization.
""" """
self._strategy = strategy
self._cluster_spec = cluster_spec self._cluster_spec = cluster_spec
self._task_type = task_type self._task_type = task_type
self._task_id = task_id self._task_id = task_id
self._session_config = session_config
self._worker_barrier = worker_barrier self._worker_barrier = worker_barrier
self._rpc_layer = rpc_layer self._rpc_layer = rpc_layer
self._master_target = self._get_master_target() self._master_target = self._get_master_target()
@ -143,26 +140,31 @@ class _WorkerContext(object):
self._is_chief_node = self._is_chief() self._is_chief_node = self._is_chief()
def _debug_message(self): def _debug_message(self):
return "[cluster_spec: %r, task_type: %r, task_id: %r]" % ( if self._cluster_spec:
self._cluster_spec, self.task_type, self.task_id) return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
self._cluster_spec, self.task_type, self.task_id)
else:
return "[local]"
def __enter__(self): def __enter__(self):
old_context = get_current_worker_context() old_context = distribute_coordinator_context.get_current_worker_context()
if old_context: if old_context:
raise ValueError( raise ValueError(
"You cannot run distribute coordinator in a `worker_fn`.\t" + "You cannot run distribute coordinator in a `worker_fn`.\t" +
self._debug_message()) self._debug_message())
_worker_context.current = self # pylint: disable=protected-access
distribute_coordinator_context._worker_context.current = self
def __exit__(self, unused_exception_type, unused_exception_value, def __exit__(self, unused_exception_type, unused_exception_value,
unused_traceback): unused_traceback):
_worker_context.current = None # pylint: disable=protected-access
distribute_coordinator_context._worker_context.current = None
def _get_master_target(self): def _get_master_target(self):
"""Return the master target for a task.""" """Return the master target for a task."""
# If cluster_spec is None or empty, we use local master. # If cluster_spec is None or empty, we use local master.
if not self._cluster_spec: if not self._cluster_spec:
return "local" return ""
# If task_type is None, then it is in-graph replicated training. In this # If task_type is None, then it is in-graph replicated training. In this
# case we use the chief or first worker's master target. # case we use the chief or first worker's master target.
@ -207,6 +209,47 @@ class _WorkerContext(object):
self._debug_message()) self._debug_message())
self._worker_barrier.wait() self._worker_barrier.wait()
def session_creator(self,
scaffold=None,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
max_wait_secs=7200):
"""Returns a session creator.
The returned session creator will be configured with the correct master
target and session configs. It will also run either init ops or ready ops
by querying the `strategy` object when `create_session` is called on it.
Args:
scaffold: A `Scaffold` used for gathering or building supportive ops. If
not specified a default one is created. It's used to finalize the graph.
config: `ConfigProto` proto used to configure the session.
checkpoint_dir: A string. Optional path to a directory where to restore
variables.
checkpoint_filename_with_path: Full file name path to the checkpoint file.
Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
specified.
max_wait_secs: Maximum time to wait for the session to become available.
Returns:
a descendant of SessionCreator.
"""
# TODO(yuefengz): merge session config.
if self._strategy.should_init:
return monitored_session.ChiefSessionCreator(
scaffold,
master=self.master_target,
config=config or self._session_config,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path)
else:
return monitored_session.WorkerSessionCreator(
scaffold,
master=self.master_target,
config=config or self._session_config,
max_wait_secs=max_wait_secs)
@property @property
def has_barrier(self): def has_barrier(self):
"""Whether the barrier is set or not.""" """Whether the barrier is set or not."""
@ -247,21 +290,38 @@ class _WorkerContext(object):
"""Returns number of workers in the cluster, including chief.""" """Returns number of workers in the cluster, including chief."""
return self._num_workers return self._num_workers
@property
def should_checkpoint(self):
"""Whether to save checkpoint."""
return self._strategy.should_checkpoint
@property
def should_save_summary(self):
"""Whether to save summaries."""
return self._strategy.should_save_summary
def _run_single_worker(worker_fn, def _run_single_worker(worker_fn,
strategy,
cluster_spec, cluster_spec,
task_type, task_type,
task_id, task_id,
rpc_layer, session_config,
rpc_layer="",
worker_barrier=None): worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context.""" """Runs a single worker by calling `worker_fn` under context."""
with _WorkerContext( strategy = copy.deepcopy(strategy)
strategy.configure(session_config, cluster_spec, task_type, task_id)
context = _WorkerContext(
strategy,
cluster_spec, cluster_spec,
task_type, task_type,
task_id, task_id,
session_config=session_config,
rpc_layer=rpc_layer, rpc_layer=rpc_layer,
worker_barrier=worker_barrier): worker_barrier=worker_barrier)
worker_fn() with context:
worker_fn(strategy)
def _run_std_server(cluster_spec=None, def _run_std_server(cluster_spec=None,
@ -280,13 +340,15 @@ def _run_std_server(cluster_spec=None,
return server return server
def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer): def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
rpc_layer):
"""Runs a standalone client for between-graph replication.""" """Runs a standalone client for between-graph replication."""
eval_thread = None eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs: if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread( eval_thread = threading.Thread(
target=_run_single_worker, target=_run_single_worker,
args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0), args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={ kwargs={
"rpc_layer": rpc_layer, "rpc_layer": rpc_layer,
}) })
@ -298,7 +360,8 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
t = threading.Thread( t = threading.Thread(
target=_run_single_worker, target=_run_single_worker,
args=(worker_fn, cluster_spec, task_type, task_id), args=(worker_fn, strategy, cluster_spec, task_type, task_id,
session_config),
kwargs={ kwargs={
"rpc_layer": rpc_layer, "rpc_layer": rpc_layer,
"worker_barrier": worker_barrier "worker_barrier": worker_barrier
@ -315,43 +378,53 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
eval_thread.join() eval_thread.join()
def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer): def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
rpc_layer):
"""Runs a standalone client for in-graph replication.""" """Runs a standalone client for in-graph replication."""
eval_thread = None eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs: if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread( eval_thread = threading.Thread(
target=_run_single_worker, target=_run_single_worker,
args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0), args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={ kwargs={
"rpc_layer": rpc_layer, "rpc_layer": rpc_layer,
}) })
eval_thread.start() eval_thread.start()
_run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer) _run_single_worker(
worker_fn,
strategy,
cluster_spec,
None,
None,
session_config,
rpc_layer=rpc_layer)
if eval_thread: if eval_thread:
eval_thread.join() eval_thread.join()
# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode.
# TODO(yuefengz): we may need a smart way to figure out whether the current task # TODO(yuefengz): we may need a smart way to figure out whether the current task
# is the special task when we support cluster_spec propagation. # is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn, def run_distribute_coordinator(worker_fn,
mode=CoordinatorMode.SPLIT_CLIENT, strategy,
mode=CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None, cluster_spec=None,
task_type=None, task_type=None,
task_id=None, task_id=None,
between_graph=False, session_config=None,
rpc_layer="grpc"): rpc_layer="grpc"):
"""Runs the coordinator for distributed TensorFlow. """Runs the coordinator for distributed TensorFlow.
This function runs a split coordinator for distributed TensorFlow in its This function runs a split coordinator for distributed TensorFlow in its
default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
server addresses and their roles in a cluster, this coordinator will figure specifying server addresses and their roles in a cluster, this coordinator
out how to set them up, give the underlying function the right targets for will figure out how to set them up, give the underlying function the right
master sessions via a scope object and coordinate their training. The cluster targets for master sessions via a scope object and coordinate their training.
consisting of standard servers needs to be brought up either with the standard The cluster consisting of standard servers needs to be brought up either with
server binary or with a binary running distribute coordinator with `task_type` the standard server binary or with a binary running distribute coordinator
set to non-client type which will then turn into standard servers. with `task_type` set to non-client type which will then turn into standard
servers.
In addition to be the distribute coordinator, this is also the source of In addition to be the distribute coordinator, this is also the source of
configurations for each job in the distributed training. As there are multiple configurations for each job in the distributed training. As there are multiple
@ -370,6 +443,14 @@ def run_distribute_coordinator(worker_fn,
`worker_fn` depending whether it is between-graph training or in-graph `worker_fn` depending whether it is between-graph training or in-graph
replicated training. replicated training.
The `strategy` object is expected to be a DistributionStrategy object which
has implemented methods needed by distributed coordinator such as
`configure(session_config, cluster_spec, task_type, task_id)` which configures
the strategy object for a specific task and `should_init` property which
instructs the distribute coordinator whether to run init ops for a task. The
distribute coordinator will make a copy of the `strategy` object, call its
`configure` method and pass it to `worker_fn` as an argument.
The `worker_fn` defines the training logic and is called under a its own The `worker_fn` defines the training logic and is called under a its own
worker context which can be accessed to via `get_current_worker_context`. A worker context which can be accessed to via `get_current_worker_context`. A
worker context provides access to configurations for each task, e.g. the worker context provides access to configurations for each task, e.g. the
@ -413,16 +494,20 @@ def run_distribute_coordinator(worker_fn,
evaluation. evaluation.
Args: Args:
worker_fn: the function to be called and given the access to a coordinator worker_fn: the function to be called. The function should accept a
context object. `strategy` object and will be given access to a context object via a
context manager scope.
strategy: a DistributionStrategy object which specifying whether it should
run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`,
`cluster_spc`, `task_type` and `task_id`.
mode: in which mode this distribute coordinator runs. mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training. in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client. task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client. task_id: the current task id, optional if this is a client.
between_graph: a boolean. It is only useful when `cluster_spec` is set and session_config: an optional @{tf.ConfigProto} object which will be passed
not empty. If true, it will use between-graph replicated training; to `strategy`'s `configure` method and used to create a session.
otherwise it will use in-graph replicated training.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc". rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
Raises: Raises:
@ -448,15 +533,18 @@ def run_distribute_coordinator(worker_fn,
if not cluster_spec: if not cluster_spec:
# `mode` is ignored in the local case. # `mode` is ignored in the local case.
_run_single_worker(worker_fn, None, None, None, rpc_layer) _run_single_worker(worker_fn, strategy, None, None, None, session_config,
elif mode == CoordinatorMode.SPLIT_CLIENT: rpc_layer)
elif mode == CoordinatorMode.STANDALONE_CLIENT:
# The client must know the cluster but servers in the cluster don't have to # The client must know the cluster but servers in the cluster don't have to
# know the client. # know the client.
if task_type in [_TaskType.CLIENT, None]: if task_type in [_TaskType.CLIENT, None]:
if between_graph: if strategy.between_graph:
_run_between_graph_client(worker_fn, cluster_spec, rpc_layer) _run_between_graph_client(worker_fn, strategy, cluster_spec,
session_config, rpc_layer)
else: else:
_run_in_graph_client(worker_fn, cluster_spec, rpc_layer) _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
rpc_layer)
else: else:
# If not a client job, run the standard server. # If not a client job, run the standard server.
server = _run_std_server( server = _run_std_server(
@ -471,19 +559,21 @@ def run_distribute_coordinator(worker_fn,
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]: if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
if between_graph: if strategy.between_graph:
# All jobs run `worker_fn` if between-graph. # All jobs run `worker_fn` if between-graph.
_run_single_worker(worker_fn, cluster_spec, task_type, task_id, _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
rpc_layer) task_id, session_config, rpc_layer)
else: else:
# Only one node runs `worker_fn` if in-graph. # Only one node runs `worker_fn` if in-graph.
context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer) context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
if context.is_chief: if context.is_chief:
_run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer) _run_single_worker(worker_fn, strategy, cluster_spec, None, None,
session_config, rpc_layer)
else: else:
server.join() server.join()
elif task_type == _TaskType.EVALUATOR: elif task_type == _TaskType.EVALUATOR:
_run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer) _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
session_config, rpc_layer)
else: else:
if task_type != _TaskType.PS: if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type) raise ValueError("Unexpected task_type: %r" % task_type)

View File

@ -0,0 +1,31 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The context retrieval method for distribute coordinator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
_worker_context = threading.local()
def get_current_worker_context():
"""Returns the current task context."""
try:
return _worker_context.current
except AttributeError:
return None

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for distribute coordinator.""" """Tests for Distribute Coordinator."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -37,6 +37,7 @@ except ImportError as _error:
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -44,17 +45,17 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
CHIEF = distribute_coordinator._TaskType.CHIEF CHIEF = distribute_coordinator._TaskType.CHIEF
WORKER = distribute_coordinator._TaskType.WORKER WORKER = distribute_coordinator._TaskType.WORKER
PS = distribute_coordinator._TaskType.PS PS = distribute_coordinator._TaskType.PS
EVALUATOR = distribute_coordinator._TaskType.EVALUATOR EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT
INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server"
NUM_WORKERS = 3 NUM_WORKERS = 3
NUM_PS = 2 NUM_PS = 2
@ -74,6 +75,57 @@ def _strip_protocol(target):
return target return target
class MockStrategy(object):
def __init__(self,
between_graph=False,
should_init=None,
should_checkpoint=None,
should_save_summary=None):
self._between_graph = between_graph
self._should_init = should_init
self._should_checkpoint = should_checkpoint
self._should_save_summary = should_save_summary
@property
def between_graph(self):
return self._between_graph
def configure(self,
session_options=None,
cluster_spec=None,
task_type=None,
task_id=None):
del session_options, cluster_spec, task_type
if self._should_init is None:
if task_id == 0:
self._should_init = True
else:
self._should_init = False
if self._should_checkpoint is None:
if task_id == 0:
self._should_checkpoint = True
else:
self._should_checkpoint = False
if self._should_save_summary is None:
if task_id == 0:
self._should_save_summary = True
else:
self._should_save_summary = False
@property
def should_init(self):
return self._should_init
@property
def should_checkpoint(self):
return self._should_checkpoint
@property
def should_save_summary(self):
return self._should_save_summary
class MockServer(object): class MockServer(object):
def __init__(self): def __init__(self):
@ -108,6 +160,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
self._result_correct = 0 self._result_correct = 0
self._lock = threading.Lock() self._lock = threading.Lock()
self._worker_context = {} self._worker_context = {}
self._strategy_property = {}
self._std_servers = {} self._std_servers = {}
self._barrier = distribute_coordinator._Barrier(NUM_WORKERS) self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
@ -142,8 +195,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()] cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
return cluster_spec return cluster_spec
def _in_graph_worker_fn(self): def _in_graph_worker_fn(self, strategy):
context = distribute_coordinator.get_current_worker_context() context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None) self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess: with self._test_session(target=context.master_target) as sess:
xs = [] xs = []
@ -164,22 +217,23 @@ class DistributeCoordinatorTestBase(test.TestCase):
if result_value == expected: if result_value == expected:
self._result_correct += 1 self._result_correct += 1
def _run_coordinator_in_thread(self, worker_fn, **kwargs): def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs):
t = threading.Thread( t = threading.Thread(
target=distribute_coordinator.run_distribute_coordinator, target=distribute_coordinator.run_distribute_coordinator,
args=(worker_fn,), args=(worker_fn, strategy),
kwargs=kwargs) kwargs=kwargs)
t.start() t.start()
return t return t
def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec, def _run_multiple_coordinator_in_threads(self, worker_fn, strategy,
**kwargs): cluster_spec, **kwargs):
threads = {} threads = {}
for task_type in cluster_spec.keys(): for task_type in cluster_spec.keys():
threads[task_type] = [] threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])): for task_id in range(len(cluster_spec[task_type])):
t = self._run_coordinator_in_thread( t = self._run_coordinator_in_thread(
worker_fn, worker_fn,
strategy,
cluster_spec=cluster_spec, cluster_spec=cluster_spec,
task_type=task_type, task_type=task_type,
task_id=task_id, task_id=task_id,
@ -187,8 +241,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
threads[task_type].append(t) threads[task_type].append(t)
return threads return threads
def _between_graph_worker_fn(self): def _between_graph_worker_fn(self, strategy):
context = distribute_coordinator.get_current_worker_context() context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None) self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess: with self._test_session(target=context.master_target) as sess:
with ops.device("/job:ps/task:0"): with ops.device("/job:ps/task:0"):
@ -234,14 +288,50 @@ class DistributeCoordinatorTestBase(test.TestCase):
with self._lock: with self._lock:
self._result_correct += 1 self._result_correct += 1
def _dump_worker_context(self): def _between_graph_with_monitored_session(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with ops.device("/job:ps/task:0"):
# TODO(yuefengz): investigate why not using resource variable will make
# the test flaky.
x = variable_scope.get_variable("x", initializer=10.0, use_resource=True)
with ops.device("/job:ps/task:1"):
y = variable_scope.get_variable("y", initializer=20.0, use_resource=True)
x_add = x.assign_add(2.0)
y_sub = y.assign_sub(2.0)
train_op = control_flow_ops.group([x_add, y_sub])
# The monitored session will run init or ready ops.
with monitored_session.MonitoredSession() as sess:
sess.run(train_op)
# Synchronize workers after one step to make sure they all have finished
# training.
if context.has_barrier:
context.wait_for_other_workers()
else:
self._barrier.wait()
x_val, y_val = sess.run([x, y])
self.assertEqual(x_val, 16.0)
self.assertEqual(y_val, 14.0)
if x_val == 16.0 and y_val == 14.0:
with self._lock:
self._result_correct += 1
def _dump_worker_context(self, strategy):
"""Dumps the propoerties of each worker context. """Dumps the propoerties of each worker context.
It dumps the context properties to a dict mapping from task_type to a list It dumps the context properties to a dict mapping from task_type to a list
of tuples of master_target, num_workers, is_chief and distribute_mode, where of tuples of master_target, num_workers, is_chief and distribute_mode, where
the list is indexed by the task_id. the list is indexed by the task_id.
Args:
strategy: a `DistributionStrategy` object.
""" """
context = distribute_coordinator.get_current_worker_context() context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None) self.assertTrue(context is not None)
task_type = str(context.task_type) task_type = str(context.task_type)
task_id = context.task_id or 0 task_id = context.task_id or 0
@ -255,6 +345,25 @@ class DistributeCoordinatorTestBase(test.TestCase):
context.is_chief, context.is_chief,
context.distributed_mode) context.distributed_mode)
def _dump_strategy_property(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
self.assertEqual(context._strategy.should_init, strategy.should_init)
self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
self.assertEqual(context.should_save_summary, strategy.should_save_summary)
task_type = str(context.task_type)
task_id = context.task_id or 0
with self._lock:
if task_type not in self._strategy_property:
self._strategy_property[task_type] = []
while len(self._strategy_property[task_type]) <= task_id:
self._strategy_property[task_type].append(None)
self._strategy_property[task_type][task_id] = (
context._strategy.should_init, context.should_checkpoint,
context.should_save_summary)
def _run_mock_std_server(self, def _run_mock_std_server(self,
session_config=None, session_config=None,
cluster_spec=None, cluster_spec=None,
@ -274,22 +383,32 @@ class DistributeCoordinatorTestBase(test.TestCase):
return server return server
class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
def testInGraphSplitMode(self): def testInGraphStandaloneMode(self):
"""Test it runs in-graph replication in split client mode.""" """Test it runs in-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator( distribute_coordinator.run_distribute_coordinator(
self._in_graph_worker_fn, self._in_graph_worker_fn,
cluster_spec=self._cluster_spec, MockStrategy(between_graph=False),
between_graph=False) cluster_spec=self._cluster_spec)
self.assertEqual(self._result_correct, 1) self.assertEqual(self._result_correct, 1)
def testBetweenGraph(self): def testBetweenGraph(self):
"""Test it runs between-graph replication in split client mode.""" """Test it runs between-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator( distribute_coordinator.run_distribute_coordinator(
self._between_graph_worker_fn, self._between_graph_worker_fn,
cluster_spec=self._cluster_spec, MockStrategy(between_graph=True),
between_graph=True) cluster_spec=self._cluster_spec)
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
def testBetweenGraphWithMonitoredSession(self):
"""Test monitored session in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._between_graph_with_monitored_session,
MockStrategy(between_graph=True),
cluster_spec=self._cluster_spec)
# Each finished worker will increment self._result_correct. # Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS) self.assertEqual(self._result_correct, NUM_WORKERS)
@ -298,8 +417,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict. # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator( distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context, self._dump_worker_context,
cluster_spec=self._cluster_spec, MockStrategy(between_graph=True),
between_graph=True) cluster_spec=self._cluster_spec)
# There is only one type of task and there three such tasks. # There is only one type of task and there three such tasks.
self.assertEqual(len(self._worker_context), 1) self.assertEqual(len(self._worker_context), 1)
@ -318,12 +437,30 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
self._worker_context[WORKER][2], self._worker_context[WORKER][2],
(_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True)) (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
def testBetweenGraphStrategyProperties(self):
# Dumps properties of the strategy objects.
distribute_coordinator.run_distribute_coordinator(
self._dump_strategy_property,
MockStrategy(between_graph=True, should_init=True),
cluster_spec=self._cluster_spec)
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._strategy_property), 1)
self.assertTrue(WORKER in self._strategy_property)
self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
# Check whether each task has the right properties of should_init,
# should_checkpoint and should_save_summary.
self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
def testInGraphContext(self): def testInGraphContext(self):
# Dumps the task contexts to the self._worker_context dict. # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator( distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context, self._dump_worker_context,
cluster_spec=self._cluster_spec, MockStrategy(between_graph=False),
between_graph=False) cluster_spec=self._cluster_spec)
# There is only a "None" task in the dumped task context. # There is only a "None" task in the dumped task context.
self.assertEqual(len(self._worker_context), 1) self.assertEqual(len(self._worker_context), 1)
@ -339,7 +476,9 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
def testLocalContext(self): def testLocalContext(self):
# Dumps the task contexts to the self._worker_context dict. # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator( distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context, cluster_spec=None, between_graph=True) self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec=None)
# There is only a "None" task. # There is only a "None" task.
self.assertEqual(len(self._worker_context), 1) self.assertEqual(len(self._worker_context), 1)
@ -348,7 +487,7 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Check whether each task has the right master_target, num_workers, is_chief # Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode. # and distributed_mode.
self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False)) self.assertEqual(self._worker_context["None"][0], ("", 0, True, False))
def testBetweenGraphContextWithChief(self): def testBetweenGraphContextWithChief(self):
# Adds a chief node, so there are NUM_WORKERS + 1 workers in total. # Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
@ -358,8 +497,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict. # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator( distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context, self._dump_worker_context,
MockStrategy(between_graph=True),
cluster_spec=cluster_spec, cluster_spec=cluster_spec,
between_graph=True,
rpc_layer="grpc") rpc_layer="grpc")
# There are one CHIEF and three workers. # There are one CHIEF and three workers.
@ -391,8 +530,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict. # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator( distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context, self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec=cluster_spec, cluster_spec=cluster_spec,
between_graph=False,
rpc_layer=None) rpc_layer=None)
# There are one "None" task and one EVALUATOR task. # There are one "None" task and one EVALUATOR task.
@ -417,8 +556,8 @@ class DistributeCoordinatorTestInpendentWorkerMode(
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
threads = self._run_multiple_coordinator_in_threads( threads = self._run_multiple_coordinator_in_threads(
self._in_graph_worker_fn, self._in_graph_worker_fn,
MockStrategy(between_graph=False),
cluster_spec, cluster_spec,
between_graph=False,
mode=INDEPENDENT_WORKER) mode=INDEPENDENT_WORKER)
threads[WORKER][0].join() threads[WORKER][0].join()
self.assertEqual(self._result_correct, 1) self.assertEqual(self._result_correct, 1)
@ -428,8 +567,22 @@ class DistributeCoordinatorTestInpendentWorkerMode(
num_workers=NUM_WORKERS, num_ps=NUM_PS) num_workers=NUM_WORKERS, num_ps=NUM_PS)
threads = self._run_multiple_coordinator_in_threads( threads = self._run_multiple_coordinator_in_threads(
self._between_graph_worker_fn, self._between_graph_worker_fn,
MockStrategy(between_graph=True),
cluster_spec,
mode=INDEPENDENT_WORKER)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
def testBetweenGraphWithMonitoredSession(self):
cluster_spec = self._create_cluster_spec(
num_workers=NUM_WORKERS, num_ps=NUM_PS)
threads = self._run_multiple_coordinator_in_threads(
self._between_graph_with_monitored_session,
MockStrategy(between_graph=True),
cluster_spec, cluster_spec,
between_graph=True,
mode=INDEPENDENT_WORKER) mode=INDEPENDENT_WORKER)
for task_id in range(NUM_WORKERS): for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join() threads[WORKER][task_id].join()
@ -444,9 +597,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server): self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads( threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context, self._dump_worker_context,
MockStrategy(between_graph=True),
cluster_spec, cluster_spec,
mode=INDEPENDENT_WORKER, mode=INDEPENDENT_WORKER,
between_graph=True,
rpc_layer=None) rpc_layer=None)
for task_id in range(NUM_WORKERS): for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join() threads[WORKER][task_id].join()
@ -476,6 +629,31 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertFalse(self._std_servers[WORKER][1].joined) self.assertFalse(self._std_servers[WORKER][1].joined)
self.assertFalse(self._std_servers[WORKER][2].joined) self.assertFalse(self._std_servers[WORKER][2].joined)
def testBetweenGraphStrategyProperties(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps properties of the strategy objects.
with test.mock.patch.object(distribute_coordinator, "_run_std_server",
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_strategy_property,
MockStrategy(between_graph=True, should_init=True),
cluster_spec,
mode=INDEPENDENT_WORKER,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._strategy_property), 1)
self.assertTrue(WORKER in self._strategy_property)
self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
# Check whether each task has the right properties of should_init,
# should_checkpoint and should_save_summary.
self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
def testInGraphContext(self): def testInGraphContext(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps the task contexts and std server arguments. # Dumps the task contexts and std server arguments.
@ -483,9 +661,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server): self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads( threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context, self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec, cluster_spec,
mode=INDEPENDENT_WORKER, mode=INDEPENDENT_WORKER,
between_graph=False,
rpc_layer=None) rpc_layer=None)
for task_id in range(NUM_WORKERS): for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join() threads[WORKER][task_id].join()
@ -519,9 +697,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server): self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads( threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context, self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec, cluster_spec,
mode=INDEPENDENT_WORKER, mode=INDEPENDENT_WORKER,
between_graph=False,
rpc_layer=None) rpc_layer=None)
for task_id in range(NUM_WORKERS): for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join() threads[WORKER][task_id].join()

View File

@ -25,6 +25,7 @@ import sys
import six import six
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -284,6 +285,63 @@ class Scaffold(object):
resources.initialize_resources(resources.local_resources())) resources.initialize_resources(resources.local_resources()))
def _create_monitored_session_with_worker_context(worker_context, # pylint: disable=missing-docstring
scaffold,
checkpoint_dir=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=None,
summary_dir=None):
all_hooks = []
if hooks:
all_hooks.extend(hooks)
if chief_only_hooks and worker_context.is_chief:
all_hooks.extend(chief_only_hooks)
summary_dir = summary_dir or checkpoint_dir
if summary_dir and worker_context.should_save_summary:
if log_step_count_steps and log_step_count_steps > 0:
all_hooks.append(
basic_session_run_hooks.StepCounterHook(
output_dir=summary_dir, every_n_steps=log_step_count_steps))
if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0):
all_hooks.append(
basic_session_run_hooks.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
output_dir=summary_dir))
if checkpoint_dir and worker_context.should_checkpoint:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0):
all_hooks.append(
basic_session_run_hooks.CheckpointSaverHook(
checkpoint_dir,
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
session_creator = worker_context.session_creator(
scaffold,
config=config,
checkpoint_dir=checkpoint_dir,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
@tf_export('train.MonitoredTrainingSession') @tf_export('train.MonitoredTrainingSession')
def MonitoredTrainingSession(master='', # pylint: disable=invalid-name def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
is_chief=True, is_chief=True,
@ -373,14 +431,35 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
save_checkpoint_steps = None save_checkpoint_steps = None
scaffold = scaffold or Scaffold() scaffold = scaffold or Scaffold()
worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context:
return _create_monitored_session_with_worker_context(
worker_context,
scaffold,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
save_summaries_secs=save_summaries_secs,
config=config,
stop_grace_period_secs=stop_grace_period_secs,
log_step_count_steps=log_step_count_steps,
max_wait_secs=max_wait_secs,
save_checkpoint_steps=save_checkpoint_steps,
summary_dir=summary_dir)
if not is_chief: if not is_chief:
session_creator = WorkerSessionCreator( session_creator = WorkerSessionCreator(
scaffold=scaffold, scaffold=scaffold,
master=master, master=master,
config=config, config=config,
max_wait_secs=max_wait_secs) max_wait_secs=max_wait_secs)
return MonitoredSession(session_creator=session_creator, hooks=hooks or [], return MonitoredSession(
stop_grace_period_secs=stop_grace_period_secs) session_creator=session_creator,
hooks=hooks or [],
stop_grace_period_secs=stop_grace_period_secs)
all_hooks = [] all_hooks = []
if chief_only_hooks: if chief_only_hooks:
@ -400,25 +479,29 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
if (save_summaries_steps and save_summaries_steps > 0) or ( if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0): save_summaries_secs and save_summaries_secs > 0):
all_hooks.append(basic_session_run_hooks.SummarySaverHook( all_hooks.append(
scaffold=scaffold, basic_session_run_hooks.SummarySaverHook(
save_steps=save_summaries_steps, scaffold=scaffold,
save_secs=save_summaries_secs, save_steps=save_summaries_steps,
output_dir=summary_dir)) save_secs=save_summaries_secs,
output_dir=summary_dir))
if checkpoint_dir: if checkpoint_dir:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or ( if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0): save_checkpoint_steps and save_checkpoint_steps > 0):
all_hooks.append(basic_session_run_hooks.CheckpointSaverHook( all_hooks.append(
checkpoint_dir, basic_session_run_hooks.CheckpointSaverHook(
save_steps=save_checkpoint_steps, checkpoint_dir,
save_secs=save_checkpoint_secs, save_steps=save_checkpoint_steps,
scaffold=scaffold)) save_secs=save_checkpoint_secs,
scaffold=scaffold))
if hooks: if hooks:
all_hooks.extend(hooks) all_hooks.extend(hooks)
return MonitoredSession(session_creator=session_creator, hooks=all_hooks, return MonitoredSession(
stop_grace_period_secs=stop_grace_period_secs) session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
@tf_export('train.SessionCreator') @tf_export('train.SessionCreator')
@ -546,6 +629,11 @@ class _MonitoredSession(object):
self._hooks = hooks or [] self._hooks = hooks or []
for h in self._hooks: for h in self._hooks:
h.begin() h.begin()
worker_context = distribute_coordinator_context.get_current_worker_context()
if not session_creator and worker_context:
session_creator = worker_context.session_creator()
# Create the session. # Create the session.
self._coordinated_creator = self._CoordinatedSessionCreator( self._coordinated_creator = self._CoordinatedSessionCreator(
session_creator=session_creator or ChiefSessionCreator(), session_creator=session_creator or ChiefSessionCreator(),

View File

@ -32,6 +32,7 @@ from tensorflow.contrib.testing.python.framework import util_test
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import debug_pb2 from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.client import session as session_lib from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
@ -381,6 +382,119 @@ class MonitoredTrainingSessionTest(test.TestCase):
self.assertEqual(0, session.run(gstep)) self.assertEqual(0, session.run(gstep))
class MockStrategy(object):
def __init__(self,
between_graph=False,
should_init=True,
should_checkpoint=None,
should_save_summary=None):
self._between_graph = between_graph
self._should_init = should_init
self._should_checkpoint = should_checkpoint
self._should_save_summary = should_save_summary
@property
def between_graph(self):
return self._between_graph
@property
def should_init(self):
return self._should_init
@property
def should_checkpoint(self):
return self._should_checkpoint
@property
def should_save_summary(self):
return self._should_save_summary
class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
"""Test distribute coordinator controls summary saving and checkpointing."""
def test_summary_hook_enabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_save_summary=True), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_summaries_steps=100,
log_step_count_steps=10) as session:
for _ in range(101):
session.run(new_gstep)
summaries = util_test.latest_summaries(logdir)
tags = [s.summary.value[0].tag for s in summaries]
self.assertIn('my_summary_tag', tags)
self.assertIn('global_step/sec', tags)
def test_summary_hook_disabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_save_summary=False), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_summaries_steps=100,
log_step_count_steps=10) as session:
for _ in range(101):
session.run(new_gstep)
# No summary is saved.
summaries = util_test.latest_summaries(logdir)
self.assertEqual(len(summaries), 0)
def test_checkpoint_hook_enabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_checkpoint=True), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_checkpoint_steps=100,
log_step_count_steps=10) as session:
for _ in range(100):
session.run(new_gstep)
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredTrainingSession(
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(100, session.run(gstep))
def test_checkpoint_hook_disabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_checkpoint=False), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_checkpoint_steps=100,
log_step_count_steps=10) as session:
for _ in range(100):
session.run(new_gstep)
# No checkpoint is saved.
checkpoint = checkpoint_management.latest_checkpoint(logdir)
self.assertIsNone(checkpoint)
class StopAtNSession(monitored_session._WrappedSession): class StopAtNSession(monitored_session._WrappedSession):
"""A wrapped session that stops at the N-th call to _check_stop.""" """A wrapped session that stops at the N-th call to _check_stop."""
@ -1365,8 +1479,8 @@ class MonitoredSessionTest(test.TestCase):
with monitored_session.MonitoredSession( with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator( session_creator=monitored_session.ChiefSessionCreator(
scaffold, scaffold,
checkpoint_filename_with_path= checkpoint_filename_with_path=checkpoint_management.
checkpoint_management.latest_checkpoint(logdir))) as session: latest_checkpoint(logdir))) as session:
self.assertEqual(2, session.run(gstep)) self.assertEqual(2, session.run(gstep))
def test_retry_initialization_on_aborted_error(self): def test_retry_initialization_on_aborted_error(self):