Allow functions with raw Send & Recv to be inlined

We allow functions that have Send & Recv to be inlined. This causes them
to use the same rendezvous as the outer function, allowing functions to
communicate with each other.

Further, We all Recv to be run asynchronously but ensure Send is added
added as a control output so that is doesn't get pruned.

PiperOrigin-RevId: 309314078
Change-Id: I485bd3486e1ff40ed1aee2b157780ccdbac681df
This commit is contained in:
Gaurav Jain 2020-04-30 15:42:57 -07:00 committed by TensorFlower Gardener
parent 306f371dbc
commit 8d1992fd27
4 changed files with 18 additions and 3 deletions

View File

@ -828,7 +828,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
// Op types that should not run in program order, e.g. because they need
// to run asynchronously to avoid deadlock.
"CollectiveGather", "CollectiveReduce", "CollectiveBcastSend",
"CollectiveBcastRecv", "NcclAllReduce",
"CollectiveBcastRecv", "NcclAllReduce", "Send", "Recv",
// Legacy random ops.
// See details in tensorflow/python/framework/auto_control_deps.py.

View File

@ -1686,6 +1686,7 @@ tf_py_test(
deps = [
":auto_control_deps",
":client_testlib",
":sendrecv_ops_gen",
],
)

View File

@ -45,6 +45,9 @@ ASYNC_STATEFUL_OPS = [
"CollectiveBcastSend",
"CollectiveBcastRecv",
"NcclAllReduce",
# We do not add "Send" here since we want it to be added as a control output
# in order to avoid being pruned.
"Recv",
]
LEGACY_RANDOM_OPS = [

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_sendrecv_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -165,6 +166,16 @@ class AutomaticControlDependenciesTest(test.TestCase):
# Last write must be in `ops_which_must_run`.
self.assertIn(assign_op4, c.ops_which_must_run)
def testSendInOpsWithMustRun(self):
with context.graph_mode(), self.cached_session():
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
with acd.AutomaticControlDependencies() as c:
send_op = gen_sendrecv_ops.send(v, "x", "/", 0, "/")
# Send must be in `ops_which_must_run`.
self.assertIn(send_op, c.ops_which_must_run)
def _testVariableReadInFunctionalOp(self, build_functional_op, op_type):
v = resource_variable_ops.ResourceVariable(1.0)
self.evaluate(variables.global_variables_initializer())
@ -751,7 +762,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
grad = backprop.implicit_grad(lambda v: v**2)(v)
with self.assertRaisesRegexp(TypeError,
'.*must return zero or more Tensors.*'):
".*must return zero or more Tensors.*"):
# TODO(akshayka): We might want to allow defun-ing Python functions
# that return operations (and just execute the op instead of running it).
optimizer.apply_gradients(grad)
@ -803,6 +814,6 @@ class AutomaticControlDependenciesTest(test.TestCase):
self.assertEqual(self.evaluate(outer()), 2.0)
if __name__ == '__main__':
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()