Support using xla.compile directly in eager mode
PiperOrigin-RevId: 242888342
This commit is contained in:
parent
615b9daaca
commit
57b6183b0d
@ -60,6 +60,7 @@ py_library(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/distribute:summary_op_util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-i
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.distribute import summary_op_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -98,9 +99,12 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin
|
||||
RuntimeError: if called when eager execution is enabled.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('xla.experimental.compile is not supported when eager '
|
||||
'execution is enabled. Try use it inside tf.function.')
|
||||
# pylint: disable=protected-access
|
||||
@def_function.function
|
||||
def xla_compile_wrapper():
|
||||
return _compile_internal(computation, inputs)
|
||||
|
||||
return xla_compile_wrapper()
|
||||
|
||||
return _compile_internal(computation, inputs)
|
||||
|
||||
|
||||
|
@ -220,13 +220,10 @@ class XlaCompileTest(test.TestCase):
|
||||
def test_xla_compile_eager(self):
|
||||
"""Tests that xla.compile raises proper exception when used eagerly."""
|
||||
|
||||
def computation():
|
||||
return 1
|
||||
def computation(a, b):
|
||||
return a + b
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, 'xla.experimental.compile is not supported when eager '
|
||||
'execution is enabled. Try use it inside tf.function.'):
|
||||
xla.compile(computation)
|
||||
self.assertEqual(self.evaluate(xla.compile(computation, [1, 2])[0]), 3)
|
||||
|
||||
def test_xla_compile_in_function(self):
|
||||
"""Tests that xla.compile works in tf.function."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user