Dist strat interop

PiperOrigin-RevId: 317410192
Change-Id: Ibfd1e3ac143422ccffa5f240075f5ae93a90ad07
This commit is contained in:
Akshay Modi 2020-06-19 17:33:29 -07:00 committed by TensorFlower Gardener
parent 715b02167d
commit 3427843d70

View File

@ -22,8 +22,13 @@ from __future__ import print_function
import numpy as onp
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
@ -36,6 +41,17 @@ from tensorflow.python.platform import test
class InteropTest(test.TestCase):
def setUp(self):
super(InteropTest, self).setUp()
physical_devices = config.list_physical_devices('CPU')
configs = config.get_logical_device_configuration(physical_devices[0])
if configs is None:
logical_devices = [
context.LogicalDeviceConfiguration() for _ in range(3)
]
config.set_logical_device_configuration(physical_devices[0],
logical_devices)
def testGradientTapeInterop(self):
with backprop.GradientTape() as t:
x = np_array_ops.asarray(3.0)
@ -139,6 +155,39 @@ class InteropTest(test.TestCase):
# self.assertEqual(t.numpy(), [1., 2., 3.])
def testDistStratInterop(self):
strategy = mirrored_strategy.MirroredStrategy(
devices=['CPU:0', 'CPU:1', 'CPU:2'])
multiplier = np_array_ops.asarray(5.)
with strategy.scope():
@def_function.function
def run():
ctx = distribution_strategy_context.get_replica_context()
val = np_array_ops.asarray(ctx.replica_id_in_sync_group)
return val * multiplier
distributed_values = strategy.run(run)
reduced = strategy.reduce(reduce_util.ReduceOp.SUM,
distributed_values, axis=None)
values = distributed_values.values
# Note that this should match the number of virtual CPUs.
self.assertLen(values, 3)
self.assertIsInstance(values[0], np_arrays.ndarray)
self.assertIsInstance(values[1], np_arrays.ndarray)
self.assertIsInstance(values[2], np_arrays.ndarray)
self.assertAllClose(values[0], 0)
self.assertAllClose(values[1], 5)
self.assertAllClose(values[2], 10)
# "strategy.reduce" doesn't rewrap in ndarray.
# self.assertIsInstance(reduced, np_arrays.ndarray)
self.assertAllClose(reduced, 15)
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()