Add DistStrat ReplicaContext.all_reduce
.
PiperOrigin-RevId: 228302155
This commit is contained in:
parent
07c6aa1a40
commit
81492c074a
@ -397,8 +397,10 @@ class DistributedCollectiveAllReduceStrategyTestWithChief(
|
||||
self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
|
||||
|
||||
|
||||
class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase,
|
||||
class LocalCollectiveAllReduceStrategy(
|
||||
CollectiveAllReduceStrategyTestBase,
|
||||
strategy_test_lib.DistributionTestBase,
|
||||
strategy_test_lib.TwoDeviceDistributionTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testMinimizeLossGraph(self, num_gpus=2):
|
||||
@ -428,6 +430,42 @@ class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase,
|
||||
self._test_input_fn_iterator(None, None, num_gpus,
|
||||
input_fn, expected_values)
|
||||
|
||||
def testAllReduceSum(self):
|
||||
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||
with self.cached_session(config=config, target=target):
|
||||
self._test_all_reduce_sum(distribution)
|
||||
|
||||
def testAllReduceSumGradients(self):
|
||||
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||
with self.cached_session(config=config, target=target):
|
||||
self._test_all_reduce_sum_gradients(distribution)
|
||||
|
||||
def testAllReduceSumGradientTape(self):
|
||||
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||
with self.cached_session(config=config, target=target):
|
||||
self._test_all_reduce_sum_gradient_tape(distribution)
|
||||
|
||||
def testAllReduceMean(self):
|
||||
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||
with self.cached_session(config=config, target=target):
|
||||
self._test_all_reduce_mean(distribution)
|
||||
|
||||
def testAllReduceMeanGradients(self):
|
||||
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||
with self.cached_session(config=config, target=target):
|
||||
self._test_all_reduce_mean_gradients(distribution)
|
||||
|
||||
def testAllReduceMeanGradientTape(self):
|
||||
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
|
||||
distribution, target, config = self._get_test_object(None, None, num_gpus=2)
|
||||
with self.cached_session(config=config, target=target):
|
||||
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -66,7 +66,9 @@ GPU_TEST = "test_gpu" in sys.argv[0]
|
||||
combinations.core_mirrored_strategy_with_gpu_and_cpu,
|
||||
combinations.core_mirrored_strategy_with_two_gpus],
|
||||
mode=["graph", "eager"]))
|
||||
class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase,
|
||||
class MirroredTwoDeviceDistributionTest(
|
||||
strategy_test_lib.DistributionTestBase,
|
||||
strategy_test_lib.TwoDeviceDistributionTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testMinimizeLoss(self, distribution):
|
||||
@ -117,6 +119,24 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase,
|
||||
def testGlobalStepUpdate(self, distribution):
|
||||
self._test_global_step_update(distribution)
|
||||
|
||||
def testAllReduceSum(self, distribution):
|
||||
self._test_all_reduce_sum(distribution)
|
||||
|
||||
def testAllReduceSumGradients(self, distribution):
|
||||
self._test_all_reduce_sum_gradients(distribution)
|
||||
|
||||
def testAllReduceSumGradientTape(self, distribution):
|
||||
self._test_all_reduce_sum_gradient_tape(distribution)
|
||||
|
||||
def testAllReduceMean(self, distribution):
|
||||
self._test_all_reduce_mean(distribution)
|
||||
|
||||
def testAllReduceMeanGradients(self, distribution):
|
||||
self._test_all_reduce_mean_gradients(distribution)
|
||||
|
||||
def testAllReduceMeanGradientTape(self, distribution):
|
||||
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||
|
||||
|
||||
def one_device_combinations():
|
||||
return combinations.combine(
|
||||
@ -128,25 +148,42 @@ def one_device_combinations():
|
||||
mode=["graph", "eager"])
|
||||
|
||||
|
||||
@combinations.generate(one_device_combinations())
|
||||
class MirroredOneDeviceDistributionTest(
|
||||
strategy_test_lib.DistributionTestBase,
|
||||
strategy_test_lib.OneDeviceDistributionTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(one_device_combinations())
|
||||
def testMinimizeLoss(self, distribution):
|
||||
if context.executing_eagerly():
|
||||
self._test_minimize_loss_eager(distribution)
|
||||
else:
|
||||
self._test_minimize_loss_graph(distribution)
|
||||
|
||||
@combinations.generate(one_device_combinations())
|
||||
def testReplicaId(self, distribution):
|
||||
self._test_replica_id(distribution)
|
||||
|
||||
@combinations.generate(one_device_combinations())
|
||||
def testCallAndMergeExceptions(self, distribution):
|
||||
self._test_call_and_merge_exceptions(distribution)
|
||||
|
||||
def testAllReduceSum(self, distribution):
|
||||
self._test_all_reduce_sum(distribution)
|
||||
|
||||
def testAllReduceSumGradients(self, distribution):
|
||||
self._test_all_reduce_sum_gradients(distribution)
|
||||
|
||||
def testAllReduceSumGradientTape(self, distribution):
|
||||
self._test_all_reduce_sum_gradient_tape(distribution)
|
||||
|
||||
def testAllReduceMean(self, distribution):
|
||||
self._test_all_reduce_mean(distribution)
|
||||
|
||||
def testAllReduceMeanGradients(self, distribution):
|
||||
self._test_all_reduce_mean_gradients(distribution)
|
||||
|
||||
def testAllReduceMeanGradientTape(self, distribution):
|
||||
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||
|
||||
|
||||
class MirroredStrategyVariableCreatorStackTest(
|
||||
test.TestCase, parameterized.TestCase):
|
||||
|
@ -25,7 +25,9 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
|
||||
class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
|
||||
class OneDeviceStrategyTest(
|
||||
strategy_test_lib.DistributionTestBase,
|
||||
strategy_test_lib.OneDeviceDistributionTestBase):
|
||||
|
||||
def _get_distribution_strategy(self):
|
||||
return one_device_strategy.OneDeviceStrategy("/device:CPU:0")
|
||||
@ -57,6 +59,24 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
|
||||
self._test_input_fn_iterator(
|
||||
iterator, d.extended.worker_devices, expected_values)
|
||||
|
||||
def testAllReduceSum(self):
|
||||
self._test_all_reduce_sum(self._get_distribution_strategy())
|
||||
|
||||
def testAllReduceSumGradients(self):
|
||||
self._test_all_reduce_sum_gradients(self._get_distribution_strategy())
|
||||
|
||||
def testAllReduceSumGradientTape(self):
|
||||
self._test_all_reduce_sum_gradient_tape(self._get_distribution_strategy())
|
||||
|
||||
def testAllReduceMean(self):
|
||||
self._test_all_reduce_mean(self._get_distribution_strategy())
|
||||
|
||||
def testAllReduceMeanGradients(self):
|
||||
self._test_all_reduce_mean_gradients(self._get_distribution_strategy())
|
||||
|
||||
def testAllReduceMeanGradientTape(self):
|
||||
self._test_all_reduce_mean_gradient_tape(self._get_distribution_strategy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -603,8 +603,10 @@ class ParameterServerStrategyTestBase(
|
||||
self.assertEqual(expected_value, computed_value)
|
||||
|
||||
|
||||
class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
|
||||
class ParameterServerStrategyTest(
|
||||
ParameterServerStrategyTestBase,
|
||||
strategy_test_lib.DistributionTestBase,
|
||||
strategy_test_lib.TwoDeviceDistributionTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
@ -782,6 +784,36 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
|
||||
# Verify isolate_session_state
|
||||
self.assertTrue(new_config.isolate_session_state)
|
||||
|
||||
def testAllReduceSum(self):
|
||||
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||
num_gpus_per_worker=2)
|
||||
self._test_all_reduce_sum(distribution)
|
||||
|
||||
def testAllReduceSumGradients(self):
|
||||
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||
num_gpus_per_worker=2)
|
||||
self._test_all_reduce_sum_gradients(distribution)
|
||||
|
||||
def testAllReduceSumGradientTape(self):
|
||||
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||
num_gpus_per_worker=2)
|
||||
self._test_all_reduce_sum_gradient_tape(distribution)
|
||||
|
||||
def testAllReduceMean(self):
|
||||
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||
num_gpus_per_worker=2)
|
||||
self._test_all_reduce_mean(distribution)
|
||||
|
||||
def testAllReduceMeanGradients(self):
|
||||
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||
num_gpus_per_worker=2)
|
||||
self._test_all_reduce_mean_gradients(distribution)
|
||||
|
||||
def testAllReduceMeanGradientTape(self):
|
||||
distribution = parameter_server_strategy.ParameterServerStrategy(
|
||||
num_gpus_per_worker=2)
|
||||
self._test_all_reduce_mean_gradient_tape(distribution)
|
||||
|
||||
|
||||
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
|
||||
parameterized.TestCase):
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values
|
||||
@ -31,6 +32,7 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.layers import core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
@ -292,3 +294,163 @@ class DistributionTestBase(test.TestCase):
|
||||
global_step_tensors = strategy.unwrap(value)
|
||||
global_step_values = self.evaluate(global_step_tensors)
|
||||
self.assertEqual((1,) * len(global_step_tensors), global_step_values)
|
||||
|
||||
|
||||
class OneDeviceDistributionTestBase(test.TestCase):
|
||||
"""Some tests that should work with any one-device DistributionStrategy."""
|
||||
|
||||
def _test_all_reduce_sum(self, strategy):
|
||||
self._test_collective_comms(
|
||||
strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.]))
|
||||
|
||||
def _test_all_reduce_sum_gradients(self, strategy):
|
||||
self._test_collective_comms_gradients(
|
||||
strategy, _all_sum, inputs=[4.], expected_grads=[4.])
|
||||
|
||||
def _test_all_reduce_sum_gradient_tape(self, strategy):
|
||||
self._test_collective_comms_gradient_tape(
|
||||
strategy, _all_sum, inputs=[4.], expected_grads=[4.])
|
||||
|
||||
def _test_all_reduce_mean(self, strategy):
|
||||
self._test_collective_comms(
|
||||
strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.]))
|
||||
|
||||
def _test_all_reduce_mean_gradients(self, strategy):
|
||||
self._test_collective_comms_gradients(
|
||||
strategy, _all_mean, inputs=[5.], expected_grads=[5.])
|
||||
|
||||
def _test_all_reduce_mean_gradient_tape(self, strategy):
|
||||
self._test_collective_comms_gradient_tape(
|
||||
strategy, _all_mean, inputs=[5.], expected_grads=[5.])
|
||||
|
||||
def _test_collective_comms(self, strategy, comm_fn, inputs, expected):
|
||||
inputs = strategy.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensors(inputs))
|
||||
|
||||
self.evaluate(inputs.initialize())
|
||||
outputs = self.evaluate(
|
||||
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
|
||||
self.assertAllEqual([expected[0]], outputs[0])
|
||||
self.assertAllEqual([expected[1]], outputs[1])
|
||||
|
||||
def _test_collective_comms_gradients(
|
||||
self, strategy, comm_fn, inputs, expected_grads):
|
||||
if context.executing_eagerly():
|
||||
self.skipTest("`tf.gradients` is not supported with eager execution.")
|
||||
|
||||
def step(c):
|
||||
x = constant_op.constant(42.)
|
||||
y = comm_fn(x) * c
|
||||
return gradients_impl.gradients(y, [x])[0]
|
||||
|
||||
inputs = strategy.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensors(inputs))
|
||||
|
||||
self.evaluate(inputs.initialize())
|
||||
self.assertAllEqual(
|
||||
expected_grads,
|
||||
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||
|
||||
def _test_collective_comms_gradient_tape(
|
||||
self, strategy, comm_fn, inputs, expected_grads):
|
||||
def step(c):
|
||||
x = constant_op.constant(42.)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = comm_fn(x) * c
|
||||
return tape.gradient(y, x)
|
||||
|
||||
inputs = strategy.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensors(inputs))
|
||||
|
||||
self.evaluate(inputs.initialize())
|
||||
self.assertAllEqual(
|
||||
expected_grads,
|
||||
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||
|
||||
|
||||
class TwoDeviceDistributionTestBase(test.TestCase):
|
||||
"""Some tests that should work with any two-device DistributionStrategy."""
|
||||
|
||||
def _test_all_reduce_sum(self, strategy):
|
||||
self._test_collective_comms(
|
||||
strategy, _all_sum,
|
||||
inputs=([1., 3.], [[39., 2.], [3., 41.]]),
|
||||
expected=(4., [42., 43.]))
|
||||
|
||||
def _test_all_reduce_sum_gradients(self, strategy):
|
||||
self._test_collective_comms_gradients(
|
||||
strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.])
|
||||
|
||||
def _test_all_reduce_sum_gradient_tape(self, strategy):
|
||||
self._test_collective_comms_gradient_tape(
|
||||
strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.])
|
||||
|
||||
def _test_all_reduce_mean(self, strategy):
|
||||
self._test_collective_comms(
|
||||
strategy, _all_mean,
|
||||
inputs=([1., 3.], [[39., 2.], [3., 41.]]),
|
||||
expected=(2., [21., 21.5]))
|
||||
|
||||
def _test_all_reduce_mean_gradients(self, strategy):
|
||||
self._test_collective_comms_gradients(
|
||||
strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.])
|
||||
|
||||
def _test_all_reduce_mean_gradient_tape(self, strategy):
|
||||
self._test_collective_comms_gradient_tape(
|
||||
strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.])
|
||||
|
||||
def _test_collective_comms(self, strategy, comm_fn, inputs, expected):
|
||||
inputs = strategy.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
|
||||
|
||||
self.evaluate(inputs.initialize())
|
||||
outputs = self.evaluate(
|
||||
list(map(strategy.unwrap, strategy.experimental_run(comm_fn, inputs))))
|
||||
self.assertAllEqual([expected[0], expected[0]], outputs[0])
|
||||
self.assertAllEqual([expected[1], expected[1]], outputs[1])
|
||||
|
||||
def _test_collective_comms_gradients(
|
||||
self, strategy, comm_fn, inputs, expected_grads):
|
||||
if context.executing_eagerly():
|
||||
self.skipTest("`tf.gradients` is not supported with eager execution.")
|
||||
|
||||
def step(c):
|
||||
x = constant_op.constant(42.)
|
||||
y = comm_fn(x) * c
|
||||
return gradients_impl.gradients(y, [x])[0]
|
||||
|
||||
inputs = strategy.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
|
||||
|
||||
self.evaluate(inputs.initialize())
|
||||
self.assertAllEqual(
|
||||
expected_grads,
|
||||
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||
|
||||
def _test_collective_comms_gradient_tape(
|
||||
self, strategy, comm_fn, inputs, expected_grads):
|
||||
def step(c):
|
||||
x = constant_op.constant(42.)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = comm_fn(x) * c
|
||||
return tape.gradient(y, x)
|
||||
|
||||
inputs = strategy.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
|
||||
|
||||
self.evaluate(inputs.initialize())
|
||||
self.assertAllEqual(
|
||||
expected_grads,
|
||||
self.evaluate(strategy.unwrap(strategy.experimental_run(step, inputs))))
|
||||
|
||||
|
||||
def _all_sum(value):
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
|
||||
|
||||
|
||||
def _all_mean(value):
|
||||
ctx = ds_context.get_replica_context()
|
||||
return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.losses import losses_impl
|
||||
@ -1554,6 +1555,50 @@ class ReplicaContext(object):
|
||||
require_replica_context(self)
|
||||
return (device_util.current(),)
|
||||
|
||||
def all_reduce(self, reduce_op, value):
|
||||
"""All-reduces the given `Tensor` nest across replicas.
|
||||
|
||||
If `all_reduce` is called in any replica, it must be called in all replicas.
|
||||
The nested structure and `Tensor` shapes must be identical in all replicas.
|
||||
|
||||
IMPORTANT: The ordering of communications must be identical in all replicas.
|
||||
|
||||
Example with two replicas:
|
||||
Replica 0 `value`: {'a': 1, 'b': [40, 1]}
|
||||
Replica 1 `value`: {'a': 3, 'b': [ 2, 98]}
|
||||
|
||||
If `reduce_op` == `SUM`:
|
||||
Result (on all replicas): {'a': 4, 'b': [42, 99]}
|
||||
|
||||
If `reduce_op` == `MEAN`:
|
||||
Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
|
||||
|
||||
Args:
|
||||
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
|
||||
value: The nested structure of `Tensor`s to all-reduced.
|
||||
The structure must be compatible with `tf.nest`.
|
||||
|
||||
Returns:
|
||||
A `Tensor` nest with the reduced `value`s from each replica.
|
||||
"""
|
||||
def batch_all_reduce(strategy, *value_flat):
|
||||
return strategy.extended.batch_reduce_to(
|
||||
reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat])
|
||||
|
||||
if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]:
|
||||
# TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
|
||||
@custom_gradient.custom_gradient
|
||||
def grad_wrapper(*xs):
|
||||
ys = self.merge_call(batch_all_reduce, args=xs)
|
||||
# The gradient of an all-sum is itself an all-sum (all-mean, likewise).
|
||||
return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
|
||||
return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
|
||||
else:
|
||||
# TODO(cjfj): Implement gradients for other reductions.
|
||||
reduced = nest.pack_sequence_as(
|
||||
value, self.merge_call(batch_all_reduce, args=nest.flatten(value)))
|
||||
return nest.map_structure(array_ops.prevent_gradient, reduced)
|
||||
|
||||
# TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
|
||||
# all-reduce. It would return a function returning the result of reducing `t`
|
||||
# across all replicas. The caller would wait to call this function until they
|
||||
@ -1564,6 +1609,15 @@ class ReplicaContext(object):
|
||||
# to that point that the first result is needed. Most likely this can be
|
||||
# implemented in terms of `merge_call()` and `batch_reduce_to()`.
|
||||
|
||||
|
||||
def _batch_reduce_destination(x):
|
||||
"""Returns the destinations for batch all-reduce."""
|
||||
if isinstance(x, ops.Tensor): # One device strategies.
|
||||
return x.device
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
@ -26,6 +26,10 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "all_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "merge_call"
|
||||
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
|
@ -26,6 +26,10 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "all_reduce"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "merge_call"
|
||||
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user