Parallel device: add a test for collectives inside a function

PiperOrigin-RevId: 319064461
Change-Id: Ib791be6f09194e2df2a64153c32ce32bb83096fd
This commit is contained in:
Allen Lavoie 2020-06-30 11:38:01 -07:00 committed by TensorFlower Gardener
parent ba3ad73a25
commit 54a01c3a6b

View File

@ -23,12 +23,14 @@ import threading
from tensorflow.python.distribute.parallel_device import parallel_device
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -42,7 +44,7 @@ from tensorflow.python.util import nest
# communicate.
# TODO(allenl): Switch to using a collective manager.
_COUNTER_LOCK = threading.Lock()
_COUNTER = 0
_COUNTER = 100
def _collective_reduce(inputs, operation, num_replicas):
@ -171,6 +173,32 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
context._reset_context()
config.set_synchronous_execution(previous)
def test_collective_in_function(self):
c = constant_op.constant([2])
@def_function.function
def broadcast_send_recv(device_id):
@def_function.function
def send():
s0 = collective_ops.broadcast_send(
c * 3, c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
with ops.control_dependencies([s0.op]):
return array_ops.identity(c)
@def_function.function
def recv():
r0 = collective_ops.broadcast_recv(
c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
return r0
return control_flow_ops.switch_case(
device_id, branch_fns={0: send, 1: recv})
with ops.device(self.device.name):
result = broadcast_send_recv(self.device.device_ids)
self.assertAllClose([[2], [6]], self.device.unpack(result))
def test_checkpointing(self):
self.skipTest(
"Disable saving until SaveableObject's methods are traceable.")