Add XLAControlFlowContext when experimental_compile=True.

Without the context, we may generate ops which are not supported by XLA for control flow (e.g. Optionals)

PiperOrigin-RevId: 274923241
Change-Id: I8c703734daa82e774676464706f58f9ee7d5545b
This commit is contained in:
Yujing Zhang 2019-10-15 17:16:44 -07:00 committed by TensorFlower Gardener
parent 1745367d2f
commit f2733d68a2
3 changed files with 44 additions and 1 deletions

View File

@ -721,6 +721,8 @@ tf_xla_py_test(
"//tensorflow/compiler/tests:xla_test",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
"//tensorflow/python:framework_ops",
],
)

View File

@ -555,7 +555,18 @@ class Function(object):
return self._python_function(*args, **kwds)
tracing_count = self._get_tracing_count()
result = self._call(*args, **kwds)
if self._experimental_compile:
# V2 control flow relies on XLAControlFlowContext to generate a
# XLA-compatible function graph.
xla_context = control_flow_ops.XLAControlFlowContext()
try:
xla_context.Enter()
result = self._call(*args, **kwds)
finally:
xla_context.Exit()
else:
result = self._call(*args, **kwds)
if tracing_count == self._get_tracing_count():
self._call_counter.called_without_tracing()
return result

View File

@ -23,6 +23,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@ -79,6 +81,34 @@ class DefFunctionTest(test.TestCase):
# XLA support is not yet enabled for TF ROCm
run_and_check(xla_func)
def testControlFlow(self):
@def_function.function(experimental_compile=True)
def f(x):
assert control_flow_util.GraphOrParentsInXlaContext(
ops.get_default_graph())
x = ops.convert_to_tensor(x)
def body(i, a):
return i + 1, control_flow_ops.cond(i > 2, lambda: a + (x**2),
lambda: a + 3)
return control_flow_ops.while_loop(
lambda i, *_: i < 10,
body, (constant_op.constant(0), constant_op.constant(3.)),
maximum_iterations=10)[1]
@def_function.function(experimental_compile=True)
def g(x):
x = ops.convert_to_tensor(x)
with backprop.GradientTape() as tape:
tape.watch(x)
y = f(x)
return y, tape.gradient(y, x)
self.assertAllClose(40.0, f(2.0))
self.assertAllClose([40.0, 28.0], g(2.0))
if __name__ == '__main__':
ops.enable_eager_execution()