Remove run_deprecated_v1 annotations from collective ops tests.

PiperOrigin-RevId: 317365063
Change-Id: Ibf13ad8629947becd40038d41ee213d3466b6292
This commit is contained in:
Ayush Dubey 2020-06-19 13:06:52 -07:00 committed by TensorFlower Gardener
parent 3ae2cf9610
commit c575e2ba93

View File

@ -104,39 +104,42 @@ class CollectiveOpTest(test.TestCase):
for i in range(group_size * num_instances): for i in range(group_size * num_instances):
self.assertAllClose(results[i], expected, rtol=1e-5, atol=1e-5) self.assertAllClose(results[i], expected, rtol=1e-5, atol=1e-5)
@test_util.run_deprecated_v1
def testCollectiveReduce(self): def testCollectiveReduce(self):
self._testCollectiveReduce( # Tests that execute collectives need to be enclosed in graph or tf.function
inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], with ops.Graph().as_default():
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]], self._testCollectiveReduce(
expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
set_graph_key=True) [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
set_graph_key=True)
@test_util.run_deprecated_v1
def testCollectiveAutoGraphKey(self): def testCollectiveAutoGraphKey(self):
self._testCollectiveReduce( # Tests that execute collectives need to be enclosed in graph or tf.function
inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], with ops.Graph().as_default():
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]], self._testCollectiveReduce(
expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
set_graph_key=False) [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
set_graph_key=False)
@test_util.run_deprecated_v1
def testFp16Reduce(self): def testFp16Reduce(self):
self._testCollectiveReduce( # Tests that execute collectives need to be enclosed in graph or tf.function
inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], with ops.Graph().as_default():
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]], self._testCollectiveReduce(
expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
set_graph_key=True, [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
fp16=True) expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
set_graph_key=True,
fp16=True)
@test_util.run_deprecated_v1
def testCollectiveMultipleConcurrentReduce(self): def testCollectiveMultipleConcurrentReduce(self):
self._testMultipleConcurrentCollectiveReduce( # Tests that execute collectives need to be enclosed in graph or tf.function
[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], with ops.Graph().as_default():
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3], self._testMultipleConcurrentCollectiveReduce(
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2]) [0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
@test_util.run_deprecated_v1
def testCollectiveTimeoutV1(self): def testCollectiveTimeoutV1(self):
timeout = 4.5 timeout = 4.5
kwargs = dict( kwargs = dict(
@ -145,14 +148,17 @@ class CollectiveOpTest(test.TestCase):
set_graph_key=True, set_graph_key=True,
timeout=timeout) timeout=timeout)
self._testCollectiveReduce(**kwargs) # Tests that execute collectives need to be enclosed in graph or tf.function
with ops.Graph().as_default():
self._testCollectiveReduce(**kwargs)
start_time = time.time() start_time = time.time()
with self.assertRaisesRegex( with ops.Graph().as_default():
errors.DeadlineExceededError, with self.assertRaisesRegex(
'Collective has timed out waiting for other workers'): errors.DeadlineExceededError,
self._testCollectiveReduce( 'Collective has timed out waiting for other workers'):
reported_group_size=len(kwargs['inputs']) + 1, **kwargs) self._testCollectiveReduce(
reported_group_size=len(kwargs['inputs']) + 1, **kwargs)
elapsed = time.time() - start_time elapsed = time.time() - start_time
self.assertAllGreaterEqual(elapsed, timeout) self.assertAllGreaterEqual(elapsed, timeout)
@ -199,17 +205,18 @@ class CollectiveOpTest(test.TestCase):
elapsed = time.time() - start_time elapsed = time.time() - start_time
self.assertAllGreaterEqual(elapsed, timeout) self.assertAllGreaterEqual(elapsed, timeout)
@test_util.run_deprecated_v1
def testNcclHintFallbackToRingReduce(self): def testNcclHintFallbackToRingReduce(self):
"""Tests that setting `communication_hint=nccl` works on non-GPU builds.""" """Tests that setting `communication_hint=nccl` works on non-GPU builds."""
if kernels.get_registered_kernels_for_op('NcclAllReduce'): if kernels.get_registered_kernels_for_op('NcclAllReduce'):
self.skipTest('Run only on non-GPU environments') self.skipTest('Run only on non-GPU environments')
self._testCollectiveReduce( # Tests that execute collectives need to be enclosed in graph or tf.function
inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1], with ops.Graph().as_default():
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]], self._testCollectiveReduce(
expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2], inputs=[[0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1],
set_graph_key=False, [0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3]],
communication_hint='nccl') expected=[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2],
set_graph_key=False,
communication_hint='nccl')
def _testWhile(self, num_vars, num_iterations, key_base): def _testWhile(self, num_vars, num_iterations, key_base):
group_size = 2 group_size = 2
@ -262,15 +269,16 @@ class CollectiveOpTest(test.TestCase):
[((1 << (num_iterations + v)) * 1.) for v in range(num_vars)] [((1 << (num_iterations + v)) * 1.) for v in range(num_vars)]
for _ in range(group_size)]) for _ in range(group_size)])
@test_util.run_deprecated_v1
def testSimpleWhile(self): def testSimpleWhile(self):
self._testWhile(num_vars=1, num_iterations=4, key_base=20) # Tests that execute collectives need to be enclosed in graph or tf.function
with ops.Graph().as_default():
self._testWhile(num_vars=1, num_iterations=4, key_base=20)
@test_util.run_deprecated_v1
def testWhileMultipleAllReduce(self): def testWhileMultipleAllReduce(self):
self._testWhile(num_vars=2, num_iterations=4, key_base=20) # Tests that execute collectives need to be enclosed in graph or tf.function
with ops.Graph().as_default():
self._testWhile(num_vars=2, num_iterations=4, key_base=20)
@test_util.run_deprecated_v1
def testWhileWithScopedAllocator(self): def testWhileWithScopedAllocator(self):
group_size = 2 group_size = 2
group_key = 1 group_key = 1
@ -284,47 +292,52 @@ class CollectiveOpTest(test.TestCase):
del rewrite_options.scoped_allocator_opts.enable_op[:] del rewrite_options.scoped_allocator_opts.enable_op[:]
rewrite_options.scoped_allocator_opts.enable_op.append('CollectiveReduce') rewrite_options.scoped_allocator_opts.enable_op.append('CollectiveReduce')
with self.session(config=config) as sess: # Tests that execute collectives need to be enclosed in graph or tf.function
run_ops = [] with ops.Graph().as_default():
for i in range(group_size): with self.session(config=config) as sess:
with ops.device('CPU:%d' % i): run_ops = []
constant = constant_op.constant(0.) for i in range(group_size):
cond = lambda i: math_ops.less(i, 10.) with ops.device('CPU:%d' % i):
body = lambda i: math_ops.add(i, 1.) constant = constant_op.constant(0.)
input0 = control_flow_ops.while_loop(cond, body, [constant]) cond = lambda i: math_ops.less(i, 10.)
input1 = math_ops.add(constant, 5) body = lambda i: math_ops.add(i, 1.)
colred0 = collective_ops.all_reduce(input0, group_size, group_key, input0 = control_flow_ops.while_loop(cond, body, [constant])
instance_key0, 'Add', 'Id') input1 = math_ops.add(constant, 5)
colred1 = collective_ops.all_reduce(input1, group_size, group_key, colred0 = collective_ops.all_reduce(input0, group_size, group_key,
instance_key1, 'Add', 'Id') instance_key0, 'Add', 'Id')
run_ops.append(math_ops.add_n([colred0, colred1])) colred1 = collective_ops.all_reduce(input1, group_size, group_key,
results = sess.run(run_ops) instance_key1, 'Add', 'Id')
run_ops.append(math_ops.add_n([colred0, colred1]))
results = sess.run(run_ops)
self.assertEqual(results, [30., 30.]) self.assertEqual(results, [30., 30.])
@test_util.run_deprecated_v1
def testCollectiveReduceScalar(self): def testCollectiveReduceScalar(self):
self._testCollectiveReduce(inputs=[0.1, 0.3], expected=0.2, # Tests that execute collectives need to be enclosed in graph or tf.function
set_graph_key=True) with ops.Graph().as_default():
self._testCollectiveReduce(inputs=[0.1, 0.3], expected=0.2,
set_graph_key=True)
@test_util.run_deprecated_v1
def testCollectiveReduceMaximum(self): def testCollectiveReduceMaximum(self):
self._testCollectiveReduce( # Tests that execute collectives need to be enclosed in graph or tf.function
inputs=[[1., 20., 3., 40., 5.], [10., 2., 30., 4., 50.]], with ops.Graph().as_default():
expected=[10., 20., 30., 40., 50.], self._testCollectiveReduce(
set_graph_key=True, inputs=[[1., 20., 3., 40., 5.], [10., 2., 30., 4., 50.]],
instance_key=30, expected=[10., 20., 30., 40., 50.],
merge_op='Max', set_graph_key=True,
final_op='Id') instance_key=30,
merge_op='Max',
final_op='Id')
@test_util.run_deprecated_v1
def testCollectiveReduceMinimum(self): def testCollectiveReduceMinimum(self):
self._testCollectiveReduce( # Tests that execute collectives need to be enclosed in graph or tf.function
inputs=[[1., 20., 3., 40., 5.], [10., 2., 30., 4., 50.]], with ops.Graph().as_default():
expected=[1., 2., 3., 4., 5.], self._testCollectiveReduce(
set_graph_key=True, inputs=[[1., 20., 3., 40., 5.], [10., 2., 30., 4., 50.]],
instance_key=40, expected=[1., 2., 3., 4., 5.],
merge_op='Min', set_graph_key=True,
final_op='Id') instance_key=40,
merge_op='Min',
final_op='Id')
def _testCollectiveBroadcast(self, in_val): def _testCollectiveBroadcast(self, in_val):
group_key = 1 group_key = 1
@ -345,13 +358,15 @@ class CollectiveOpTest(test.TestCase):
self.assertAllClose(results[0], in_val, rtol=1e-5, atol=1e-5) self.assertAllClose(results[0], in_val, rtol=1e-5, atol=1e-5)
self.assertAllClose(results[1], in_val, rtol=1e-5, atol=1e-5) self.assertAllClose(results[1], in_val, rtol=1e-5, atol=1e-5)
@test_util.run_deprecated_v1
def testCollectiveBroadcast(self): def testCollectiveBroadcast(self):
self._testCollectiveBroadcast([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1]) # Tests that execute collectives need to be enclosed in graph or tf.function
with ops.Graph().as_default():
self._testCollectiveBroadcast([0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1])
@test_util.run_deprecated_v1
def testCollectiveBroadcastBool(self): def testCollectiveBroadcastBool(self):
self._testCollectiveBroadcast([True, False]) # Tests that execute collectives need to be enclosed in graph or tf.function
with ops.Graph().as_default():
self._testCollectiveBroadcast([True, False])
def _testCollectiveGather(self, t0, t1, expected, set_graph_key): def _testCollectiveGather(self, t0, t1, expected, set_graph_key):
group_key = 1 group_key = 1
@ -371,94 +386,101 @@ class CollectiveOpTest(test.TestCase):
self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5) self.assertAllClose(results[0], expected, rtol=1e-5, atol=1e-5)
self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5) self.assertAllClose(results[1], expected, rtol=1e-5, atol=1e-5)
@test_util.run_deprecated_v1
def testCollectiveGather(self): def testCollectiveGather(self):
self._testCollectiveGather([0, 1, 2, 3, 4, 5, 6, 7], # Tests that execute collectives need to be enclosed in graph or tf.function
[10, 11, 12, 13, 14, 15, 16, 17], with ops.Graph().as_default():
[0, 1, 2, 3, 4, 5, 6, 7, self._testCollectiveGather([0, 1, 2, 3, 4, 5, 6, 7],
10, 11, 12, 13, 14, 15, 16, 17], [10, 11, 12, 13, 14, 15, 16, 17],
True) [0, 1, 2, 3, 4, 5, 6, 7,
self._testCollectiveGather([[0, 1, 2, 3], [4, 5, 6, 7]], 10, 11, 12, 13, 14, 15, 16, 17],
[[10, 11, 12, 13], [14, 15, 16, 17]], True)
[[0, 1, 2, 3], [4, 5, 6, 7], self._testCollectiveGather([[0, 1, 2, 3], [4, 5, 6, 7]],
[10, 11, 12, 13], [14, 15, 16, 17]], [[10, 11, 12, 13], [14, 15, 16, 17]],
True) [[0, 1, 2, 3], [4, 5, 6, 7],
self._testCollectiveGather([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], [10, 11, 12, 13], [14, 15, 16, 17]],
[[[10, 11], [12, 13]], [[14, 15], [16, 17]]], True)
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], self._testCollectiveGather([[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
[[10, 11], [12, 13]], [[14, 15], [16, 17]]], [[[10, 11], [12, 13]], [[14, 15], [16, 17]]],
True) [[[0, 1], [2, 3]], [[4, 5], [6, 7]],
[[10, 11], [12, 13]], [[14, 15], [16, 17]]],
True)
@test_util.run_deprecated_v1
def testCollectiveGatherShapeMismatch(self): def testCollectiveGatherShapeMismatch(self):
group_key = 1 group_key = 1
instance_key = 1 instance_key = 1
t0 = [1, 2, 3, 4] t0 = [1, 2, 3, 4]
t1 = [5, 6, 7, 8] t1 = [5, 6, 7, 8]
t2 = [9, 10] t2 = [9, 10]
with self.session( # Tests that execute collectives need to be enclosed in graph or tf.function
config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess: with ops.Graph().as_default():
with ops.device('/CPU:0'): with self.session(
in0 = constant_op.constant(t0) config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
c0 = collective_ops.all_gather(in0, 2, group_key, instance_key) with ops.device('/CPU:0'):
with ops.device('/CPU:1'): in0 = constant_op.constant(t0)
in1 = constant_op.constant(t1) c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
in2 = constant_op.constant(t2) with ops.device('/CPU:1'):
c1 = collective_ops.all_gather(in1, 2, group_key, instance_key) in1 = constant_op.constant(t1)
c2 = collective_ops.all_gather(in2, 2, group_key, instance_key) in2 = constant_op.constant(t2)
run_options = config_pb2.RunOptions() c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
run_options.experimental.collective_graph_key = 1 c2 = collective_ops.all_gather(in2, 2, group_key, instance_key)
sess.run([c0, c1], options=run_options) run_options = config_pb2.RunOptions()
with self.assertRaisesRegexp(errors.InvalidArgumentError, run_options.experimental.collective_graph_key = 1
'Shape mismatch'): sess.run([c0, c1], options=run_options)
sess.run([c0, c2], options=run_options) with self.assertRaisesRegexp(errors.InvalidArgumentError,
'Shape mismatch'):
sess.run([c0, c2], options=run_options)
@test_util.run_deprecated_v1
def testCollectiveGatherShapeMismatchAcrossDevices(self): def testCollectiveGatherShapeMismatchAcrossDevices(self):
group_key = 1 group_key = 1
instance_key = 1 instance_key = 1
t0 = [1, 2, 3, 4] t0 = [1, 2, 3, 4]
t1 = [5, 6] t1 = [5, 6]
with self.session( # Tests that execute collectives need to be enclosed in graph or tf.function
config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess: with ops.Graph().as_default():
with ops.device('/CPU:0'): with self.session(
in0 = constant_op.constant(t0) config=config_pb2.ConfigProto(device_count={'CPU': 2})) as sess:
c0 = collective_ops.all_gather(in0, 2, group_key, instance_key) with ops.device('/CPU:0'):
with ops.device('/CPU:1'): in0 = constant_op.constant(t0)
in1 = constant_op.constant(t1) c0 = collective_ops.all_gather(in0, 2, group_key, instance_key)
c1 = collective_ops.all_gather(in1, 2, group_key, instance_key) with ops.device('/CPU:1'):
run_options = config_pb2.RunOptions() in1 = constant_op.constant(t1)
run_options.experimental.collective_graph_key = 1 c1 = collective_ops.all_gather(in1, 2, group_key, instance_key)
with self.assertRaisesRegexp(errors.InvalidArgumentError, run_options = config_pb2.RunOptions()
'Shape mismatch'): run_options.experimental.collective_graph_key = 1
sess.run([c0, c1], options=run_options) with self.assertRaisesRegexp(errors.InvalidArgumentError,
'Shape mismatch'):
sess.run([c0, c1], options=run_options)
@test_util.run_deprecated_v1
def testCollectiveGatherPolymorphicShape(self): def testCollectiveGatherPolymorphicShape(self):
t0 = [0, 1, 2, 3, 4, 5, 6, 7] t0 = [0, 1, 2, 3, 4, 5, 6, 7]
t1 = [10, 11, 12, 13, 14, 15, 16, 17] t1 = [10, 11, 12, 13, 14, 15, 16, 17]
group_size = 2 group_size = 2
group_key = 1 group_key = 1
instance_key = 123 instance_key = 123
with self.session( # Tests that execute collectives need to be enclosed in graph or tf.function
config=config_pb2.ConfigProto( with ops.Graph().as_default():
device_count={'CPU': group_size})) as sess: with self.session(
with ops.device('/CPU:0'): config=config_pb2.ConfigProto(
in0 = array_ops.placeholder(dtype=dtypes.int32, shape=[None]) device_count={'CPU': group_size})) as sess:
c0 = collective_ops.all_gather(in0, group_size, group_key, instance_key) with ops.device('/CPU:0'):
with ops.device('/CPU:1'): in0 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
in1 = array_ops.placeholder(dtype=dtypes.int32, shape=[None]) c0 = collective_ops.all_gather(in0, group_size, group_key,
c1 = collective_ops.all_gather(in1, group_size, group_key, instance_key) instance_key)
with ops.device('/CPU:1'):
in1 = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
c1 = collective_ops.all_gather(in1, group_size, group_key,
instance_key)
results = sess.run([c0, c1], feed_dict={in0: t0, in1: t1}) results = sess.run([c0, c1], feed_dict={in0: t0, in1: t1})
expected_output = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17] results_ = sess.run([c0, c1], feed_dict={in0: t0[1:], in1: t1[1:]})
self.assertAllClose(results[0], expected_output, rtol=1e-5, atol=1e-5)
self.assertAllClose(results[1], expected_output, rtol=1e-5, atol=1e-5)
results_ = sess.run([c0, c1], feed_dict={in0: t0[1:], in1: t1[1:]}) expected_output = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 16, 17]
expected_output_ = [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17] self.assertAllClose(results[0], expected_output, rtol=1e-5, atol=1e-5)
self.assertAllClose(results_[0], expected_output_, rtol=1e-5, atol=1e-5) self.assertAllClose(results[1], expected_output, rtol=1e-5, atol=1e-5)
self.assertAllClose(results_[1], expected_output_, rtol=1e-5, atol=1e-5)
expected_output_ = [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 14, 15, 16, 17]
self.assertAllClose(results_[0], expected_output_, rtol=1e-5, atol=1e-5)
self.assertAllClose(results_[1], expected_output_, rtol=1e-5, atol=1e-5)
@test_util.run_v2_only @test_util.run_v2_only
def testCollectiveGroupSizeMismatch(self): def testCollectiveGroupSizeMismatch(self):
@ -492,8 +514,17 @@ class CollectiveOpTest(test.TestCase):
'but that group has size'): 'but that group has size'):
run_all_reduce() run_all_reduce()
@test_util.run_deprecated_v1 @test_util.run_v2_only
def testCollectiveTensorsHaveNoDeviceSpecified(self): def testCollectiveTensorsHaveNoDeviceSpecified(self):
context._reset_context()
cpus = config.list_physical_devices('CPU')
self.assertEqual(len(cpus), 1)
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
context.ensure_initialized()
group_size = 2 group_size = 2
group_key = 1 group_key = 1
instance_key = 1 instance_key = 1
@ -517,20 +548,12 @@ class CollectiveOpTest(test.TestCase):
return results return results
with self.session(config=config_pb2.ConfigProto( with ops.device('/CPU:0'):
device_count={'CPU': 2})) as sess: in0 = constant_op.constant(1)
with ops.device('/CPU:0'): with ops.device('/CPU:1'):
in0 = constant_op.constant(1) in1 = constant_op.constant(3)
with ops.device('/CPU:1'): result = fn([in0, in1])
in1 = constant_op.constant(3) self.assertAllClose(result, [2, 2])
result_op = fn([in0, in1])
run_options = config_pb2.RunOptions()
run_options.experimental.collective_graph_key = 1
result = sess.run(result_op, options=run_options)
self.assertAllClose(result, [2, 2])
@test_util.run_v2_only @test_util.run_v2_only
def testCollectiveGroupSizeOne(self): def testCollectiveGroupSizeOne(self):
@ -548,7 +571,6 @@ class CollectiveOpTest(test.TestCase):
in_tensor, group_size, group_key, instance_key) in_tensor, group_size, group_key, instance_key)
self.assertAllEqual(in_value, gathered_tensor.numpy()) self.assertAllEqual(in_value, gathered_tensor.numpy())
@test_util.run_deprecated_v1
def testConstantWithScopedAllocator(self): def testConstantWithScopedAllocator(self):
group_size = 2 group_size = 2
group_key = 1 group_key = 1
@ -565,21 +587,25 @@ class CollectiveOpTest(test.TestCase):
del rewrite_options.scoped_allocator_opts.enable_op[:] del rewrite_options.scoped_allocator_opts.enable_op[:]
rewrite_options.scoped_allocator_opts.enable_op.append('CollectiveReduce') rewrite_options.scoped_allocator_opts.enable_op.append('CollectiveReduce')
with self.session(config=cfg) as sess: # Tests that execute collectives need to be enclosed in graph or tf.function
run_ops = [] with ops.Graph().as_default():
for i in range(group_size): with self.session(config=cfg) as sess:
with ops.device('CPU:%d' % i): run_ops = []
constant = constant_op.constant(i + 1.) for i in range(group_size):
input_tensor1 = array_ops.identity(constant) with ops.device('CPU:%d' % i):
input_tensor2 = array_ops.identity(constant) constant = constant_op.constant(i + 1.)
reduced_tensor1 = collective_ops.all_reduce( input_tensor1 = array_ops.identity(constant)
input_tensor1, group_size, group_key, instance_key1, 'Add', 'Id') input_tensor2 = array_ops.identity(constant)
reduced_tensor2 = collective_ops.all_reduce( reduced_tensor1 = collective_ops.all_reduce(
input_tensor2, group_size, group_key, instance_key2, 'Add', 'Id') input_tensor1, group_size, group_key, instance_key1, 'Add',
run_ops.append(array_ops.identity(reduced_tensor1)) 'Id')
run_ops.append(array_ops.identity(reduced_tensor2)) reduced_tensor2 = collective_ops.all_reduce(
results = sess.run(run_ops) input_tensor2, group_size, group_key, instance_key2, 'Add',
self.assertEqual(results, [3., 3., 3., 3.]) 'Id')
run_ops.append(array_ops.identity(reduced_tensor1))
run_ops.append(array_ops.identity(reduced_tensor2))
results = sess.run(run_ops)
self.assertEqual(results, [3., 3., 3., 3.])
@test_util.run_v2_only @test_util.run_v2_only
def testMultipleGroups(self): def testMultipleGroups(self):