Use distribution strategy to configure distribute coordinator.

Add session_creator and a couple properties to worker context which then are used to configure monitored sessions.

PiperOrigin-RevId: 209026599
This commit is contained in:
Yuefeng Zhou 2018-08-16 12:25:17 -07:00 committed by TensorFlower Gardener
parent 5360d73687
commit 1326f33515
7 changed files with 624 additions and 109 deletions

View File

@ -3265,6 +3265,7 @@ py_library(
"@six_archive//:six",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:distribute_coordinator_context",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
# `layers` dependency only exists due to the use of a small utility.
@ -4658,7 +4659,10 @@ py_test(
size = "medium",
srcs = ["training/monitored_session_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"], # b/67945581
tags = [
"no_pip",
"notsan", # b/67945581
],
deps = [
":array_ops",
":checkpoint_management",
@ -4676,6 +4680,7 @@ py_test(
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/distribute:distribute_coordinator",
],
)

View File

@ -41,3 +41,12 @@ py_test(
"//tensorflow/python:variables",
],
)
py_library(
name = "distribute_coordinator_context",
srcs = [
"distribute_coordinator_context.py",
],
srcs_version = "PY2AND3",
deps = [],
)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A unified and split coordinator for distributed TensorFlow."""
"""A component for running distributed TensorFlow."""
from __future__ import absolute_import
from __future__ import division
@ -24,6 +24,8 @@ import os
import threading
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@ -43,23 +45,12 @@ class CoordinatorMode(object):
# client and connects to remote servers for training. Each remote server can
# use the distribute coordinator binary with task_type set correctly which
# will then turn into standard servers.
SPLIT_CLIENT = 0
STANDALONE_CLIENT = "standalone_client"
# The distribute coordinator runs on each worker. It will run a standard
# server on each worker and optionally run the `worker_fn` that is configured
# to talk to its standard server.
INDEPENDENT_WORKER = 1
_worker_context = threading.local()
def get_current_worker_context():
"""Returns the current task context."""
try:
return _worker_context.current
except AttributeError:
return None
INDEPENDENT_WORKER = "independent_worker"
class _Barrier(object):
@ -113,14 +104,17 @@ class _WorkerContext(object):
"""
def __init__(self,
strategy,
cluster_spec,
task_type,
task_id,
session_config=None,
rpc_layer="grpc",
worker_barrier=None):
"""Initialize the worker context object.
Args:
strategy: a `DistributionStrategy` object.
cluster_spec: a ClusterSpec object. It can be empty or None in the local
training case.
task_type: a string indicating the role of the corresponding task, such as
@ -128,14 +122,17 @@ class _WorkerContext(object):
replicated training.
task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training.
session_config: an optional @{tf.ConfigProto} object.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
worker_barrier: optional, the barrier object for worker synchronization.
"""
self._strategy = strategy
self._cluster_spec = cluster_spec
self._task_type = task_type
self._task_id = task_id
self._session_config = session_config
self._worker_barrier = worker_barrier
self._rpc_layer = rpc_layer
self._master_target = self._get_master_target()
@ -143,26 +140,31 @@ class _WorkerContext(object):
self._is_chief_node = self._is_chief()
def _debug_message(self):
return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
self._cluster_spec, self.task_type, self.task_id)
if self._cluster_spec:
return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
self._cluster_spec, self.task_type, self.task_id)
else:
return "[local]"
def __enter__(self):
old_context = get_current_worker_context()
old_context = distribute_coordinator_context.get_current_worker_context()
if old_context:
raise ValueError(
"You cannot run distribute coordinator in a `worker_fn`.\t" +
self._debug_message())
_worker_context.current = self
# pylint: disable=protected-access
distribute_coordinator_context._worker_context.current = self
def __exit__(self, unused_exception_type, unused_exception_value,
unused_traceback):
_worker_context.current = None
# pylint: disable=protected-access
distribute_coordinator_context._worker_context.current = None
def _get_master_target(self):
"""Return the master target for a task."""
# If cluster_spec is None or empty, we use local master.
if not self._cluster_spec:
return "local"
return ""
# If task_type is None, then it is in-graph replicated training. In this
# case we use the chief or first worker's master target.
@ -207,6 +209,47 @@ class _WorkerContext(object):
self._debug_message())
self._worker_barrier.wait()
def session_creator(self,
scaffold=None,
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None,
max_wait_secs=7200):
"""Returns a session creator.
The returned session creator will be configured with the correct master
target and session configs. It will also run either init ops or ready ops
by querying the `strategy` object when `create_session` is called on it.
Args:
scaffold: A `Scaffold` used for gathering or building supportive ops. If
not specified a default one is created. It's used to finalize the graph.
config: `ConfigProto` proto used to configure the session.
checkpoint_dir: A string. Optional path to a directory where to restore
variables.
checkpoint_filename_with_path: Full file name path to the checkpoint file.
Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
specified.
max_wait_secs: Maximum time to wait for the session to become available.
Returns:
a descendant of SessionCreator.
"""
# TODO(yuefengz): merge session config.
if self._strategy.should_init:
return monitored_session.ChiefSessionCreator(
scaffold,
master=self.master_target,
config=config or self._session_config,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path)
else:
return monitored_session.WorkerSessionCreator(
scaffold,
master=self.master_target,
config=config or self._session_config,
max_wait_secs=max_wait_secs)
@property
def has_barrier(self):
"""Whether the barrier is set or not."""
@ -247,21 +290,38 @@ class _WorkerContext(object):
"""Returns number of workers in the cluster, including chief."""
return self._num_workers
@property
def should_checkpoint(self):
"""Whether to save checkpoint."""
return self._strategy.should_checkpoint
@property
def should_save_summary(self):
"""Whether to save summaries."""
return self._strategy.should_save_summary
def _run_single_worker(worker_fn,
strategy,
cluster_spec,
task_type,
task_id,
rpc_layer,
session_config,
rpc_layer="",
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
with _WorkerContext(
strategy = copy.deepcopy(strategy)
strategy.configure(session_config, cluster_spec, task_type, task_id)
context = _WorkerContext(
strategy,
cluster_spec,
task_type,
task_id,
session_config=session_config,
rpc_layer=rpc_layer,
worker_barrier=worker_barrier):
worker_fn()
worker_barrier=worker_barrier)
with context:
worker_fn(strategy)
def _run_std_server(cluster_spec=None,
@ -280,13 +340,15 @@ def _run_std_server(cluster_spec=None,
return server
def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
rpc_layer):
"""Runs a standalone client for between-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
})
@ -298,7 +360,8 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
t = threading.Thread(
target=_run_single_worker,
args=(worker_fn, cluster_spec, task_type, task_id),
args=(worker_fn, strategy, cluster_spec, task_type, task_id,
session_config),
kwargs={
"rpc_layer": rpc_layer,
"worker_barrier": worker_barrier
@ -315,43 +378,53 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
eval_thread.join()
def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer):
def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
rpc_layer):
"""Runs a standalone client for in-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
})
eval_thread.start()
_run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
_run_single_worker(
worker_fn,
strategy,
cluster_spec,
None,
None,
session_config,
rpc_layer=rpc_layer)
if eval_thread:
eval_thread.join()
# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode.
# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
# TODO(yuefengz): we may need a smart way to figure out whether the current task
# is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn,
mode=CoordinatorMode.SPLIT_CLIENT,
strategy,
mode=CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None,
task_type=None,
task_id=None,
between_graph=False,
session_config=None,
rpc_layer="grpc"):
"""Runs the coordinator for distributed TensorFlow.
This function runs a split coordinator for distributed TensorFlow in its
default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying
server addresses and their roles in a cluster, this coordinator will figure
out how to set them up, give the underlying function the right targets for
master sessions via a scope object and coordinate their training. The cluster
consisting of standard servers needs to be brought up either with the standard
server binary or with a binary running distribute coordinator with `task_type`
set to non-client type which will then turn into standard servers.
default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
specifying server addresses and their roles in a cluster, this coordinator
will figure out how to set them up, give the underlying function the right
targets for master sessions via a scope object and coordinate their training.
The cluster consisting of standard servers needs to be brought up either with
the standard server binary or with a binary running distribute coordinator
with `task_type` set to non-client type which will then turn into standard
servers.
In addition to be the distribute coordinator, this is also the source of
configurations for each job in the distributed training. As there are multiple
@ -370,6 +443,14 @@ def run_distribute_coordinator(worker_fn,
`worker_fn` depending whether it is between-graph training or in-graph
replicated training.
The `strategy` object is expected to be a DistributionStrategy object which
has implemented methods needed by distributed coordinator such as
`configure(session_config, cluster_spec, task_type, task_id)` which configures
the strategy object for a specific task and `should_init` property which
instructs the distribute coordinator whether to run init ops for a task. The
distribute coordinator will make a copy of the `strategy` object, call its
`configure` method and pass it to `worker_fn` as an argument.
The `worker_fn` defines the training logic and is called under a its own
worker context which can be accessed to via `get_current_worker_context`. A
worker context provides access to configurations for each task, e.g. the
@ -413,16 +494,20 @@ def run_distribute_coordinator(worker_fn,
evaluation.
Args:
worker_fn: the function to be called and given the access to a coordinator
context object.
worker_fn: the function to be called. The function should accept a
`strategy` object and will be given access to a context object via a
context manager scope.
strategy: a DistributionStrategy object which specifying whether it should
run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`,
`cluster_spc`, `task_type` and `task_id`.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client.
between_graph: a boolean. It is only useful when `cluster_spec` is set and
not empty. If true, it will use between-graph replicated training;
otherwise it will use in-graph replicated training.
session_config: an optional @{tf.ConfigProto} object which will be passed
to `strategy`'s `configure` method and used to create a session.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
Raises:
@ -448,15 +533,18 @@ def run_distribute_coordinator(worker_fn,
if not cluster_spec:
# `mode` is ignored in the local case.
_run_single_worker(worker_fn, None, None, None, rpc_layer)
elif mode == CoordinatorMode.SPLIT_CLIENT:
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
elif mode == CoordinatorMode.STANDALONE_CLIENT:
# The client must know the cluster but servers in the cluster don't have to
# know the client.
if task_type in [_TaskType.CLIENT, None]:
if between_graph:
_run_between_graph_client(worker_fn, cluster_spec, rpc_layer)
if strategy.between_graph:
_run_between_graph_client(worker_fn, strategy, cluster_spec,
session_config, rpc_layer)
else:
_run_in_graph_client(worker_fn, cluster_spec, rpc_layer)
_run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
rpc_layer)
else:
# If not a client job, run the standard server.
server = _run_std_server(
@ -471,19 +559,21 @@ def run_distribute_coordinator(worker_fn,
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
if between_graph:
if strategy.between_graph:
# All jobs run `worker_fn` if between-graph.
_run_single_worker(worker_fn, cluster_spec, task_type, task_id,
rpc_layer)
_run_single_worker(worker_fn, strategy, cluster_spec, task_type,
task_id, session_config, rpc_layer)
else:
# Only one node runs `worker_fn` if in-graph.
context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer)
context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
if context.is_chief:
_run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
_run_single_worker(worker_fn, strategy, cluster_spec, None, None,
session_config, rpc_layer)
else:
server.join()
elif task_type == _TaskType.EVALUATOR:
_run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer)
_run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
session_config, rpc_layer)
else:
if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type)

View File

@ -0,0 +1,31 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The context retrieval method for distribute coordinator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
_worker_context = threading.local()
def get_current_worker_context():
"""Returns the current task context."""
try:
return _worker_context.current
except AttributeError:
return None

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for distribute coordinator."""
"""Tests for Distribute Coordinator."""
from __future__ import absolute_import
from __future__ import division
@ -37,6 +37,7 @@ except ImportError as _error:
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
@ -44,17 +45,17 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
CHIEF = distribute_coordinator._TaskType.CHIEF
WORKER = distribute_coordinator._TaskType.WORKER
PS = distribute_coordinator._TaskType.PS
EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT
STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT
INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server"
NUM_WORKERS = 3
NUM_PS = 2
@ -74,6 +75,57 @@ def _strip_protocol(target):
return target
class MockStrategy(object):
def __init__(self,
between_graph=False,
should_init=None,
should_checkpoint=None,
should_save_summary=None):
self._between_graph = between_graph
self._should_init = should_init
self._should_checkpoint = should_checkpoint
self._should_save_summary = should_save_summary
@property
def between_graph(self):
return self._between_graph
def configure(self,
session_options=None,
cluster_spec=None,
task_type=None,
task_id=None):
del session_options, cluster_spec, task_type
if self._should_init is None:
if task_id == 0:
self._should_init = True
else:
self._should_init = False
if self._should_checkpoint is None:
if task_id == 0:
self._should_checkpoint = True
else:
self._should_checkpoint = False
if self._should_save_summary is None:
if task_id == 0:
self._should_save_summary = True
else:
self._should_save_summary = False
@property
def should_init(self):
return self._should_init
@property
def should_checkpoint(self):
return self._should_checkpoint
@property
def should_save_summary(self):
return self._should_save_summary
class MockServer(object):
def __init__(self):
@ -108,6 +160,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
self._result_correct = 0
self._lock = threading.Lock()
self._worker_context = {}
self._strategy_property = {}
self._std_servers = {}
self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
@ -142,8 +195,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
return cluster_spec
def _in_graph_worker_fn(self):
context = distribute_coordinator.get_current_worker_context()
def _in_graph_worker_fn(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
xs = []
@ -164,22 +217,23 @@ class DistributeCoordinatorTestBase(test.TestCase):
if result_value == expected:
self._result_correct += 1
def _run_coordinator_in_thread(self, worker_fn, **kwargs):
def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs):
t = threading.Thread(
target=distribute_coordinator.run_distribute_coordinator,
args=(worker_fn,),
args=(worker_fn, strategy),
kwargs=kwargs)
t.start()
return t
def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec,
**kwargs):
def _run_multiple_coordinator_in_threads(self, worker_fn, strategy,
cluster_spec, **kwargs):
threads = {}
for task_type in cluster_spec.keys():
threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
t = self._run_coordinator_in_thread(
worker_fn,
strategy,
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
@ -187,8 +241,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
threads[task_type].append(t)
return threads
def _between_graph_worker_fn(self):
context = distribute_coordinator.get_current_worker_context()
def _between_graph_worker_fn(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
with ops.device("/job:ps/task:0"):
@ -234,14 +288,50 @@ class DistributeCoordinatorTestBase(test.TestCase):
with self._lock:
self._result_correct += 1
def _dump_worker_context(self):
def _between_graph_with_monitored_session(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with ops.device("/job:ps/task:0"):
# TODO(yuefengz): investigate why not using resource variable will make
# the test flaky.
x = variable_scope.get_variable("x", initializer=10.0, use_resource=True)
with ops.device("/job:ps/task:1"):
y = variable_scope.get_variable("y", initializer=20.0, use_resource=True)
x_add = x.assign_add(2.0)
y_sub = y.assign_sub(2.0)
train_op = control_flow_ops.group([x_add, y_sub])
# The monitored session will run init or ready ops.
with monitored_session.MonitoredSession() as sess:
sess.run(train_op)
# Synchronize workers after one step to make sure they all have finished
# training.
if context.has_barrier:
context.wait_for_other_workers()
else:
self._barrier.wait()
x_val, y_val = sess.run([x, y])
self.assertEqual(x_val, 16.0)
self.assertEqual(y_val, 14.0)
if x_val == 16.0 and y_val == 14.0:
with self._lock:
self._result_correct += 1
def _dump_worker_context(self, strategy):
"""Dumps the propoerties of each worker context.
It dumps the context properties to a dict mapping from task_type to a list
of tuples of master_target, num_workers, is_chief and distribute_mode, where
the list is indexed by the task_id.
Args:
strategy: a `DistributionStrategy` object.
"""
context = distribute_coordinator.get_current_worker_context()
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
task_type = str(context.task_type)
task_id = context.task_id or 0
@ -255,6 +345,25 @@ class DistributeCoordinatorTestBase(test.TestCase):
context.is_chief,
context.distributed_mode)
def _dump_strategy_property(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
self.assertEqual(context._strategy.should_init, strategy.should_init)
self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
self.assertEqual(context.should_save_summary, strategy.should_save_summary)
task_type = str(context.task_type)
task_id = context.task_id or 0
with self._lock:
if task_type not in self._strategy_property:
self._strategy_property[task_type] = []
while len(self._strategy_property[task_type]) <= task_id:
self._strategy_property[task_type].append(None)
self._strategy_property[task_type][task_id] = (
context._strategy.should_init, context.should_checkpoint,
context.should_save_summary)
def _run_mock_std_server(self,
session_config=None,
cluster_spec=None,
@ -274,22 +383,32 @@ class DistributeCoordinatorTestBase(test.TestCase):
return server
class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
def testInGraphSplitMode(self):
"""Test it runs in-graph replication in split client mode."""
def testInGraphStandaloneMode(self):
"""Test it runs in-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._in_graph_worker_fn,
cluster_spec=self._cluster_spec,
between_graph=False)
MockStrategy(between_graph=False),
cluster_spec=self._cluster_spec)
self.assertEqual(self._result_correct, 1)
def testBetweenGraph(self):
"""Test it runs between-graph replication in split client mode."""
"""Test it runs between-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._between_graph_worker_fn,
cluster_spec=self._cluster_spec,
between_graph=True)
MockStrategy(between_graph=True),
cluster_spec=self._cluster_spec)
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
def testBetweenGraphWithMonitoredSession(self):
"""Test monitored session in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._between_graph_with_monitored_session,
MockStrategy(between_graph=True),
cluster_spec=self._cluster_spec)
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
@ -298,8 +417,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
cluster_spec=self._cluster_spec,
between_graph=True)
MockStrategy(between_graph=True),
cluster_spec=self._cluster_spec)
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._worker_context), 1)
@ -318,12 +437,30 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
self._worker_context[WORKER][2],
(_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
def testBetweenGraphStrategyProperties(self):
# Dumps properties of the strategy objects.
distribute_coordinator.run_distribute_coordinator(
self._dump_strategy_property,
MockStrategy(between_graph=True, should_init=True),
cluster_spec=self._cluster_spec)
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._strategy_property), 1)
self.assertTrue(WORKER in self._strategy_property)
self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
# Check whether each task has the right properties of should_init,
# should_checkpoint and should_save_summary.
self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
def testInGraphContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
cluster_spec=self._cluster_spec,
between_graph=False)
MockStrategy(between_graph=False),
cluster_spec=self._cluster_spec)
# There is only a "None" task in the dumped task context.
self.assertEqual(len(self._worker_context), 1)
@ -339,7 +476,9 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
def testLocalContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context, cluster_spec=None, between_graph=True)
self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec=None)
# There is only a "None" task.
self.assertEqual(len(self._worker_context), 1)
@ -348,7 +487,7 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False))
self.assertEqual(self._worker_context["None"][0], ("", 0, True, False))
def testBetweenGraphContextWithChief(self):
# Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
@ -358,8 +497,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
MockStrategy(between_graph=True),
cluster_spec=cluster_spec,
between_graph=True,
rpc_layer="grpc")
# There are one CHIEF and three workers.
@ -391,8 +530,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec=cluster_spec,
between_graph=False,
rpc_layer=None)
# There are one "None" task and one EVALUATOR task.
@ -417,8 +556,8 @@ class DistributeCoordinatorTestInpendentWorkerMode(
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
threads = self._run_multiple_coordinator_in_threads(
self._in_graph_worker_fn,
MockStrategy(between_graph=False),
cluster_spec,
between_graph=False,
mode=INDEPENDENT_WORKER)
threads[WORKER][0].join()
self.assertEqual(self._result_correct, 1)
@ -428,8 +567,22 @@ class DistributeCoordinatorTestInpendentWorkerMode(
num_workers=NUM_WORKERS, num_ps=NUM_PS)
threads = self._run_multiple_coordinator_in_threads(
self._between_graph_worker_fn,
MockStrategy(between_graph=True),
cluster_spec,
mode=INDEPENDENT_WORKER)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
def testBetweenGraphWithMonitoredSession(self):
cluster_spec = self._create_cluster_spec(
num_workers=NUM_WORKERS, num_ps=NUM_PS)
threads = self._run_multiple_coordinator_in_threads(
self._between_graph_with_monitored_session,
MockStrategy(between_graph=True),
cluster_spec,
between_graph=True,
mode=INDEPENDENT_WORKER)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@ -444,9 +597,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
MockStrategy(between_graph=True),
cluster_spec,
mode=INDEPENDENT_WORKER,
between_graph=True,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@ -476,6 +629,31 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertFalse(self._std_servers[WORKER][1].joined)
self.assertFalse(self._std_servers[WORKER][2].joined)
def testBetweenGraphStrategyProperties(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps properties of the strategy objects.
with test.mock.patch.object(distribute_coordinator, "_run_std_server",
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_strategy_property,
MockStrategy(between_graph=True, should_init=True),
cluster_spec,
mode=INDEPENDENT_WORKER,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._strategy_property), 1)
self.assertTrue(WORKER in self._strategy_property)
self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
# Check whether each task has the right properties of should_init,
# should_checkpoint and should_save_summary.
self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
def testInGraphContext(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps the task contexts and std server arguments.
@ -483,9 +661,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
between_graph=False,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@ -519,9 +697,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
between_graph=False,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()

View File

@ -25,6 +25,7 @@ import sys
import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -284,6 +285,63 @@ class Scaffold(object):
resources.initialize_resources(resources.local_resources()))
def _create_monitored_session_with_worker_context(worker_context, # pylint: disable=missing-docstring
scaffold,
checkpoint_dir=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None,
config=None,
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
save_checkpoint_steps=None,
summary_dir=None):
all_hooks = []
if hooks:
all_hooks.extend(hooks)
if chief_only_hooks and worker_context.is_chief:
all_hooks.extend(chief_only_hooks)
summary_dir = summary_dir or checkpoint_dir
if summary_dir and worker_context.should_save_summary:
if log_step_count_steps and log_step_count_steps > 0:
all_hooks.append(
basic_session_run_hooks.StepCounterHook(
output_dir=summary_dir, every_n_steps=log_step_count_steps))
if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0):
all_hooks.append(
basic_session_run_hooks.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
output_dir=summary_dir))
if checkpoint_dir and worker_context.should_checkpoint:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0):
all_hooks.append(
basic_session_run_hooks.CheckpointSaverHook(
checkpoint_dir,
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
session_creator = worker_context.session_creator(
scaffold,
config=config,
checkpoint_dir=checkpoint_dir,
max_wait_secs=max_wait_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
@tf_export('train.MonitoredTrainingSession')
def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
is_chief=True,
@ -373,14 +431,35 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
save_checkpoint_steps = None
scaffold = scaffold or Scaffold()
worker_context = distribute_coordinator_context.get_current_worker_context()
if worker_context:
return _create_monitored_session_with_worker_context(
worker_context,
scaffold,
checkpoint_dir=checkpoint_dir,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
save_summaries_secs=save_summaries_secs,
config=config,
stop_grace_period_secs=stop_grace_period_secs,
log_step_count_steps=log_step_count_steps,
max_wait_secs=max_wait_secs,
save_checkpoint_steps=save_checkpoint_steps,
summary_dir=summary_dir)
if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold,
master=master,
config=config,
max_wait_secs=max_wait_secs)
return MonitoredSession(session_creator=session_creator, hooks=hooks or [],
stop_grace_period_secs=stop_grace_period_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=hooks or [],
stop_grace_period_secs=stop_grace_period_secs)
all_hooks = []
if chief_only_hooks:
@ -400,25 +479,29 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0):
all_hooks.append(basic_session_run_hooks.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
output_dir=summary_dir))
all_hooks.append(
basic_session_run_hooks.SummarySaverHook(
scaffold=scaffold,
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
output_dir=summary_dir))
if checkpoint_dir:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0):
all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(
checkpoint_dir,
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
all_hooks.append(
basic_session_run_hooks.CheckpointSaverHook(
checkpoint_dir,
save_steps=save_checkpoint_steps,
save_secs=save_checkpoint_secs,
scaffold=scaffold))
if hooks:
all_hooks.extend(hooks)
return MonitoredSession(session_creator=session_creator, hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
return MonitoredSession(
session_creator=session_creator,
hooks=all_hooks,
stop_grace_period_secs=stop_grace_period_secs)
@tf_export('train.SessionCreator')
@ -546,6 +629,11 @@ class _MonitoredSession(object):
self._hooks = hooks or []
for h in self._hooks:
h.begin()
worker_context = distribute_coordinator_context.get_current_worker_context()
if not session_creator and worker_context:
session_creator = worker_context.session_creator()
# Create the session.
self._coordinated_creator = self._CoordinatedSessionCreator(
session_creator=session_creator or ChiefSessionCreator(),

View File

@ -32,6 +32,7 @@ from tensorflow.contrib.testing.python.framework import util_test
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@ -381,6 +382,119 @@ class MonitoredTrainingSessionTest(test.TestCase):
self.assertEqual(0, session.run(gstep))
class MockStrategy(object):
def __init__(self,
between_graph=False,
should_init=True,
should_checkpoint=None,
should_save_summary=None):
self._between_graph = between_graph
self._should_init = should_init
self._should_checkpoint = should_checkpoint
self._should_save_summary = should_save_summary
@property
def between_graph(self):
return self._between_graph
@property
def should_init(self):
return self._should_init
@property
def should_checkpoint(self):
return self._should_checkpoint
@property
def should_save_summary(self):
return self._should_save_summary
class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
"""Test distribute coordinator controls summary saving and checkpointing."""
def test_summary_hook_enabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_save_summary=True), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_summaries_steps=100,
log_step_count_steps=10) as session:
for _ in range(101):
session.run(new_gstep)
summaries = util_test.latest_summaries(logdir)
tags = [s.summary.value[0].tag for s in summaries]
self.assertIn('my_summary_tag', tags)
self.assertIn('global_step/sec', tags)
def test_summary_hook_disabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_save_summary=False), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_summaries_steps=100,
log_step_count_steps=10) as session:
for _ in range(101):
session.run(new_gstep)
# No summary is saved.
summaries = util_test.latest_summaries(logdir)
self.assertEqual(len(summaries), 0)
def test_checkpoint_hook_enabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_checkpoint=True), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_checkpoint_steps=100,
log_step_count_steps=10) as session:
for _ in range(100):
session.run(new_gstep)
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredTrainingSession(
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(100, session.run(gstep))
def test_checkpoint_hook_disabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_checkpoint=False), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_checkpoint_steps=100,
log_step_count_steps=10) as session:
for _ in range(100):
session.run(new_gstep)
# No checkpoint is saved.
checkpoint = checkpoint_management.latest_checkpoint(logdir)
self.assertIsNone(checkpoint)
class StopAtNSession(monitored_session._WrappedSession):
"""A wrapped session that stops at the N-th call to _check_stop."""
@ -1365,8 +1479,8 @@ class MonitoredSessionTest(test.TestCase):
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
checkpoint_filename_with_path=
checkpoint_management.latest_checkpoint(logdir))) as session:
checkpoint_filename_with_path=checkpoint_management.
latest_checkpoint(logdir))) as session:
self.assertEqual(2, session.run(gstep))
def test_retry_initialization_on_aborted_error(self):