diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7c2e8b003b9..91c7fd16c51 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3265,6 +3265,7 @@ py_library( "@six_archive//:six", "//tensorflow/core:protos_all_py", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:distribute_coordinator_context", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", # `layers` dependency only exists due to the use of a small utility. @@ -4658,7 +4659,10 @@ py_test( size = "medium", srcs = ["training/monitored_session_test.py"], srcs_version = "PY2AND3", - tags = ["notsan"], # b/67945581 + tags = [ + "no_pip", + "notsan", # b/67945581 + ], deps = [ ":array_ops", ":checkpoint_management", @@ -4676,6 +4680,7 @@ py_test( "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/testing:testing_py", "//tensorflow/core:protos_all_py", + "//tensorflow/python/distribute:distribute_coordinator", ], ) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 68d8b8d13b1..16fbe3f4b55 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -41,3 +41,12 @@ py_test( "//tensorflow/python:variables", ], ) + +py_library( + name = "distribute_coordinator_context", + srcs = [ + "distribute_coordinator_context.py", + ], + srcs_version = "PY2AND3", + deps = [], +) diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index fc9ca4ac4a3..eb081b65fc7 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A unified and split coordinator for distributed TensorFlow.""" +"""A component for running distributed TensorFlow.""" from __future__ import absolute_import from __future__ import division @@ -24,6 +24,8 @@ import os import threading from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.distribute import distribute_coordinator_context +from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib @@ -43,23 +45,12 @@ class CoordinatorMode(object): # client and connects to remote servers for training. Each remote server can # use the distribute coordinator binary with task_type set correctly which # will then turn into standard servers. - SPLIT_CLIENT = 0 + STANDALONE_CLIENT = "standalone_client" # The distribute coordinator runs on each worker. It will run a standard # server on each worker and optionally run the `worker_fn` that is configured # to talk to its standard server. - INDEPENDENT_WORKER = 1 - - -_worker_context = threading.local() - - -def get_current_worker_context(): - """Returns the current task context.""" - try: - return _worker_context.current - except AttributeError: - return None + INDEPENDENT_WORKER = "independent_worker" class _Barrier(object): @@ -113,14 +104,17 @@ class _WorkerContext(object): """ def __init__(self, + strategy, cluster_spec, task_type, task_id, + session_config=None, rpc_layer="grpc", worker_barrier=None): """Initialize the worker context object. Args: + strategy: a `DistributionStrategy` object. cluster_spec: a ClusterSpec object. It can be empty or None in the local training case. task_type: a string indicating the role of the corresponding task, such as @@ -128,14 +122,17 @@ class _WorkerContext(object): replicated training. task_id: an integer indicating id of the corresponding task. It can be None if it is local training or in-graph replicated training. + session_config: an optional @{tf.ConfigProto} object. rpc_layer: optional string specifying the RPC protocol for communication with worker masters. If None or empty, hosts in the `cluster_spec` will be used directly. worker_barrier: optional, the barrier object for worker synchronization. """ + self._strategy = strategy self._cluster_spec = cluster_spec self._task_type = task_type self._task_id = task_id + self._session_config = session_config self._worker_barrier = worker_barrier self._rpc_layer = rpc_layer self._master_target = self._get_master_target() @@ -143,26 +140,31 @@ class _WorkerContext(object): self._is_chief_node = self._is_chief() def _debug_message(self): - return "[cluster_spec: %r, task_type: %r, task_id: %r]" % ( - self._cluster_spec, self.task_type, self.task_id) + if self._cluster_spec: + return "[cluster_spec: %r, task_type: %r, task_id: %r]" % ( + self._cluster_spec, self.task_type, self.task_id) + else: + return "[local]" def __enter__(self): - old_context = get_current_worker_context() + old_context = distribute_coordinator_context.get_current_worker_context() if old_context: raise ValueError( "You cannot run distribute coordinator in a `worker_fn`.\t" + self._debug_message()) - _worker_context.current = self + # pylint: disable=protected-access + distribute_coordinator_context._worker_context.current = self def __exit__(self, unused_exception_type, unused_exception_value, unused_traceback): - _worker_context.current = None + # pylint: disable=protected-access + distribute_coordinator_context._worker_context.current = None def _get_master_target(self): """Return the master target for a task.""" # If cluster_spec is None or empty, we use local master. if not self._cluster_spec: - return "local" + return "" # If task_type is None, then it is in-graph replicated training. In this # case we use the chief or first worker's master target. @@ -207,6 +209,47 @@ class _WorkerContext(object): self._debug_message()) self._worker_barrier.wait() + def session_creator(self, + scaffold=None, + config=None, + checkpoint_dir=None, + checkpoint_filename_with_path=None, + max_wait_secs=7200): + """Returns a session creator. + + The returned session creator will be configured with the correct master + target and session configs. It will also run either init ops or ready ops + by querying the `strategy` object when `create_session` is called on it. + + Args: + scaffold: A `Scaffold` used for gathering or building supportive ops. If + not specified a default one is created. It's used to finalize the graph. + config: `ConfigProto` proto used to configure the session. + checkpoint_dir: A string. Optional path to a directory where to restore + variables. + checkpoint_filename_with_path: Full file name path to the checkpoint file. + Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be + specified. + max_wait_secs: Maximum time to wait for the session to become available. + + Returns: + a descendant of SessionCreator. + """ + # TODO(yuefengz): merge session config. + if self._strategy.should_init: + return monitored_session.ChiefSessionCreator( + scaffold, + master=self.master_target, + config=config or self._session_config, + checkpoint_dir=checkpoint_dir, + checkpoint_filename_with_path=checkpoint_filename_with_path) + else: + return monitored_session.WorkerSessionCreator( + scaffold, + master=self.master_target, + config=config or self._session_config, + max_wait_secs=max_wait_secs) + @property def has_barrier(self): """Whether the barrier is set or not.""" @@ -247,21 +290,38 @@ class _WorkerContext(object): """Returns number of workers in the cluster, including chief.""" return self._num_workers + @property + def should_checkpoint(self): + """Whether to save checkpoint.""" + return self._strategy.should_checkpoint + + @property + def should_save_summary(self): + """Whether to save summaries.""" + return self._strategy.should_save_summary + def _run_single_worker(worker_fn, + strategy, cluster_spec, task_type, task_id, - rpc_layer, + session_config, + rpc_layer="", worker_barrier=None): """Runs a single worker by calling `worker_fn` under context.""" - with _WorkerContext( + strategy = copy.deepcopy(strategy) + strategy.configure(session_config, cluster_spec, task_type, task_id) + context = _WorkerContext( + strategy, cluster_spec, task_type, task_id, + session_config=session_config, rpc_layer=rpc_layer, - worker_barrier=worker_barrier): - worker_fn() + worker_barrier=worker_barrier) + with context: + worker_fn(strategy) def _run_std_server(cluster_spec=None, @@ -280,13 +340,15 @@ def _run_std_server(cluster_spec=None, return server -def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer): +def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config, + rpc_layer): """Runs a standalone client for between-graph replication.""" eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( target=_run_single_worker, - args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0), + args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, + session_config), kwargs={ "rpc_layer": rpc_layer, }) @@ -298,7 +360,8 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer): for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): t = threading.Thread( target=_run_single_worker, - args=(worker_fn, cluster_spec, task_type, task_id), + args=(worker_fn, strategy, cluster_spec, task_type, task_id, + session_config), kwargs={ "rpc_layer": rpc_layer, "worker_barrier": worker_barrier @@ -315,43 +378,53 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer): eval_thread.join() -def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer): +def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, + rpc_layer): """Runs a standalone client for in-graph replication.""" eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( target=_run_single_worker, - args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0), + args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, + session_config), kwargs={ "rpc_layer": rpc_layer, }) eval_thread.start() - _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer) + _run_single_worker( + worker_fn, + strategy, + cluster_spec, + None, + None, + session_config, + rpc_layer=rpc_layer) if eval_thread: eval_thread.join() - -# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode. +# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode. # TODO(yuefengz): we may need a smart way to figure out whether the current task # is the special task when we support cluster_spec propagation. def run_distribute_coordinator(worker_fn, - mode=CoordinatorMode.SPLIT_CLIENT, + strategy, + mode=CoordinatorMode.STANDALONE_CLIENT, cluster_spec=None, task_type=None, task_id=None, - between_graph=False, + session_config=None, rpc_layer="grpc"): """Runs the coordinator for distributed TensorFlow. This function runs a split coordinator for distributed TensorFlow in its - default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying - server addresses and their roles in a cluster, this coordinator will figure - out how to set them up, give the underlying function the right targets for - master sessions via a scope object and coordinate their training. The cluster - consisting of standard servers needs to be brought up either with the standard - server binary or with a binary running distribute coordinator with `task_type` - set to non-client type which will then turn into standard servers. + default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec` + specifying server addresses and their roles in a cluster, this coordinator + will figure out how to set them up, give the underlying function the right + targets for master sessions via a scope object and coordinate their training. + The cluster consisting of standard servers needs to be brought up either with + the standard server binary or with a binary running distribute coordinator + with `task_type` set to non-client type which will then turn into standard + servers. In addition to be the distribute coordinator, this is also the source of configurations for each job in the distributed training. As there are multiple @@ -370,6 +443,14 @@ def run_distribute_coordinator(worker_fn, `worker_fn` depending whether it is between-graph training or in-graph replicated training. + The `strategy` object is expected to be a DistributionStrategy object which + has implemented methods needed by distributed coordinator such as + `configure(session_config, cluster_spec, task_type, task_id)` which configures + the strategy object for a specific task and `should_init` property which + instructs the distribute coordinator whether to run init ops for a task. The + distribute coordinator will make a copy of the `strategy` object, call its + `configure` method and pass it to `worker_fn` as an argument. + The `worker_fn` defines the training logic and is called under a its own worker context which can be accessed to via `get_current_worker_context`. A worker context provides access to configurations for each task, e.g. the @@ -413,16 +494,20 @@ def run_distribute_coordinator(worker_fn, evaluation. Args: - worker_fn: the function to be called and given the access to a coordinator - context object. + worker_fn: the function to be called. The function should accept a + `strategy` object and will be given access to a context object via a + context manager scope. + strategy: a DistributionStrategy object which specifying whether it should + run between-graph replicated training or not, whether to run init ops, + etc. This object will also be configured given `session_config`, + `cluster_spc`, `task_type` and `task_id`. mode: in which mode this distribute coordinator runs. cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles in a cluster. If not set or empty, fall back to local training. task_type: the current task type, optional if this is a client. task_id: the current task id, optional if this is a client. - between_graph: a boolean. It is only useful when `cluster_spec` is set and - not empty. If true, it will use between-graph replicated training; - otherwise it will use in-graph replicated training. + session_config: an optional @{tf.ConfigProto} object which will be passed + to `strategy`'s `configure` method and used to create a session. rpc_layer: optional string, the protocol for RPC, e.g. "grpc". Raises: @@ -448,15 +533,18 @@ def run_distribute_coordinator(worker_fn, if not cluster_spec: # `mode` is ignored in the local case. - _run_single_worker(worker_fn, None, None, None, rpc_layer) - elif mode == CoordinatorMode.SPLIT_CLIENT: + _run_single_worker(worker_fn, strategy, None, None, None, session_config, + rpc_layer) + elif mode == CoordinatorMode.STANDALONE_CLIENT: # The client must know the cluster but servers in the cluster don't have to # know the client. if task_type in [_TaskType.CLIENT, None]: - if between_graph: - _run_between_graph_client(worker_fn, cluster_spec, rpc_layer) + if strategy.between_graph: + _run_between_graph_client(worker_fn, strategy, cluster_spec, + session_config, rpc_layer) else: - _run_in_graph_client(worker_fn, cluster_spec, rpc_layer) + _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, + rpc_layer) else: # If not a client job, run the standard server. server = _run_std_server( @@ -471,19 +559,21 @@ def run_distribute_coordinator(worker_fn, cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) if task_type in [_TaskType.CHIEF, _TaskType.WORKER]: - if between_graph: + if strategy.between_graph: # All jobs run `worker_fn` if between-graph. - _run_single_worker(worker_fn, cluster_spec, task_type, task_id, - rpc_layer) + _run_single_worker(worker_fn, strategy, cluster_spec, task_type, + task_id, session_config, rpc_layer) else: # Only one node runs `worker_fn` if in-graph. - context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer) + context = _WorkerContext(strategy, cluster_spec, task_type, task_id) if context.is_chief: - _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer) + _run_single_worker(worker_fn, strategy, cluster_spec, None, None, + session_config, rpc_layer) else: server.join() elif task_type == _TaskType.EVALUATOR: - _run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer) + _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id, + session_config, rpc_layer) else: if task_type != _TaskType.PS: raise ValueError("Unexpected task_type: %r" % task_type) diff --git a/tensorflow/python/distribute/distribute_coordinator_context.py b/tensorflow/python/distribute/distribute_coordinator_context.py new file mode 100644 index 00000000000..dee65ce8839 --- /dev/null +++ b/tensorflow/python/distribute/distribute_coordinator_context.py @@ -0,0 +1,31 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The context retrieval method for distribute coordinator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +_worker_context = threading.local() + + +def get_current_worker_context(): + """Returns the current task context.""" + try: + return _worker_context.current + except AttributeError: + return None diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py index 319c29ba2fa..97c6bdd15a5 100644 --- a/tensorflow/python/distribute/distribute_coordinator_test.py +++ b/tensorflow/python/distribute/distribute_coordinator_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for distribute coordinator.""" +"""Tests for Distribute Coordinator.""" from __future__ import absolute_import from __future__ import division @@ -37,6 +37,7 @@ except ImportError as _error: from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.distribute import distribute_coordinator +from tensorflow.python.distribute import distribute_coordinator_context from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops @@ -44,17 +45,17 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + CHIEF = distribute_coordinator._TaskType.CHIEF WORKER = distribute_coordinator._TaskType.WORKER PS = distribute_coordinator._TaskType.PS EVALUATOR = distribute_coordinator._TaskType.EVALUATOR -SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT +STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER -RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server" - NUM_WORKERS = 3 NUM_PS = 2 @@ -74,6 +75,57 @@ def _strip_protocol(target): return target +class MockStrategy(object): + + def __init__(self, + between_graph=False, + should_init=None, + should_checkpoint=None, + should_save_summary=None): + self._between_graph = between_graph + self._should_init = should_init + self._should_checkpoint = should_checkpoint + self._should_save_summary = should_save_summary + + @property + def between_graph(self): + return self._between_graph + + def configure(self, + session_options=None, + cluster_spec=None, + task_type=None, + task_id=None): + del session_options, cluster_spec, task_type + if self._should_init is None: + if task_id == 0: + self._should_init = True + else: + self._should_init = False + if self._should_checkpoint is None: + if task_id == 0: + self._should_checkpoint = True + else: + self._should_checkpoint = False + if self._should_save_summary is None: + if task_id == 0: + self._should_save_summary = True + else: + self._should_save_summary = False + + @property + def should_init(self): + return self._should_init + + @property + def should_checkpoint(self): + return self._should_checkpoint + + @property + def should_save_summary(self): + return self._should_save_summary + + class MockServer(object): def __init__(self): @@ -108,6 +160,7 @@ class DistributeCoordinatorTestBase(test.TestCase): self._result_correct = 0 self._lock = threading.Lock() self._worker_context = {} + self._strategy_property = {} self._std_servers = {} self._barrier = distribute_coordinator._Barrier(NUM_WORKERS) @@ -142,8 +195,8 @@ class DistributeCoordinatorTestBase(test.TestCase): cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()] return cluster_spec - def _in_graph_worker_fn(self): - context = distribute_coordinator.get_current_worker_context() + def _in_graph_worker_fn(self, strategy): + context = distribute_coordinator_context.get_current_worker_context() self.assertTrue(context is not None) with self._test_session(target=context.master_target) as sess: xs = [] @@ -164,22 +217,23 @@ class DistributeCoordinatorTestBase(test.TestCase): if result_value == expected: self._result_correct += 1 - def _run_coordinator_in_thread(self, worker_fn, **kwargs): + def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs): t = threading.Thread( target=distribute_coordinator.run_distribute_coordinator, - args=(worker_fn,), + args=(worker_fn, strategy), kwargs=kwargs) t.start() return t - def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec, - **kwargs): + def _run_multiple_coordinator_in_threads(self, worker_fn, strategy, + cluster_spec, **kwargs): threads = {} for task_type in cluster_spec.keys(): threads[task_type] = [] for task_id in range(len(cluster_spec[task_type])): t = self._run_coordinator_in_thread( worker_fn, + strategy, cluster_spec=cluster_spec, task_type=task_type, task_id=task_id, @@ -187,8 +241,8 @@ class DistributeCoordinatorTestBase(test.TestCase): threads[task_type].append(t) return threads - def _between_graph_worker_fn(self): - context = distribute_coordinator.get_current_worker_context() + def _between_graph_worker_fn(self, strategy): + context = distribute_coordinator_context.get_current_worker_context() self.assertTrue(context is not None) with self._test_session(target=context.master_target) as sess: with ops.device("/job:ps/task:0"): @@ -234,14 +288,50 @@ class DistributeCoordinatorTestBase(test.TestCase): with self._lock: self._result_correct += 1 - def _dump_worker_context(self): + def _between_graph_with_monitored_session(self, strategy): + context = distribute_coordinator_context.get_current_worker_context() + self.assertTrue(context is not None) + with ops.device("/job:ps/task:0"): + # TODO(yuefengz): investigate why not using resource variable will make + # the test flaky. + x = variable_scope.get_variable("x", initializer=10.0, use_resource=True) + with ops.device("/job:ps/task:1"): + y = variable_scope.get_variable("y", initializer=20.0, use_resource=True) + + x_add = x.assign_add(2.0) + y_sub = y.assign_sub(2.0) + train_op = control_flow_ops.group([x_add, y_sub]) + + # The monitored session will run init or ready ops. + with monitored_session.MonitoredSession() as sess: + sess.run(train_op) + + # Synchronize workers after one step to make sure they all have finished + # training. + if context.has_barrier: + context.wait_for_other_workers() + else: + self._barrier.wait() + + x_val, y_val = sess.run([x, y]) + + self.assertEqual(x_val, 16.0) + self.assertEqual(y_val, 14.0) + if x_val == 16.0 and y_val == 14.0: + with self._lock: + self._result_correct += 1 + + def _dump_worker_context(self, strategy): """Dumps the propoerties of each worker context. It dumps the context properties to a dict mapping from task_type to a list of tuples of master_target, num_workers, is_chief and distribute_mode, where the list is indexed by the task_id. + + Args: + strategy: a `DistributionStrategy` object. """ - context = distribute_coordinator.get_current_worker_context() + context = distribute_coordinator_context.get_current_worker_context() self.assertTrue(context is not None) task_type = str(context.task_type) task_id = context.task_id or 0 @@ -255,6 +345,25 @@ class DistributeCoordinatorTestBase(test.TestCase): context.is_chief, context.distributed_mode) + def _dump_strategy_property(self, strategy): + context = distribute_coordinator_context.get_current_worker_context() + self.assertTrue(context is not None) + + self.assertEqual(context._strategy.should_init, strategy.should_init) + self.assertEqual(context.should_checkpoint, strategy.should_checkpoint) + self.assertEqual(context.should_save_summary, strategy.should_save_summary) + + task_type = str(context.task_type) + task_id = context.task_id or 0 + with self._lock: + if task_type not in self._strategy_property: + self._strategy_property[task_type] = [] + while len(self._strategy_property[task_type]) <= task_id: + self._strategy_property[task_type].append(None) + self._strategy_property[task_type][task_id] = ( + context._strategy.should_init, context.should_checkpoint, + context.should_save_summary) + def _run_mock_std_server(self, session_config=None, cluster_spec=None, @@ -274,22 +383,32 @@ class DistributeCoordinatorTestBase(test.TestCase): return server -class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): +class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase): - def testInGraphSplitMode(self): - """Test it runs in-graph replication in split client mode.""" + def testInGraphStandaloneMode(self): + """Test it runs in-graph replication in standalone client mode.""" distribute_coordinator.run_distribute_coordinator( self._in_graph_worker_fn, - cluster_spec=self._cluster_spec, - between_graph=False) + MockStrategy(between_graph=False), + cluster_spec=self._cluster_spec) self.assertEqual(self._result_correct, 1) def testBetweenGraph(self): - """Test it runs between-graph replication in split client mode.""" + """Test it runs between-graph replication in standalone client mode.""" distribute_coordinator.run_distribute_coordinator( self._between_graph_worker_fn, - cluster_spec=self._cluster_spec, - between_graph=True) + MockStrategy(between_graph=True), + cluster_spec=self._cluster_spec) + + # Each finished worker will increment self._result_correct. + self.assertEqual(self._result_correct, NUM_WORKERS) + + def testBetweenGraphWithMonitoredSession(self): + """Test monitored session in standalone client mode.""" + distribute_coordinator.run_distribute_coordinator( + self._between_graph_with_monitored_session, + MockStrategy(between_graph=True), + cluster_spec=self._cluster_spec) # Each finished worker will increment self._result_correct. self.assertEqual(self._result_correct, NUM_WORKERS) @@ -298,8 +417,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, - cluster_spec=self._cluster_spec, - between_graph=True) + MockStrategy(between_graph=True), + cluster_spec=self._cluster_spec) # There is only one type of task and there three such tasks. self.assertEqual(len(self._worker_context), 1) @@ -318,12 +437,30 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): self._worker_context[WORKER][2], (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True)) + def testBetweenGraphStrategyProperties(self): + # Dumps properties of the strategy objects. + distribute_coordinator.run_distribute_coordinator( + self._dump_strategy_property, + MockStrategy(between_graph=True, should_init=True), + cluster_spec=self._cluster_spec) + + # There is only one type of task and there three such tasks. + self.assertEqual(len(self._strategy_property), 1) + self.assertTrue(WORKER in self._strategy_property) + self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS) + + # Check whether each task has the right properties of should_init, + # should_checkpoint and should_save_summary. + self.assertEqual(self._strategy_property[WORKER][0], (True, True, True)) + self.assertEqual(self._strategy_property[WORKER][1], (True, False, False)) + self.assertEqual(self._strategy_property[WORKER][2], (True, False, False)) + def testInGraphContext(self): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, - cluster_spec=self._cluster_spec, - between_graph=False) + MockStrategy(between_graph=False), + cluster_spec=self._cluster_spec) # There is only a "None" task in the dumped task context. self.assertEqual(len(self._worker_context), 1) @@ -339,7 +476,9 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): def testLocalContext(self): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( - self._dump_worker_context, cluster_spec=None, between_graph=True) + self._dump_worker_context, + MockStrategy(between_graph=False), + cluster_spec=None) # There is only a "None" task. self.assertEqual(len(self._worker_context), 1) @@ -348,7 +487,7 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. - self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False)) + self.assertEqual(self._worker_context["None"][0], ("", 0, True, False)) def testBetweenGraphContextWithChief(self): # Adds a chief node, so there are NUM_WORKERS + 1 workers in total. @@ -358,8 +497,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, + MockStrategy(between_graph=True), cluster_spec=cluster_spec, - between_graph=True, rpc_layer="grpc") # There are one CHIEF and three workers. @@ -391,8 +530,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, + MockStrategy(between_graph=False), cluster_spec=cluster_spec, - between_graph=False, rpc_layer=None) # There are one "None" task and one EVALUATOR task. @@ -417,8 +556,8 @@ class DistributeCoordinatorTestInpendentWorkerMode( cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) threads = self._run_multiple_coordinator_in_threads( self._in_graph_worker_fn, + MockStrategy(between_graph=False), cluster_spec, - between_graph=False, mode=INDEPENDENT_WORKER) threads[WORKER][0].join() self.assertEqual(self._result_correct, 1) @@ -428,8 +567,22 @@ class DistributeCoordinatorTestInpendentWorkerMode( num_workers=NUM_WORKERS, num_ps=NUM_PS) threads = self._run_multiple_coordinator_in_threads( self._between_graph_worker_fn, + MockStrategy(between_graph=True), + cluster_spec, + mode=INDEPENDENT_WORKER) + for task_id in range(NUM_WORKERS): + threads[WORKER][task_id].join() + + # Each finished worker will increment self._result_correct. + self.assertEqual(self._result_correct, NUM_WORKERS) + + def testBetweenGraphWithMonitoredSession(self): + cluster_spec = self._create_cluster_spec( + num_workers=NUM_WORKERS, num_ps=NUM_PS) + threads = self._run_multiple_coordinator_in_threads( + self._between_graph_with_monitored_session, + MockStrategy(between_graph=True), cluster_spec, - between_graph=True, mode=INDEPENDENT_WORKER) for task_id in range(NUM_WORKERS): threads[WORKER][task_id].join() @@ -444,9 +597,9 @@ class DistributeCoordinatorTestInpendentWorkerMode( self._run_mock_std_server): threads = self._run_multiple_coordinator_in_threads( self._dump_worker_context, + MockStrategy(between_graph=True), cluster_spec, mode=INDEPENDENT_WORKER, - between_graph=True, rpc_layer=None) for task_id in range(NUM_WORKERS): threads[WORKER][task_id].join() @@ -476,6 +629,31 @@ class DistributeCoordinatorTestInpendentWorkerMode( self.assertFalse(self._std_servers[WORKER][1].joined) self.assertFalse(self._std_servers[WORKER][2].joined) + def testBetweenGraphStrategyProperties(self): + cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) + # Dumps properties of the strategy objects. + with test.mock.patch.object(distribute_coordinator, "_run_std_server", + self._run_mock_std_server): + threads = self._run_multiple_coordinator_in_threads( + self._dump_strategy_property, + MockStrategy(between_graph=True, should_init=True), + cluster_spec, + mode=INDEPENDENT_WORKER, + rpc_layer=None) + for task_id in range(NUM_WORKERS): + threads[WORKER][task_id].join() + + # There is only one type of task and there three such tasks. + self.assertEqual(len(self._strategy_property), 1) + self.assertTrue(WORKER in self._strategy_property) + self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS) + + # Check whether each task has the right properties of should_init, + # should_checkpoint and should_save_summary. + self.assertEqual(self._strategy_property[WORKER][0], (True, True, True)) + self.assertEqual(self._strategy_property[WORKER][1], (True, False, False)) + self.assertEqual(self._strategy_property[WORKER][2], (True, False, False)) + def testInGraphContext(self): cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS) # Dumps the task contexts and std server arguments. @@ -483,9 +661,9 @@ class DistributeCoordinatorTestInpendentWorkerMode( self._run_mock_std_server): threads = self._run_multiple_coordinator_in_threads( self._dump_worker_context, + MockStrategy(between_graph=False), cluster_spec, mode=INDEPENDENT_WORKER, - between_graph=False, rpc_layer=None) for task_id in range(NUM_WORKERS): threads[WORKER][task_id].join() @@ -519,9 +697,9 @@ class DistributeCoordinatorTestInpendentWorkerMode( self._run_mock_std_server): threads = self._run_multiple_coordinator_in_threads( self._dump_worker_context, + MockStrategy(between_graph=False), cluster_spec, mode=INDEPENDENT_WORKER, - between_graph=False, rpc_layer=None) for task_id in range(NUM_WORKERS): threads[WORKER][task_id].join() diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index 7b06bffa4b2..c077630de2b 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -25,6 +25,7 @@ import sys import six from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import distribute_coordinator_context from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -284,6 +285,63 @@ class Scaffold(object): resources.initialize_resources(resources.local_resources())) +def _create_monitored_session_with_worker_context(worker_context, # pylint: disable=missing-docstring + scaffold, + checkpoint_dir=None, + hooks=None, + chief_only_hooks=None, + save_checkpoint_secs=None, + save_summaries_steps=None, + save_summaries_secs=None, + config=None, + stop_grace_period_secs=120, + log_step_count_steps=100, + max_wait_secs=7200, + save_checkpoint_steps=None, + summary_dir=None): + all_hooks = [] + if hooks: + all_hooks.extend(hooks) + if chief_only_hooks and worker_context.is_chief: + all_hooks.extend(chief_only_hooks) + + summary_dir = summary_dir or checkpoint_dir + if summary_dir and worker_context.should_save_summary: + if log_step_count_steps and log_step_count_steps > 0: + all_hooks.append( + basic_session_run_hooks.StepCounterHook( + output_dir=summary_dir, every_n_steps=log_step_count_steps)) + + if (save_summaries_steps and save_summaries_steps > 0) or ( + save_summaries_secs and save_summaries_secs > 0): + all_hooks.append( + basic_session_run_hooks.SummarySaverHook( + scaffold=scaffold, + save_steps=save_summaries_steps, + save_secs=save_summaries_secs, + output_dir=summary_dir)) + + if checkpoint_dir and worker_context.should_checkpoint: + if (save_checkpoint_secs and save_checkpoint_secs > 0) or ( + save_checkpoint_steps and save_checkpoint_steps > 0): + all_hooks.append( + basic_session_run_hooks.CheckpointSaverHook( + checkpoint_dir, + save_steps=save_checkpoint_steps, + save_secs=save_checkpoint_secs, + scaffold=scaffold)) + + session_creator = worker_context.session_creator( + scaffold, + config=config, + checkpoint_dir=checkpoint_dir, + max_wait_secs=max_wait_secs) + return MonitoredSession( + session_creator=session_creator, + hooks=all_hooks, + stop_grace_period_secs=stop_grace_period_secs) + + @tf_export('train.MonitoredTrainingSession') def MonitoredTrainingSession(master='', # pylint: disable=invalid-name is_chief=True, @@ -373,14 +431,35 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name save_checkpoint_steps = None scaffold = scaffold or Scaffold() + worker_context = distribute_coordinator_context.get_current_worker_context() + + if worker_context: + return _create_monitored_session_with_worker_context( + worker_context, + scaffold, + checkpoint_dir=checkpoint_dir, + hooks=hooks, + chief_only_hooks=chief_only_hooks, + save_checkpoint_secs=save_checkpoint_secs, + save_summaries_steps=save_summaries_steps, + save_summaries_secs=save_summaries_secs, + config=config, + stop_grace_period_secs=stop_grace_period_secs, + log_step_count_steps=log_step_count_steps, + max_wait_secs=max_wait_secs, + save_checkpoint_steps=save_checkpoint_steps, + summary_dir=summary_dir) + if not is_chief: session_creator = WorkerSessionCreator( scaffold=scaffold, master=master, config=config, max_wait_secs=max_wait_secs) - return MonitoredSession(session_creator=session_creator, hooks=hooks or [], - stop_grace_period_secs=stop_grace_period_secs) + return MonitoredSession( + session_creator=session_creator, + hooks=hooks or [], + stop_grace_period_secs=stop_grace_period_secs) all_hooks = [] if chief_only_hooks: @@ -400,25 +479,29 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name if (save_summaries_steps and save_summaries_steps > 0) or ( save_summaries_secs and save_summaries_secs > 0): - all_hooks.append(basic_session_run_hooks.SummarySaverHook( - scaffold=scaffold, - save_steps=save_summaries_steps, - save_secs=save_summaries_secs, - output_dir=summary_dir)) + all_hooks.append( + basic_session_run_hooks.SummarySaverHook( + scaffold=scaffold, + save_steps=save_summaries_steps, + save_secs=save_summaries_secs, + output_dir=summary_dir)) if checkpoint_dir: if (save_checkpoint_secs and save_checkpoint_secs > 0) or ( save_checkpoint_steps and save_checkpoint_steps > 0): - all_hooks.append(basic_session_run_hooks.CheckpointSaverHook( - checkpoint_dir, - save_steps=save_checkpoint_steps, - save_secs=save_checkpoint_secs, - scaffold=scaffold)) + all_hooks.append( + basic_session_run_hooks.CheckpointSaverHook( + checkpoint_dir, + save_steps=save_checkpoint_steps, + save_secs=save_checkpoint_secs, + scaffold=scaffold)) if hooks: all_hooks.extend(hooks) - return MonitoredSession(session_creator=session_creator, hooks=all_hooks, - stop_grace_period_secs=stop_grace_period_secs) + return MonitoredSession( + session_creator=session_creator, + hooks=all_hooks, + stop_grace_period_secs=stop_grace_period_secs) @tf_export('train.SessionCreator') @@ -546,6 +629,11 @@ class _MonitoredSession(object): self._hooks = hooks or [] for h in self._hooks: h.begin() + + worker_context = distribute_coordinator_context.get_current_worker_context() + if not session_creator and worker_context: + session_creator = worker_context.session_creator() + # Create the session. self._coordinated_creator = self._CoordinatedSessionCreator( session_creator=session_creator or ChiefSessionCreator(), diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 92533ca4f3b..ff586b6c03f 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -32,6 +32,7 @@ from tensorflow.contrib.testing.python.framework import util_test from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import debug_pb2 from tensorflow.python.client import session as session_lib +from tensorflow.python.distribute import distribute_coordinator from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -381,6 +382,119 @@ class MonitoredTrainingSessionTest(test.TestCase): self.assertEqual(0, session.run(gstep)) +class MockStrategy(object): + + def __init__(self, + between_graph=False, + should_init=True, + should_checkpoint=None, + should_save_summary=None): + self._between_graph = between_graph + self._should_init = should_init + self._should_checkpoint = should_checkpoint + self._should_save_summary = should_save_summary + + @property + def between_graph(self): + return self._between_graph + + @property + def should_init(self): + return self._should_init + + @property + def should_checkpoint(self): + return self._should_checkpoint + + @property + def should_save_summary(self): + return self._should_save_summary + + +class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase): + """Test distribute coordinator controls summary saving and checkpointing.""" + + def test_summary_hook_enabled(self): + context = distribute_coordinator._WorkerContext( + MockStrategy(should_save_summary=True), None, None, None) + + logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled') + with ops.Graph().as_default(): + gstep = variables_lib.get_or_create_global_step() + new_gstep = state_ops.assign_add(gstep, 1) + summary.scalar('my_summary_tag', new_gstep * 2) + with context, monitored_session.MonitoredTrainingSession( + checkpoint_dir=logdir, + save_summaries_steps=100, + log_step_count_steps=10) as session: + for _ in range(101): + session.run(new_gstep) + + summaries = util_test.latest_summaries(logdir) + tags = [s.summary.value[0].tag for s in summaries] + self.assertIn('my_summary_tag', tags) + self.assertIn('global_step/sec', tags) + + def test_summary_hook_disabled(self): + context = distribute_coordinator._WorkerContext( + MockStrategy(should_save_summary=False), None, None, None) + + logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled') + with ops.Graph().as_default(): + gstep = variables_lib.get_or_create_global_step() + new_gstep = state_ops.assign_add(gstep, 1) + summary.scalar('my_summary_tag', new_gstep * 2) + with context, monitored_session.MonitoredTrainingSession( + checkpoint_dir=logdir, + save_summaries_steps=100, + log_step_count_steps=10) as session: + for _ in range(101): + session.run(new_gstep) + + # No summary is saved. + summaries = util_test.latest_summaries(logdir) + self.assertEqual(len(summaries), 0) + + def test_checkpoint_hook_enabled(self): + context = distribute_coordinator._WorkerContext( + MockStrategy(should_checkpoint=True), None, None, None) + + logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled') + with ops.Graph().as_default(): + gstep = variables_lib.get_or_create_global_step() + new_gstep = state_ops.assign_add(gstep, 1) + with context, monitored_session.MonitoredTrainingSession( + checkpoint_dir=logdir, + save_checkpoint_steps=100, + log_step_count_steps=10) as session: + for _ in range(100): + session.run(new_gstep) + + # A restart will find the checkpoint and recover automatically. + with monitored_session.MonitoredTrainingSession( + is_chief=True, checkpoint_dir=logdir) as session: + self.assertEqual(100, session.run(gstep)) + + def test_checkpoint_hook_disabled(self): + context = distribute_coordinator._WorkerContext( + MockStrategy(should_checkpoint=False), None, None, None) + + logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled') + with ops.Graph().as_default(): + gstep = variables_lib.get_or_create_global_step() + new_gstep = state_ops.assign_add(gstep, 1) + with context, monitored_session.MonitoredTrainingSession( + checkpoint_dir=logdir, + save_checkpoint_steps=100, + log_step_count_steps=10) as session: + for _ in range(100): + session.run(new_gstep) + + # No checkpoint is saved. + checkpoint = checkpoint_management.latest_checkpoint(logdir) + self.assertIsNone(checkpoint) + + class StopAtNSession(monitored_session._WrappedSession): """A wrapped session that stops at the N-th call to _check_stop.""" @@ -1365,8 +1479,8 @@ class MonitoredSessionTest(test.TestCase): with monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( scaffold, - checkpoint_filename_with_path= - checkpoint_management.latest_checkpoint(logdir))) as session: + checkpoint_filename_with_path=checkpoint_management. + latest_checkpoint(logdir))) as session: self.assertEqual(2, session.run(gstep)) def test_retry_initialization_on_aborted_error(self):