diff --git a/tensorflow/python/ops/numpy_ops/np_interop_test.py b/tensorflow/python/ops/numpy_ops/np_interop_test.py index f52d3dae78b..9580b787202 100644 --- a/tensorflow/python/ops/numpy_ops/np_interop_test.py +++ b/tensorflow/python/ops/numpy_ops/np_interop_test.py @@ -22,8 +22,13 @@ from __future__ import print_function import numpy as onp +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops @@ -36,6 +41,17 @@ from tensorflow.python.platform import test class InteropTest(test.TestCase): + def setUp(self): + super(InteropTest, self).setUp() + physical_devices = config.list_physical_devices('CPU') + configs = config.get_logical_device_configuration(physical_devices[0]) + if configs is None: + logical_devices = [ + context.LogicalDeviceConfiguration() for _ in range(3) + ] + config.set_logical_device_configuration(physical_devices[0], + logical_devices) + def testGradientTapeInterop(self): with backprop.GradientTape() as t: x = np_array_ops.asarray(3.0) @@ -139,6 +155,39 @@ class InteropTest(test.TestCase): # self.assertEqual(t.numpy(), [1., 2., 3.]) + def testDistStratInterop(self): + strategy = mirrored_strategy.MirroredStrategy( + devices=['CPU:0', 'CPU:1', 'CPU:2']) + + multiplier = np_array_ops.asarray(5.) + + with strategy.scope(): + @def_function.function + def run(): + ctx = distribution_strategy_context.get_replica_context() + val = np_array_ops.asarray(ctx.replica_id_in_sync_group) + return val * multiplier + + distributed_values = strategy.run(run) + reduced = strategy.reduce(reduce_util.ReduceOp.SUM, + distributed_values, axis=None) + + values = distributed_values.values + + # Note that this should match the number of virtual CPUs. + self.assertLen(values, 3) + self.assertIsInstance(values[0], np_arrays.ndarray) + self.assertIsInstance(values[1], np_arrays.ndarray) + self.assertIsInstance(values[2], np_arrays.ndarray) + self.assertAllClose(values[0], 0) + self.assertAllClose(values[1], 5) + self.assertAllClose(values[2], 10) + + # "strategy.reduce" doesn't rewrap in ndarray. + # self.assertIsInstance(reduced, np_arrays.ndarray) + self.assertAllClose(reduced, 15) + + if __name__ == '__main__': ops.enable_eager_execution() test.main()