Dist strat interop
PiperOrigin-RevId: 317410192 Change-Id: Ibfd1e3ac143422ccffa5f240075f5ae93a90ad07
This commit is contained in:
parent
715b02167d
commit
3427843d70
@ -22,8 +22,13 @@ from __future__ import print_function
|
||||
import numpy as onp
|
||||
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -36,6 +41,17 @@ from tensorflow.python.platform import test
|
||||
|
||||
class InteropTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(InteropTest, self).setUp()
|
||||
physical_devices = config.list_physical_devices('CPU')
|
||||
configs = config.get_logical_device_configuration(physical_devices[0])
|
||||
if configs is None:
|
||||
logical_devices = [
|
||||
context.LogicalDeviceConfiguration() for _ in range(3)
|
||||
]
|
||||
config.set_logical_device_configuration(physical_devices[0],
|
||||
logical_devices)
|
||||
|
||||
def testGradientTapeInterop(self):
|
||||
with backprop.GradientTape() as t:
|
||||
x = np_array_ops.asarray(3.0)
|
||||
@ -139,6 +155,39 @@ class InteropTest(test.TestCase):
|
||||
|
||||
# self.assertEqual(t.numpy(), [1., 2., 3.])
|
||||
|
||||
def testDistStratInterop(self):
|
||||
strategy = mirrored_strategy.MirroredStrategy(
|
||||
devices=['CPU:0', 'CPU:1', 'CPU:2'])
|
||||
|
||||
multiplier = np_array_ops.asarray(5.)
|
||||
|
||||
with strategy.scope():
|
||||
@def_function.function
|
||||
def run():
|
||||
ctx = distribution_strategy_context.get_replica_context()
|
||||
val = np_array_ops.asarray(ctx.replica_id_in_sync_group)
|
||||
return val * multiplier
|
||||
|
||||
distributed_values = strategy.run(run)
|
||||
reduced = strategy.reduce(reduce_util.ReduceOp.SUM,
|
||||
distributed_values, axis=None)
|
||||
|
||||
values = distributed_values.values
|
||||
|
||||
# Note that this should match the number of virtual CPUs.
|
||||
self.assertLen(values, 3)
|
||||
self.assertIsInstance(values[0], np_arrays.ndarray)
|
||||
self.assertIsInstance(values[1], np_arrays.ndarray)
|
||||
self.assertIsInstance(values[2], np_arrays.ndarray)
|
||||
self.assertAllClose(values[0], 0)
|
||||
self.assertAllClose(values[1], 5)
|
||||
self.assertAllClose(values[2], 10)
|
||||
|
||||
# "strategy.reduce" doesn't rewrap in ndarray.
|
||||
# self.assertIsInstance(reduced, np_arrays.ndarray)
|
||||
self.assertAllClose(reduced, 15)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user