STT-tensorflow/tensorflow/python/kernel_tests/collective_ops_test.py
Ayush Dubey e38610399a Add CollectiveBcastSend/RecvV2 which takes input tensors rather than attributes.
CollectiveBcastSend/Recv accepts the following inputs as attributes on the op:
group_size, group_key, and instance_key.  Attributes imply these values are
embedded in the NodeDef.

The use case motivating this change is a compact representation for SPMD
computation.  The goal is to change those inputs from the collective op which
can be accepted during runtime to tensors rather than attributes.  This enables
the graph builder to avoid early explosion of the SPMD program.

This op is not exposed in the `tf.` namespace for now, and should be considered
experimental.

PiperOrigin-RevId: 348049802
Change-Id: I15ffcb562fb75cc64e5970ac7795cb81ff187fd8
2020-12-17 10:45:19 -08:00

931 lines
31 KiB
Python

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for V2 Collective Operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import time
from absl.testing import parameterized
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.experimental.ops import testing as dataset_testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops as _collective_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
class CollectiveOpsV1(object):
all_reduce = _collective_ops.all_reduce
all_gather = _collective_ops.all_gather
broadcast_send = _collective_ops.broadcast_send
broadcast_recv = _collective_ops.broadcast_recv
class CollectiveOpsV2(object):
@staticmethod
def all_reduce(t, group_size, group_key, instance_key, *args, **kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
return _collective_ops.all_reduce_v2(t, group_size, group_key, instance_key,
*args, **kwargs)
@staticmethod
def all_gather(t, group_size, group_key, instance_key, *args, **kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
return _collective_ops.all_gather_v2(t, group_size, group_key, instance_key,
*args, **kwargs)
@staticmethod
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
*args, **kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
return _collective_ops.broadcast_send_v2(t, group_size, group_key,
instance_key, *args, **kwargs)
@staticmethod
def broadcast_recv(shape, dtype, group_size, group_key, instance_key, *args,
**kwargs):
group_size = array_ops.identity(group_size)
group_key = array_ops.identity(group_key)
instance_key = array_ops.identity(instance_key)
shape = array_ops.identity(shape)
return _collective_ops.broadcast_recv_v2(
shape, dtype, group_size, group_key, instance_key, *args, **kwargs)
device_combination = (
combinations.combine(device='CPU', communication='RING', required_gpus=0) +
combinations.combine(
device='GPU', communication=['RING', 'NCCL'], required_gpus=2))
collective_op_combinations = combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce', CollectiveOpsV1.all_reduce),
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather', CollectiveOpsV1.all_gather),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination)
@combinations.generate(
combinations.times(
combinations.combine(
collective_ops=[
combinations.NamedObject('v1', CollectiveOpsV1),
combinations.NamedObject('v2', CollectiveOpsV2)
],
mode='eager'), device_combination))
class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
def testReduce(self, collective_ops, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
@def_function.function
def run_all_reduce_1device():
with ops.device(dev0):
in_value = constant_op.constant([1.])
group_size = 1
group_key = 1
instance_key = 1
return collective_ops.all_reduce(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication)
@def_function.function
def run_all_reduce_2devices():
in_value = constant_op.constant([1.])
group_size = 2
group_key = 2
instance_key = 2
collectives = []
with ops.device(dev0):
collectives.append(
collective_ops.all_reduce(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication))
with ops.device(dev1):
collectives.append(
collective_ops.all_reduce(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication))
return collectives
self.assertAllClose(run_all_reduce_1device(), [1.], rtol=1e-5, atol=1e-5)
for result in run_all_reduce_2devices():
self.assertAllClose(result, [2.], rtol=1e-5, atol=1e-5)
def testGather(self, collective_ops, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
@def_function.function
def run_all_gather_1device():
with ops.device(dev0):
in_value = constant_op.constant([1.])
group_size = 1
group_key = 1
instance_key = 1
return collective_ops.all_gather(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication)
@def_function.function
def run_all_gather_2devices():
in_value = constant_op.constant([1.])
group_size = 2
group_key = 2
instance_key = 2
collectives = []
with ops.device(dev0):
collectives.append(
collective_ops.all_gather(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication))
with ops.device(dev1):
collectives.append(
collective_ops.all_gather(
in_value,
group_size,
group_key,
instance_key,
communication_hint=communication))
return collectives
self.assertAllClose(run_all_gather_1device(), [1.], rtol=1e-5, atol=1e-5)
for result in run_all_gather_2devices():
self.assertAllClose(result, [1., 1.], rtol=1e-5, atol=1e-5)
def testBroadcast(self, collective_ops, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
@def_function.function
def run_broadcast_2devices():
shape = [3]
in_value = constant_op.constant([1., 2., 3.], shape=shape)
group_size = 2
group_key = 2
instance_key = 2
collectives = []
with ops.device(dev0):
collectives.append(
collective_ops.broadcast_send(
in_value,
shape,
in_value.dtype,
group_size,
group_key,
instance_key,
communication_hint=communication))
with ops.device(dev1):
collectives.append(
collective_ops.broadcast_recv(
shape,
in_value.dtype,
group_size,
group_key,
instance_key,
communication_hint=communication))
return collectives
for result in run_broadcast_2devices():
self.assertAllClose(result, [1., 2., 3.], rtol=1e-5, atol=1e-5)
def testInstanceKeyScopedUnderGroupKey(self, collective_ops, device,
communication):
if device == 'GPU' and context.num_gpus() < 4:
self.skipTest('not enough GPU')
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
dev2 = '/device:%s:2' % device
dev3 = '/device:%s:3' % device
@def_function.function
def run_all_reduce_4devices_same_instance_key():
# Use a common instance key for both groups.
instance_key = 0
# We will create 2 groups each with 2 devices.
group_size = 2
# Group 0 comprises dev0 and dev1.
group0_key = 0
# Group 1 comprises dev2 and dev3.
group1_key = 1
collectives = []
with ops.device(dev0):
collectives.append(
collective_ops.all_reduce(
constant_op.constant(1.), group_size, group0_key, instance_key))
with ops.device(dev1):
collectives.append(
collective_ops.all_reduce(
constant_op.constant(2.), group_size, group0_key, instance_key))
with ops.device(dev2):
collectives.append(
collective_ops.all_reduce(
constant_op.constant(3.), group_size, group1_key, instance_key))
with ops.device(dev3):
collectives.append(
collective_ops.all_reduce(
constant_op.constant(4.), group_size, group1_key, instance_key))
return collectives
results = run_all_reduce_4devices_same_instance_key()
self.assertAllClose(results[0], 3., rtol=1e-5, atol=1e-5)
self.assertAllClose(results[1], 3., rtol=1e-5, atol=1e-5)
self.assertAllClose(results[2], 7., rtol=1e-5, atol=1e-5)
self.assertAllClose(results[3], 7., rtol=1e-5, atol=1e-5)
def testCollectiveGroupSizeOne(self, collective_ops, device, communication):
if communication == 'NCCL':
self.skipTest('b/170672646: it crashes with NCCL and group size one')
dev0 = '/device:%s:0' % device
group_size = 1
group_key = 100
instance_key = 100
in_value = [1., 2., 3., 4.]
in_tensor = constant_op.constant(in_value)
with ops.device(dev0):
reduced_tensor = collective_ops.all_reduce(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
self.assertAllEqual(in_value, reduced_tensor.numpy())
with ops.device(dev0):
gathered_tensor = collective_ops.all_gather(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
self.assertAllEqual(in_value, gathered_tensor.numpy())
def testMultipleGroups(self, collective_ops, device, communication):
if device == 'GPU' and context.num_gpus() < 4:
self.skipTest('not enough GPU')
num_elements = 4
@def_function.function
def run_all_reduce(group_size, group_key):
instance_key = group_key
input_value = [float(group_key) for i in range(num_elements)]
collectives = []
for device_idx in range(group_size):
with ops.device('/{}:{}'.format(device, device_idx)):
input_tensor = constant_op.constant(input_value)
collectives.append(
collective_ops.all_reduce(
input_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication))
return collectives
def run_and_assert(group_size, group_key):
for reduced_tensor in run_all_reduce(group_size, group_key):
self.assertAllEqual(
[float(group_key) * group_size for i in range(num_elements)],
reduced_tensor.numpy())
run_and_assert(group_size=2, group_key=1)
run_and_assert(group_size=3, group_key=2)
@combinations.generate(collective_op_combinations)
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
def testAbortGroupParamsResolution(self, collective_op, device,
communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
t = threading.Thread(target=abort_fn)
t.start()
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
# This hangs on params resolution since we're only launching one
# collective for a group size of 2.
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
t.join()
# Reset the context in order to reset the collective executor.
_setup_context()
# After reset non-NCCL collectives should work.
def collective_fn():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
def_function.function(collective_fn)()
def testAbortInstanceParamsResolution(self, collective_op, device,
communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
def collective_fn():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# First perform a normal all-reduce to complete the group resolution.
def_function.function(collective_fn)()
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
t = threading.Thread(target=abort_fn)
t.start()
# Use a different instance key to trigger another instance resolution.
instance_key = 101
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
# This hangs on params resolution since we're only launching one
# collective for a group size of 2.
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
context._reset_context() # pylint: disable=protected-access
t.join()
# Reset the context in order to reset the collective executor.
_setup_context()
# After reset non-NCCL collectives should work.
def_function.function(collective_fn)()
def testAbortCommunication(self, collective_op, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# First perform a normal collective to finish resolution.
def collective_fn():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
def_function.function(collective_fn)()
# Launch a collective that hangs, and abort the collective executor after
# the launch.
def abort_fn():
time.sleep(2)
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
t = threading.Thread(target=abort_fn)
t.start()
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# After abortion, subsequent collectives should fail immediately.
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# Reset the context in order to reset the collective executor.
t.join()
_setup_context()
def_function.function(collective_fn)()
class OpCancellationTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce',
CollectiveOpsV1.all_reduce),
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather',
CollectiveOpsV1.all_gather),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination))
def testOpErrorNotAbortIfNoCollective(self, collective_op, device,
communication):
# Do not abort if there's no active collective ops. There could be
# exceptions like EOF which we expect users to catch, aborting collective
# ops on all op errors intervenes with this workflow.
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
dataset = dataset_ops.Dataset.from_tensors([1.])
@def_function.function
def collective_fn(in_tensor):
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@def_function.function
def f():
iterator = iter(dataset)
collective_fn(next(iterator))
# This next(iterator) should raise EOF.
collective_fn(next(iterator))
with self.assertRaises(errors.OutOfRangeError):
f()
collective_fn(constant_op.constant([1.]))
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce',
CollectiveOpsV1.all_reduce),
combinations.NamedObject('all_gather',
CollectiveOpsV1.all_gather),
],
mode='eager'), device_combination))
def testOpErrorAbortWithCollective(self, collective_op, device,
communication):
# Abort v1 collective ops if there're active collective ops at the time of
# an op error. This is due to the inability to cancel collective ops, and op
# errors may cause running collective ops to hang.
dev0 = '/device:%s:0' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
# Make the dataset sleep a while so that the collective is being executed
# when the EOF happens.
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
dataset_testing.sleep(sleep_microseconds=200))
@def_function.function
def f():
# Launch a collective op that won't be able to finish to test abortion
# when other ops error.
with ops.device(dev0):
ret = collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
iterator = iter(dataset)
next(iterator)
# This should raise EOF.
next(iterator)
return ret
with self.assertRaises(errors.OutOfRangeError):
f()
# Now collective ops is aborted, subsequent collective ops should fail with
# the previous error.
with self.assertRaises(errors.CancelledError):
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination))
def testOpErrorNotAbortWithCollective(self, collective_op, device,
communication):
# Do not abort v2 collective ops even if there're active collective ops at
# the time of an op error. We rely cancellation to terminate active
# collective ops.
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
@def_function.function
def collective_fn():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
# Local params resolution cannot be cancelled yet, so we perform a normal
# collective so that the group is resolved.
collective_fn()
# Make the dataset sleep a while so that the collective is being executed
# when the EOF happens.
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
dataset_testing.sleep(sleep_microseconds=200))
@def_function.function
def f():
# Launch a collective op that won't be able to finish to test cancellation
# when other ops error.
with ops.device(dev0):
ret = collective_op(
in_tensor,
group_size,
group_key,
instance_key,
communication_hint=communication)
iterator = iter(dataset)
next(iterator)
# This should raise EOF.
next(iterator)
return ret
with self.assertRaises(errors.OutOfRangeError):
f()
# Collective ops shouldn't be aborted and new collectives should be able to
# proceed.
collective_fn()
@combinations.generate(collective_op_combinations)
class TimeoutTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
def testTimeout(self, collective_op, device, communication):
timeout = 1.5
@def_function.function
def run(group_size, reported_group_size=None):
group_key = 20
instance_key = 30
tensor = [1., 2., 3., 4.]
results = []
if reported_group_size is None:
reported_group_size = group_size
for i in range(group_size):
with ops.device('/{}:{}'.format(device, i)):
input_data = constant_op.constant(tensor)
result = collective_op(
input_data,
group_size=reported_group_size,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication,
timeout=timeout)
results.append(result)
return results
run(2, 2)
start_time = time.time()
with self.assertRaisesRegex(errors.DeadlineExceededError,
'Collective has timed out during execution'):
run(1, 2)
elapsed = time.time() - start_time
self.assertAllGreaterEqual(elapsed, timeout)
def testParamResolutionAfterTimeout(self, collective_op, device,
communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
timeout = 1.5
group_key = 20
instance_key = 30
input_data = constant_op.constant([1., 2., 3., 4.])
# This timeout comes from param solution.
with self.assertRaisesRegex(
errors.DeadlineExceededError,
'Collective has timed out waiting for other workers'):
with ops.device(dev0):
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication,
timeout=timeout)
# We launch the second device after the first device times out. This is to
# simulate the situation when other workers are slow and the timeout is
# short. It should error immediately.
with self.assertRaisesRegex(
errors.DeadlineExceededError,
'Collective has timed out waiting for other workers'):
with ops.device(dev1):
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication)
def testExecutionAfterTimeout(self, collective_op, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
timeout = 1.5
group_key = 20
instance_key = 30
input_data = constant_op.constant([1., 2., 3., 4.])
@def_function.function
def run():
for device in [dev0, dev1]:
with ops.device(device):
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication,
timeout=timeout)
# Run a normal all-reduce to complete param resolution.
run()
with self.assertRaisesRegex(errors.DeadlineExceededError,
'Collective has timed out during execution'):
with ops.device(dev0):
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication,
timeout=timeout)
# We launch the second device after the first device times out. This is to
# simulate the situation when other workers are slow and the timeout is
# short. It should error immediately.
with self.assertRaisesRegex(errors.DeadlineExceededError,
'Collective has timed out during execution'):
with ops.device(dev1):
# No timeout.
collective_op(
input_data,
group_size=2,
group_key=group_key,
instance_key=instance_key,
communication_hint=communication)
@combinations.generate(
combinations.times(
combinations.combine(
collective_op=[
combinations.NamedObject('all_reduce_v2',
CollectiveOpsV2.all_reduce),
combinations.NamedObject('all_gather_v2',
CollectiveOpsV2.all_gather),
],
mode='eager'), device_combination))
class OrderingTest(test.TestCase, parameterized.TestCase):
def setUp(self):
_setup_context()
super().setUp()
def testOrdering(self, collective_op, device, communication):
dev0 = '/device:%s:0' % device
dev1 = '/device:%s:1' % device
group_size = 2
group_key = 100
instance_key = 100
in_tensor = constant_op.constant([1.])
with ops.device(dev0):
token0 = resource_variable_ops.ResourceVariable(0.)
with ops.device(dev1):
token1 = resource_variable_ops.ResourceVariable(0.)
@def_function.function
def f():
# Launch the first collective with token.
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
ordering_token=token0.handle)
with ops.device(dev1):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
ordering_token=token1.handle)
# Launch the second collective without token.
with ops.device(dev0):
collective_op(in_tensor, group_size, group_key, instance_key)
with ops.device(dev1):
collective_op(in_tensor, group_size, group_key, instance_key)
# Launch the third collective with token.
with ops.device(dev0):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
ordering_token=token0.handle)
with ops.device(dev1):
collective_op(
in_tensor,
group_size,
group_key,
instance_key,
ordering_token=token1.handle)
graph = f.get_concrete_function().graph
for device in [dev0, dev1]:
# Try to find the third collective, which should have the first collective
# as a control input.
third = None
for op in graph.get_operations():
if (op.type.startswith('Collective') and op.device.endswith(device) and
op.control_inputs and
op.control_inputs[0].type.startswith('Collective')):
self.assertIsNone(third)
third = op
self.assertIsNotNone(third)
# Verify it's not the second collective by looking at the inputs.
self.assertTrue(any(v.dtype == dtypes.resource for v in third.inputs))
first = third.control_inputs[0]
self.assertEqual(third.device, first.device)
# Verify it's not the second collective by looking at the inputs.
self.assertTrue(any(v.dtype == dtypes.resource for v in first.inputs))
self.assertEmpty(first.control_inputs)
def _setup_context():
context._reset_context()
test_util.set_logical_devices_to_at_least('CPU', 4)
context.ensure_initialized()
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()