Add XLAControlFlowContext when experimental_compile=True.
Without the context, we may generate ops which are not supported by XLA for control flow (e.g. Optionals) PiperOrigin-RevId: 274923241 Change-Id: I8c703734daa82e774676464706f58f9ee7d5545b
This commit is contained in:
parent
1745367d2f
commit
f2733d68a2
@ -721,6 +721,8 @@ tf_xla_py_test(
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:control_flow_util",
|
||||
"//tensorflow/python:framework_ops",
|
||||
],
|
||||
)
|
||||
|
@ -555,7 +555,18 @@ class Function(object):
|
||||
return self._python_function(*args, **kwds)
|
||||
|
||||
tracing_count = self._get_tracing_count()
|
||||
result = self._call(*args, **kwds)
|
||||
if self._experimental_compile:
|
||||
# V2 control flow relies on XLAControlFlowContext to generate a
|
||||
# XLA-compatible function graph.
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
try:
|
||||
xla_context.Enter()
|
||||
result = self._call(*args, **kwds)
|
||||
finally:
|
||||
xla_context.Exit()
|
||||
else:
|
||||
result = self._call(*args, **kwds)
|
||||
|
||||
if tracing_count == self._get_tracing_count():
|
||||
self._call_counter.called_without_tracing()
|
||||
return result
|
||||
|
@ -23,6 +23,8 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -79,6 +81,34 @@ class DefFunctionTest(test.TestCase):
|
||||
# XLA support is not yet enabled for TF ROCm
|
||||
run_and_check(xla_func)
|
||||
|
||||
def testControlFlow(self):
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f(x):
|
||||
assert control_flow_util.GraphOrParentsInXlaContext(
|
||||
ops.get_default_graph())
|
||||
x = ops.convert_to_tensor(x)
|
||||
|
||||
def body(i, a):
|
||||
return i + 1, control_flow_ops.cond(i > 2, lambda: a + (x**2),
|
||||
lambda: a + 3)
|
||||
|
||||
return control_flow_ops.while_loop(
|
||||
lambda i, *_: i < 10,
|
||||
body, (constant_op.constant(0), constant_op.constant(3.)),
|
||||
maximum_iterations=10)[1]
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def g(x):
|
||||
x = ops.convert_to_tensor(x)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = f(x)
|
||||
return y, tape.gradient(y, x)
|
||||
|
||||
self.assertAllClose(40.0, f(2.0))
|
||||
self.assertAllClose([40.0, 28.0], g(2.0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user