[retry] Graduate MultiWorkerMirroredStrategy out of experimental
Over the past months we've several improvements: - Test coverage is now on par with other strategies. - Peer failure will no longer cause the cluster to hang. - Major issues with saving are fixed. - gather() API is added. PiperOrigin-RevId: 338175223 Change-Id: I3c52a4d53d1c487558f1caaae7d094fe2245183b
This commit is contained in:
parent
1fb1f17465
commit
0e14b0fdc4
RELEASE.md
tensorflow
python/distribute
tools/api/golden/v2
@ -91,6 +91,10 @@
|
||||
`tf.config.experimental.enable_tensor_float_32_execution`.
|
||||
|
||||
* `tf.distribute`:
|
||||
* `MultiWorkerMirroredStrategy` is graduated out of experimental.
|
||||
* Peer failure will no longer cause the cluster to hang.
|
||||
* Major issues with saving are fixed.
|
||||
* See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial.
|
||||
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
@ -430,16 +430,24 @@ py_library(
|
||||
":collective_util",
|
||||
":cross_device_ops",
|
||||
":cross_device_utils",
|
||||
":device_util",
|
||||
":distribute_lib",
|
||||
":distribute_utils",
|
||||
":input_lib",
|
||||
":mirrored_strategy",
|
||||
":multi_worker_util",
|
||||
":numpy_dataset",
|
||||
":reduce_util",
|
||||
":values",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:collective_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import numpy_dataset
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.cluster_resolver import ClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
@ -46,10 +47,12 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[])
|
||||
# pylint: disable=line-too-long
|
||||
@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[])
|
||||
class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
"""A distribution strategy for synchronous training on multiple workers.
|
||||
|
||||
@ -63,7 +66,12 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
`cluster_resolver` correctly. For example, if you are using
|
||||
`tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to
|
||||
have its corresponding `task_type` and `task_id` set in the `TF_CONFIG`
|
||||
environment variable.
|
||||
environment variable. An example TF_CONFIG on worker-0 of a two worker cluster
|
||||
is:
|
||||
|
||||
```
|
||||
TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'
|
||||
```
|
||||
|
||||
Your program runs on each worker as-is. Note that collectives require each
|
||||
worker to participate. All `tf.distribute` and non `tf.distribute` API may use
|
||||
@ -76,8 +84,57 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
strategy uses. If it's zero, the strategy uses the CPU. All workers need to
|
||||
use the same number of devices, otherwise the behavior is undefined.
|
||||
|
||||
This strategy is not intended for TPU. Use
|
||||
`tf.distribute.experimental.TPUStrategy` instead.
|
||||
This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy`
|
||||
instead.
|
||||
|
||||
After setting up TF_CONFIG, using this strategy is similar to using
|
||||
`tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`.
|
||||
|
||||
```
|
||||
strategy = tf.distribute.MultiWorkerMirroredStrategy()
|
||||
|
||||
with strategy.scope():
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Dense(2, input_shape=(5,)),
|
||||
])
|
||||
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
|
||||
|
||||
def dataset_fn(ctx):
|
||||
x = np.random.random((2, 5)).astype(np.float32)
|
||||
y = np.random.randint(2, size=(2, 1))
|
||||
dataset = tf.data.Dataset.from_tensor_slices((x, y))
|
||||
return dataset.repeat().batch(1, drop_remainder=True)
|
||||
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
|
||||
|
||||
model.compile()
|
||||
model.fit(dist_dataset)
|
||||
```
|
||||
|
||||
You can also write your own training loop:
|
||||
|
||||
```
|
||||
@tf.function
|
||||
def train_step(iterator):
|
||||
|
||||
def step_fn(inputs):
|
||||
features, labels = inputs
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(features, training=True)
|
||||
loss = tf.keras.losses.sparse_categorical_crossentropy(
|
||||
labels, logits)
|
||||
|
||||
grads = tape.gradient(loss, model.trainable_variables)
|
||||
optimizer.apply_gradients(zip(grads, model.trainable_variables))
|
||||
|
||||
strategy.run(step_fn, args=(next(iterator),))
|
||||
|
||||
for _ in range(NUM_STEP):
|
||||
train_step(iterator)
|
||||
```
|
||||
|
||||
See
|
||||
[Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
|
||||
for a detailed tutorial.
|
||||
|
||||
__Saving__
|
||||
|
||||
@ -98,6 +155,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
Tensorflow API.
|
||||
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
# TODO(anjalisridhar): Update our guides with examples showing how we can use
|
||||
# the cluster_resolver argument.
|
||||
@ -106,21 +164,23 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
_collective_key_base = 0
|
||||
|
||||
def __init__(self,
|
||||
communication=collective_util.CommunicationImplemenation.AUTO,
|
||||
cluster_resolver=None):
|
||||
cluster_resolver=None,
|
||||
communication_options=None):
|
||||
"""Creates the strategy.
|
||||
|
||||
Args:
|
||||
communication: optional
|
||||
`tf.distribute.experimental.CommunicationImplemenation`. This is a hint
|
||||
on the preferred collective communication implementation. Possible
|
||||
values include `AUTO`, `RING`, and `NCCL`.
|
||||
cluster_resolver: optional
|
||||
`tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
|
||||
`tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
|
||||
communication_options: optional
|
||||
`tf.distribute.experimental.CommunicationOptions`. This configures the
|
||||
default options for cross device communications. It can be overridden by
|
||||
options provided to the communication APIs like
|
||||
`tf.distribute.ReplicaContext.all_reduce`. See
|
||||
`tf.distribute.experimental.CommunicationOptions` for details.
|
||||
"""
|
||||
communication_options = collective_util.Options(
|
||||
implementation=communication)
|
||||
if communication_options is None:
|
||||
communication_options = collective_util.Options()
|
||||
super(CollectiveAllReduceStrategy, self).__init__(
|
||||
CollectiveAllReduceExtended(
|
||||
self,
|
||||
@ -136,12 +196,9 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
"num_replicas_per_worker").set(self.extended._num_gpus_per_worker)
|
||||
|
||||
@classmethod
|
||||
def _from_local_devices(
|
||||
cls,
|
||||
devices,
|
||||
communication=collective_util.CommunicationImplemenation.AUTO):
|
||||
def _from_local_devices(cls, devices, communication_options=None):
|
||||
"""A convenience method to create an object with a list of devices."""
|
||||
obj = cls(communication)
|
||||
obj = cls(communication_options=communication_options)
|
||||
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
|
||||
return obj
|
||||
|
||||
@ -158,11 +215,66 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||
return self.extended._cluster_resolver # pylint: disable=protected-access
|
||||
|
||||
|
||||
class _CollectiveAllReduceStrategyExperimentalMeta(type):
|
||||
|
||||
@classmethod
|
||||
def __instancecheck__(cls, instance):
|
||||
# This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(),
|
||||
# tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is
|
||||
# performing such check.
|
||||
return isinstance(instance, CollectiveAllReduceStrategy)
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[])
|
||||
class _CollectiveAllReduceStrategyExperimental(
|
||||
CollectiveAllReduceStrategy,
|
||||
metaclass=_CollectiveAllReduceStrategyExperimentalMeta):
|
||||
|
||||
__doc__ = CollectiveAllReduceStrategy.__doc__
|
||||
|
||||
@deprecation.deprecated(
|
||||
None, "use distribute.MultiWorkerMirroredStrategy instead")
|
||||
def __init__(self,
|
||||
communication=collective_util.CommunicationImplemenation.AUTO,
|
||||
cluster_resolver=None):
|
||||
"""Creates the strategy.
|
||||
|
||||
Args:
|
||||
communication: optional
|
||||
`tf.distribute.experimental.CommunicationImplementation`. This is a hint
|
||||
on the preferred collective communication implementation. Possible
|
||||
values include `AUTO`, `RING`, and `NCCL`.
|
||||
cluster_resolver: optional
|
||||
`tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
|
||||
`tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
|
||||
"""
|
||||
communication_options = collective_util.Options(
|
||||
implementation=communication)
|
||||
super(_CollectiveAllReduceStrategyExperimental,
|
||||
self).__init__(cluster_resolver, communication_options)
|
||||
|
||||
@classmethod
|
||||
def _from_local_devices(
|
||||
cls,
|
||||
devices,
|
||||
communication=collective_util.CommunicationImplemenation.AUTO):
|
||||
"""A convenience method to create an object with a list of devices."""
|
||||
obj = cls(communication)
|
||||
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
|
||||
return obj
|
||||
|
||||
|
||||
_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__
|
||||
|
||||
|
||||
@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring
|
||||
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
|
||||
|
||||
__doc__ = CollectiveAllReduceStrategy.__doc__
|
||||
|
||||
# The starting number for collective keys. This should only be set in tests.
|
||||
_collective_key_base = 0
|
||||
|
||||
def __init__(self,
|
||||
communication=collective_util.CommunicationImplemenation.AUTO,
|
||||
cluster_resolver=None):
|
||||
@ -200,9 +312,16 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
|
||||
def __init__(self, container_strategy, cluster_resolver,
|
||||
communication_options):
|
||||
if not isinstance(communication_options, collective_util.Options):
|
||||
raise ValueError("communication_options must be an instance of "
|
||||
"tf.distribute.experimental.CommunicationOptions")
|
||||
self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
|
||||
if not isinstance(self._cluster_resolver, ClusterResolver):
|
||||
raise ValueError("cluster_resolver must be an instance of "
|
||||
"tf.distribute.cluster_resolver.ClusterResolver")
|
||||
distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
|
||||
self._communication_options = communication_options
|
||||
self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access
|
||||
self._initialize_strategy(self._cluster_resolver)
|
||||
self._cfer_fn_cache = weakref.WeakKeyDictionary()
|
||||
self.experimental_enable_get_next_as_optional = True
|
||||
@ -248,7 +367,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
|
||||
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=1 + CollectiveAllReduceStrategy._collective_key_base) # pylint: disable=protected-access
|
||||
group_key_start=1 + self._collective_key_base)
|
||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=local_devices,
|
||||
group_size=len(local_devices),
|
||||
@ -363,7 +482,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
local_devices = (self._worker_device,)
|
||||
|
||||
self._collective_keys = cross_device_utils.CollectiveKeys(
|
||||
group_key_start=1 + CollectiveAllReduceStrategy._collective_key_base) # pylint: disable=protected-access
|
||||
group_key_start=1 + self._collective_key_base)
|
||||
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
devices=local_devices,
|
||||
group_size=len(local_devices) * self._num_workers,
|
||||
|
@ -62,6 +62,8 @@ CollectiveAllReduceStrategy = (
|
||||
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
|
||||
CollectiveAllReduceExtended = (
|
||||
collective_all_reduce_strategy.CollectiveAllReduceExtended)
|
||||
_CollectiveAllReduceStrategyExperimental = (
|
||||
collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental)
|
||||
|
||||
|
||||
def create_test_objects(cluster_spec=None,
|
||||
@ -610,5 +612,27 @@ class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase):
|
||||
strategy.extended._num_workers, results[1].numpy())
|
||||
|
||||
|
||||
class ExperimentalCompatibilityTest(test.TestCase):
|
||||
|
||||
def testIsInstance(self):
|
||||
# It's not uncommon for people to special case MultiWorkerMirroredStrategy,
|
||||
# so we need to make sure isinstance check works for combinations between
|
||||
# the experimental and non-experimental endpoints.
|
||||
strategy = CollectiveAllReduceStrategy()
|
||||
experimental_strategy = _CollectiveAllReduceStrategyExperimental()
|
||||
self.assertIsInstance(strategy, CollectiveAllReduceStrategy)
|
||||
self.assertIsInstance(strategy, _CollectiveAllReduceStrategyExperimental)
|
||||
self.assertIsInstance(experimental_strategy, CollectiveAllReduceStrategy)
|
||||
self.assertIsInstance(experimental_strategy,
|
||||
_CollectiveAllReduceStrategyExperimental)
|
||||
|
||||
def testName(self):
|
||||
# Estimator checks the __name__ to special case MultiWorkerMirroredStrategy.
|
||||
self.assertEqual(CollectiveAllReduceStrategy.__name__,
|
||||
'CollectiveAllReduceStrategy')
|
||||
self.assertEqual(_CollectiveAllReduceStrategyExperimental.__name__,
|
||||
'CollectiveAllReduceStrategy')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_util.main()
|
||||
|
@ -470,8 +470,9 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
devices = ["/device:CPU:0"]
|
||||
|
||||
if use_strategy_object:
|
||||
comm_options = collective_util.Options(implementation=communication)
|
||||
strategy = (mwms_lib.CollectiveAllReduceStrategy
|
||||
._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access
|
||||
._from_local_devices(devices, comm_options)) # pylint: disable=protected-access
|
||||
return strategy, devices, ""
|
||||
else:
|
||||
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
|
||||
@ -500,8 +501,9 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
task_type=task_type,
|
||||
task_id=task_id,
|
||||
num_accelerators={"GPU": num_gpus})
|
||||
comm_options = collective_util.Options(implementation=communication)
|
||||
strategy = mwms_lib.CollectiveAllReduceStrategy(
|
||||
cluster_resolver=resolver, communication=communication)
|
||||
communication_options=comm_options, cluster_resolver=resolver)
|
||||
return (strategy, devices,
|
||||
"grpc://" + self._cluster_spec[task_type][task_id])
|
||||
else:
|
||||
|
@ -0,0 +1,91 @@
|
||||
path: "tensorflow.distribute.MultiWorkerMirroredStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "cluster_resolver"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_replicas_in_sync"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'cluster_resolver\', \'communication_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run"
|
||||
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "gather"
|
||||
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "make_dataset_iterator"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_input_fn_iterator"
|
||||
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "run"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "scope"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unwrap"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "update_config_proto"
|
||||
argspec: "args=[\'self\', \'config_proto\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
path: "tensorflow.distribute.experimental.MultiWorkerMirroredStrategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
|
||||
|
@ -22,7 +22,7 @@ tf_module {
|
||||
}
|
||||
member {
|
||||
name: "MultiWorkerMirroredStrategy"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimentalMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "ParameterServerStrategy"
|
||||
|
@ -36,6 +36,10 @@ tf_module {
|
||||
name: "MirroredStrategy"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "MultiWorkerMirroredStrategy"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "NcclAllReduce"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user