Add a unit test demonstrating collective ops with different groups of devices.

Collectives configured this way can be used to implement all-reduce for batch
norm with a subset of all available devices.

PiperOrigin-RevId: 316927203
Change-Id: Ic288d01134776efbe0e49a83fd8030f721890725
This commit is contained in:
Ayush Dubey 2020-06-17 11:21:20 -07:00 committed by TensorFlower Gardener
parent 18db4c71cd
commit 5e7fc9584a

View File

@ -581,6 +581,41 @@ class CollectiveOpTest(test.TestCase):
results = sess.run(run_ops)
self.assertEqual(results, [3., 3., 3., 3.])
@test_util.run_v2_only
def testMultipleGroups(self):
context._reset_context()
cpus = config.list_physical_devices('CPU')
self.assertEqual(len(cpus), 1)
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
context.ensure_initialized()
num_elements = 4
@def_function.function
def run_all_reduce(group_size, group_key):
instance_key = group_key
input_value = [group_key for i in range(num_elements)]
collectives = []
for device_idx in range(group_size):
with ops.device('/CPU:{}'.format(device_idx)):
input_tensor = constant_op.constant(input_value)
collectives.append(collective_ops.all_reduce(
input_tensor, group_size, group_key, instance_key, merge_op='Add',
final_op='Id'))
return collectives
def run_and_assert(group_size, group_key):
for reduced_tensor in run_all_reduce(group_size, group_key):
self.assertAllEqual(
[group_key * group_size for i in range(num_elements)],
reduced_tensor.numpy())
run_and_assert(group_size=2, group_key=1)
run_and_assert(group_size=3, group_key=2)
if __name__ == '__main__':
test.main()