Add a unit test demonstrating collective ops with different groups of devices.
Collectives configured this way can be used to implement all-reduce for batch norm with a subset of all available devices. PiperOrigin-RevId: 316927203 Change-Id: Ic288d01134776efbe0e49a83fd8030f721890725
This commit is contained in:
parent
18db4c71cd
commit
5e7fc9584a
@ -581,6 +581,41 @@ class CollectiveOpTest(test.TestCase):
|
||||
results = sess.run(run_ops)
|
||||
self.assertEqual(results, [3., 3., 3., 3.])
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testMultipleGroups(self):
|
||||
context._reset_context()
|
||||
cpus = config.list_physical_devices('CPU')
|
||||
self.assertEqual(len(cpus), 1)
|
||||
config.set_logical_device_configuration(cpus[0], [
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration()
|
||||
])
|
||||
context.ensure_initialized()
|
||||
num_elements = 4
|
||||
|
||||
@def_function.function
|
||||
def run_all_reduce(group_size, group_key):
|
||||
instance_key = group_key
|
||||
input_value = [group_key for i in range(num_elements)]
|
||||
collectives = []
|
||||
for device_idx in range(group_size):
|
||||
with ops.device('/CPU:{}'.format(device_idx)):
|
||||
input_tensor = constant_op.constant(input_value)
|
||||
collectives.append(collective_ops.all_reduce(
|
||||
input_tensor, group_size, group_key, instance_key, merge_op='Add',
|
||||
final_op='Id'))
|
||||
return collectives
|
||||
|
||||
def run_and_assert(group_size, group_key):
|
||||
for reduced_tensor in run_all_reduce(group_size, group_key):
|
||||
self.assertAllEqual(
|
||||
[group_key * group_size for i in range(num_elements)],
|
||||
reduced_tensor.numpy())
|
||||
|
||||
run_and_assert(group_size=2, group_key=1)
|
||||
run_and_assert(group_size=3, group_key=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user