[retry] Graduate MultiWorkerMirroredStrategy out of experimental

Over the past months we've several improvements:
  - Test coverage is now on par with other strategies.
  - Peer failure will no longer cause the cluster to hang.
  - Major issues with saving are fixed.
  - gather() API is added.

PiperOrigin-RevId: 338175223
Change-Id: I3c52a4d53d1c487558f1caaae7d094fe2245183b
This commit is contained in:
Ran Chen 2020-10-20 18:09:24 -07:00 committed by TensorFlower Gardener
parent 1fb1f17465
commit 0e14b0fdc4
9 changed files with 275 additions and 22 deletions

View File

@ -91,6 +91,10 @@
`tf.config.experimental.enable_tensor_float_32_execution`. `tf.config.experimental.enable_tensor_float_32_execution`.
* `tf.distribute`: * `tf.distribute`:
* `MultiWorkerMirroredStrategy` is graduated out of experimental.
* Peer failure will no longer cause the cluster to hang.
* Major issues with saving are fixed.
* See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial.
* Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental. * Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental.
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes

View File

@ -430,16 +430,24 @@ py_library(
":collective_util", ":collective_util",
":cross_device_ops", ":cross_device_ops",
":cross_device_utils", ":cross_device_utils",
":device_util",
":distribute_lib",
":distribute_utils",
":input_lib", ":input_lib",
":mirrored_strategy", ":mirrored_strategy",
":multi_worker_util", ":multi_worker_util",
":numpy_dataset", ":numpy_dataset",
":reduce_util",
":values", ":values",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:collective_ops", "//tensorflow/python:collective_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:tf_export",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
], ],

View File

@ -37,6 +37,7 @@ from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import ClusterResolver
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -46,10 +47,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops from tensorflow.python.ops import collective_ops
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import base
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[]) # pylint: disable=line-too-long
@tf_export("distribute.MultiWorkerMirroredStrategy", v1=[])
class CollectiveAllReduceStrategy(distribute_lib.Strategy): class CollectiveAllReduceStrategy(distribute_lib.Strategy):
"""A distribution strategy for synchronous training on multiple workers. """A distribution strategy for synchronous training on multiple workers.
@ -63,7 +66,12 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
`cluster_resolver` correctly. For example, if you are using `cluster_resolver` correctly. For example, if you are using
`tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to `tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to
have its corresponding `task_type` and `task_id` set in the `TF_CONFIG` have its corresponding `task_type` and `task_id` set in the `TF_CONFIG`
environment variable. environment variable. An example TF_CONFIG on worker-0 of a two worker cluster
is:
```
TF_CONFIG = '{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'
```
Your program runs on each worker as-is. Note that collectives require each Your program runs on each worker as-is. Note that collectives require each
worker to participate. All `tf.distribute` and non `tf.distribute` API may use worker to participate. All `tf.distribute` and non `tf.distribute` API may use
@ -76,8 +84,57 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
strategy uses. If it's zero, the strategy uses the CPU. All workers need to strategy uses. If it's zero, the strategy uses the CPU. All workers need to
use the same number of devices, otherwise the behavior is undefined. use the same number of devices, otherwise the behavior is undefined.
This strategy is not intended for TPU. Use This strategy is not intended for TPU. Use `tf.distribute.TPUStrategy`
`tf.distribute.experimental.TPUStrategy` instead. instead.
After setting up TF_CONFIG, using this strategy is similar to using
`tf.distribute.MirroredStrategy` and `tf.distribute.TPUStrategy`.
```
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Dense(2, input_shape=(5,)),
])
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
def dataset_fn(ctx):
x = np.random.random((2, 5)).astype(np.float32)
y = np.random.randint(2, size=(2, 1))
dataset = tf.data.Dataset.from_tensor_slices((x, y))
return dataset.repeat().batch(1, drop_remainder=True)
dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
model.compile()
model.fit(dist_dataset)
```
You can also write your own training loop:
```
@tf.function
def train_step(iterator):
def step_fn(inputs):
features, labels = inputs
with tf.GradientTape() as tape:
logits = model(features, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
strategy.run(step_fn, args=(next(iterator),))
for _ in range(NUM_STEP):
train_step(iterator)
```
See
[Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
for a detailed tutorial.
__Saving__ __Saving__
@ -98,6 +155,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
Tensorflow API. Tensorflow API.
""" """
# pylint: enable=line-too-long
# TODO(anjalisridhar): Update our guides with examples showing how we can use # TODO(anjalisridhar): Update our guides with examples showing how we can use
# the cluster_resolver argument. # the cluster_resolver argument.
@ -106,21 +164,23 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
_collective_key_base = 0 _collective_key_base = 0
def __init__(self, def __init__(self,
communication=collective_util.CommunicationImplemenation.AUTO, cluster_resolver=None,
cluster_resolver=None): communication_options=None):
"""Creates the strategy. """Creates the strategy.
Args: Args:
communication: optional
`tf.distribute.experimental.CommunicationImplemenation`. This is a hint
on the preferred collective communication implementation. Possible
values include `AUTO`, `RING`, and `NCCL`.
cluster_resolver: optional cluster_resolver: optional
`tf.distribute.cluster_resolver.ClusterResolver`. If `None`, `tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
`tf.distribute.cluster_resolver.TFConfigClusterResolver` is used. `tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
communication_options: optional
`tf.distribute.experimental.CommunicationOptions`. This configures the
default options for cross device communications. It can be overridden by
options provided to the communication APIs like
`tf.distribute.ReplicaContext.all_reduce`. See
`tf.distribute.experimental.CommunicationOptions` for details.
""" """
communication_options = collective_util.Options( if communication_options is None:
implementation=communication) communication_options = collective_util.Options()
super(CollectiveAllReduceStrategy, self).__init__( super(CollectiveAllReduceStrategy, self).__init__(
CollectiveAllReduceExtended( CollectiveAllReduceExtended(
self, self,
@ -136,12 +196,9 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
"num_replicas_per_worker").set(self.extended._num_gpus_per_worker) "num_replicas_per_worker").set(self.extended._num_gpus_per_worker)
@classmethod @classmethod
def _from_local_devices( def _from_local_devices(cls, devices, communication_options=None):
cls,
devices,
communication=collective_util.CommunicationImplemenation.AUTO):
"""A convenience method to create an object with a list of devices.""" """A convenience method to create an object with a list of devices."""
obj = cls(communication) obj = cls(communication_options=communication_options)
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
return obj return obj
@ -158,11 +215,66 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
return self.extended._cluster_resolver # pylint: disable=protected-access return self.extended._cluster_resolver # pylint: disable=protected-access
class _CollectiveAllReduceStrategyExperimentalMeta(type):
@classmethod
def __instancecheck__(cls, instance):
# This is to make isinstance(tf.distribute.MultiWorkerMirroredStrategy(),
# tf.distribute.experimental.MultiWorkerMirroredStrategy). Some libraries is
# performing such check.
return isinstance(instance, CollectiveAllReduceStrategy)
@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[])
class _CollectiveAllReduceStrategyExperimental(
CollectiveAllReduceStrategy,
metaclass=_CollectiveAllReduceStrategyExperimentalMeta):
__doc__ = CollectiveAllReduceStrategy.__doc__
@deprecation.deprecated(
None, "use distribute.MultiWorkerMirroredStrategy instead")
def __init__(self,
communication=collective_util.CommunicationImplemenation.AUTO,
cluster_resolver=None):
"""Creates the strategy.
Args:
communication: optional
`tf.distribute.experimental.CommunicationImplementation`. This is a hint
on the preferred collective communication implementation. Possible
values include `AUTO`, `RING`, and `NCCL`.
cluster_resolver: optional
`tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
`tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
"""
communication_options = collective_util.Options(
implementation=communication)
super(_CollectiveAllReduceStrategyExperimental,
self).__init__(cluster_resolver, communication_options)
@classmethod
def _from_local_devices(
cls,
devices,
communication=collective_util.CommunicationImplemenation.AUTO):
"""A convenience method to create an object with a list of devices."""
obj = cls(communication)
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
return obj
_CollectiveAllReduceStrategyExperimental.__name__ = CollectiveAllReduceStrategy.__name__
@tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring @tf_export(v1=["distribute.experimental.MultiWorkerMirroredStrategy"]) # pylint: disable=missing-docstring
class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1): class CollectiveAllReduceStrategyV1(distribute_lib.StrategyV1):
__doc__ = CollectiveAllReduceStrategy.__doc__ __doc__ = CollectiveAllReduceStrategy.__doc__
# The starting number for collective keys. This should only be set in tests.
_collective_key_base = 0
def __init__(self, def __init__(self,
communication=collective_util.CommunicationImplemenation.AUTO, communication=collective_util.CommunicationImplemenation.AUTO,
cluster_resolver=None): cluster_resolver=None):
@ -200,9 +312,16 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
def __init__(self, container_strategy, cluster_resolver, def __init__(self, container_strategy, cluster_resolver,
communication_options): communication_options):
if not isinstance(communication_options, collective_util.Options):
raise ValueError("communication_options must be an instance of "
"tf.distribute.experimental.CommunicationOptions")
self._cluster_resolver = cluster_resolver or TFConfigClusterResolver() self._cluster_resolver = cluster_resolver or TFConfigClusterResolver()
if not isinstance(self._cluster_resolver, ClusterResolver):
raise ValueError("cluster_resolver must be an instance of "
"tf.distribute.cluster_resolver.ClusterResolver")
distribute_lib.StrategyExtendedV1.__init__(self, container_strategy) distribute_lib.StrategyExtendedV1.__init__(self, container_strategy)
self._communication_options = communication_options self._communication_options = communication_options
self._collective_key_base = container_strategy._collective_key_base # pylint: disable=protected-access
self._initialize_strategy(self._cluster_resolver) self._initialize_strategy(self._cluster_resolver)
self._cfer_fn_cache = weakref.WeakKeyDictionary() self._cfer_fn_cache = weakref.WeakKeyDictionary()
self.experimental_enable_get_next_as_optional = True self.experimental_enable_get_next_as_optional = True
@ -248,7 +367,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device) self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
self._collective_keys = cross_device_utils.CollectiveKeys( self._collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=1 + CollectiveAllReduceStrategy._collective_key_base) # pylint: disable=protected-access group_key_start=1 + self._collective_key_base)
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=local_devices, devices=local_devices,
group_size=len(local_devices), group_size=len(local_devices),
@ -363,7 +482,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
local_devices = (self._worker_device,) local_devices = (self._worker_device,)
self._collective_keys = cross_device_utils.CollectiveKeys( self._collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=1 + CollectiveAllReduceStrategy._collective_key_base) # pylint: disable=protected-access group_key_start=1 + self._collective_key_base)
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=local_devices, devices=local_devices,
group_size=len(local_devices) * self._num_workers, group_size=len(local_devices) * self._num_workers,

View File

@ -62,6 +62,8 @@ CollectiveAllReduceStrategy = (
collective_all_reduce_strategy.CollectiveAllReduceStrategy) collective_all_reduce_strategy.CollectiveAllReduceStrategy)
CollectiveAllReduceExtended = ( CollectiveAllReduceExtended = (
collective_all_reduce_strategy.CollectiveAllReduceExtended) collective_all_reduce_strategy.CollectiveAllReduceExtended)
_CollectiveAllReduceStrategyExperimental = (
collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental)
def create_test_objects(cluster_spec=None, def create_test_objects(cluster_spec=None,
@ -610,5 +612,27 @@ class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase):
strategy.extended._num_workers, results[1].numpy()) strategy.extended._num_workers, results[1].numpy())
class ExperimentalCompatibilityTest(test.TestCase):
def testIsInstance(self):
# It's not uncommon for people to special case MultiWorkerMirroredStrategy,
# so we need to make sure isinstance check works for combinations between
# the experimental and non-experimental endpoints.
strategy = CollectiveAllReduceStrategy()
experimental_strategy = _CollectiveAllReduceStrategyExperimental()
self.assertIsInstance(strategy, CollectiveAllReduceStrategy)
self.assertIsInstance(strategy, _CollectiveAllReduceStrategyExperimental)
self.assertIsInstance(experimental_strategy, CollectiveAllReduceStrategy)
self.assertIsInstance(experimental_strategy,
_CollectiveAllReduceStrategyExperimental)
def testName(self):
# Estimator checks the __name__ to special case MultiWorkerMirroredStrategy.
self.assertEqual(CollectiveAllReduceStrategy.__name__,
'CollectiveAllReduceStrategy')
self.assertEqual(_CollectiveAllReduceStrategyExperimental.__name__,
'CollectiveAllReduceStrategy')
if __name__ == '__main__': if __name__ == '__main__':
test_util.main() test_util.main()

View File

@ -470,8 +470,9 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
devices = ["/device:CPU:0"] devices = ["/device:CPU:0"]
if use_strategy_object: if use_strategy_object:
comm_options = collective_util.Options(implementation=communication)
strategy = (mwms_lib.CollectiveAllReduceStrategy strategy = (mwms_lib.CollectiveAllReduceStrategy
._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access ._from_local_devices(devices, comm_options)) # pylint: disable=protected-access
return strategy, devices, "" return strategy, devices, ""
else: else:
collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
@ -500,8 +501,9 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
task_type=task_type, task_type=task_type,
task_id=task_id, task_id=task_id,
num_accelerators={"GPU": num_gpus}) num_accelerators={"GPU": num_gpus})
comm_options = collective_util.Options(implementation=communication)
strategy = mwms_lib.CollectiveAllReduceStrategy( strategy = mwms_lib.CollectiveAllReduceStrategy(
cluster_resolver=resolver, communication=communication) communication_options=comm_options, cluster_resolver=resolver)
return (strategy, devices, return (strategy, devices,
"grpc://" + self._cluster_spec[task_type][task_id]) "grpc://" + self._cluster_spec[task_type][task_id])
else: else:

View File

@ -0,0 +1,91 @@
path: "tensorflow.distribute.MultiWorkerMirroredStrategy"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"
is_instance: "<type \'object\'>"
member {
name: "cluster_resolver"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'cluster_resolver\', \'communication_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_values_from_function"
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "gather"
argspec: "args=[\'self\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "make_dataset_iterator"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_input_fn_iterator"
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "run"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\', \'options\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'None\'], "
}
member_method {
name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "unwrap"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update_config_proto"
argspec: "args=[\'self\', \'config_proto\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,6 @@
path: "tensorflow.distribute.experimental.MultiWorkerMirroredStrategy" path: "tensorflow.distribute.experimental.MultiWorkerMirroredStrategy"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental\'>"
is_instance: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy\'>" is_instance: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.Strategy\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.StrategyBase\'>"

View File

@ -22,7 +22,7 @@ tf_module {
} }
member { member {
name: "MultiWorkerMirroredStrategy" name: "MultiWorkerMirroredStrategy"
mtype: "<type \'type\'>" mtype: "<class \'tensorflow.python.distribute.collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimentalMeta\'>"
} }
member { member {
name: "ParameterServerStrategy" name: "ParameterServerStrategy"

View File

@ -36,6 +36,10 @@ tf_module {
name: "MirroredStrategy" name: "MirroredStrategy"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "MultiWorkerMirroredStrategy"
mtype: "<type \'type\'>"
}
member { member {
name: "NcclAllReduce" name: "NcclAllReduce"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"