From 57b6183b0dd4337ae71b6ba171d5808efb393ad3 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Wed, 10 Apr 2019 09:59:06 -0700 Subject: [PATCH] Support using xla.compile directly in eager mode PiperOrigin-RevId: 242888342 --- tensorflow/python/compiler/xla/BUILD | 1 + tensorflow/python/compiler/xla/xla.py | 10 +++++++--- tensorflow/python/compiler/xla/xla_test.py | 9 +++------ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/compiler/xla/BUILD b/tensorflow/python/compiler/xla/BUILD index f9c8705d434..87ea46fa261 100644 --- a/tensorflow/python/compiler/xla/BUILD +++ b/tensorflow/python/compiler/xla/BUILD @@ -60,6 +60,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/distribute:summary_op_util", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", ], ) diff --git a/tensorflow/python/compiler/xla/xla.py b/tensorflow/python/compiler/xla/xla.py index a1eac2e6615..f25f08aa755 100644 --- a/tensorflow/python/compiler/xla/xla.py +++ b/tensorflow/python/compiler/xla/xla.py @@ -27,6 +27,7 @@ from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-i from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.distribute import summary_op_util from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -98,9 +99,12 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin RuntimeError: if called when eager execution is enabled. """ if context.executing_eagerly(): - raise RuntimeError('xla.experimental.compile is not supported when eager ' - 'execution is enabled. Try use it inside tf.function.') - # pylint: disable=protected-access + @def_function.function + def xla_compile_wrapper(): + return _compile_internal(computation, inputs) + + return xla_compile_wrapper() + return _compile_internal(computation, inputs) diff --git a/tensorflow/python/compiler/xla/xla_test.py b/tensorflow/python/compiler/xla/xla_test.py index b654676eb33..6dc0789ba4f 100644 --- a/tensorflow/python/compiler/xla/xla_test.py +++ b/tensorflow/python/compiler/xla/xla_test.py @@ -220,13 +220,10 @@ class XlaCompileTest(test.TestCase): def test_xla_compile_eager(self): """Tests that xla.compile raises proper exception when used eagerly.""" - def computation(): - return 1 + def computation(a, b): + return a + b - with self.assertRaisesRegexp( - RuntimeError, 'xla.experimental.compile is not supported when eager ' - 'execution is enabled. Try use it inside tf.function.'): - xla.compile(computation) + self.assertEqual(self.evaluate(xla.compile(computation, [1, 2])[0]), 3) def test_xla_compile_in_function(self): """Tests that xla.compile works in tf.function."""