Make the check for whether it is multi_worker_mode not rely on TF_CONFIG environment variable, but delegate to strategy extended objects.
PiperOrigin-RevId: 264883884
This commit is contained in:
parent
952ae3f70c
commit
0390084145
@ -255,3 +255,7 @@ class CentralStorageStrategyV1(distribute_lib.StrategyV1):
|
||||
compute_devices=compute_devices,
|
||||
parameter_device=parameter_device))
|
||||
__init__.__doc__ = CentralStorageStrategy.__init__.__doc__
|
||||
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Whether this strategy indicates working in multi-worker settings."""
|
||||
return False
|
||||
|
@ -496,6 +496,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
logging.warning("Enabled NCCL communication but no GPUs detected/"
|
||||
"specified.")
|
||||
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Whether this strategy indicates working in multi-worker settings."""
|
||||
return self._num_workers > 1
|
||||
|
||||
@property
|
||||
def experimental_between_graph(self):
|
||||
return True
|
||||
|
@ -97,8 +97,6 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import weakref
|
||||
|
||||
@ -126,7 +124,6 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.losses import loss_reduction
|
||||
from tensorflow.python.ops.losses import losses_impl
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
@ -953,20 +950,6 @@ class Strategy(object):
|
||||
def __copy__(self):
|
||||
raise RuntimeError("Must only deepcopy DistributionStrategy.")
|
||||
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Method to infer if this `Strategy` is working in multi-worker settings.
|
||||
|
||||
Experimental. Signature and implementation are subject to change.
|
||||
|
||||
Returns:
|
||||
Whether this strategy indicates working in multi-worker settings.
|
||||
"""
|
||||
# TODO(b/137857865): Check for whether it is multi_worker_mode should not
|
||||
# rely on TF_CONFIG environment variable.
|
||||
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
||||
cluster_spec = server_lib.ClusterSpec(tf_config.get("cluster", {}))
|
||||
return tf_config and "master" not in cluster_spec.jobs
|
||||
|
||||
|
||||
# TF v1.x version has additional deprecated APIs
|
||||
@tf_export(v1=["distribute.Strategy"])
|
||||
@ -1659,6 +1642,23 @@ class StrategyExtendedV2(object):
|
||||
def _update_config_proto(self, config_proto):
|
||||
return copy.deepcopy(config_proto)
|
||||
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Whether this strategy indicates working in multi-worker settings.
|
||||
|
||||
Multi-worker training refers to the setup where the training is
|
||||
distributed across multiple workers, as opposed to the case where
|
||||
only a local process performs the training. This function is
|
||||
used by higher-level apis such as Keras' `model.fit()` to infer
|
||||
for example whether or not a distribute coordinator should be run,
|
||||
and thus TensorFlow servers should be started for communication
|
||||
with other servers in the cluster, or whether or not saving/restoring
|
||||
checkpoints is relevant for preemption fault tolerance.
|
||||
|
||||
Subclasses should override this to provide whether the strategy is
|
||||
currently in multi-worker setup.
|
||||
"""
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring
|
||||
class StrategyExtendedV1(StrategyExtendedV2):
|
||||
@ -2200,6 +2200,11 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
def non_slot_devices(self, var_list):
|
||||
return min(var_list, key=lambda x: x.name)
|
||||
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Whether this strategy indicates working in multi-worker settings."""
|
||||
# Default strategy doesn't indicate multi-worker training.
|
||||
return False
|
||||
|
||||
# TODO(priyag): This should inherit from `InputIterator`, once dependency
|
||||
# issues have been resolved.
|
||||
class DefaultInputIterator(object):
|
||||
|
@ -420,6 +420,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._inferred_cross_device_ops = None if self._cross_device_ops else (
|
||||
cross_device_ops_lib.choose_the_best(devices))
|
||||
self._host_input_device = numpy_dataset.SingleDevice("/cpu:0")
|
||||
self._is_multi_worker_training = False
|
||||
|
||||
def _initialize_multi_worker(self, devices):
|
||||
"""Initializes the object for multi-worker training."""
|
||||
@ -446,6 +447,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._device_map = values.ReplicaDeviceMap(devices)
|
||||
self._input_workers = input_lib.InputWorkers(
|
||||
self._device_map, worker_devices)
|
||||
self._is_multi_worker_training = True
|
||||
|
||||
if len(workers) > 1:
|
||||
if not isinstance(self._cross_device_ops,
|
||||
@ -795,6 +797,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
"""
|
||||
return True
|
||||
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Whether this strategy indicates working in multi-worker settings."""
|
||||
return False
|
||||
|
||||
|
||||
class _MirroredReplicaThread(threading.Thread):
|
||||
"""A thread that runs() a function on a device."""
|
||||
|
@ -383,6 +383,10 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
def value_container(self, value):
|
||||
return value
|
||||
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Whether this strategy indicates working in multi-worker settings."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def _num_replicas_in_sync(self):
|
||||
return 1
|
||||
|
@ -587,6 +587,12 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
"/job:%s/task:%d" % (self._task_type, self._task_id))
|
||||
return updated_config
|
||||
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Whether this strategy indicates working in multi-worker settings."""
|
||||
# With a PS job, PS strategy should always be considered as in multi
|
||||
# worker mode.
|
||||
return True
|
||||
|
||||
@property
|
||||
def _num_replicas_in_sync(self):
|
||||
return self._device_map.num_replicas_in_graph
|
||||
|
@ -685,6 +685,15 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._tpu_function_cache[fn] = tpu_function
|
||||
return tpu_function
|
||||
|
||||
def _in_multi_worker_mode(self):
|
||||
"""Whether this strategy indicates working in multi-worker settings."""
|
||||
# TPUStrategy has different distributed training structure that the whole
|
||||
# cluster should be treated as single worker from higher-level (e.g. Keras)
|
||||
# library's point of view.
|
||||
# TODO(rchao): Revisit this as we design a fault-tolerance solution for
|
||||
# TPUStrategy.
|
||||
return False
|
||||
|
||||
|
||||
class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
||||
"""Replication Context class for TPU Strategy."""
|
||||
|
@ -5817,7 +5817,7 @@ def configure_and_create_distributed_session(distribution_strategy):
|
||||
|
||||
set_session(session)
|
||||
|
||||
if distribution_strategy._in_multi_worker_mode():
|
||||
if distribution_strategy.extended._in_multi_worker_mode():
|
||||
dc.run_distribute_coordinator(
|
||||
_create_session,
|
||||
distribution_strategy,
|
||||
|
@ -34,10 +34,8 @@ from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import multi_worker_test_base as test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import callbacks
|
||||
from tensorflow.python.keras import metrics as metrics_module
|
||||
@ -201,77 +199,6 @@ class MultiWorkerVerificationCallback(callbacks.Callback):
|
||||
})
|
||||
|
||||
|
||||
# TODO(yuefengz): right now, fit or evaluate has to be called under distribution
|
||||
# strategy's scope.
|
||||
def _run_standalone_client(test_obj, strategy, cluster_spec):
|
||||
input_shape = (28, 28, 1)
|
||||
with strategy.scope():
|
||||
orig_model = multi_worker_testing_utils.get_mnist_model(input_shape)
|
||||
|
||||
def worker_fn(strategy):
|
||||
with ops.Graph().as_default():
|
||||
batch_size = 64
|
||||
steps = 2
|
||||
|
||||
with strategy.scope():
|
||||
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
|
||||
batch_size, steps)
|
||||
model = _clone_and_build_model(orig_model, strategy)
|
||||
|
||||
orig_loss, orig_acc = model.evaluate(train_ds, steps=steps)
|
||||
|
||||
# Workaround for the metrics issue (b/122928955) in async training. This
|
||||
# can only be used in standalone client mode.
|
||||
multi_worker_util.wait_for_other_workers()
|
||||
|
||||
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
|
||||
|
||||
multi_worker_util.wait_for_other_workers()
|
||||
|
||||
trained_loss, trained_acc = model.evaluate(train_ds, steps=steps)
|
||||
|
||||
test_obj.assertLessEqual(trained_loss, orig_loss)
|
||||
test_obj.assertGreaterEqual(trained_acc, orig_acc)
|
||||
|
||||
dc.run_distribute_coordinator(
|
||||
worker_fn,
|
||||
strategy,
|
||||
mode=dc.CoordinatorMode.STANDALONE_CLIENT,
|
||||
cluster_spec=cluster_spec)
|
||||
|
||||
|
||||
class KerasMultiWorkerTestStandaloneClient(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Create a local cluster with 2 workers."""
|
||||
super(KerasMultiWorkerTestStandaloneClient, cls).setUpClass()
|
||||
cls._cluster_spec = test_base.create_in_process_cluster(
|
||||
num_workers=2, num_ps=1, has_eval=False)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=['graph'],
|
||||
strategy_cls=[
|
||||
ParameterServerStrategy,
|
||||
collective_strategy.CollectiveAllReduceStrategy,
|
||||
],
|
||||
required_gpus=[0, 1]))
|
||||
def testSimpleModelStandaloneClient(self, strategy_cls):
|
||||
# With standalone client, training_utils.should_run_multi_worker returns
|
||||
# False which means the distribute coordinator won't be called again in
|
||||
# `fit`. This is still correct and intended since session is still
|
||||
# configured under distribute coordinator's worker context and distribution
|
||||
# strategy object is already configured by distribute coordinator for
|
||||
# multi-worker training.
|
||||
# The logic should be much clearer once standalone client is merged into
|
||||
# core Keras as well.
|
||||
strategy = strategy_cls()
|
||||
|
||||
_run_standalone_client(self, strategy, self._cluster_spec)
|
||||
|
||||
|
||||
class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
|
@ -2901,7 +2901,7 @@ class Model(network.Network):
|
||||
# Otherwise, use the strategy whose scope this is in.
|
||||
if not strategy and distribution_strategy_context.has_strategy():
|
||||
strategy = distribution_strategy_context.get_strategy()
|
||||
return strategy and strategy._in_multi_worker_mode() # pylint: disable=protected-access
|
||||
return strategy and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
|
||||
|
||||
|
||||
class DistributedCallbackModel(Model):
|
||||
|
Loading…
Reference in New Issue
Block a user