Even if this is only for the first collective that runs for a particular group, having a non-cancellable deadlock is pretty annoying. PiperOrigin-RevId: 355871763 Change-Id: I76383e9f7eb0670fce943643d632c0eaf33b8c92
1031 lines
34 KiB
Python
1031 lines
34 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for V2 Collective Operations."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import threading
|
|
import time
|
|
|
|
from absl.testing import parameterized
|
|
|
|
from tensorflow.python.compat import v2_compat
|
|
from tensorflow.python.data.experimental.ops import testing as dataset_testing
|
|
from tensorflow.python.data.ops import dataset_ops
|
|
from tensorflow.python.distribute import combinations
|
|
from tensorflow.python.distribute import test_util
|
|
from tensorflow.python.eager import cancellation
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import check_ops
|
|
from tensorflow.python.ops import collective_ops as _collective_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
class CollectiveOpsV1(object):
|
|
all_reduce = _collective_ops.all_reduce
|
|
all_gather = _collective_ops.all_gather
|
|
broadcast_send = _collective_ops.broadcast_send
|
|
broadcast_recv = _collective_ops.broadcast_recv
|
|
|
|
|
|
class CollectiveOpsV2(object):
|
|
|
|
@staticmethod
|
|
def all_reduce(t, group_size, group_key, instance_key, *args, **kwargs):
|
|
group_size = array_ops.identity(group_size)
|
|
group_key = array_ops.identity(group_key)
|
|
instance_key = array_ops.identity(instance_key)
|
|
return _collective_ops.all_reduce_v2(t, group_size, group_key, instance_key,
|
|
*args, **kwargs)
|
|
|
|
@staticmethod
|
|
def all_gather(t, group_size, group_key, instance_key, *args, **kwargs):
|
|
group_size = array_ops.identity(group_size)
|
|
group_key = array_ops.identity(group_key)
|
|
instance_key = array_ops.identity(instance_key)
|
|
return _collective_ops.all_gather_v2(t, group_size, group_key, instance_key,
|
|
*args, **kwargs)
|
|
|
|
@staticmethod
|
|
def broadcast_send(t, shape, dtype, group_size, group_key, instance_key,
|
|
*args, **kwargs):
|
|
group_size = array_ops.identity(group_size)
|
|
group_key = array_ops.identity(group_key)
|
|
instance_key = array_ops.identity(instance_key)
|
|
return _collective_ops.broadcast_send_v2(t, group_size, group_key,
|
|
instance_key, *args, **kwargs)
|
|
|
|
@staticmethod
|
|
def broadcast_recv(shape, dtype, group_size, group_key, instance_key, *args,
|
|
**kwargs):
|
|
group_size = array_ops.identity(group_size)
|
|
group_key = array_ops.identity(group_key)
|
|
instance_key = array_ops.identity(instance_key)
|
|
shape = array_ops.identity(shape)
|
|
return _collective_ops.broadcast_recv_v2(
|
|
shape, dtype, group_size, group_key, instance_key, *args, **kwargs)
|
|
|
|
|
|
device_combination = (
|
|
combinations.combine(device='CPU', communication='RING', required_gpus=0) +
|
|
combinations.combine(
|
|
device='GPU', communication=['RING', 'NCCL'], required_gpus=2))
|
|
|
|
|
|
collective_op_combinations = combinations.times(
|
|
combinations.combine(
|
|
collective_op=[
|
|
combinations.NamedObject('all_reduce', CollectiveOpsV1.all_reduce),
|
|
combinations.NamedObject('all_reduce_v2',
|
|
CollectiveOpsV2.all_reduce),
|
|
combinations.NamedObject('all_gather', CollectiveOpsV1.all_gather),
|
|
combinations.NamedObject('all_gather_v2',
|
|
CollectiveOpsV2.all_gather),
|
|
],
|
|
mode='eager'), device_combination)
|
|
|
|
|
|
@combinations.generate(
|
|
combinations.times(
|
|
combinations.combine(
|
|
collective_ops=[
|
|
combinations.NamedObject('v1', CollectiveOpsV1),
|
|
combinations.NamedObject('v2', CollectiveOpsV2)
|
|
],
|
|
mode='eager'), device_combination))
|
|
class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
_setup_context()
|
|
super().setUp()
|
|
|
|
def testReduce(self, collective_ops, device, communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
|
|
@def_function.function
|
|
def run_all_reduce_1device():
|
|
with ops.device(dev0):
|
|
in_value = constant_op.constant([1.])
|
|
group_size = 1
|
|
group_key = 1
|
|
instance_key = 1
|
|
return collective_ops.all_reduce(
|
|
in_value,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
@def_function.function
|
|
def run_all_reduce_2devices():
|
|
in_value = constant_op.constant([1.])
|
|
group_size = 2
|
|
group_key = 2
|
|
instance_key = 2
|
|
collectives = []
|
|
with ops.device(dev0):
|
|
collectives.append(
|
|
collective_ops.all_reduce(
|
|
in_value,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication))
|
|
with ops.device(dev1):
|
|
collectives.append(
|
|
collective_ops.all_reduce(
|
|
in_value,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication))
|
|
return collectives
|
|
|
|
self.assertAllClose(run_all_reduce_1device(), [1.], rtol=1e-5, atol=1e-5)
|
|
for result in run_all_reduce_2devices():
|
|
self.assertAllClose(result, [2.], rtol=1e-5, atol=1e-5)
|
|
|
|
def testGather(self, collective_ops, device, communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
|
|
@def_function.function
|
|
def run_all_gather_1device():
|
|
with ops.device(dev0):
|
|
in_value = constant_op.constant([1.])
|
|
group_size = 1
|
|
group_key = 1
|
|
instance_key = 1
|
|
return collective_ops.all_gather(
|
|
in_value,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
@def_function.function
|
|
def run_all_gather_2devices():
|
|
in_value = constant_op.constant([1.])
|
|
group_size = 2
|
|
group_key = 2
|
|
instance_key = 2
|
|
collectives = []
|
|
with ops.device(dev0):
|
|
collectives.append(
|
|
collective_ops.all_gather(
|
|
in_value,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication))
|
|
with ops.device(dev1):
|
|
collectives.append(
|
|
collective_ops.all_gather(
|
|
in_value,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication))
|
|
return collectives
|
|
|
|
self.assertAllClose(run_all_gather_1device(), [1.], rtol=1e-5, atol=1e-5)
|
|
for result in run_all_gather_2devices():
|
|
self.assertAllClose(result, [1., 1.], rtol=1e-5, atol=1e-5)
|
|
|
|
def testBroadcast(self, collective_ops, device, communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
|
|
@def_function.function
|
|
def run_broadcast_2devices():
|
|
shape = [3]
|
|
in_value = constant_op.constant([1., 2., 3.], shape=shape)
|
|
group_size = 2
|
|
group_key = 2
|
|
instance_key = 2
|
|
collectives = []
|
|
with ops.device(dev0):
|
|
collectives.append(
|
|
collective_ops.broadcast_send(
|
|
in_value,
|
|
shape,
|
|
in_value.dtype,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication))
|
|
with ops.device(dev1):
|
|
collectives.append(
|
|
collective_ops.broadcast_recv(
|
|
shape,
|
|
in_value.dtype,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication))
|
|
return collectives
|
|
|
|
for result in run_broadcast_2devices():
|
|
self.assertAllClose(result, [1., 2., 3.], rtol=1e-5, atol=1e-5)
|
|
|
|
def testInstanceKeyScopedUnderGroupKey(self, collective_ops, device,
|
|
communication):
|
|
if device == 'GPU' and context.num_gpus() < 4:
|
|
self.skipTest('not enough GPU')
|
|
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
dev2 = '/device:%s:2' % device
|
|
dev3 = '/device:%s:3' % device
|
|
|
|
@def_function.function
|
|
def run_all_reduce_4devices_same_instance_key():
|
|
# Use a common instance key for both groups.
|
|
instance_key = 0
|
|
# We will create 2 groups each with 2 devices.
|
|
group_size = 2
|
|
# Group 0 comprises dev0 and dev1.
|
|
group0_key = 0
|
|
# Group 1 comprises dev2 and dev3.
|
|
group1_key = 1
|
|
collectives = []
|
|
with ops.device(dev0):
|
|
collectives.append(
|
|
collective_ops.all_reduce(
|
|
constant_op.constant(1.), group_size, group0_key, instance_key))
|
|
with ops.device(dev1):
|
|
collectives.append(
|
|
collective_ops.all_reduce(
|
|
constant_op.constant(2.), group_size, group0_key, instance_key))
|
|
with ops.device(dev2):
|
|
collectives.append(
|
|
collective_ops.all_reduce(
|
|
constant_op.constant(3.), group_size, group1_key, instance_key))
|
|
with ops.device(dev3):
|
|
collectives.append(
|
|
collective_ops.all_reduce(
|
|
constant_op.constant(4.), group_size, group1_key, instance_key))
|
|
return collectives
|
|
|
|
results = run_all_reduce_4devices_same_instance_key()
|
|
self.assertAllClose(results[0], 3., rtol=1e-5, atol=1e-5)
|
|
self.assertAllClose(results[1], 3., rtol=1e-5, atol=1e-5)
|
|
self.assertAllClose(results[2], 7., rtol=1e-5, atol=1e-5)
|
|
self.assertAllClose(results[3], 7., rtol=1e-5, atol=1e-5)
|
|
|
|
def testCollectiveGroupSizeOne(self, collective_ops, device, communication):
|
|
dev0 = '/device:%s:0' % device
|
|
|
|
group_size = 1
|
|
group_key = 100
|
|
in_value = [1., 2., 3., 4.]
|
|
in_tensor = constant_op.constant(in_value)
|
|
|
|
with ops.device(dev0):
|
|
reduced_tensor = collective_ops.all_reduce(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key=100,
|
|
communication_hint=communication)
|
|
self.assertAllEqual(in_value, reduced_tensor.numpy())
|
|
|
|
with ops.device(dev0):
|
|
gathered_tensor = collective_ops.all_gather(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key=200,
|
|
communication_hint=communication)
|
|
self.assertAllEqual(in_value, gathered_tensor.numpy())
|
|
|
|
def testCollectiveInvalidKey(self, collective_ops, device, communication):
|
|
dev0 = '/device:%s:0' % device
|
|
|
|
group_size = 1
|
|
group_key = 100
|
|
instance_key = 100
|
|
in_value = [1., 2., 3., 4.]
|
|
in_tensor = constant_op.constant(in_value)
|
|
|
|
with ops.device(dev0):
|
|
reduced_tensor = collective_ops.all_reduce(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
self.assertAllEqual(in_value, reduced_tensor.numpy())
|
|
|
|
with self.assertRaisesRegex(
|
|
errors.InternalError, 'instance 100 expected type 0 and data_type 1 but'
|
|
' got type 2 and data_type 1'):
|
|
with ops.device(dev0):
|
|
collective_ops.all_gather(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
def testMultipleGroups(self, collective_ops, device, communication):
|
|
if device == 'GPU' and context.num_gpus() < 4:
|
|
self.skipTest('not enough GPU')
|
|
|
|
num_elements = 4
|
|
|
|
@def_function.function
|
|
def run_all_reduce(group_size, group_key):
|
|
instance_key = group_key
|
|
input_value = [float(group_key) for i in range(num_elements)]
|
|
collectives = []
|
|
for device_idx in range(group_size):
|
|
with ops.device('/{}:{}'.format(device, device_idx)):
|
|
input_tensor = constant_op.constant(input_value)
|
|
collectives.append(
|
|
collective_ops.all_reduce(
|
|
input_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication))
|
|
return collectives
|
|
|
|
def run_and_assert(group_size, group_key):
|
|
for reduced_tensor in run_all_reduce(group_size, group_key):
|
|
self.assertAllEqual(
|
|
[float(group_key) * group_size for i in range(num_elements)],
|
|
reduced_tensor.numpy())
|
|
|
|
run_and_assert(group_size=2, group_key=1)
|
|
run_and_assert(group_size=3, group_key=2)
|
|
|
|
|
|
@combinations.generate(collective_op_combinations)
|
|
class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
_setup_context()
|
|
super().setUp()
|
|
|
|
def testAbortGroupParamsResolution(self, collective_op, device,
|
|
communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
group_size = 2
|
|
group_key = 100
|
|
instance_key = 100
|
|
in_tensor = constant_op.constant([1.])
|
|
|
|
def abort_fn():
|
|
time.sleep(2)
|
|
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
|
|
|
|
t = threading.Thread(target=abort_fn)
|
|
t.start()
|
|
|
|
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
|
|
# This hangs on params resolution since we're only launching one
|
|
# collective for a group size of 2.
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
# After abortion, subsequent collectives should fail immediately.
|
|
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
t.join()
|
|
# Reset the context in order to reset the collective executor.
|
|
_setup_context()
|
|
|
|
# After reset non-NCCL collectives should work.
|
|
def collective_fn():
|
|
for device in [dev0, dev1]:
|
|
with ops.device(device):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
def_function.function(collective_fn)()
|
|
|
|
def testAbortInstanceParamsResolution(self, collective_op, device,
|
|
communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
group_size = 2
|
|
group_key = 100
|
|
instance_key = 100
|
|
in_tensor = constant_op.constant([1.])
|
|
|
|
def collective_fn():
|
|
for device in [dev0, dev1]:
|
|
with ops.device(device):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
# First perform a normal all-reduce to complete the group resolution.
|
|
def_function.function(collective_fn)()
|
|
|
|
def abort_fn():
|
|
time.sleep(2)
|
|
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
|
|
|
|
t = threading.Thread(target=abort_fn)
|
|
t.start()
|
|
|
|
# Use a different instance key to trigger another instance resolution.
|
|
instance_key = 101
|
|
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
|
|
# This hangs on params resolution since we're only launching one
|
|
# collective for a group size of 2.
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
# After abortion, subsequent collectives should fail immediately.
|
|
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
context._reset_context() # pylint: disable=protected-access
|
|
t.join()
|
|
# Reset the context in order to reset the collective executor.
|
|
_setup_context()
|
|
|
|
# After reset non-NCCL collectives should work.
|
|
def_function.function(collective_fn)()
|
|
|
|
def testAbortCommunication(self, collective_op, device, communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
group_size = 2
|
|
group_key = 100
|
|
instance_key = 100
|
|
in_tensor = constant_op.constant([1.])
|
|
|
|
# First perform a normal collective to finish resolution.
|
|
def collective_fn():
|
|
for device in [dev0, dev1]:
|
|
with ops.device(device):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
def_function.function(collective_fn)()
|
|
|
|
# Launch a collective that hangs, and abort the collective executor after
|
|
# the launch.
|
|
def abort_fn():
|
|
time.sleep(2)
|
|
context.context().abort_collective_ops(errors.UNAVAILABLE, 'peer down')
|
|
|
|
t = threading.Thread(target=abort_fn)
|
|
t.start()
|
|
|
|
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
# After abortion, subsequent collectives should fail immediately.
|
|
with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
# Reset the context in order to reset the collective executor.
|
|
t.join()
|
|
_setup_context()
|
|
def_function.function(collective_fn)()
|
|
|
|
|
|
class OpCancellationTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
_setup_context()
|
|
super().setUp()
|
|
|
|
@combinations.generate(
|
|
combinations.times(
|
|
combinations.combine(
|
|
collective_op=[
|
|
combinations.NamedObject('all_reduce',
|
|
CollectiveOpsV1.all_reduce),
|
|
combinations.NamedObject('all_reduce_v2',
|
|
CollectiveOpsV2.all_reduce),
|
|
combinations.NamedObject('all_gather',
|
|
CollectiveOpsV1.all_gather),
|
|
combinations.NamedObject('all_gather_v2',
|
|
CollectiveOpsV2.all_gather),
|
|
],
|
|
mode='eager'), device_combination))
|
|
def testOpErrorNotAbortIfNoCollective(self, collective_op, device,
|
|
communication):
|
|
# Do not abort if there's no active collective ops. There could be
|
|
# exceptions like EOF which we expect users to catch, aborting collective
|
|
# ops on all op errors intervenes with this workflow.
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
group_size = 2
|
|
group_key = 100
|
|
instance_key = 100
|
|
dataset = dataset_ops.Dataset.from_tensors([1.])
|
|
|
|
@def_function.function
|
|
def collective_fn(in_tensor):
|
|
for device in [dev0, dev1]:
|
|
with ops.device(device):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
@def_function.function
|
|
def f():
|
|
iterator = iter(dataset)
|
|
collective_fn(next(iterator))
|
|
# This next(iterator) should raise EOF.
|
|
collective_fn(next(iterator))
|
|
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
f()
|
|
collective_fn(constant_op.constant([1.]))
|
|
|
|
@combinations.generate(
|
|
combinations.times(
|
|
combinations.combine(
|
|
collective_op=[
|
|
combinations.NamedObject('all_reduce',
|
|
CollectiveOpsV1.all_reduce),
|
|
combinations.NamedObject('all_gather',
|
|
CollectiveOpsV1.all_gather),
|
|
],
|
|
mode='eager'), device_combination))
|
|
def testOpErrorAbortWithCollective(self, collective_op, device,
|
|
communication):
|
|
# Abort v1 collective ops if there're active collective ops at the time of
|
|
# an op error. This is due to the inability to cancel collective ops, and op
|
|
# errors may cause running collective ops to hang.
|
|
dev0 = '/device:%s:0' % device
|
|
group_size = 2
|
|
group_key = 100
|
|
instance_key = 100
|
|
in_tensor = constant_op.constant([1.])
|
|
# Make the dataset sleep a while so that the collective is being executed
|
|
# when the EOF happens.
|
|
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
|
|
dataset_testing.sleep(sleep_microseconds=200))
|
|
|
|
@def_function.function
|
|
def f():
|
|
# Launch a collective op that won't be able to finish to test abortion
|
|
# when other ops error.
|
|
with ops.device(dev0):
|
|
ret = collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
iterator = iter(dataset)
|
|
next(iterator)
|
|
# This should raise EOF.
|
|
next(iterator)
|
|
return ret
|
|
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
f()
|
|
# Now collective ops is aborted, subsequent collective ops should fail with
|
|
# the previous error.
|
|
with self.assertRaises(errors.CancelledError):
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
@combinations.generate(
|
|
combinations.times(
|
|
combinations.combine(
|
|
collective_op=[
|
|
combinations.NamedObject('all_reduce_v2',
|
|
CollectiveOpsV2.all_reduce),
|
|
combinations.NamedObject('all_gather_v2',
|
|
CollectiveOpsV2.all_gather),
|
|
],
|
|
mode='eager'), device_combination))
|
|
def testOpErrorNotAbortWithCollective(self, collective_op, device,
|
|
communication):
|
|
# Do not abort v2 collective ops even if there're active collective ops at
|
|
# the time of an op error. We rely cancellation to terminate active
|
|
# collective ops.
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
group_size = 2
|
|
group_key = 100
|
|
instance_key = 100
|
|
in_tensor = constant_op.constant([1.])
|
|
|
|
@def_function.function
|
|
def collective_fn():
|
|
for device in [dev0, dev1]:
|
|
with ops.device(device):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
# Local params resolution cannot be cancelled yet, so we perform a normal
|
|
# collective so that the group is resolved.
|
|
collective_fn()
|
|
|
|
# Make the dataset sleep a while so that the collective is being executed
|
|
# when the EOF happens.
|
|
dataset = dataset_ops.Dataset.from_tensors([1.]).apply(
|
|
dataset_testing.sleep(sleep_microseconds=200))
|
|
|
|
@def_function.function
|
|
def f():
|
|
# Launch a collective op that won't be able to finish to test cancellation
|
|
# when other ops error.
|
|
with ops.device(dev0):
|
|
ret = collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
iterator = iter(dataset)
|
|
next(iterator)
|
|
# This should raise EOF.
|
|
next(iterator)
|
|
return ret
|
|
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
f()
|
|
# Collective ops shouldn't be aborted and new collectives should be able to
|
|
# proceed.
|
|
collective_fn()
|
|
|
|
@combinations.generate(
|
|
combinations.times(
|
|
combinations.combine(
|
|
collective_op=[
|
|
combinations.NamedObject('all_reduce_v2',
|
|
CollectiveOpsV2.all_reduce),
|
|
combinations.NamedObject('all_gather_v2',
|
|
CollectiveOpsV2.all_gather),
|
|
],
|
|
mode='eager'), device_combination))
|
|
def testCancelDuringParamResolution(self, collective_op, device,
|
|
communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
group_size = 2
|
|
group_key = 100
|
|
instance_key = 100
|
|
in_tensor = constant_op.constant([1.])
|
|
t1_cancellation_manager = cancellation.CancellationManager()
|
|
t2_cancellation_manager = cancellation.CancellationManager()
|
|
|
|
@def_function.function
|
|
def _collective_fn(x):
|
|
# Run an assertion to crash one of the two function executions running
|
|
# collectives. We explicitly cancel the other in response.
|
|
assert_op = check_ops.assert_equal(x, in_tensor)
|
|
with ops.control_dependencies([assert_op]):
|
|
return collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
communication_hint=communication)
|
|
|
|
collective_concrete = _collective_fn.get_concrete_function(in_tensor)
|
|
|
|
finish_mu = threading.Lock()
|
|
finishes = 0
|
|
|
|
def _placement_wrapper(device, x, my_cancellation, other_cancellation):
|
|
try:
|
|
with ops.device(device):
|
|
cancelable_collective = my_cancellation.get_cancelable_function(
|
|
collective_concrete)
|
|
return cancelable_collective(x)
|
|
except errors.InvalidArgumentError:
|
|
# `assert_equal` failed for this execution of the function. The other
|
|
# function would deadlock without cancellation.
|
|
other_cancellation.start_cancel()
|
|
except errors.CancelledError:
|
|
pass
|
|
nonlocal finishes
|
|
with finish_mu:
|
|
finishes += 1
|
|
|
|
t1 = threading.Thread(
|
|
target=_placement_wrapper,
|
|
args=(dev0, constant_op.constant([1.]), t1_cancellation_manager,
|
|
t2_cancellation_manager))
|
|
t2 = threading.Thread(
|
|
target=_placement_wrapper,
|
|
# Will cause the assertion to fail
|
|
args=(dev1, constant_op.constant([2.]), t2_cancellation_manager,
|
|
t1_cancellation_manager))
|
|
t1.start()
|
|
t2.start()
|
|
t1.join()
|
|
t2.join()
|
|
self.assertEqual(finishes, 2)
|
|
|
|
|
|
@combinations.generate(collective_op_combinations)
|
|
class TimeoutTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
_setup_context()
|
|
super().setUp()
|
|
|
|
def testTimeout(self, collective_op, device, communication):
|
|
timeout = 1.5
|
|
|
|
@def_function.function
|
|
def run(group_size, reported_group_size=None):
|
|
group_key = 20
|
|
instance_key = 30
|
|
tensor = [1., 2., 3., 4.]
|
|
results = []
|
|
if reported_group_size is None:
|
|
reported_group_size = group_size
|
|
for i in range(group_size):
|
|
with ops.device('/{}:{}'.format(device, i)):
|
|
input_data = constant_op.constant(tensor)
|
|
result = collective_op(
|
|
input_data,
|
|
group_size=reported_group_size,
|
|
group_key=group_key,
|
|
instance_key=instance_key,
|
|
communication_hint=communication,
|
|
timeout=timeout)
|
|
results.append(result)
|
|
return results
|
|
|
|
run(2, 2)
|
|
|
|
start_time = time.time()
|
|
with self.assertRaisesRegex(errors.DeadlineExceededError,
|
|
'Collective has timed out during execution'):
|
|
run(1, 2)
|
|
elapsed = time.time() - start_time
|
|
self.assertAllGreaterEqual(elapsed, timeout)
|
|
|
|
def testParamResolutionAfterTimeout(self, collective_op, device,
|
|
communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
timeout = 1.5
|
|
group_key = 20
|
|
instance_key = 30
|
|
input_data = constant_op.constant([1., 2., 3., 4.])
|
|
|
|
# This timeout comes from param solution.
|
|
with self.assertRaisesRegex(
|
|
errors.DeadlineExceededError,
|
|
'Collective has timed out waiting for other workers'):
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
input_data,
|
|
group_size=2,
|
|
group_key=group_key,
|
|
instance_key=instance_key,
|
|
communication_hint=communication,
|
|
timeout=timeout)
|
|
|
|
# We launch the second device after the first device times out. This is to
|
|
# simulate the situation when other workers are slow and the timeout is
|
|
# short. It should error immediately.
|
|
with self.assertRaisesRegex(
|
|
errors.DeadlineExceededError,
|
|
'Collective has timed out waiting for other workers'):
|
|
with ops.device(dev1):
|
|
collective_op(
|
|
input_data,
|
|
group_size=2,
|
|
group_key=group_key,
|
|
instance_key=instance_key,
|
|
communication_hint=communication)
|
|
|
|
def testExecutionAfterTimeout(self, collective_op, device, communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
timeout = 1.5
|
|
group_key = 20
|
|
instance_key = 30
|
|
input_data = constant_op.constant([1., 2., 3., 4.])
|
|
|
|
@def_function.function
|
|
def run():
|
|
for device in [dev0, dev1]:
|
|
with ops.device(device):
|
|
collective_op(
|
|
input_data,
|
|
group_size=2,
|
|
group_key=group_key,
|
|
instance_key=instance_key,
|
|
communication_hint=communication,
|
|
timeout=timeout)
|
|
|
|
# Run a normal all-reduce to complete param resolution.
|
|
run()
|
|
|
|
with self.assertRaisesRegex(errors.DeadlineExceededError,
|
|
'Collective has timed out during execution'):
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
input_data,
|
|
group_size=2,
|
|
group_key=group_key,
|
|
instance_key=instance_key,
|
|
communication_hint=communication,
|
|
timeout=timeout)
|
|
|
|
# We launch the second device after the first device times out. This is to
|
|
# simulate the situation when other workers are slow and the timeout is
|
|
# short. It should error immediately.
|
|
with self.assertRaisesRegex(errors.DeadlineExceededError,
|
|
'Collective has timed out during execution'):
|
|
with ops.device(dev1):
|
|
# No timeout.
|
|
collective_op(
|
|
input_data,
|
|
group_size=2,
|
|
group_key=group_key,
|
|
instance_key=instance_key,
|
|
communication_hint=communication)
|
|
|
|
|
|
@combinations.generate(
|
|
combinations.times(
|
|
combinations.combine(
|
|
collective_op=[
|
|
combinations.NamedObject('all_reduce_v2',
|
|
CollectiveOpsV2.all_reduce),
|
|
combinations.NamedObject('all_gather_v2',
|
|
CollectiveOpsV2.all_gather),
|
|
],
|
|
mode='eager'), device_combination))
|
|
class OrderingTest(test.TestCase, parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
_setup_context()
|
|
super().setUp()
|
|
|
|
def testOrdering(self, collective_op, device, communication):
|
|
dev0 = '/device:%s:0' % device
|
|
dev1 = '/device:%s:1' % device
|
|
group_size = 2
|
|
group_key = 100
|
|
instance_key = 100
|
|
in_tensor = constant_op.constant([1.])
|
|
|
|
with ops.device(dev0):
|
|
token0 = resource_variable_ops.ResourceVariable(0.)
|
|
with ops.device(dev1):
|
|
token1 = resource_variable_ops.ResourceVariable(0.)
|
|
|
|
@def_function.function
|
|
def f():
|
|
# Launch the first collective with token.
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
ordering_token=token0.handle)
|
|
with ops.device(dev1):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
ordering_token=token1.handle)
|
|
# Launch the second collective without token.
|
|
with ops.device(dev0):
|
|
collective_op(in_tensor, group_size, group_key, instance_key)
|
|
with ops.device(dev1):
|
|
collective_op(in_tensor, group_size, group_key, instance_key)
|
|
# Launch the third collective with token.
|
|
with ops.device(dev0):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
ordering_token=token0.handle)
|
|
with ops.device(dev1):
|
|
collective_op(
|
|
in_tensor,
|
|
group_size,
|
|
group_key,
|
|
instance_key,
|
|
ordering_token=token1.handle)
|
|
|
|
graph = f.get_concrete_function().graph
|
|
for device in [dev0, dev1]:
|
|
# Try to find the third collective, which should have the first collective
|
|
# as a control input.
|
|
third = None
|
|
for op in graph.get_operations():
|
|
if (op.type.startswith('Collective') and op.device.endswith(device) and
|
|
op.control_inputs and
|
|
op.control_inputs[0].type.startswith('Collective')):
|
|
self.assertIsNone(third)
|
|
third = op
|
|
self.assertIsNotNone(third)
|
|
# Verify it's not the second collective by looking at the inputs.
|
|
self.assertTrue(any(v.dtype == dtypes.resource for v in third.inputs))
|
|
first = third.control_inputs[0]
|
|
self.assertEqual(third.device, first.device)
|
|
# Verify it's not the second collective by looking at the inputs.
|
|
self.assertTrue(any(v.dtype == dtypes.resource for v in first.inputs))
|
|
self.assertEmpty(first.control_inputs)
|
|
|
|
|
|
def _setup_context():
|
|
context._reset_context()
|
|
test_util.set_logical_devices_to_at_least('CPU', 4)
|
|
context.ensure_initialized()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
os.environ['NCCL_DEBUG'] = 'INFO'
|
|
v2_compat.enable_v2_behavior()
|
|
test.main()
|