diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 9c378bd3b2b..4a76bd79513 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -721,6 +721,8 @@ tf_xla_py_test( "//tensorflow/compiler/tests:xla_test", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:control_flow_util", "//tensorflow/python:framework_ops", ], ) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 261bf3f5862..e7b4a6f84b2 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -555,7 +555,18 @@ class Function(object): return self._python_function(*args, **kwds) tracing_count = self._get_tracing_count() - result = self._call(*args, **kwds) + if self._experimental_compile: + # V2 control flow relies on XLAControlFlowContext to generate a + # XLA-compatible function graph. + xla_context = control_flow_ops.XLAControlFlowContext() + try: + xla_context.Enter() + result = self._call(*args, **kwds) + finally: + xla_context.Exit() + else: + result = self._call(*args, **kwds) + if tracing_count == self._get_tracing_count(): self._call_counter.called_without_tracing() return result diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index c477f5e0532..5338725f88d 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -23,6 +23,8 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -79,6 +81,34 @@ class DefFunctionTest(test.TestCase): # XLA support is not yet enabled for TF ROCm run_and_check(xla_func) + def testControlFlow(self): + + @def_function.function(experimental_compile=True) + def f(x): + assert control_flow_util.GraphOrParentsInXlaContext( + ops.get_default_graph()) + x = ops.convert_to_tensor(x) + + def body(i, a): + return i + 1, control_flow_ops.cond(i > 2, lambda: a + (x**2), + lambda: a + 3) + + return control_flow_ops.while_loop( + lambda i, *_: i < 10, + body, (constant_op.constant(0), constant_op.constant(3.)), + maximum_iterations=10)[1] + + @def_function.function(experimental_compile=True) + def g(x): + x = ops.convert_to_tensor(x) + with backprop.GradientTape() as tape: + tape.watch(x) + y = f(x) + return y, tape.gradient(y, x) + + self.assertAllClose(40.0, f(2.0)) + self.assertAllClose([40.0, 28.0], g(2.0)) + if __name__ == '__main__': ops.enable_eager_execution()