Export xla.compile() and friends in v2.

PiperOrigin-RevId: 240485793
This commit is contained in:
Yanan Cao 2019-03-26 20:44:05 -07:00 committed by TensorFlower Gardener
parent a26413ef0a
commit 4bc51cb6a8
11 changed files with 157 additions and 26 deletions

View File

@ -15,6 +15,7 @@ py_library(
":xla",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python/eager:context",
],
)
@ -56,6 +57,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/distribute:summary_op_util",
"//tensorflow/python/eager:context",
],
)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import contextlib
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.util.tf_export import tf_export
@ -37,7 +38,7 @@ class _XlaScope(object):
@contextlib.contextmanager
@tf_export(v1=["xla.experimental.jit_scope"])
@tf_export("xla.experimental.jit_scope")
def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
"""Enable or disable JIT compilation of operators within the scope.
@ -75,10 +76,15 @@ def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
as the name of the gradients. As a result, the gradients will be compiled
in a scope that is separate from both the forward computation, and from
other gradients.
Raises:
RuntimeError: if called when eager execution is enabled.
Yields:
The current scope, enabling or disabling compilation.
"""
if context.executing_eagerly():
raise RuntimeError("xla.experimental.jit_scope is not supported when eager "
"execution is enabled. Try use it inside tf.function.")
if callable(compile_ops):
def xla_compile(node_def):
return attr_value_pb2.AttrValue(b=compile_ops(node_def))

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.compiler.xla import jit
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
@ -45,8 +47,7 @@ def enable_jit_nonstateful(node_def):
raise ValueError("Unregistered op being created: %s" % node_def)
@test_util.run_v1_only("b/128927195")
class JITTest(test.TestCase):
class JITTest(test.TestCase, parameterized.TestCase):
def compute(self, use_jit, compute_fn):
random_seed.set_random_seed(1234)
@ -56,6 +57,16 @@ class JITTest(test.TestCase):
sess.run(variables.global_variables_initializer())
return (r, sess.run(r))
@test_util.run_v2_only
def testJITInEager(self):
with self.assertRaisesRegexp(
RuntimeError, "xla.experimental.jit_scope is not supported when eager "
"execution is enabled. Try use it inside tf.function."):
with jit.experimental_jit_scope(True):
constant_op.constant(1)
@test_util.build_as_function_and_v1_graph
def testJITCreateOpsLambda(self):
"""Test several ways of customizing the compilation attribute."""
def create_ops():
@ -89,6 +100,7 @@ class JITTest(test.TestCase):
self.assertAllClose(v_true_1, v_true_2)
self.assertAllClose(v_false_1, v_true_1)
@test_util.build_as_function_and_v1_graph
def testJITXlaScope(self):
with self.session(graph=ops.Graph()):
with jit.experimental_jit_scope(True):
@ -116,6 +128,7 @@ class JITTest(test.TestCase):
self.assertEqual(b"jit_scope_1", a5.op.get_attr("_XlaScope"))
self.assertEqual(b"jit_scope_2", a6.op.get_attr("_XlaScope"))
@test_util.build_as_function_and_v1_graph
def testJITVariableSeed(self):
"""Test that the stateful initializer is not marked for compilation.
@ -139,6 +152,7 @@ class JITTest(test.TestCase):
self.assertAllClose(v_true_1, v_true_2)
self.assertAllClose(v_false_1, v_true_1)
@test_util.build_as_function_and_v1_graph
def testDefunNoJitScope(self):
with self.session(graph=ops.Graph()):
@ -155,6 +169,7 @@ class JITTest(test.TestCase):
# No enclosing jit scope so function sets its own value for _XlaScope.
self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s)
@test_util.build_as_function_and_v1_graph
def testDefunInheritsJitScope(self):
with self.session(graph=ops.Graph()):
with jit.experimental_jit_scope(True):
@ -172,9 +187,9 @@ class JITTest(test.TestCase):
self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)
@test_util.run_v1_only("b/128927195")
class CompilationEnabledInGradientTest(test.TestCase):
class CompilationEnabledInGradientTest(test.TestCase, parameterized.TestCase):
@test_util.build_as_function_and_v1_graph
def testCompilationInGradient(self):
with self.cached_session():
x = constant_op.constant([[3.]])
@ -198,6 +213,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
# d/dx (x ** 4) = 4 * (x ** 3)
self.assertAllClose([[108]], x_grads.eval())
@test_util.build_as_function_and_v1_graph
def testCompilationGradientScopeNames(self):
with self.session(graph=ops.Graph()):
with jit.experimental_jit_scope():
@ -220,6 +236,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
self.assertEqual(b"jit_scope_0", grad_a1.op.get_attr("_XlaScope"))
self.assertEqual(b"jit_scope_1", grad_a2.op.get_attr("_XlaScope"))
@test_util.build_as_function_and_v1_graph
def testCompilationSeparateGradientScopeNames(self):
with self.session(graph=ops.Graph()):
with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
@ -244,6 +261,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
self.assertEqual(b"jit_scope_1_grad_GB",
grad_a2.op.get_attr("_XlaScope"))
@test_util.build_as_function_and_v1_graph
def testPlaysNicelyWithDefun(self):
with self.session(graph=ops.Graph()) as sess:
with jit.experimental_jit_scope(True):
@ -269,6 +287,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
# Ensure the ops run: grad(x1*x1) = 2*x1
self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))
@test_util.build_as_function_and_v1_graph
def testPlaysNicelyWithDefunSeparateGradientScope(self):
with self.session(graph=ops.Graph()) as sess:
with jit.experimental_jit_scope(True):

View File

@ -26,6 +26,7 @@ from tensorflow.compiler.jit.ops import xla_ops
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.distribute import summary_op_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -60,7 +61,7 @@ _UNSUPPORTED_OPS = set([
])
@tf_export(v1=['xla.experimental.compile'])
@tf_export('xla.experimental.compile')
def compile(computation, inputs=None): # pylint: disable=redefined-builtin
"""Builds an operator that compiles and runs `computation` with XLA.
@ -92,7 +93,13 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin
3) Operation-only outputs: a NoOp would be returned which
control-depends on computation.
TODO(b/121383831): Investigate into removing these special cases.
Raises:
RuntimeError: if called when eager execution is enabled.
"""
if context.executing_eagerly():
raise RuntimeError('xla.experimental.compile is not supported when eager '
'execution is enabled. Try use it inside tf.function.')
# pylint: disable=protected-access
return _compile_internal(computation, inputs)

View File

@ -18,9 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.python import summary
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import def_function
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
@ -41,15 +44,15 @@ _EXPECTED_FEATURE = 2
_EXPECTED_LABEL = 3
@test_util.run_v1_only('b/128927195')
class XLACompileContextTest(test.TestCase):
class XLACompileContextTest(test.TestCase, parameterized.TestCase):
def create_test_xla_compile_context(self):
computation_name = ops.get_default_graph().unique_name('computation')
pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
return xla.XLACompileContext(name=computation_name, pivot=pivot)
def test_report_unsupported_operations(self):
@test_util.run_v1_only('Testing graph mode behavior only')
def test_report_unsupported_operations_graph_mode(self):
"""Tests that unsupported operations are detected."""
context = self.create_test_xla_compile_context()
context.Enter()
@ -75,16 +78,33 @@ class XLACompileContextTest(test.TestCase):
u'print_op'
])
def test_resource_variable(self):
@test_util.run_v1_only('Testing graph mode behavior only')
def test_resource_variable_graph_mode(self):
"""Tests that resource variable usage is allowed."""
a = variable_scope.get_variable(
name='variable_a', shape=(1), use_resource=True)
name='variable_a', use_resource=True, initializer=1)
context = self.create_test_xla_compile_context()
context.Enter()
state_ops.assign(a, a + 1)
a.assign(2)
context.Exit()
def test_resource_variable_in_function(self):
"""Tests that resource variable usage is allowed."""
a = variable_scope.get_variable(
name='variable_a', use_resource=True, initializer=1)
@def_function.function
def func():
context = self.create_test_xla_compile_context()
context.Enter()
o = a.assign(2)
context.Exit()
return o
self.assertEqual(self.evaluate(func()), 2)
@test_util.run_v1_only('Testing v1-only ref variable handling.')
def test_non_resource_variable_error(self):
"""Tests that non-resource variable usage is disallowed."""
a = variable_scope.get_variable(
@ -98,6 +118,7 @@ class XLACompileContextTest(test.TestCase):
state_ops.assign(a, a + 1)
context.Exit()
@test_util.build_as_function_and_v1_graph
def test_nested_xla_compile_error(self):
"""Tests that nested XLA computation leads to fatal error."""
context1 = self.create_test_xla_compile_context()
@ -111,6 +132,7 @@ class XLACompileContextTest(test.TestCase):
context2.Exit()
context1.Exit()
@test_util.build_as_function_and_v1_graph
def test_xla_compile_attr(self):
"""Tests that ops are tagged with XLA compile ID attribute."""
context = self.create_test_xla_compile_context()
@ -119,6 +141,7 @@ class XLACompileContextTest(test.TestCase):
context.Exit()
self.assertIn('_xla_compile_id', op.op.node_def.attr)
@test_util.build_as_function_and_v1_graph
def test_op_without_input(self):
"""Tests that ops without inputs depend on pivot correctly."""
context = self.create_test_xla_compile_context()
@ -128,7 +151,8 @@ class XLACompileContextTest(test.TestCase):
self.assertIn(context._pivot, op.op.control_inputs)
def test_external_control_edges(self):
@test_util.run_v1_only('Testing graph mode behavior only')
def test_external_control_edges_graph_mode(self):
"""Tests that external control edges are handled correctly."""
i = constant_op.constant(1)
op1 = constant_op.constant(1)
@ -150,6 +174,7 @@ class XLACompileContextTest(test.TestCase):
control_flow_ops.while_loop(
cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
@test_util.build_as_function_and_v1_graph
def test_op_output_marked_as_seen(self):
"""Tests that any op output is marked as seen in context."""
context = self.create_test_xla_compile_context()
@ -159,7 +184,8 @@ class XLACompileContextTest(test.TestCase):
self.assertIn(op.name, context._values)
def testOpIsInContext(self):
@test_util.build_as_function_and_v1_graph
def test_op_is_in_context(self):
"""Tests that XLACompileContext is recognized as an XLA context."""
op1 = constant_op.constant(1)
context = self.create_test_xla_compile_context()
@ -169,7 +195,8 @@ class XLACompileContextTest(test.TestCase):
self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
def testOpPreventFeeding(self):
@test_util.build_as_function_and_v1_graph
def test_op_prevent_feeding(self):
"""Tests that ops created inside XLACompileContext can not be fed."""
context = self.create_test_xla_compile_context()
context.Enter()
@ -177,7 +204,8 @@ class XLACompileContextTest(test.TestCase):
context.Exit()
self.assertFalse(op.graph.is_feedable(op.op))
def testOpPreventFetching(self):
@test_util.build_as_function_and_v1_graph
def test_op_prevent_fetching(self):
"""Tests that ops created inside XLACompileContext can not be fetched."""
context = self.create_test_xla_compile_context()
context.Enter()
@ -186,10 +214,55 @@ class XLACompileContextTest(test.TestCase):
self.assertFalse(op.graph.is_fetchable(op.op))
@test_util.run_v1_only('b/128927195')
class XlaCompileTest(test.TestCase):
@test_util.run_v2_only
def test_xla_compile_eager(self):
"""Tests that xla.compile raises proper exception when used eagerly."""
def computation():
return 1
with self.assertRaisesRegexp(
RuntimeError, 'xla.experimental.compile is not supported when eager '
'execution is enabled. Try use it inside tf.function.'):
xla.compile(computation)
def test_xla_compile_in_function(self):
"""Tests that xla.compile works in tf.function."""
@def_function.function
def func_wrapper(a):
def compute(a):
return a + 1
return xla.compile(compute, [a])
self.assertEqual(self.evaluate(func_wrapper(1))[0], 2)
def test_xla_compile_write_variable_in_function(self):
"""Tests that xla.compile works with variable in tf.function."""
a = variable_scope.get_variable(
name='variable_a', use_resource=True, initializer=1)
@def_function.function
def func_wrapper():
def compute():
a.assign_add(1)
a.assign_sub(2)
return a.read_value()
return xla.compile(compute)
self.evaluate(a.initializer)
self.assertEqual(self.evaluate(func_wrapper())[0], 0)
class CheckFunctionArgumentCountTest(test.TestCase):
def testSimple(self):
def test_simple(self):
"""Tests that arg checker works for functions with no varargs or defaults.
"""
@ -204,7 +277,7 @@ class CheckFunctionArgumentCountTest(test.TestCase):
self.assertEqual('exactly 3 arguments',
xla.check_function_argument_count(func, 2, queue))
def testDefaultArgs(self):
def test_default_args(self):
"""Tests that arg checker works for a function with no varargs."""
def func(x, y, z=17):
@ -224,7 +297,7 @@ class CheckFunctionArgumentCountTest(test.TestCase):
self.assertEqual('at most 3 arguments',
xla.check_function_argument_count(func, 4, queue))
def testVarArgs(self):
def test_var_args(self):
"""Tests that arg checker works for a function with varargs."""
def func(x, y, *z):
@ -242,7 +315,7 @@ class CheckFunctionArgumentCountTest(test.TestCase):
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
def testVarArgsAndDefaults(self):
def test_var_args_and_defaults(self):
"""Tests that arg checker works for a function with varargs and defaults."""
def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg

View File

@ -58,6 +58,8 @@ TENSORFLOW_API_INIT_FILES = [
"train/__init__.py",
"train/experimental/__init__.py",
"version/__init__.py",
"xla/__init__.py",
"xla/experimental/__init__.py",
# END GENERATED FILES
]

View File

@ -372,6 +372,10 @@ tf_module {
name: "version"
mtype: "<type \'module\'>"
}
member {
name: "xla"
mtype: "<type \'module\'>"
}
member {
name: "zeros_initializer"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,11 @@
path: "tensorflow.xla.experimental"
tf_module {
member_method {
name: "compile"
argspec: "args=[\'computation\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "jit_scope"
argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.xla"
tf_module {
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}

View File

@ -884,9 +884,9 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
"tf.test.compute_gradient":
"tf.compat.v1.test.compute_gradient",
"tf.xla.experimental.compile":
"tf.compat.v1.xla.experimental.compile",
"tf.xla.experimental.compile",
"tf.xla.experimental.jit_scope":
"tf.compat.v1.xla.experimental.jit_scope",
"tf.xla.experimental.jit_scope",
}
# pylint: enable=line-too-long

View File

@ -1549,12 +1549,12 @@ def _log_prob(self, x):
def testXlaExperimental(self):
text = "tf.xla.experimental.jit_scope(0)"
expected_text = "tf.compat.v1.xla.experimental.jit_scope(0)"
expected_text = "tf.xla.experimental.jit_scope(0)"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
text = "tf.xla.experimental.compile(0)"
expected_text = "tf.compat.v1.xla.experimental.compile(0)"
expected_text = "tf.xla.experimental.compile(0)"
_, _, _, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)