Export xla.compile() and friends in v2.
PiperOrigin-RevId: 240485793
This commit is contained in:
parent
a26413ef0a
commit
4bc51cb6a8
tensorflow
python
tools
@ -15,6 +15,7 @@ py_library(
|
||||
":xla",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
@ -56,6 +57,7 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/distribute:summary_op_util",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import contextlib
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -37,7 +38,7 @@ class _XlaScope(object):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@tf_export(v1=["xla.experimental.jit_scope"])
|
||||
@tf_export("xla.experimental.jit_scope")
|
||||
def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
|
||||
"""Enable or disable JIT compilation of operators within the scope.
|
||||
|
||||
@ -75,10 +76,15 @@ def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False):
|
||||
as the name of the gradients. As a result, the gradients will be compiled
|
||||
in a scope that is separate from both the forward computation, and from
|
||||
other gradients.
|
||||
Raises:
|
||||
RuntimeError: if called when eager execution is enabled.
|
||||
Yields:
|
||||
The current scope, enabling or disabling compilation.
|
||||
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("xla.experimental.jit_scope is not supported when eager "
|
||||
"execution is enabled. Try use it inside tf.function.")
|
||||
|
||||
if callable(compile_ops):
|
||||
def xla_compile(node_def):
|
||||
return attr_value_pb2.AttrValue(b=compile_ops(node_def))
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import function
|
||||
@ -45,8 +47,7 @@ def enable_jit_nonstateful(node_def):
|
||||
raise ValueError("Unregistered op being created: %s" % node_def)
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/128927195")
|
||||
class JITTest(test.TestCase):
|
||||
class JITTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def compute(self, use_jit, compute_fn):
|
||||
random_seed.set_random_seed(1234)
|
||||
@ -56,6 +57,16 @@ class JITTest(test.TestCase):
|
||||
sess.run(variables.global_variables_initializer())
|
||||
return (r, sess.run(r))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testJITInEager(self):
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, "xla.experimental.jit_scope is not supported when eager "
|
||||
"execution is enabled. Try use it inside tf.function."):
|
||||
with jit.experimental_jit_scope(True):
|
||||
constant_op.constant(1)
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testJITCreateOpsLambda(self):
|
||||
"""Test several ways of customizing the compilation attribute."""
|
||||
def create_ops():
|
||||
@ -89,6 +100,7 @@ class JITTest(test.TestCase):
|
||||
self.assertAllClose(v_true_1, v_true_2)
|
||||
self.assertAllClose(v_false_1, v_true_1)
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testJITXlaScope(self):
|
||||
with self.session(graph=ops.Graph()):
|
||||
with jit.experimental_jit_scope(True):
|
||||
@ -116,6 +128,7 @@ class JITTest(test.TestCase):
|
||||
self.assertEqual(b"jit_scope_1", a5.op.get_attr("_XlaScope"))
|
||||
self.assertEqual(b"jit_scope_2", a6.op.get_attr("_XlaScope"))
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testJITVariableSeed(self):
|
||||
"""Test that the stateful initializer is not marked for compilation.
|
||||
|
||||
@ -139,6 +152,7 @@ class JITTest(test.TestCase):
|
||||
self.assertAllClose(v_true_1, v_true_2)
|
||||
self.assertAllClose(v_false_1, v_true_1)
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testDefunNoJitScope(self):
|
||||
with self.session(graph=ops.Graph()):
|
||||
|
||||
@ -155,6 +169,7 @@ class JITTest(test.TestCase):
|
||||
# No enclosing jit scope so function sets its own value for _XlaScope.
|
||||
self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s)
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testDefunInheritsJitScope(self):
|
||||
with self.session(graph=ops.Graph()):
|
||||
with jit.experimental_jit_scope(True):
|
||||
@ -172,9 +187,9 @@ class JITTest(test.TestCase):
|
||||
self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)
|
||||
|
||||
|
||||
@test_util.run_v1_only("b/128927195")
|
||||
class CompilationEnabledInGradientTest(test.TestCase):
|
||||
class CompilationEnabledInGradientTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testCompilationInGradient(self):
|
||||
with self.cached_session():
|
||||
x = constant_op.constant([[3.]])
|
||||
@ -198,6 +213,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
|
||||
# d/dx (x ** 4) = 4 * (x ** 3)
|
||||
self.assertAllClose([[108]], x_grads.eval())
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testCompilationGradientScopeNames(self):
|
||||
with self.session(graph=ops.Graph()):
|
||||
with jit.experimental_jit_scope():
|
||||
@ -220,6 +236,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
|
||||
self.assertEqual(b"jit_scope_0", grad_a1.op.get_attr("_XlaScope"))
|
||||
self.assertEqual(b"jit_scope_1", grad_a2.op.get_attr("_XlaScope"))
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testCompilationSeparateGradientScopeNames(self):
|
||||
with self.session(graph=ops.Graph()):
|
||||
with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
|
||||
@ -244,6 +261,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
|
||||
self.assertEqual(b"jit_scope_1_grad_GB",
|
||||
grad_a2.op.get_attr("_XlaScope"))
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testPlaysNicelyWithDefun(self):
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
with jit.experimental_jit_scope(True):
|
||||
@ -269,6 +287,7 @@ class CompilationEnabledInGradientTest(test.TestCase):
|
||||
# Ensure the ops run: grad(x1*x1) = 2*x1
|
||||
self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def testPlaysNicelyWithDefunSeparateGradientScope(self):
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
with jit.experimental_jit_scope(True):
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.compiler.jit.ops import xla_ops
|
||||
from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.distribute import summary_op_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -60,7 +61,7 @@ _UNSUPPORTED_OPS = set([
|
||||
])
|
||||
|
||||
|
||||
@tf_export(v1=['xla.experimental.compile'])
|
||||
@tf_export('xla.experimental.compile')
|
||||
def compile(computation, inputs=None): # pylint: disable=redefined-builtin
|
||||
"""Builds an operator that compiles and runs `computation` with XLA.
|
||||
|
||||
@ -92,7 +93,13 @@ def compile(computation, inputs=None): # pylint: disable=redefined-builtin
|
||||
3) Operation-only outputs: a NoOp would be returned which
|
||||
control-depends on computation.
|
||||
TODO(b/121383831): Investigate into removing these special cases.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if called when eager execution is enabled.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('xla.experimental.compile is not supported when eager '
|
||||
'execution is enabled. Try use it inside tf.function.')
|
||||
# pylint: disable=protected-access
|
||||
return _compile_internal(computation, inputs)
|
||||
|
||||
|
@ -18,9 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_feed
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
@ -41,15 +44,15 @@ _EXPECTED_FEATURE = 2
|
||||
_EXPECTED_LABEL = 3
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/128927195')
|
||||
class XLACompileContextTest(test.TestCase):
|
||||
class XLACompileContextTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def create_test_xla_compile_context(self):
|
||||
computation_name = ops.get_default_graph().unique_name('computation')
|
||||
pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
|
||||
return xla.XLACompileContext(name=computation_name, pivot=pivot)
|
||||
|
||||
def test_report_unsupported_operations(self):
|
||||
@test_util.run_v1_only('Testing graph mode behavior only')
|
||||
def test_report_unsupported_operations_graph_mode(self):
|
||||
"""Tests that unsupported operations are detected."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
@ -75,16 +78,33 @@ class XLACompileContextTest(test.TestCase):
|
||||
u'print_op'
|
||||
])
|
||||
|
||||
def test_resource_variable(self):
|
||||
@test_util.run_v1_only('Testing graph mode behavior only')
|
||||
def test_resource_variable_graph_mode(self):
|
||||
"""Tests that resource variable usage is allowed."""
|
||||
a = variable_scope.get_variable(
|
||||
name='variable_a', shape=(1), use_resource=True)
|
||||
name='variable_a', use_resource=True, initializer=1)
|
||||
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
state_ops.assign(a, a + 1)
|
||||
a.assign(2)
|
||||
context.Exit()
|
||||
|
||||
def test_resource_variable_in_function(self):
|
||||
"""Tests that resource variable usage is allowed."""
|
||||
a = variable_scope.get_variable(
|
||||
name='variable_a', use_resource=True, initializer=1)
|
||||
|
||||
@def_function.function
|
||||
def func():
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
o = a.assign(2)
|
||||
context.Exit()
|
||||
return o
|
||||
|
||||
self.assertEqual(self.evaluate(func()), 2)
|
||||
|
||||
@test_util.run_v1_only('Testing v1-only ref variable handling.')
|
||||
def test_non_resource_variable_error(self):
|
||||
"""Tests that non-resource variable usage is disallowed."""
|
||||
a = variable_scope.get_variable(
|
||||
@ -98,6 +118,7 @@ class XLACompileContextTest(test.TestCase):
|
||||
state_ops.assign(a, a + 1)
|
||||
context.Exit()
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def test_nested_xla_compile_error(self):
|
||||
"""Tests that nested XLA computation leads to fatal error."""
|
||||
context1 = self.create_test_xla_compile_context()
|
||||
@ -111,6 +132,7 @@ class XLACompileContextTest(test.TestCase):
|
||||
context2.Exit()
|
||||
context1.Exit()
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def test_xla_compile_attr(self):
|
||||
"""Tests that ops are tagged with XLA compile ID attribute."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
@ -119,6 +141,7 @@ class XLACompileContextTest(test.TestCase):
|
||||
context.Exit()
|
||||
self.assertIn('_xla_compile_id', op.op.node_def.attr)
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def test_op_without_input(self):
|
||||
"""Tests that ops without inputs depend on pivot correctly."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
@ -128,7 +151,8 @@ class XLACompileContextTest(test.TestCase):
|
||||
|
||||
self.assertIn(context._pivot, op.op.control_inputs)
|
||||
|
||||
def test_external_control_edges(self):
|
||||
@test_util.run_v1_only('Testing graph mode behavior only')
|
||||
def test_external_control_edges_graph_mode(self):
|
||||
"""Tests that external control edges are handled correctly."""
|
||||
i = constant_op.constant(1)
|
||||
op1 = constant_op.constant(1)
|
||||
@ -150,6 +174,7 @@ class XLACompileContextTest(test.TestCase):
|
||||
control_flow_ops.while_loop(
|
||||
cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
|
||||
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def test_op_output_marked_as_seen(self):
|
||||
"""Tests that any op output is marked as seen in context."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
@ -159,7 +184,8 @@ class XLACompileContextTest(test.TestCase):
|
||||
|
||||
self.assertIn(op.name, context._values)
|
||||
|
||||
def testOpIsInContext(self):
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def test_op_is_in_context(self):
|
||||
"""Tests that XLACompileContext is recognized as an XLA context."""
|
||||
op1 = constant_op.constant(1)
|
||||
context = self.create_test_xla_compile_context()
|
||||
@ -169,7 +195,8 @@ class XLACompileContextTest(test.TestCase):
|
||||
self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
|
||||
self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
|
||||
|
||||
def testOpPreventFeeding(self):
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def test_op_prevent_feeding(self):
|
||||
"""Tests that ops created inside XLACompileContext can not be fed."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
@ -177,7 +204,8 @@ class XLACompileContextTest(test.TestCase):
|
||||
context.Exit()
|
||||
self.assertFalse(op.graph.is_feedable(op.op))
|
||||
|
||||
def testOpPreventFetching(self):
|
||||
@test_util.build_as_function_and_v1_graph
|
||||
def test_op_prevent_fetching(self):
|
||||
"""Tests that ops created inside XLACompileContext can not be fetched."""
|
||||
context = self.create_test_xla_compile_context()
|
||||
context.Enter()
|
||||
@ -186,10 +214,55 @@ class XLACompileContextTest(test.TestCase):
|
||||
self.assertFalse(op.graph.is_fetchable(op.op))
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/128927195')
|
||||
class XlaCompileTest(test.TestCase):
|
||||
|
||||
@test_util.run_v2_only
|
||||
def test_xla_compile_eager(self):
|
||||
"""Tests that xla.compile raises proper exception when used eagerly."""
|
||||
|
||||
def computation():
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
RuntimeError, 'xla.experimental.compile is not supported when eager '
|
||||
'execution is enabled. Try use it inside tf.function.'):
|
||||
xla.compile(computation)
|
||||
|
||||
def test_xla_compile_in_function(self):
|
||||
"""Tests that xla.compile works in tf.function."""
|
||||
|
||||
@def_function.function
|
||||
def func_wrapper(a):
|
||||
|
||||
def compute(a):
|
||||
return a + 1
|
||||
|
||||
return xla.compile(compute, [a])
|
||||
|
||||
self.assertEqual(self.evaluate(func_wrapper(1))[0], 2)
|
||||
|
||||
def test_xla_compile_write_variable_in_function(self):
|
||||
"""Tests that xla.compile works with variable in tf.function."""
|
||||
a = variable_scope.get_variable(
|
||||
name='variable_a', use_resource=True, initializer=1)
|
||||
|
||||
@def_function.function
|
||||
def func_wrapper():
|
||||
|
||||
def compute():
|
||||
a.assign_add(1)
|
||||
a.assign_sub(2)
|
||||
return a.read_value()
|
||||
|
||||
return xla.compile(compute)
|
||||
|
||||
self.evaluate(a.initializer)
|
||||
self.assertEqual(self.evaluate(func_wrapper())[0], 0)
|
||||
|
||||
|
||||
class CheckFunctionArgumentCountTest(test.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
def test_simple(self):
|
||||
"""Tests that arg checker works for functions with no varargs or defaults.
|
||||
"""
|
||||
|
||||
@ -204,7 +277,7 @@ class CheckFunctionArgumentCountTest(test.TestCase):
|
||||
self.assertEqual('exactly 3 arguments',
|
||||
xla.check_function_argument_count(func, 2, queue))
|
||||
|
||||
def testDefaultArgs(self):
|
||||
def test_default_args(self):
|
||||
"""Tests that arg checker works for a function with no varargs."""
|
||||
|
||||
def func(x, y, z=17):
|
||||
@ -224,7 +297,7 @@ class CheckFunctionArgumentCountTest(test.TestCase):
|
||||
self.assertEqual('at most 3 arguments',
|
||||
xla.check_function_argument_count(func, 4, queue))
|
||||
|
||||
def testVarArgs(self):
|
||||
def test_var_args(self):
|
||||
"""Tests that arg checker works for a function with varargs."""
|
||||
|
||||
def func(x, y, *z):
|
||||
@ -242,7 +315,7 @@ class CheckFunctionArgumentCountTest(test.TestCase):
|
||||
self.assertEqual('at least 2 arguments',
|
||||
xla.check_function_argument_count(func, 0, queue))
|
||||
|
||||
def testVarArgsAndDefaults(self):
|
||||
def test_var_args_and_defaults(self):
|
||||
"""Tests that arg checker works for a function with varargs and defaults."""
|
||||
|
||||
def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg
|
||||
|
@ -58,6 +58,8 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"train/__init__.py",
|
||||
"train/experimental/__init__.py",
|
||||
"version/__init__.py",
|
||||
"xla/__init__.py",
|
||||
"xla/experimental/__init__.py",
|
||||
# END GENERATED FILES
|
||||
]
|
||||
|
||||
|
@ -372,6 +372,10 @@ tf_module {
|
||||
name: "version"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "xla"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "zeros_initializer"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,11 @@
|
||||
path: "tensorflow.xla.experimental"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "compile"
|
||||
argspec: "args=[\'computation\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "jit_scope"
|
||||
argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
|
||||
}
|
||||
}
|
7
tensorflow/tools/api/golden/v2/tensorflow.xla.pbtxt
Normal file
7
tensorflow/tools/api/golden/v2/tensorflow.xla.pbtxt
Normal file
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.xla"
|
||||
tf_module {
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
}
|
@ -884,9 +884,9 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
|
||||
"tf.test.compute_gradient":
|
||||
"tf.compat.v1.test.compute_gradient",
|
||||
"tf.xla.experimental.compile":
|
||||
"tf.compat.v1.xla.experimental.compile",
|
||||
"tf.xla.experimental.compile",
|
||||
"tf.xla.experimental.jit_scope":
|
||||
"tf.compat.v1.xla.experimental.jit_scope",
|
||||
"tf.xla.experimental.jit_scope",
|
||||
}
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
@ -1549,12 +1549,12 @@ def _log_prob(self, x):
|
||||
|
||||
def testXlaExperimental(self):
|
||||
text = "tf.xla.experimental.jit_scope(0)"
|
||||
expected_text = "tf.compat.v1.xla.experimental.jit_scope(0)"
|
||||
expected_text = "tf.xla.experimental.jit_scope(0)"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
text = "tf.xla.experimental.compile(0)"
|
||||
expected_text = "tf.compat.v1.xla.experimental.compile(0)"
|
||||
expected_text = "tf.xla.experimental.compile(0)"
|
||||
_, _, _, new_text = self._upgrade(text)
|
||||
self.assertEqual(new_text, expected_text)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user