Allow functions with raw Send & Recv to be inlined
We allow functions that have Send & Recv to be inlined. This causes them to use the same rendezvous as the outer function, allowing functions to communicate with each other. Further, We all Recv to be run asynchronously but ensure Send is added added as a control output so that is doesn't get pruned. PiperOrigin-RevId: 309314078 Change-Id: I485bd3486e1ff40ed1aee2b157780ccdbac681df
This commit is contained in:
parent
306f371dbc
commit
8d1992fd27
@ -828,7 +828,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
||||
// Op types that should not run in program order, e.g. because they need
|
||||
// to run asynchronously to avoid deadlock.
|
||||
"CollectiveGather", "CollectiveReduce", "CollectiveBcastSend",
|
||||
"CollectiveBcastRecv", "NcclAllReduce",
|
||||
"CollectiveBcastRecv", "NcclAllReduce", "Send", "Recv",
|
||||
|
||||
// Legacy random ops.
|
||||
// See details in tensorflow/python/framework/auto_control_deps.py.
|
||||
|
@ -1686,6 +1686,7 @@ tf_py_test(
|
||||
deps = [
|
||||
":auto_control_deps",
|
||||
":client_testlib",
|
||||
":sendrecv_ops_gen",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -45,6 +45,9 @@ ASYNC_STATEFUL_OPS = [
|
||||
"CollectiveBcastSend",
|
||||
"CollectiveBcastRecv",
|
||||
"NcclAllReduce",
|
||||
# We do not add "Send" here since we want it to be added as a control output
|
||||
# in order to avoid being pruned.
|
||||
"Recv",
|
||||
]
|
||||
|
||||
LEGACY_RANDOM_OPS = [
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import gen_sendrecv_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -165,6 +166,16 @@ class AutomaticControlDependenciesTest(test.TestCase):
|
||||
# Last write must be in `ops_which_must_run`.
|
||||
self.assertIn(assign_op4, c.ops_which_must_run)
|
||||
|
||||
def testSendInOpsWithMustRun(self):
|
||||
with context.graph_mode(), self.cached_session():
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
with acd.AutomaticControlDependencies() as c:
|
||||
send_op = gen_sendrecv_ops.send(v, "x", "/", 0, "/")
|
||||
|
||||
# Send must be in `ops_which_must_run`.
|
||||
self.assertIn(send_op, c.ops_which_must_run)
|
||||
|
||||
def _testVariableReadInFunctionalOp(self, build_functional_op, op_type):
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
@ -751,7 +762,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
|
||||
grad = backprop.implicit_grad(lambda v: v**2)(v)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'.*must return zero or more Tensors.*'):
|
||||
".*must return zero or more Tensors.*"):
|
||||
# TODO(akshayka): We might want to allow defun-ing Python functions
|
||||
# that return operations (and just execute the op instead of running it).
|
||||
optimizer.apply_gradients(grad)
|
||||
@ -803,6 +814,6 @@ class AutomaticControlDependenciesTest(test.TestCase):
|
||||
self.assertEqual(self.evaluate(outer()), 2.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user