Support using xla.compile directly in eager mode

PiperOrigin-RevId: 242888342
This commit is contained in:
Yanan Cao 2019-04-10 09:59:06 -07:00 committed by TensorFlower Gardener
parent 615b9daaca
commit 57b6183b0d
3 changed files with 11 additions and 9 deletions

View File

@ -60,6 +60,7 @@ py_library(
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python/distribute:summary_op_util", "//tensorflow/python/distribute:summary_op_util",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
], ],
) )

View File

@ -27,6 +27,7 @@ from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-i
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.distribute import summary_op_util from tensorflow.python.distribute import summary_op_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -98,9 +99,12 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin
RuntimeError: if called when eager execution is enabled. RuntimeError: if called when eager execution is enabled.
""" """
if context.executing_eagerly(): if context.executing_eagerly():
raise RuntimeError('xla.experimental.compile is not supported when eager ' @def_function.function
'execution is enabled. Try use it inside tf.function.') def xla_compile_wrapper():
# pylint: disable=protected-access return _compile_internal(computation, inputs)
return xla_compile_wrapper()
return _compile_internal(computation, inputs) return _compile_internal(computation, inputs)

View File

@ -220,13 +220,10 @@ class XlaCompileTest(test.TestCase):
def test_xla_compile_eager(self): def test_xla_compile_eager(self):
"""Tests that xla.compile raises proper exception when used eagerly.""" """Tests that xla.compile raises proper exception when used eagerly."""
def computation(): def computation(a, b):
return 1 return a + b
with self.assertRaisesRegexp( self.assertEqual(self.evaluate(xla.compile(computation, [1, 2])[0]), 3)
RuntimeError, 'xla.experimental.compile is not supported when eager '
'execution is enabled. Try use it inside tf.function.'):
xla.compile(computation)
def test_xla_compile_in_function(self): def test_xla_compile_in_function(self):
"""Tests that xla.compile works in tf.function.""" """Tests that xla.compile works in tf.function."""