Support using xla.compile directly in eager mode

PiperOrigin-RevId: 242888342
This commit is contained in:
Yanan Cao 2019-04-10 09:59:06 -07:00 committed by TensorFlower Gardener
parent 615b9daaca
commit 57b6183b0d
3 changed files with 11 additions and 9 deletions

View File

@ -60,6 +60,7 @@ py_library(
"//tensorflow/python:variable_scope",
"//tensorflow/python/distribute:summary_op_util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
],
)

View File

@ -27,6 +27,7 @@ from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-i
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.distribute import summary_op_util
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -98,9 +99,12 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin
RuntimeError: if called when eager execution is enabled.
"""
if context.executing_eagerly():
raise RuntimeError('xla.experimental.compile is not supported when eager '
'execution is enabled. Try use it inside tf.function.')
# pylint: disable=protected-access
@def_function.function
def xla_compile_wrapper():
return _compile_internal(computation, inputs)
return xla_compile_wrapper()
return _compile_internal(computation, inputs)

View File

@ -220,13 +220,10 @@ class XlaCompileTest(test.TestCase):
def test_xla_compile_eager(self):
"""Tests that xla.compile raises proper exception when used eagerly."""
def computation():
return 1
def computation(a, b):
return a + b
with self.assertRaisesRegexp(
RuntimeError, 'xla.experimental.compile is not supported when eager '
'execution is enabled. Try use it inside tf.function.'):
xla.compile(computation)
self.assertEqual(self.evaluate(xla.compile(computation, [1, 2])[0]), 3)
def test_xla_compile_in_function(self):
"""Tests that xla.compile works in tf.function."""