Use distribution strategy to configure distribute coordinator.
Add session_creator and a couple properties to worker context which then are used to configure monitored sessions. PiperOrigin-RevId: 209026599
This commit is contained in:
parent
5360d73687
commit
1326f33515
@ -3265,6 +3265,7 @@ py_library(
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:distribute_coordinator_context",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
# `layers` dependency only exists due to the use of a small utility.
|
||||
@ -4658,7 +4659,10 @@ py_test(
|
||||
size = "medium",
|
||||
srcs = ["training/monitored_session_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"], # b/67945581
|
||||
tags = [
|
||||
"no_pip",
|
||||
"notsan", # b/67945581
|
||||
],
|
||||
deps = [
|
||||
":array_ops",
|
||||
":checkpoint_management",
|
||||
@ -4676,6 +4680,7 @@ py_test(
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/distribute:distribute_coordinator",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -41,3 +41,12 @@ py_test(
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distribute_coordinator_context",
|
||||
srcs = [
|
||||
"distribute_coordinator_context.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [],
|
||||
)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A unified and split coordinator for distributed TensorFlow."""
|
||||
"""A component for running distributed TensorFlow."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -24,6 +24,8 @@ import os
|
||||
import threading
|
||||
|
||||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.python.distribute import distribute_coordinator_context
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
@ -43,23 +45,12 @@ class CoordinatorMode(object):
|
||||
# client and connects to remote servers for training. Each remote server can
|
||||
# use the distribute coordinator binary with task_type set correctly which
|
||||
# will then turn into standard servers.
|
||||
SPLIT_CLIENT = 0
|
||||
STANDALONE_CLIENT = "standalone_client"
|
||||
|
||||
# The distribute coordinator runs on each worker. It will run a standard
|
||||
# server on each worker and optionally run the `worker_fn` that is configured
|
||||
# to talk to its standard server.
|
||||
INDEPENDENT_WORKER = 1
|
||||
|
||||
|
||||
_worker_context = threading.local()
|
||||
|
||||
|
||||
def get_current_worker_context():
|
||||
"""Returns the current task context."""
|
||||
try:
|
||||
return _worker_context.current
|
||||
except AttributeError:
|
||||
return None
|
||||
INDEPENDENT_WORKER = "independent_worker"
|
||||
|
||||
|
||||
class _Barrier(object):
|
||||
@ -113,14 +104,17 @@ class _WorkerContext(object):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy,
|
||||
cluster_spec,
|
||||
task_type,
|
||||
task_id,
|
||||
session_config=None,
|
||||
rpc_layer="grpc",
|
||||
worker_barrier=None):
|
||||
"""Initialize the worker context object.
|
||||
|
||||
Args:
|
||||
strategy: a `DistributionStrategy` object.
|
||||
cluster_spec: a ClusterSpec object. It can be empty or None in the local
|
||||
training case.
|
||||
task_type: a string indicating the role of the corresponding task, such as
|
||||
@ -128,14 +122,17 @@ class _WorkerContext(object):
|
||||
replicated training.
|
||||
task_id: an integer indicating id of the corresponding task. It can be
|
||||
None if it is local training or in-graph replicated training.
|
||||
session_config: an optional @{tf.ConfigProto} object.
|
||||
rpc_layer: optional string specifying the RPC protocol for communication
|
||||
with worker masters. If None or empty, hosts in the `cluster_spec` will
|
||||
be used directly.
|
||||
worker_barrier: optional, the barrier object for worker synchronization.
|
||||
"""
|
||||
self._strategy = strategy
|
||||
self._cluster_spec = cluster_spec
|
||||
self._task_type = task_type
|
||||
self._task_id = task_id
|
||||
self._session_config = session_config
|
||||
self._worker_barrier = worker_barrier
|
||||
self._rpc_layer = rpc_layer
|
||||
self._master_target = self._get_master_target()
|
||||
@ -143,26 +140,31 @@ class _WorkerContext(object):
|
||||
self._is_chief_node = self._is_chief()
|
||||
|
||||
def _debug_message(self):
|
||||
return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
|
||||
self._cluster_spec, self.task_type, self.task_id)
|
||||
if self._cluster_spec:
|
||||
return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
|
||||
self._cluster_spec, self.task_type, self.task_id)
|
||||
else:
|
||||
return "[local]"
|
||||
|
||||
def __enter__(self):
|
||||
old_context = get_current_worker_context()
|
||||
old_context = distribute_coordinator_context.get_current_worker_context()
|
||||
if old_context:
|
||||
raise ValueError(
|
||||
"You cannot run distribute coordinator in a `worker_fn`.\t" +
|
||||
self._debug_message())
|
||||
_worker_context.current = self
|
||||
# pylint: disable=protected-access
|
||||
distribute_coordinator_context._worker_context.current = self
|
||||
|
||||
def __exit__(self, unused_exception_type, unused_exception_value,
|
||||
unused_traceback):
|
||||
_worker_context.current = None
|
||||
# pylint: disable=protected-access
|
||||
distribute_coordinator_context._worker_context.current = None
|
||||
|
||||
def _get_master_target(self):
|
||||
"""Return the master target for a task."""
|
||||
# If cluster_spec is None or empty, we use local master.
|
||||
if not self._cluster_spec:
|
||||
return "local"
|
||||
return ""
|
||||
|
||||
# If task_type is None, then it is in-graph replicated training. In this
|
||||
# case we use the chief or first worker's master target.
|
||||
@ -207,6 +209,47 @@ class _WorkerContext(object):
|
||||
self._debug_message())
|
||||
self._worker_barrier.wait()
|
||||
|
||||
def session_creator(self,
|
||||
scaffold=None,
|
||||
config=None,
|
||||
checkpoint_dir=None,
|
||||
checkpoint_filename_with_path=None,
|
||||
max_wait_secs=7200):
|
||||
"""Returns a session creator.
|
||||
|
||||
The returned session creator will be configured with the correct master
|
||||
target and session configs. It will also run either init ops or ready ops
|
||||
by querying the `strategy` object when `create_session` is called on it.
|
||||
|
||||
Args:
|
||||
scaffold: A `Scaffold` used for gathering or building supportive ops. If
|
||||
not specified a default one is created. It's used to finalize the graph.
|
||||
config: `ConfigProto` proto used to configure the session.
|
||||
checkpoint_dir: A string. Optional path to a directory where to restore
|
||||
variables.
|
||||
checkpoint_filename_with_path: Full file name path to the checkpoint file.
|
||||
Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
|
||||
specified.
|
||||
max_wait_secs: Maximum time to wait for the session to become available.
|
||||
|
||||
Returns:
|
||||
a descendant of SessionCreator.
|
||||
"""
|
||||
# TODO(yuefengz): merge session config.
|
||||
if self._strategy.should_init:
|
||||
return monitored_session.ChiefSessionCreator(
|
||||
scaffold,
|
||||
master=self.master_target,
|
||||
config=config or self._session_config,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
checkpoint_filename_with_path=checkpoint_filename_with_path)
|
||||
else:
|
||||
return monitored_session.WorkerSessionCreator(
|
||||
scaffold,
|
||||
master=self.master_target,
|
||||
config=config or self._session_config,
|
||||
max_wait_secs=max_wait_secs)
|
||||
|
||||
@property
|
||||
def has_barrier(self):
|
||||
"""Whether the barrier is set or not."""
|
||||
@ -247,21 +290,38 @@ class _WorkerContext(object):
|
||||
"""Returns number of workers in the cluster, including chief."""
|
||||
return self._num_workers
|
||||
|
||||
@property
|
||||
def should_checkpoint(self):
|
||||
"""Whether to save checkpoint."""
|
||||
return self._strategy.should_checkpoint
|
||||
|
||||
@property
|
||||
def should_save_summary(self):
|
||||
"""Whether to save summaries."""
|
||||
return self._strategy.should_save_summary
|
||||
|
||||
|
||||
def _run_single_worker(worker_fn,
|
||||
strategy,
|
||||
cluster_spec,
|
||||
task_type,
|
||||
task_id,
|
||||
rpc_layer,
|
||||
session_config,
|
||||
rpc_layer="",
|
||||
worker_barrier=None):
|
||||
"""Runs a single worker by calling `worker_fn` under context."""
|
||||
with _WorkerContext(
|
||||
strategy = copy.deepcopy(strategy)
|
||||
strategy.configure(session_config, cluster_spec, task_type, task_id)
|
||||
context = _WorkerContext(
|
||||
strategy,
|
||||
cluster_spec,
|
||||
task_type,
|
||||
task_id,
|
||||
session_config=session_config,
|
||||
rpc_layer=rpc_layer,
|
||||
worker_barrier=worker_barrier):
|
||||
worker_fn()
|
||||
worker_barrier=worker_barrier)
|
||||
with context:
|
||||
worker_fn(strategy)
|
||||
|
||||
|
||||
def _run_std_server(cluster_spec=None,
|
||||
@ -280,13 +340,15 @@ def _run_std_server(cluster_spec=None,
|
||||
return server
|
||||
|
||||
|
||||
def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
|
||||
def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
|
||||
rpc_layer):
|
||||
"""Runs a standalone client for between-graph replication."""
|
||||
eval_thread = None
|
||||
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
||||
eval_thread = threading.Thread(
|
||||
target=_run_single_worker,
|
||||
args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
|
||||
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
|
||||
session_config),
|
||||
kwargs={
|
||||
"rpc_layer": rpc_layer,
|
||||
})
|
||||
@ -298,7 +360,8 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
|
||||
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
|
||||
t = threading.Thread(
|
||||
target=_run_single_worker,
|
||||
args=(worker_fn, cluster_spec, task_type, task_id),
|
||||
args=(worker_fn, strategy, cluster_spec, task_type, task_id,
|
||||
session_config),
|
||||
kwargs={
|
||||
"rpc_layer": rpc_layer,
|
||||
"worker_barrier": worker_barrier
|
||||
@ -315,43 +378,53 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
|
||||
eval_thread.join()
|
||||
|
||||
|
||||
def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer):
|
||||
def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
|
||||
rpc_layer):
|
||||
"""Runs a standalone client for in-graph replication."""
|
||||
eval_thread = None
|
||||
if _TaskType.EVALUATOR in cluster_spec.jobs:
|
||||
eval_thread = threading.Thread(
|
||||
target=_run_single_worker,
|
||||
args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
|
||||
args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
|
||||
session_config),
|
||||
kwargs={
|
||||
"rpc_layer": rpc_layer,
|
||||
})
|
||||
eval_thread.start()
|
||||
|
||||
_run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
|
||||
_run_single_worker(
|
||||
worker_fn,
|
||||
strategy,
|
||||
cluster_spec,
|
||||
None,
|
||||
None,
|
||||
session_config,
|
||||
rpc_layer=rpc_layer)
|
||||
if eval_thread:
|
||||
eval_thread.join()
|
||||
|
||||
|
||||
# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode.
|
||||
# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
|
||||
# TODO(yuefengz): we may need a smart way to figure out whether the current task
|
||||
# is the special task when we support cluster_spec propagation.
|
||||
def run_distribute_coordinator(worker_fn,
|
||||
mode=CoordinatorMode.SPLIT_CLIENT,
|
||||
strategy,
|
||||
mode=CoordinatorMode.STANDALONE_CLIENT,
|
||||
cluster_spec=None,
|
||||
task_type=None,
|
||||
task_id=None,
|
||||
between_graph=False,
|
||||
session_config=None,
|
||||
rpc_layer="grpc"):
|
||||
"""Runs the coordinator for distributed TensorFlow.
|
||||
|
||||
This function runs a split coordinator for distributed TensorFlow in its
|
||||
default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying
|
||||
server addresses and their roles in a cluster, this coordinator will figure
|
||||
out how to set them up, give the underlying function the right targets for
|
||||
master sessions via a scope object and coordinate their training. The cluster
|
||||
consisting of standard servers needs to be brought up either with the standard
|
||||
server binary or with a binary running distribute coordinator with `task_type`
|
||||
set to non-client type which will then turn into standard servers.
|
||||
default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
|
||||
specifying server addresses and their roles in a cluster, this coordinator
|
||||
will figure out how to set them up, give the underlying function the right
|
||||
targets for master sessions via a scope object and coordinate their training.
|
||||
The cluster consisting of standard servers needs to be brought up either with
|
||||
the standard server binary or with a binary running distribute coordinator
|
||||
with `task_type` set to non-client type which will then turn into standard
|
||||
servers.
|
||||
|
||||
In addition to be the distribute coordinator, this is also the source of
|
||||
configurations for each job in the distributed training. As there are multiple
|
||||
@ -370,6 +443,14 @@ def run_distribute_coordinator(worker_fn,
|
||||
`worker_fn` depending whether it is between-graph training or in-graph
|
||||
replicated training.
|
||||
|
||||
The `strategy` object is expected to be a DistributionStrategy object which
|
||||
has implemented methods needed by distributed coordinator such as
|
||||
`configure(session_config, cluster_spec, task_type, task_id)` which configures
|
||||
the strategy object for a specific task and `should_init` property which
|
||||
instructs the distribute coordinator whether to run init ops for a task. The
|
||||
distribute coordinator will make a copy of the `strategy` object, call its
|
||||
`configure` method and pass it to `worker_fn` as an argument.
|
||||
|
||||
The `worker_fn` defines the training logic and is called under a its own
|
||||
worker context which can be accessed to via `get_current_worker_context`. A
|
||||
worker context provides access to configurations for each task, e.g. the
|
||||
@ -413,16 +494,20 @@ def run_distribute_coordinator(worker_fn,
|
||||
evaluation.
|
||||
|
||||
Args:
|
||||
worker_fn: the function to be called and given the access to a coordinator
|
||||
context object.
|
||||
worker_fn: the function to be called. The function should accept a
|
||||
`strategy` object and will be given access to a context object via a
|
||||
context manager scope.
|
||||
strategy: a DistributionStrategy object which specifying whether it should
|
||||
run between-graph replicated training or not, whether to run init ops,
|
||||
etc. This object will also be configured given `session_config`,
|
||||
`cluster_spc`, `task_type` and `task_id`.
|
||||
mode: in which mode this distribute coordinator runs.
|
||||
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
|
||||
in a cluster. If not set or empty, fall back to local training.
|
||||
task_type: the current task type, optional if this is a client.
|
||||
task_id: the current task id, optional if this is a client.
|
||||
between_graph: a boolean. It is only useful when `cluster_spec` is set and
|
||||
not empty. If true, it will use between-graph replicated training;
|
||||
otherwise it will use in-graph replicated training.
|
||||
session_config: an optional @{tf.ConfigProto} object which will be passed
|
||||
to `strategy`'s `configure` method and used to create a session.
|
||||
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
|
||||
|
||||
Raises:
|
||||
@ -448,15 +533,18 @@ def run_distribute_coordinator(worker_fn,
|
||||
|
||||
if not cluster_spec:
|
||||
# `mode` is ignored in the local case.
|
||||
_run_single_worker(worker_fn, None, None, None, rpc_layer)
|
||||
elif mode == CoordinatorMode.SPLIT_CLIENT:
|
||||
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
|
||||
rpc_layer)
|
||||
elif mode == CoordinatorMode.STANDALONE_CLIENT:
|
||||
# The client must know the cluster but servers in the cluster don't have to
|
||||
# know the client.
|
||||
if task_type in [_TaskType.CLIENT, None]:
|
||||
if between_graph:
|
||||
_run_between_graph_client(worker_fn, cluster_spec, rpc_layer)
|
||||
if strategy.between_graph:
|
||||
_run_between_graph_client(worker_fn, strategy, cluster_spec,
|
||||
session_config, rpc_layer)
|
||||
else:
|
||||
_run_in_graph_client(worker_fn, cluster_spec, rpc_layer)
|
||||
_run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
|
||||
rpc_layer)
|
||||
else:
|
||||
# If not a client job, run the standard server.
|
||||
server = _run_std_server(
|
||||
@ -471,19 +559,21 @@ def run_distribute_coordinator(worker_fn,
|
||||
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
|
||||
|
||||
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
|
||||
if between_graph:
|
||||
if strategy.between_graph:
|
||||
# All jobs run `worker_fn` if between-graph.
|
||||
_run_single_worker(worker_fn, cluster_spec, task_type, task_id,
|
||||
rpc_layer)
|
||||
_run_single_worker(worker_fn, strategy, cluster_spec, task_type,
|
||||
task_id, session_config, rpc_layer)
|
||||
else:
|
||||
# Only one node runs `worker_fn` if in-graph.
|
||||
context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer)
|
||||
context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
|
||||
if context.is_chief:
|
||||
_run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
|
||||
_run_single_worker(worker_fn, strategy, cluster_spec, None, None,
|
||||
session_config, rpc_layer)
|
||||
else:
|
||||
server.join()
|
||||
elif task_type == _TaskType.EVALUATOR:
|
||||
_run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer)
|
||||
_run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
|
||||
session_config, rpc_layer)
|
||||
else:
|
||||
if task_type != _TaskType.PS:
|
||||
raise ValueError("Unexpected task_type: %r" % task_type)
|
||||
|
@ -0,0 +1,31 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The context retrieval method for distribute coordinator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
_worker_context = threading.local()
|
||||
|
||||
|
||||
def get_current_worker_context():
|
||||
"""Returns the current task context."""
|
||||
try:
|
||||
return _worker_context.current
|
||||
except AttributeError:
|
||||
return None
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for distribute coordinator."""
|
||||
"""Tests for Distribute Coordinator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -37,6 +37,7 @@ except ImportError as _error:
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.distribute import distribute_coordinator
|
||||
from tensorflow.python.distribute import distribute_coordinator_context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -44,17 +45,17 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import monitored_session
|
||||
|
||||
|
||||
CHIEF = distribute_coordinator._TaskType.CHIEF
|
||||
WORKER = distribute_coordinator._TaskType.WORKER
|
||||
PS = distribute_coordinator._TaskType.PS
|
||||
EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
|
||||
|
||||
SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT
|
||||
STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT
|
||||
INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
|
||||
|
||||
RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server"
|
||||
|
||||
NUM_WORKERS = 3
|
||||
NUM_PS = 2
|
||||
|
||||
@ -74,6 +75,57 @@ def _strip_protocol(target):
|
||||
return target
|
||||
|
||||
|
||||
class MockStrategy(object):
|
||||
|
||||
def __init__(self,
|
||||
between_graph=False,
|
||||
should_init=None,
|
||||
should_checkpoint=None,
|
||||
should_save_summary=None):
|
||||
self._between_graph = between_graph
|
||||
self._should_init = should_init
|
||||
self._should_checkpoint = should_checkpoint
|
||||
self._should_save_summary = should_save_summary
|
||||
|
||||
@property
|
||||
def between_graph(self):
|
||||
return self._between_graph
|
||||
|
||||
def configure(self,
|
||||
session_options=None,
|
||||
cluster_spec=None,
|
||||
task_type=None,
|
||||
task_id=None):
|
||||
del session_options, cluster_spec, task_type
|
||||
if self._should_init is None:
|
||||
if task_id == 0:
|
||||
self._should_init = True
|
||||
else:
|
||||
self._should_init = False
|
||||
if self._should_checkpoint is None:
|
||||
if task_id == 0:
|
||||
self._should_checkpoint = True
|
||||
else:
|
||||
self._should_checkpoint = False
|
||||
if self._should_save_summary is None:
|
||||
if task_id == 0:
|
||||
self._should_save_summary = True
|
||||
else:
|
||||
self._should_save_summary = False
|
||||
|
||||
@property
|
||||
def should_init(self):
|
||||
return self._should_init
|
||||
|
||||
@property
|
||||
def should_checkpoint(self):
|
||||
return self._should_checkpoint
|
||||
|
||||
@property
|
||||
def should_save_summary(self):
|
||||
return self._should_save_summary
|
||||
|
||||
|
||||
class MockServer(object):
|
||||
|
||||
def __init__(self):
|
||||
@ -108,6 +160,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
self._result_correct = 0
|
||||
self._lock = threading.Lock()
|
||||
self._worker_context = {}
|
||||
self._strategy_property = {}
|
||||
self._std_servers = {}
|
||||
self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
|
||||
|
||||
@ -142,8 +195,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
|
||||
return cluster_spec
|
||||
|
||||
def _in_graph_worker_fn(self):
|
||||
context = distribute_coordinator.get_current_worker_context()
|
||||
def _in_graph_worker_fn(self, strategy):
|
||||
context = distribute_coordinator_context.get_current_worker_context()
|
||||
self.assertTrue(context is not None)
|
||||
with self._test_session(target=context.master_target) as sess:
|
||||
xs = []
|
||||
@ -164,22 +217,23 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
if result_value == expected:
|
||||
self._result_correct += 1
|
||||
|
||||
def _run_coordinator_in_thread(self, worker_fn, **kwargs):
|
||||
def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs):
|
||||
t = threading.Thread(
|
||||
target=distribute_coordinator.run_distribute_coordinator,
|
||||
args=(worker_fn,),
|
||||
args=(worker_fn, strategy),
|
||||
kwargs=kwargs)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec,
|
||||
**kwargs):
|
||||
def _run_multiple_coordinator_in_threads(self, worker_fn, strategy,
|
||||
cluster_spec, **kwargs):
|
||||
threads = {}
|
||||
for task_type in cluster_spec.keys():
|
||||
threads[task_type] = []
|
||||
for task_id in range(len(cluster_spec[task_type])):
|
||||
t = self._run_coordinator_in_thread(
|
||||
worker_fn,
|
||||
strategy,
|
||||
cluster_spec=cluster_spec,
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
@ -187,8 +241,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
threads[task_type].append(t)
|
||||
return threads
|
||||
|
||||
def _between_graph_worker_fn(self):
|
||||
context = distribute_coordinator.get_current_worker_context()
|
||||
def _between_graph_worker_fn(self, strategy):
|
||||
context = distribute_coordinator_context.get_current_worker_context()
|
||||
self.assertTrue(context is not None)
|
||||
with self._test_session(target=context.master_target) as sess:
|
||||
with ops.device("/job:ps/task:0"):
|
||||
@ -234,14 +288,50 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
with self._lock:
|
||||
self._result_correct += 1
|
||||
|
||||
def _dump_worker_context(self):
|
||||
def _between_graph_with_monitored_session(self, strategy):
|
||||
context = distribute_coordinator_context.get_current_worker_context()
|
||||
self.assertTrue(context is not None)
|
||||
with ops.device("/job:ps/task:0"):
|
||||
# TODO(yuefengz): investigate why not using resource variable will make
|
||||
# the test flaky.
|
||||
x = variable_scope.get_variable("x", initializer=10.0, use_resource=True)
|
||||
with ops.device("/job:ps/task:1"):
|
||||
y = variable_scope.get_variable("y", initializer=20.0, use_resource=True)
|
||||
|
||||
x_add = x.assign_add(2.0)
|
||||
y_sub = y.assign_sub(2.0)
|
||||
train_op = control_flow_ops.group([x_add, y_sub])
|
||||
|
||||
# The monitored session will run init or ready ops.
|
||||
with monitored_session.MonitoredSession() as sess:
|
||||
sess.run(train_op)
|
||||
|
||||
# Synchronize workers after one step to make sure they all have finished
|
||||
# training.
|
||||
if context.has_barrier:
|
||||
context.wait_for_other_workers()
|
||||
else:
|
||||
self._barrier.wait()
|
||||
|
||||
x_val, y_val = sess.run([x, y])
|
||||
|
||||
self.assertEqual(x_val, 16.0)
|
||||
self.assertEqual(y_val, 14.0)
|
||||
if x_val == 16.0 and y_val == 14.0:
|
||||
with self._lock:
|
||||
self._result_correct += 1
|
||||
|
||||
def _dump_worker_context(self, strategy):
|
||||
"""Dumps the propoerties of each worker context.
|
||||
|
||||
It dumps the context properties to a dict mapping from task_type to a list
|
||||
of tuples of master_target, num_workers, is_chief and distribute_mode, where
|
||||
the list is indexed by the task_id.
|
||||
|
||||
Args:
|
||||
strategy: a `DistributionStrategy` object.
|
||||
"""
|
||||
context = distribute_coordinator.get_current_worker_context()
|
||||
context = distribute_coordinator_context.get_current_worker_context()
|
||||
self.assertTrue(context is not None)
|
||||
task_type = str(context.task_type)
|
||||
task_id = context.task_id or 0
|
||||
@ -255,6 +345,25 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
context.is_chief,
|
||||
context.distributed_mode)
|
||||
|
||||
def _dump_strategy_property(self, strategy):
|
||||
context = distribute_coordinator_context.get_current_worker_context()
|
||||
self.assertTrue(context is not None)
|
||||
|
||||
self.assertEqual(context._strategy.should_init, strategy.should_init)
|
||||
self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
|
||||
self.assertEqual(context.should_save_summary, strategy.should_save_summary)
|
||||
|
||||
task_type = str(context.task_type)
|
||||
task_id = context.task_id or 0
|
||||
with self._lock:
|
||||
if task_type not in self._strategy_property:
|
||||
self._strategy_property[task_type] = []
|
||||
while len(self._strategy_property[task_type]) <= task_id:
|
||||
self._strategy_property[task_type].append(None)
|
||||
self._strategy_property[task_type][task_id] = (
|
||||
context._strategy.should_init, context.should_checkpoint,
|
||||
context.should_save_summary)
|
||||
|
||||
def _run_mock_std_server(self,
|
||||
session_config=None,
|
||||
cluster_spec=None,
|
||||
@ -274,22 +383,32 @@ class DistributeCoordinatorTestBase(test.TestCase):
|
||||
return server
|
||||
|
||||
|
||||
class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
|
||||
class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
|
||||
|
||||
def testInGraphSplitMode(self):
|
||||
"""Test it runs in-graph replication in split client mode."""
|
||||
def testInGraphStandaloneMode(self):
|
||||
"""Test it runs in-graph replication in standalone client mode."""
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
self._in_graph_worker_fn,
|
||||
cluster_spec=self._cluster_spec,
|
||||
between_graph=False)
|
||||
MockStrategy(between_graph=False),
|
||||
cluster_spec=self._cluster_spec)
|
||||
self.assertEqual(self._result_correct, 1)
|
||||
|
||||
def testBetweenGraph(self):
|
||||
"""Test it runs between-graph replication in split client mode."""
|
||||
"""Test it runs between-graph replication in standalone client mode."""
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
self._between_graph_worker_fn,
|
||||
cluster_spec=self._cluster_spec,
|
||||
between_graph=True)
|
||||
MockStrategy(between_graph=True),
|
||||
cluster_spec=self._cluster_spec)
|
||||
|
||||
# Each finished worker will increment self._result_correct.
|
||||
self.assertEqual(self._result_correct, NUM_WORKERS)
|
||||
|
||||
def testBetweenGraphWithMonitoredSession(self):
|
||||
"""Test monitored session in standalone client mode."""
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
self._between_graph_with_monitored_session,
|
||||
MockStrategy(between_graph=True),
|
||||
cluster_spec=self._cluster_spec)
|
||||
|
||||
# Each finished worker will increment self._result_correct.
|
||||
self.assertEqual(self._result_correct, NUM_WORKERS)
|
||||
@ -298,8 +417,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
|
||||
# Dumps the task contexts to the self._worker_context dict.
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
self._dump_worker_context,
|
||||
cluster_spec=self._cluster_spec,
|
||||
between_graph=True)
|
||||
MockStrategy(between_graph=True),
|
||||
cluster_spec=self._cluster_spec)
|
||||
|
||||
# There is only one type of task and there three such tasks.
|
||||
self.assertEqual(len(self._worker_context), 1)
|
||||
@ -318,12 +437,30 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
|
||||
self._worker_context[WORKER][2],
|
||||
(_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
|
||||
|
||||
def testBetweenGraphStrategyProperties(self):
|
||||
# Dumps properties of the strategy objects.
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
self._dump_strategy_property,
|
||||
MockStrategy(between_graph=True, should_init=True),
|
||||
cluster_spec=self._cluster_spec)
|
||||
|
||||
# There is only one type of task and there three such tasks.
|
||||
self.assertEqual(len(self._strategy_property), 1)
|
||||
self.assertTrue(WORKER in self._strategy_property)
|
||||
self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
|
||||
|
||||
# Check whether each task has the right properties of should_init,
|
||||
# should_checkpoint and should_save_summary.
|
||||
self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
|
||||
self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
|
||||
self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
|
||||
|
||||
def testInGraphContext(self):
|
||||
# Dumps the task contexts to the self._worker_context dict.
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
self._dump_worker_context,
|
||||
cluster_spec=self._cluster_spec,
|
||||
between_graph=False)
|
||||
MockStrategy(between_graph=False),
|
||||
cluster_spec=self._cluster_spec)
|
||||
|
||||
# There is only a "None" task in the dumped task context.
|
||||
self.assertEqual(len(self._worker_context), 1)
|
||||
@ -339,7 +476,9 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
|
||||
def testLocalContext(self):
|
||||
# Dumps the task contexts to the self._worker_context dict.
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
self._dump_worker_context, cluster_spec=None, between_graph=True)
|
||||
self._dump_worker_context,
|
||||
MockStrategy(between_graph=False),
|
||||
cluster_spec=None)
|
||||
|
||||
# There is only a "None" task.
|
||||
self.assertEqual(len(self._worker_context), 1)
|
||||
@ -348,7 +487,7 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
|
||||
|
||||
# Check whether each task has the right master_target, num_workers, is_chief
|
||||
# and distributed_mode.
|
||||
self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False))
|
||||
self.assertEqual(self._worker_context["None"][0], ("", 0, True, False))
|
||||
|
||||
def testBetweenGraphContextWithChief(self):
|
||||
# Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
|
||||
@ -358,8 +497,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
|
||||
# Dumps the task contexts to the self._worker_context dict.
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
self._dump_worker_context,
|
||||
MockStrategy(between_graph=True),
|
||||
cluster_spec=cluster_spec,
|
||||
between_graph=True,
|
||||
rpc_layer="grpc")
|
||||
|
||||
# There are one CHIEF and three workers.
|
||||
@ -391,8 +530,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
|
||||
# Dumps the task contexts to the self._worker_context dict.
|
||||
distribute_coordinator.run_distribute_coordinator(
|
||||
self._dump_worker_context,
|
||||
MockStrategy(between_graph=False),
|
||||
cluster_spec=cluster_spec,
|
||||
between_graph=False,
|
||||
rpc_layer=None)
|
||||
|
||||
# There are one "None" task and one EVALUATOR task.
|
||||
@ -417,8 +556,8 @@ class DistributeCoordinatorTestInpendentWorkerMode(
|
||||
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
|
||||
threads = self._run_multiple_coordinator_in_threads(
|
||||
self._in_graph_worker_fn,
|
||||
MockStrategy(between_graph=False),
|
||||
cluster_spec,
|
||||
between_graph=False,
|
||||
mode=INDEPENDENT_WORKER)
|
||||
threads[WORKER][0].join()
|
||||
self.assertEqual(self._result_correct, 1)
|
||||
@ -428,8 +567,22 @@ class DistributeCoordinatorTestInpendentWorkerMode(
|
||||
num_workers=NUM_WORKERS, num_ps=NUM_PS)
|
||||
threads = self._run_multiple_coordinator_in_threads(
|
||||
self._between_graph_worker_fn,
|
||||
MockStrategy(between_graph=True),
|
||||
cluster_spec,
|
||||
mode=INDEPENDENT_WORKER)
|
||||
for task_id in range(NUM_WORKERS):
|
||||
threads[WORKER][task_id].join()
|
||||
|
||||
# Each finished worker will increment self._result_correct.
|
||||
self.assertEqual(self._result_correct, NUM_WORKERS)
|
||||
|
||||
def testBetweenGraphWithMonitoredSession(self):
|
||||
cluster_spec = self._create_cluster_spec(
|
||||
num_workers=NUM_WORKERS, num_ps=NUM_PS)
|
||||
threads = self._run_multiple_coordinator_in_threads(
|
||||
self._between_graph_with_monitored_session,
|
||||
MockStrategy(between_graph=True),
|
||||
cluster_spec,
|
||||
between_graph=True,
|
||||
mode=INDEPENDENT_WORKER)
|
||||
for task_id in range(NUM_WORKERS):
|
||||
threads[WORKER][task_id].join()
|
||||
@ -444,9 +597,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
|
||||
self._run_mock_std_server):
|
||||
threads = self._run_multiple_coordinator_in_threads(
|
||||
self._dump_worker_context,
|
||||
MockStrategy(between_graph=True),
|
||||
cluster_spec,
|
||||
mode=INDEPENDENT_WORKER,
|
||||
between_graph=True,
|
||||
rpc_layer=None)
|
||||
for task_id in range(NUM_WORKERS):
|
||||
threads[WORKER][task_id].join()
|
||||
@ -476,6 +629,31 @@ class DistributeCoordinatorTestInpendentWorkerMode(
|
||||
self.assertFalse(self._std_servers[WORKER][1].joined)
|
||||
self.assertFalse(self._std_servers[WORKER][2].joined)
|
||||
|
||||
def testBetweenGraphStrategyProperties(self):
|
||||
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
|
||||
# Dumps properties of the strategy objects.
|
||||
with test.mock.patch.object(distribute_coordinator, "_run_std_server",
|
||||
self._run_mock_std_server):
|
||||
threads = self._run_multiple_coordinator_in_threads(
|
||||
self._dump_strategy_property,
|
||||
MockStrategy(between_graph=True, should_init=True),
|
||||
cluster_spec,
|
||||
mode=INDEPENDENT_WORKER,
|
||||
rpc_layer=None)
|
||||
for task_id in range(NUM_WORKERS):
|
||||
threads[WORKER][task_id].join()
|
||||
|
||||
# There is only one type of task and there three such tasks.
|
||||
self.assertEqual(len(self._strategy_property), 1)
|
||||
self.assertTrue(WORKER in self._strategy_property)
|
||||
self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
|
||||
|
||||
# Check whether each task has the right properties of should_init,
|
||||
# should_checkpoint and should_save_summary.
|
||||
self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
|
||||
self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
|
||||
self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
|
||||
|
||||
def testInGraphContext(self):
|
||||
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
|
||||
# Dumps the task contexts and std server arguments.
|
||||
@ -483,9 +661,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
|
||||
self._run_mock_std_server):
|
||||
threads = self._run_multiple_coordinator_in_threads(
|
||||
self._dump_worker_context,
|
||||
MockStrategy(between_graph=False),
|
||||
cluster_spec,
|
||||
mode=INDEPENDENT_WORKER,
|
||||
between_graph=False,
|
||||
rpc_layer=None)
|
||||
for task_id in range(NUM_WORKERS):
|
||||
threads[WORKER][task_id].join()
|
||||
@ -519,9 +697,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
|
||||
self._run_mock_std_server):
|
||||
threads = self._run_multiple_coordinator_in_threads(
|
||||
self._dump_worker_context,
|
||||
MockStrategy(between_graph=False),
|
||||
cluster_spec,
|
||||
mode=INDEPENDENT_WORKER,
|
||||
between_graph=False,
|
||||
rpc_layer=None)
|
||||
for task_id in range(NUM_WORKERS):
|
||||
threads[WORKER][task_id].join()
|
||||
|
@ -25,6 +25,7 @@ import sys
|
||||
import six
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.distribute import distribute_coordinator_context
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -284,6 +285,63 @@ class Scaffold(object):
|
||||
resources.initialize_resources(resources.local_resources()))
|
||||
|
||||
|
||||
def _create_monitored_session_with_worker_context(worker_context, # pylint: disable=missing-docstring
|
||||
scaffold,
|
||||
checkpoint_dir=None,
|
||||
hooks=None,
|
||||
chief_only_hooks=None,
|
||||
save_checkpoint_secs=None,
|
||||
save_summaries_steps=None,
|
||||
save_summaries_secs=None,
|
||||
config=None,
|
||||
stop_grace_period_secs=120,
|
||||
log_step_count_steps=100,
|
||||
max_wait_secs=7200,
|
||||
save_checkpoint_steps=None,
|
||||
summary_dir=None):
|
||||
all_hooks = []
|
||||
if hooks:
|
||||
all_hooks.extend(hooks)
|
||||
if chief_only_hooks and worker_context.is_chief:
|
||||
all_hooks.extend(chief_only_hooks)
|
||||
|
||||
summary_dir = summary_dir or checkpoint_dir
|
||||
if summary_dir and worker_context.should_save_summary:
|
||||
if log_step_count_steps and log_step_count_steps > 0:
|
||||
all_hooks.append(
|
||||
basic_session_run_hooks.StepCounterHook(
|
||||
output_dir=summary_dir, every_n_steps=log_step_count_steps))
|
||||
|
||||
if (save_summaries_steps and save_summaries_steps > 0) or (
|
||||
save_summaries_secs and save_summaries_secs > 0):
|
||||
all_hooks.append(
|
||||
basic_session_run_hooks.SummarySaverHook(
|
||||
scaffold=scaffold,
|
||||
save_steps=save_summaries_steps,
|
||||
save_secs=save_summaries_secs,
|
||||
output_dir=summary_dir))
|
||||
|
||||
if checkpoint_dir and worker_context.should_checkpoint:
|
||||
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
|
||||
save_checkpoint_steps and save_checkpoint_steps > 0):
|
||||
all_hooks.append(
|
||||
basic_session_run_hooks.CheckpointSaverHook(
|
||||
checkpoint_dir,
|
||||
save_steps=save_checkpoint_steps,
|
||||
save_secs=save_checkpoint_secs,
|
||||
scaffold=scaffold))
|
||||
|
||||
session_creator = worker_context.session_creator(
|
||||
scaffold,
|
||||
config=config,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
max_wait_secs=max_wait_secs)
|
||||
return MonitoredSession(
|
||||
session_creator=session_creator,
|
||||
hooks=all_hooks,
|
||||
stop_grace_period_secs=stop_grace_period_secs)
|
||||
|
||||
|
||||
@tf_export('train.MonitoredTrainingSession')
|
||||
def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
|
||||
is_chief=True,
|
||||
@ -373,14 +431,35 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
|
||||
save_checkpoint_steps = None
|
||||
|
||||
scaffold = scaffold or Scaffold()
|
||||
worker_context = distribute_coordinator_context.get_current_worker_context()
|
||||
|
||||
if worker_context:
|
||||
return _create_monitored_session_with_worker_context(
|
||||
worker_context,
|
||||
scaffold,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
hooks=hooks,
|
||||
chief_only_hooks=chief_only_hooks,
|
||||
save_checkpoint_secs=save_checkpoint_secs,
|
||||
save_summaries_steps=save_summaries_steps,
|
||||
save_summaries_secs=save_summaries_secs,
|
||||
config=config,
|
||||
stop_grace_period_secs=stop_grace_period_secs,
|
||||
log_step_count_steps=log_step_count_steps,
|
||||
max_wait_secs=max_wait_secs,
|
||||
save_checkpoint_steps=save_checkpoint_steps,
|
||||
summary_dir=summary_dir)
|
||||
|
||||
if not is_chief:
|
||||
session_creator = WorkerSessionCreator(
|
||||
scaffold=scaffold,
|
||||
master=master,
|
||||
config=config,
|
||||
max_wait_secs=max_wait_secs)
|
||||
return MonitoredSession(session_creator=session_creator, hooks=hooks or [],
|
||||
stop_grace_period_secs=stop_grace_period_secs)
|
||||
return MonitoredSession(
|
||||
session_creator=session_creator,
|
||||
hooks=hooks or [],
|
||||
stop_grace_period_secs=stop_grace_period_secs)
|
||||
|
||||
all_hooks = []
|
||||
if chief_only_hooks:
|
||||
@ -400,25 +479,29 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
|
||||
|
||||
if (save_summaries_steps and save_summaries_steps > 0) or (
|
||||
save_summaries_secs and save_summaries_secs > 0):
|
||||
all_hooks.append(basic_session_run_hooks.SummarySaverHook(
|
||||
scaffold=scaffold,
|
||||
save_steps=save_summaries_steps,
|
||||
save_secs=save_summaries_secs,
|
||||
output_dir=summary_dir))
|
||||
all_hooks.append(
|
||||
basic_session_run_hooks.SummarySaverHook(
|
||||
scaffold=scaffold,
|
||||
save_steps=save_summaries_steps,
|
||||
save_secs=save_summaries_secs,
|
||||
output_dir=summary_dir))
|
||||
|
||||
if checkpoint_dir:
|
||||
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
|
||||
save_checkpoint_steps and save_checkpoint_steps > 0):
|
||||
all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(
|
||||
checkpoint_dir,
|
||||
save_steps=save_checkpoint_steps,
|
||||
save_secs=save_checkpoint_secs,
|
||||
scaffold=scaffold))
|
||||
all_hooks.append(
|
||||
basic_session_run_hooks.CheckpointSaverHook(
|
||||
checkpoint_dir,
|
||||
save_steps=save_checkpoint_steps,
|
||||
save_secs=save_checkpoint_secs,
|
||||
scaffold=scaffold))
|
||||
|
||||
if hooks:
|
||||
all_hooks.extend(hooks)
|
||||
return MonitoredSession(session_creator=session_creator, hooks=all_hooks,
|
||||
stop_grace_period_secs=stop_grace_period_secs)
|
||||
return MonitoredSession(
|
||||
session_creator=session_creator,
|
||||
hooks=all_hooks,
|
||||
stop_grace_period_secs=stop_grace_period_secs)
|
||||
|
||||
|
||||
@tf_export('train.SessionCreator')
|
||||
@ -546,6 +629,11 @@ class _MonitoredSession(object):
|
||||
self._hooks = hooks or []
|
||||
for h in self._hooks:
|
||||
h.begin()
|
||||
|
||||
worker_context = distribute_coordinator_context.get_current_worker_context()
|
||||
if not session_creator and worker_context:
|
||||
session_creator = worker_context.session_creator()
|
||||
|
||||
# Create the session.
|
||||
self._coordinated_creator = self._CoordinatedSessionCreator(
|
||||
session_creator=session_creator or ChiefSessionCreator(),
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.contrib.testing.python.framework import util_test
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import debug_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.distribute import distribute_coordinator
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -381,6 +382,119 @@ class MonitoredTrainingSessionTest(test.TestCase):
|
||||
self.assertEqual(0, session.run(gstep))
|
||||
|
||||
|
||||
class MockStrategy(object):
|
||||
|
||||
def __init__(self,
|
||||
between_graph=False,
|
||||
should_init=True,
|
||||
should_checkpoint=None,
|
||||
should_save_summary=None):
|
||||
self._between_graph = between_graph
|
||||
self._should_init = should_init
|
||||
self._should_checkpoint = should_checkpoint
|
||||
self._should_save_summary = should_save_summary
|
||||
|
||||
@property
|
||||
def between_graph(self):
|
||||
return self._between_graph
|
||||
|
||||
@property
|
||||
def should_init(self):
|
||||
return self._should_init
|
||||
|
||||
@property
|
||||
def should_checkpoint(self):
|
||||
return self._should_checkpoint
|
||||
|
||||
@property
|
||||
def should_save_summary(self):
|
||||
return self._should_save_summary
|
||||
|
||||
|
||||
class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
|
||||
"""Test distribute coordinator controls summary saving and checkpointing."""
|
||||
|
||||
def test_summary_hook_enabled(self):
|
||||
context = distribute_coordinator._WorkerContext(
|
||||
MockStrategy(should_save_summary=True), None, None, None)
|
||||
|
||||
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
|
||||
with ops.Graph().as_default():
|
||||
gstep = variables_lib.get_or_create_global_step()
|
||||
new_gstep = state_ops.assign_add(gstep, 1)
|
||||
summary.scalar('my_summary_tag', new_gstep * 2)
|
||||
with context, monitored_session.MonitoredTrainingSession(
|
||||
checkpoint_dir=logdir,
|
||||
save_summaries_steps=100,
|
||||
log_step_count_steps=10) as session:
|
||||
for _ in range(101):
|
||||
session.run(new_gstep)
|
||||
|
||||
summaries = util_test.latest_summaries(logdir)
|
||||
tags = [s.summary.value[0].tag for s in summaries]
|
||||
self.assertIn('my_summary_tag', tags)
|
||||
self.assertIn('global_step/sec', tags)
|
||||
|
||||
def test_summary_hook_disabled(self):
|
||||
context = distribute_coordinator._WorkerContext(
|
||||
MockStrategy(should_save_summary=False), None, None, None)
|
||||
|
||||
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
|
||||
with ops.Graph().as_default():
|
||||
gstep = variables_lib.get_or_create_global_step()
|
||||
new_gstep = state_ops.assign_add(gstep, 1)
|
||||
summary.scalar('my_summary_tag', new_gstep * 2)
|
||||
with context, monitored_session.MonitoredTrainingSession(
|
||||
checkpoint_dir=logdir,
|
||||
save_summaries_steps=100,
|
||||
log_step_count_steps=10) as session:
|
||||
for _ in range(101):
|
||||
session.run(new_gstep)
|
||||
|
||||
# No summary is saved.
|
||||
summaries = util_test.latest_summaries(logdir)
|
||||
self.assertEqual(len(summaries), 0)
|
||||
|
||||
def test_checkpoint_hook_enabled(self):
|
||||
context = distribute_coordinator._WorkerContext(
|
||||
MockStrategy(should_checkpoint=True), None, None, None)
|
||||
|
||||
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
|
||||
with ops.Graph().as_default():
|
||||
gstep = variables_lib.get_or_create_global_step()
|
||||
new_gstep = state_ops.assign_add(gstep, 1)
|
||||
with context, monitored_session.MonitoredTrainingSession(
|
||||
checkpoint_dir=logdir,
|
||||
save_checkpoint_steps=100,
|
||||
log_step_count_steps=10) as session:
|
||||
for _ in range(100):
|
||||
session.run(new_gstep)
|
||||
|
||||
# A restart will find the checkpoint and recover automatically.
|
||||
with monitored_session.MonitoredTrainingSession(
|
||||
is_chief=True, checkpoint_dir=logdir) as session:
|
||||
self.assertEqual(100, session.run(gstep))
|
||||
|
||||
def test_checkpoint_hook_disabled(self):
|
||||
context = distribute_coordinator._WorkerContext(
|
||||
MockStrategy(should_checkpoint=False), None, None, None)
|
||||
|
||||
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
|
||||
with ops.Graph().as_default():
|
||||
gstep = variables_lib.get_or_create_global_step()
|
||||
new_gstep = state_ops.assign_add(gstep, 1)
|
||||
with context, monitored_session.MonitoredTrainingSession(
|
||||
checkpoint_dir=logdir,
|
||||
save_checkpoint_steps=100,
|
||||
log_step_count_steps=10) as session:
|
||||
for _ in range(100):
|
||||
session.run(new_gstep)
|
||||
|
||||
# No checkpoint is saved.
|
||||
checkpoint = checkpoint_management.latest_checkpoint(logdir)
|
||||
self.assertIsNone(checkpoint)
|
||||
|
||||
|
||||
class StopAtNSession(monitored_session._WrappedSession):
|
||||
"""A wrapped session that stops at the N-th call to _check_stop."""
|
||||
|
||||
@ -1365,8 +1479,8 @@ class MonitoredSessionTest(test.TestCase):
|
||||
with monitored_session.MonitoredSession(
|
||||
session_creator=monitored_session.ChiefSessionCreator(
|
||||
scaffold,
|
||||
checkpoint_filename_with_path=
|
||||
checkpoint_management.latest_checkpoint(logdir))) as session:
|
||||
checkpoint_filename_with_path=checkpoint_management.
|
||||
latest_checkpoint(logdir))) as session:
|
||||
self.assertEqual(2, session.run(gstep))
|
||||
|
||||
def test_retry_initialization_on_aborted_error(self):
|
||||
|
Loading…
Reference in New Issue
Block a user