Make the check for whether it is multi_worker_mode not rely on TF_CONFIG environment variable, but delegate to strategy extended objects.

PiperOrigin-RevId: 264883884
This commit is contained in:
Rick Chao 2019-08-22 11:41:51 -07:00 committed by TensorFlower Gardener
parent 952ae3f70c
commit 0390084145
10 changed files with 57 additions and 92 deletions

View File

@ -255,3 +255,7 @@ class CentralStorageStrategyV1(distribute_lib.StrategyV1):
compute_devices=compute_devices,
parameter_device=parameter_device))
__init__.__doc__ = CentralStorageStrategy.__init__.__doc__
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
return False

View File

@ -496,6 +496,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
logging.warning("Enabled NCCL communication but no GPUs detected/"
"specified.")
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
return self._num_workers > 1
@property
def experimental_between_graph(self):
return True

View File

@ -97,8 +97,6 @@ from __future__ import print_function
import copy
import enum # pylint: disable=g-bad-import-order
import json
import os
import threading
import weakref
@ -126,7 +124,6 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import loss_reduction
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import server_lib
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@ -953,20 +950,6 @@ class Strategy(object):
def __copy__(self):
raise RuntimeError("Must only deepcopy DistributionStrategy.")
def _in_multi_worker_mode(self):
"""Method to infer if this `Strategy` is working in multi-worker settings.
Experimental. Signature and implementation are subject to change.
Returns:
Whether this strategy indicates working in multi-worker settings.
"""
# TODO(b/137857865): Check for whether it is multi_worker_mode should not
# rely on TF_CONFIG environment variable.
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
cluster_spec = server_lib.ClusterSpec(tf_config.get("cluster", {}))
return tf_config and "master" not in cluster_spec.jobs
# TF v1.x version has additional deprecated APIs
@tf_export(v1=["distribute.Strategy"])
@ -1659,6 +1642,23 @@ class StrategyExtendedV2(object):
def _update_config_proto(self, config_proto):
return copy.deepcopy(config_proto)
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings.
Multi-worker training refers to the setup where the training is
distributed across multiple workers, as opposed to the case where
only a local process performs the training. This function is
used by higher-level apis such as Keras' `model.fit()` to infer
for example whether or not a distribute coordinator should be run,
and thus TensorFlow servers should be started for communication
with other servers in the cluster, or whether or not saving/restoring
checkpoints is relevant for preemption fault tolerance.
Subclasses should override this to provide whether the strategy is
currently in multi-worker setup.
"""
raise NotImplementedError("must be implemented in descendants")
@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring
class StrategyExtendedV1(StrategyExtendedV2):
@ -2200,6 +2200,11 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
def non_slot_devices(self, var_list):
return min(var_list, key=lambda x: x.name)
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
# Default strategy doesn't indicate multi-worker training.
return False
# TODO(priyag): This should inherit from `InputIterator`, once dependency
# issues have been resolved.
class DefaultInputIterator(object):

View File

@ -420,6 +420,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
self._inferred_cross_device_ops = None if self._cross_device_ops else (
cross_device_ops_lib.choose_the_best(devices))
self._host_input_device = numpy_dataset.SingleDevice("/cpu:0")
self._is_multi_worker_training = False
def _initialize_multi_worker(self, devices):
"""Initializes the object for multi-worker training."""
@ -446,6 +447,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
self._device_map = values.ReplicaDeviceMap(devices)
self._input_workers = input_lib.InputWorkers(
self._device_map, worker_devices)
self._is_multi_worker_training = True
if len(workers) > 1:
if not isinstance(self._cross_device_ops,
@ -795,6 +797,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
"""
return True
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
return False
class _MirroredReplicaThread(threading.Thread):
"""A thread that runs() a function on a device."""

View File

@ -383,6 +383,10 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
def value_container(self, value):
return value
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
return False
@property
def _num_replicas_in_sync(self):
return 1

View File

@ -587,6 +587,12 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
"/job:%s/task:%d" % (self._task_type, self._task_id))
return updated_config
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
# With a PS job, PS strategy should always be considered as in multi
# worker mode.
return True
@property
def _num_replicas_in_sync(self):
return self._device_map.num_replicas_in_graph

View File

@ -685,6 +685,15 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
self._tpu_function_cache[fn] = tpu_function
return tpu_function
def _in_multi_worker_mode(self):
"""Whether this strategy indicates working in multi-worker settings."""
# TPUStrategy has different distributed training structure that the whole
# cluster should be treated as single worker from higher-level (e.g. Keras)
# library's point of view.
# TODO(rchao): Revisit this as we design a fault-tolerance solution for
# TPUStrategy.
return False
class _TPUReplicaContext(distribute_lib.ReplicaContext):
"""Replication Context class for TPU Strategy."""

View File

@ -5817,7 +5817,7 @@ def configure_and_create_distributed_session(distribution_strategy):
set_session(session)
if distribution_strategy._in_multi_worker_mode():
if distribution_strategy.extended._in_multi_worker_mode():
dc.run_distribute_coordinator(
_create_session,
distribution_strategy,

View File

@ -34,10 +34,8 @@ from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_worker_test_base as test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks
from tensorflow.python.keras import metrics as metrics_module
@ -201,77 +199,6 @@ class MultiWorkerVerificationCallback(callbacks.Callback):
})
# TODO(yuefengz): right now, fit or evaluate has to be called under distribution
# strategy's scope.
def _run_standalone_client(test_obj, strategy, cluster_spec):
input_shape = (28, 28, 1)
with strategy.scope():
orig_model = multi_worker_testing_utils.get_mnist_model(input_shape)
def worker_fn(strategy):
with ops.Graph().as_default():
batch_size = 64
steps = 2
with strategy.scope():
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
batch_size, steps)
model = _clone_and_build_model(orig_model, strategy)
orig_loss, orig_acc = model.evaluate(train_ds, steps=steps)
# Workaround for the metrics issue (b/122928955) in async training. This
# can only be used in standalone client mode.
multi_worker_util.wait_for_other_workers()
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
multi_worker_util.wait_for_other_workers()
trained_loss, trained_acc = model.evaluate(train_ds, steps=steps)
test_obj.assertLessEqual(trained_loss, orig_loss)
test_obj.assertGreaterEqual(trained_acc, orig_acc)
dc.run_distribute_coordinator(
worker_fn,
strategy,
mode=dc.CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=cluster_spec)
class KerasMultiWorkerTestStandaloneClient(test.TestCase,
parameterized.TestCase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
super(KerasMultiWorkerTestStandaloneClient, cls).setUpClass()
cls._cluster_spec = test_base.create_in_process_cluster(
num_workers=2, num_ps=1, has_eval=False)
@combinations.generate(
combinations.combine(
mode=['graph'],
strategy_cls=[
ParameterServerStrategy,
collective_strategy.CollectiveAllReduceStrategy,
],
required_gpus=[0, 1]))
def testSimpleModelStandaloneClient(self, strategy_cls):
# With standalone client, training_utils.should_run_multi_worker returns
# False which means the distribute coordinator won't be called again in
# `fit`. This is still correct and intended since session is still
# configured under distribute coordinator's worker context and distribution
# strategy object is already configured by distribute coordinator for
# multi-worker training.
# The logic should be much clearer once standalone client is merged into
# core Keras as well.
strategy = strategy_cls()
_run_standalone_client(self, strategy, self._cluster_spec)
class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase,
parameterized.TestCase):

View File

@ -2901,7 +2901,7 @@ class Model(network.Network):
# Otherwise, use the strategy whose scope this is in.
if not strategy and distribution_strategy_context.has_strategy():
strategy = distribution_strategy_context.get_strategy()
return strategy and strategy._in_multi_worker_mode() # pylint: disable=protected-access
return strategy and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
class DistributedCallbackModel(Model):