Make cluster_resolver standard property in tf.distribute strategies.
PiperOrigin-RevId: 317771299 Change-Id: I71b5c585cef7bd7ef80e66b75e30287fddcf89e2
This commit is contained in:
parent
e74a115bcf
commit
4d13d6416d
|
@ -204,6 +204,7 @@ py_test(
|
|||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/autograph/core:test_lib",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -1847,10 +1848,11 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
distribute_py_test(
|
||||
name = "strategy_common_test",
|
||||
srcs = ["strategy_common_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 12,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
# TODO(b/155301154): Enable this test on multi-gpu guitar once multi process
|
||||
|
@ -1859,6 +1861,7 @@ cuda_py_test(
|
|||
],
|
||||
xla_enable_strict_auto_jit = True,
|
||||
deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
":combinations",
|
||||
":multi_worker_test_base",
|
||||
":reduce_util",
|
||||
|
|
|
@ -138,6 +138,18 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
|||
"""
|
||||
return super(CollectiveAllReduceStrategy, self).scope()
|
||||
|
||||
@property
|
||||
def cluster_resolver(self):
|
||||
"""Returns the cluster resolver associated with this strategy.
|
||||
|
||||
As a multi-worker strategy,
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy` provides the
|
||||
associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
|
||||
provides one in `__init__`, that instance is returned; if the user does
|
||||
not, a default `TFConfigClusterResolver` is provided.
|
||||
"""
|
||||
return self.extended._cluster_resolver # pylint: disable=protected-access
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring
|
||||
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
|
||||
|
|
|
@ -505,8 +505,7 @@ class DistributedCollectiveAllReduceStrategyTest(
|
|||
self.assertEqual(['CollectiveReduce'],
|
||||
new_rewrite_options.scoped_allocator_opts.enable_op)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testEnableCollectiveOps(self):
|
||||
def _get_strategy_with_mocked_methods(self):
|
||||
mock_called = [False]
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
@ -525,9 +524,21 @@ class DistributedCollectiveAllReduceStrategyTest(
|
|||
mock_configure_collective_ops):
|
||||
strategy, _, _ = self._get_test_object(
|
||||
task_type='worker', task_id=1, num_gpus=2)
|
||||
|
||||
return strategy, mock_called
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testEnableCollectiveOps(self):
|
||||
strategy, mock_called = self._get_strategy_with_mocked_methods()
|
||||
self.assertTrue(strategy.extended._std_server_started)
|
||||
self.assertTrue(mock_called[0])
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testEnableCollectiveOpsAndClusterResolver(self):
|
||||
strategy, _ = self._get_strategy_with_mocked_methods()
|
||||
self.assertEqual(strategy.cluster_resolver.task_type, 'worker')
|
||||
self.assertEqual(strategy.cluster_resolver.task_id, 1)
|
||||
|
||||
|
||||
class DistributedCollectiveAllReduceStrategyTestWithChief(
|
||||
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
|
||||
|
|
|
@ -1439,6 +1439,65 @@ class StrategyBase(object):
|
|||
def __copy__(self):
|
||||
raise RuntimeError("Must only deepcopy DistributionStrategy.")
|
||||
|
||||
@property
|
||||
def cluster_resolver(self):
|
||||
"""Returns the cluster resolver associated with this strategy.
|
||||
|
||||
In general, when using a multi-worker `tf.distribute` strategy such as
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy` or
|
||||
`tf.distribute.experimental.TPUStrategy()`, there is a
|
||||
`tf.distribute.cluster_resolver.ClusterResolver` associated with the
|
||||
strategy used, and such an instance is returned by this property.
|
||||
|
||||
Strategies that intend to have an associated
|
||||
`tf.distribute.cluster_resolver.ClusterResolver` must set the
|
||||
relevant attribute, or override this property; otherwise, `None` is returned
|
||||
by default. Those strategies should also provide information regarding what
|
||||
is returned by this property.
|
||||
|
||||
Single-worker strategies usually do not have a
|
||||
`tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this
|
||||
property will return `None`.
|
||||
|
||||
The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the
|
||||
user needs to access information such as the cluster spec, task type or task
|
||||
id. For example,
|
||||
|
||||
```python
|
||||
|
||||
os.environ['TF_CONFIG'] = json.dumps({
|
||||
'cluster': {
|
||||
'worker': ["localhost:12345", "localhost:23456"],
|
||||
'ps': ["localhost:34567"]
|
||||
},
|
||||
'task': {'type': 'worker', 'index': 0}
|
||||
})
|
||||
|
||||
# This implicitly uses TF_CONFIG for the cluster and current task info.
|
||||
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
||||
|
||||
...
|
||||
|
||||
if strategy.cluster_resolver.task_type == 'worker':
|
||||
# Perform something that's only applicable on workers. Since we set this
|
||||
# as a worker above, this block will run on this particular instance.
|
||||
elif strategy.cluster_resolver.task_type == 'ps':
|
||||
# Perform something that's only applicable on parameter servers. Since we
|
||||
# set this as a worker above, this block will not run on this particular
|
||||
# instance.
|
||||
```
|
||||
|
||||
For more information, please see
|
||||
`tf.distribute.cluster_resolver.ClusterResolver`'s API docstring.
|
||||
|
||||
Returns:
|
||||
The cluster resolver associated with this strategy. Returns `None` if a
|
||||
cluster resolver is not applicable or available in this strategy.
|
||||
"""
|
||||
if hasattr(self.extended, "_cluster_resolver"):
|
||||
return self.extended._cluster_resolver # pylint: disable=protected-access
|
||||
return None
|
||||
|
||||
|
||||
@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring
|
||||
class Strategy(StrategyBase):
|
||||
|
|
|
@ -28,6 +28,7 @@ from tensorflow.python.distribute import distribute_lib
|
|||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
@ -36,6 +37,7 @@ from tensorflow.python.framework import ops
|
|||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
|
@ -422,6 +424,17 @@ class TestStrategyTest(test.TestCase):
|
|||
|
||||
test_fn()
|
||||
|
||||
def testClusterResolverDefaultNotImplemented(self):
|
||||
dist = _TestStrategy()
|
||||
self.assertIsNone(dist.cluster_resolver)
|
||||
base_cluster_spec = server_lib.ClusterSpec({
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
|
||||
})
|
||||
cluster_resolver = SimpleClusterResolver(base_cluster_spec)
|
||||
dist.extended._cluster_resolver = cluster_resolver
|
||||
self.assertIs(dist.cluster_resolver, cluster_resolver)
|
||||
|
||||
|
||||
# _TestStrategy2 is like _TestStrategy, except it doesn't change variable
|
||||
# creation.
|
||||
|
|
|
@ -27,6 +27,8 @@ from tensorflow.python.distribute import multi_worker_test_base
|
|||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy
|
||||
from tensorflow.python.distribute.tpu_strategy import TPUStrategy
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
@ -184,5 +186,38 @@ class DistributedCollectiveAllReduceStrategyTest(
|
|||
# worker strategy combinations can run on a fixed number of GPUs.
|
||||
|
||||
|
||||
class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
strategy=[strategy_combinations.multi_worker_mirrored_two_workers] +
|
||||
strategy_combinations.all_strategies,
|
||||
mode=['eager']))
|
||||
def testClusterResolverProperty(self, strategy):
|
||||
# CollectiveAllReduceStrategy and TPUStrategy must have a cluster resolver.
|
||||
# `None` otherwise.
|
||||
resolver = strategy.cluster_resolver
|
||||
if not isinstance(strategy, CollectiveAllReduceStrategy) and not isinstance(
|
||||
strategy, TPUStrategy):
|
||||
self.assertIsNone(resolver)
|
||||
return
|
||||
|
||||
with strategy.scope():
|
||||
self.assertIs(strategy.cluster_resolver, resolver)
|
||||
self.assertTrue(hasattr(resolver, 'cluster_spec'))
|
||||
self.assertTrue(hasattr(resolver, 'environment'))
|
||||
self.assertTrue(hasattr(resolver, 'master'))
|
||||
self.assertTrue(hasattr(resolver, 'num_accelerators'))
|
||||
self.assertIsNone(resolver.rpc_layer)
|
||||
if isinstance(strategy, CollectiveAllReduceStrategy):
|
||||
self.assertGreaterEqual(resolver.task_id, 0)
|
||||
self.assertLessEqual(resolver.task_id, 1)
|
||||
self.assertEqual(resolver.task_type, 'worker')
|
||||
elif isinstance(strategy, TPUStrategy):
|
||||
# TPUStrategy does not have task_id and task_type applicable.
|
||||
self.assertIsNone(resolver.task_id)
|
||||
self.assertIsNone(resolver.task_type)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
combinations.main()
|
||||
|
|
|
@ -345,6 +345,18 @@ class TPUStrategy(distribute_lib.Strategy):
|
|||
options = options or distribute_lib.RunOptions()
|
||||
return self.extended.tpu_run(fn, args, kwargs, options)
|
||||
|
||||
@property
|
||||
def cluster_resolver(self):
|
||||
"""Returns the cluster resolver associated with this strategy.
|
||||
|
||||
`tf.distribute.experimental.TPUStrategy` provides the
|
||||
associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
|
||||
provides one in `__init__`, that instance is returned; if the user does
|
||||
not, a default
|
||||
`tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
|
||||
"""
|
||||
return self.extended._tpu_cluster_resolver # pylint: disable=protected-access
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.experimental.TPUStrategy"])
|
||||
class TPUStrategyV1(distribute_lib.StrategyV1):
|
||||
|
|
|
@ -555,6 +555,13 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
|||
update_variable.get_concrete_function()
|
||||
self.assertLen(strategy.extended.worker_devices, trace_count[0])
|
||||
|
||||
def test_cluster_resolver_available(self, enable_packed_var):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
remote.connect_to_cluster(resolver)
|
||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
strategy = tpu_lib.TPUStrategy(resolver)
|
||||
self.assertIs(strategy.cluster_resolver, resolver)
|
||||
|
||||
|
||||
class TPUStrategyDataPrefetchTest(test.TestCase):
|
||||
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -3,6 +3,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyV1\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -3,6 +3,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
|
@ -4,6 +4,10 @@ tf_class {
|
|||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
|
|
Loading…
Reference in New Issue