Parallel device: add a test for collectives inside a function
PiperOrigin-RevId: 319064461 Change-Id: Ib791be6f09194e2df2a64153c32ce32bb83096fd
This commit is contained in:
parent
ba3ad73a25
commit
54a01c3a6b
@ -23,12 +23,14 @@ import threading
|
||||
from tensorflow.python.distribute.parallel_device import parallel_device
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -42,7 +44,7 @@ from tensorflow.python.util import nest
|
||||
# communicate.
|
||||
# TODO(allenl): Switch to using a collective manager.
|
||||
_COUNTER_LOCK = threading.Lock()
|
||||
_COUNTER = 0
|
||||
_COUNTER = 100
|
||||
|
||||
|
||||
def _collective_reduce(inputs, operation, num_replicas):
|
||||
@ -171,6 +173,32 @@ class ParallelDeviceTests(_VirtualDeviceTestCase):
|
||||
context._reset_context()
|
||||
config.set_synchronous_execution(previous)
|
||||
|
||||
def test_collective_in_function(self):
|
||||
c = constant_op.constant([2])
|
||||
|
||||
@def_function.function
|
||||
def broadcast_send_recv(device_id):
|
||||
|
||||
@def_function.function
|
||||
def send():
|
||||
s0 = collective_ops.broadcast_send(
|
||||
c * 3, c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
|
||||
with ops.control_dependencies([s0.op]):
|
||||
return array_ops.identity(c)
|
||||
|
||||
@def_function.function
|
||||
def recv():
|
||||
r0 = collective_ops.broadcast_recv(
|
||||
c.shape, c.dtype, group_size=2, group_key=1, instance_key=1)
|
||||
return r0
|
||||
|
||||
return control_flow_ops.switch_case(
|
||||
device_id, branch_fns={0: send, 1: recv})
|
||||
|
||||
with ops.device(self.device.name):
|
||||
result = broadcast_send_recv(self.device.device_ids)
|
||||
self.assertAllClose([[2], [6]], self.device.unpack(result))
|
||||
|
||||
def test_checkpointing(self):
|
||||
self.skipTest(
|
||||
"Disable saving until SaveableObject's methods are traceable.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user