Rationale:

- Just `compile` is a Python built-in, and is already overloaded by the Keras
   `compile` method.
 - `jit` is closer to other ML frameworks.
 - Technically speaking, `jit=True` is not self-descriptive (it does not
   specify what is it doing just-in-time)
 - Moreover, `tf.function` by itself without compilation could be described as
   a JIT.
 - Also `jit` by itself is less grep'able.
 - Thus `@tf.function(jit_compile=True)` is the preferred spelling.

PiperOrigin-RevId: 340503501
Change-Id: I7bffe60aca69be6640390f6e6c4af40c6c4dbfda
This commit is contained in:
George Karpenkov 2020-11-03 12:41:28 -08:00 committed by TensorFlower Gardener
parent 6561045a9b
commit 9df9d06e27
32 changed files with 178 additions and 166 deletions

View File

@ -36,6 +36,9 @@
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used * Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
to control how external state should be handled during dataset to control how external state should be handled during dataset
serialization or iterator checkpointing. serialization or iterator checkpointing.
* XLA compilation:
* `tf.function(experimental_compile=True)` has become a stable API,
renamed `tf.function(jit_compile=True)`.
* `tf.lite`: * `tf.lite`:
* NNAPI * NNAPI

View File

@ -398,7 +398,7 @@ static void ShowXlaDeviceDeprecationWarning(
absl::call_once(once, [] { absl::call_once(once, [] {
LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be " LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either " "removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile " "@tf.function(jit_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 " "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation."; "for auto-clustering best-effort compilation.";
}); });

View File

@ -1481,7 +1481,7 @@ tf_xla_py_test(
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
], ],
use_xla_device = False, # Uses tf.function(experimental_compile=True) use_xla_device = False, # Uses tf.function(jit_compile=True)
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/compiler/tf2xla/python:xla",

View File

@ -32,7 +32,7 @@ class CaseTest(xla_test.XLATestCase):
def testCaseBasic(self): def testCaseBasic(self):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def switch_case_test(branch_index): def switch_case_test(branch_index):
def f1(): def f1():
@ -58,7 +58,7 @@ class CaseTest(xla_test.XLATestCase):
def testBranchIsPruned(self): def testBranchIsPruned(self):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def switch_case_test(): def switch_case_test():
branch_index = array_ops.constant(0) branch_index = array_ops.constant(0)

View File

@ -693,7 +693,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
return x, y return x, y
wholly_compiled_f = def_function.function(f) wholly_compiled_f = def_function.function(f)
op_by_op_f = def_function.function(f, experimental_compile=False) op_by_op_f = def_function.function(f, jit_compile=False)
x = array_ops.identity([0.0, 2.0], name='data') x = array_ops.identity([0.0, 2.0], name='data')

View File

@ -45,22 +45,22 @@ flags.DEFINE_bool('vary_seed', False,
NUM_SAMPLES = int(1e3) NUM_SAMPLES = int(1e3)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def _igamma(a, x): def _igamma(a, x):
return math_ops.igamma(a, x) return math_ops.igamma(a, x)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def _igammac(a, x): def _igammac(a, x):
return math_ops.igammac(a, x) return math_ops.igammac(a, x)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def _polygamma(n, x): def _polygamma(n, x):
return math_ops.polygamma(n, x) return math_ops.polygamma(n, x)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def _zeta(a, q): def _zeta(a, q):
return math_ops.zeta(a, q) return math_ops.zeta(a, q)
@ -72,7 +72,7 @@ def implicit_reparameterization_grad(a, x):
return -gen_math_ops.igamma_grad_a(a, x) / prob return -gen_math_ops.igamma_grad_a(a, x) / prob
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def _log1p(x): def _log1p(x):
return math_ops.log1p(x) return math_ops.log1p(x)

View File

@ -51,7 +51,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
scenarios (e.g. TPU). The new version of stateless_random_* requires the scenarios (e.g. TPU). The new version of stateless_random_* requires the
intermediate tensor `alg` to be compile-time constant, so we need to check intermediate tensor `alg` to be compile-time constant, so we need to check
that this requirement is met. We use xla.compile instead of tf.function's that this requirement is met. We use xla.compile instead of tf.function's
experimental_compile because the latter doesn't throw an error even if the jit_compile because the latter doesn't throw an error even if the
compile-time-constant constraint is not met. compile-time-constant constraint is not met.
""" """
if config.list_logical_devices('TPU'): if config.list_logical_devices('TPU'):

View File

@ -78,7 +78,7 @@ For example, the following TensorFlow function which performs the MNIST training
is compiled with XLA: is compiled with XLA:
``` ```
@tf.function(experimental_compile=True) @tf.function(jit_compile=True)
def train_mnist(images, labels): def train_mnist(images, labels):
images, labels = cast(images, labels) images, labels = cast(images, labels)
@ -92,7 +92,7 @@ def train_mnist(images, labels):
optimizer.apply_gradients(zip(grads, layer_variables)) optimizer.apply_gradients(zip(grads, layer_variables))
``` ```
The `experimental_compile` API has _must-compile_ semantics: either the entire The `jit_compile` API has _must-compile_ semantics: either the entire
function is compiled with XLA, or an `errors.InvalidArgumentError` exception is function is compiled with XLA, or an `errors.InvalidArgumentError` exception is
thrown. XLA can not currently compile functions where dimensions are not thrown. XLA can not currently compile functions where dimensions are not
_inferrable_: that is, if it's not possible to infer the dimensions of all _inferrable_: that is, if it's not possible to infer the dimensions of all
@ -108,7 +108,7 @@ def not_compilable(x):
Shapes can vary across the runs though: Shapes can vary across the runs though:
``` ```
@tf.function(experimental_compile=True) @tf.function(jit_compile=True)
def recompiled_on_launch(a, b): def recompiled_on_launch(a, b):
return a + b return a + b

View File

@ -164,7 +164,7 @@
"source": [ "source": [
"# Define the training function\n", "# Define the training function\n",
"\n", "\n",
"In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside `tf.function` with `experimental_compile=True`." "In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside `tf.function` with `jit_compile=True`."
] ]
}, },
{ {
@ -177,7 +177,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"@tf.function(experimental_compile=True)\n", "@tf.function(jit_compile=True)\n",
"def train_mnist(images, labels):\n", "def train_mnist(images, labels):\n",
" images, labels = cast(images, labels)\n", " images, labels = cast(images, labels)\n",
"\n", "\n",

View File

@ -179,12 +179,12 @@ message FunctionSpec {
// field, so we instead map to an enum. // field, so we instead map to an enum.
// //
// See `tf.function` for details. // See `tf.function` for details.
enum ExperimentalCompile { enum JitCompile {
DEFAULT = 0; DEFAULT = 0;
ON = 1; ON = 1;
OFF = 2; OFF = 2;
} }
ExperimentalCompile experimental_compile = 6; JitCompile jit_compile = 6;
reserved 3, 4; reserved 3, 4;
} }

View File

@ -95,8 +95,8 @@ cuda_py_test(
) )
cuda_py_test( cuda_py_test(
name = "experimental_compile_test", name = "jit_compile_test",
srcs = ["experimental_compile_test.py"], srcs = ["jit_compile_test.py"],
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_mac", "no_mac",

View File

@ -27,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
class ExperimentalCompileTest(test.TestCase): class JitCompileTest(test.TestCase):
def testBasic(self): def testBasic(self):
with ops.Graph().as_default() as g: with ops.Graph().as_default() as g:
@ -35,7 +35,7 @@ class ExperimentalCompileTest(test.TestCase):
def fn(x, a): def fn(x, a):
return x + a return x + a
xla_func = def_function.function(fn, experimental_compile=True) xla_func = def_function.function(fn, jit_compile=True)
inputs = array_ops.placeholder(dtypes.float32, [5]) inputs = array_ops.placeholder(dtypes.float32, [5])
# XLA support is not yet enabled for TF ROCm # XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm(): if not test.is_built_with_rocm():
@ -55,7 +55,7 @@ class ExperimentalCompileTest(test.TestCase):
return 2 * x + a return 2 * x + a
with ops.Graph().as_default() as g: with ops.Graph().as_default() as g:
xla_func = def_function.function(fn, experimental_compile=True) xla_func = def_function.function(fn, jit_compile=True)
with backprop.GradientTape() as tape: with backprop.GradientTape() as tape:
inputs = array_ops.placeholder(dtypes.float32, [5]) inputs = array_ops.placeholder(dtypes.float32, [5])
tape.watch(inputs) tape.watch(inputs)
@ -79,7 +79,7 @@ class ExperimentalCompileTest(test.TestCase):
def fn(x, a): def fn(x, a):
return x + a return x + a
xla_func = def_function.function(fn, experimental_compile=True) xla_func = def_function.function(fn, jit_compile=True)
inputs = array_ops.placeholder(dtypes.int32, [5]) inputs = array_ops.placeholder(dtypes.int32, [5])
# XLA support is not yet enabled for TF ROCm # XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm(): if not test.is_built_with_rocm():
@ -98,7 +98,7 @@ class ExperimentalCompileTest(test.TestCase):
def fn(x): def fn(x):
return array_ops.unique(x).y # Unique is not supported by XLA return array_ops.unique(x).y # Unique is not supported by XLA
xla_func = def_function.function(fn, experimental_compile=True) xla_func = def_function.function(fn, jit_compile=True)
inputs = array_ops.placeholder(dtypes.float32, [5]) inputs = array_ops.placeholder(dtypes.float32, [5])
x = xla_func(inputs) x = xla_func(inputs)
# XLA support is not yet enabled for TF ROCm # XLA support is not yet enabled for TF ROCm

View File

@ -67,7 +67,7 @@ _UNSUPPORTED_OPS = set([
@tf_export('xla.experimental.compile') @tf_export('xla.experimental.compile')
@deprecated( @deprecated(
None, 'xla.experimental.compile is deprecated. Consider using ' None, 'xla.experimental.compile is deprecated. Consider using '
'tf.function(experimental_compile=True)', 'tf.function(jit_compile=True)',
warn_once=True) warn_once=True)
def compile(computation, inputs=None): # pylint: disable=redefined-builtin def compile(computation, inputs=None): # pylint: disable=redefined-builtin
"""Builds an operator that compiles and runs `computation` with XLA. """Builds an operator that compiles and runs `computation` with XLA.

View File

@ -917,7 +917,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
x = constant_op.constant([[1, 0.], [0., 0.]]) x = constant_op.constant([[1, 0.], [0., 0.]])
if defunc: if defunc:
reduce_func = def_function.function( reduce_func = def_function.function(
math_ops.reduce_logsumexp, experimental_compile=xla_compile) math_ops.reduce_logsumexp, jit_compile=xla_compile)
func = lambda: reduce_func(x) func = lambda: reduce_func(x)
else: else:
func = lambda: math_ops.reduce_logsumexp(x) func = lambda: math_ops.reduce_logsumexp(x)

View File

@ -465,10 +465,10 @@ class Function(object):
name, name,
input_signature=None, input_signature=None,
autograph=True, autograph=True,
jit_compile=None,
experimental_implements=None, experimental_implements=None,
experimental_autograph_options=None, experimental_autograph_options=None,
experimental_relax_shapes=False, experimental_relax_shapes=False,
experimental_compile=None,
experimental_follow_type_hints=None): experimental_follow_type_hints=None):
"""Initializes a `Function`. """Initializes a `Function`.
@ -477,10 +477,10 @@ class Function(object):
name: the name given to it. name: the name given to it.
input_signature: See the documentation for `tf.function`. input_signature: See the documentation for `tf.function`.
autograph: See the documentation for `tf.function`. autograph: See the documentation for `tf.function`.
jit_compile: See the documentation for `tf.function`.
experimental_implements: See the documentation for `tf.function`. experimental_implements: See the documentation for `tf.function`.
experimental_autograph_options: See the documentation for `tf.function`. experimental_autograph_options: See the documentation for `tf.function`.
experimental_relax_shapes: See the documentation for `tf.function`. experimental_relax_shapes: See the documentation for `tf.function`.
experimental_compile: See the documentation for `tf.function`.
experimental_follow_type_hints: See the documentation for `tf.function`. experimental_follow_type_hints: See the documentation for `tf.function`.
Raises: Raises:
@ -492,7 +492,7 @@ class Function(object):
self._function_spec = function_lib.FunctionSpec.from_function_and_signature( self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
python_function, python_function,
input_signature, input_signature,
experimental_compile=experimental_compile, jit_compile=jit_compile,
experimental_follow_type_hints=experimental_follow_type_hints, experimental_follow_type_hints=experimental_follow_type_hints,
) )
self._implements = experimental_implements self._implements = experimental_implements
@ -503,7 +503,7 @@ class Function(object):
self._autograph = autograph self._autograph = autograph
self._experimental_autograph_options = experimental_autograph_options self._experimental_autograph_options = experimental_autograph_options
self._experimental_relax_shapes = experimental_relax_shapes self._experimental_relax_shapes = experimental_relax_shapes
self._experimental_compile = experimental_compile self._jit_compile = jit_compile
if experimental_follow_type_hints is None: if experimental_follow_type_hints is None:
experimental_follow_type_hints = False experimental_follow_type_hints = False
self._experimental_follow_type_hints = experimental_follow_type_hints self._experimental_follow_type_hints = experimental_follow_type_hints
@ -558,7 +558,7 @@ class Function(object):
"""Creates a defun wrapped inside a variable creator scope.""" """Creates a defun wrapped inside a variable creator scope."""
weak_wrapped_fn = None weak_wrapped_fn = None
compile_with_xla = self._experimental_compile compile_with_xla = self._jit_compile
def wrapped_fn(*args, **kwds): def wrapped_fn(*args, **kwds):
"""Wraps `self._python_function` in a variable creator scope.""" """Wraps `self._python_function` in a variable creator scope."""
@ -629,9 +629,9 @@ class Function(object):
if share is not None: if share is not None:
attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share
if self._experimental_compile is not None: if self._jit_compile is not None:
attributes.update(_XlaMustCompile=bool(self._experimental_compile)) attributes.update(_XlaMustCompile=bool(self._jit_compile))
if self._experimental_compile: if self._jit_compile:
attributes.update(_noinline=True) attributes.update(_noinline=True)
if not attributes: if not attributes:
attributes = None attributes = None
@ -640,8 +640,8 @@ class Function(object):
input_signature=self.input_signature, input_signature=self.input_signature,
attributes=attributes, attributes=attributes,
autograph=self._autograph, autograph=self._autograph,
jit_compile=self._jit_compile,
experimental_autograph_options=self._experimental_autograph_options, experimental_autograph_options=self._experimental_autograph_options,
experimental_compile=self._experimental_compile,
experimental_follow_type_hints=self._experimental_follow_type_hints, experimental_follow_type_hints=self._experimental_follow_type_hints,
experimental_relax_shapes=self._experimental_relax_shapes) experimental_relax_shapes=self._experimental_relax_shapes)
@ -698,10 +698,10 @@ class Function(object):
name=self._name, name=self._name,
input_signature=self._input_signature, input_signature=self._input_signature,
autograph=self._autograph, autograph=self._autograph,
jit_compile=self._jit_compile,
experimental_implements=self._implements, experimental_implements=self._implements,
experimental_autograph_options=self._experimental_autograph_options, experimental_autograph_options=self._experimental_autograph_options,
experimental_relax_shapes=self._experimental_relax_shapes, experimental_relax_shapes=self._experimental_relax_shapes,
experimental_compile=self._experimental_compile,
experimental_follow_type_hints=self._experimental_follow_type_hints) experimental_follow_type_hints=self._experimental_follow_type_hints)
if self._shared_rendezvous: if self._shared_rendezvous:
@ -782,7 +782,7 @@ class Function(object):
tracing_count = self.experimental_get_tracing_count() tracing_count = self.experimental_get_tracing_count()
with trace.Trace(self._name) as tm: with trace.Trace(self._name) as tm:
result = self._call(*args, **kwds) result = self._call(*args, **kwds)
compiler = "xla" if self._experimental_compile else "nonXla" compiler = "xla" if self._jit_compile else "nonXla"
new_tracing_count = self.experimental_get_tracing_count() new_tracing_count = self.experimental_get_tracing_count()
without_tracing = (tracing_count == new_tracing_count) without_tracing = (tracing_count == new_tracing_count)
execution_mode = "notTraced" if without_tracing else "traced" execution_mode = "notTraced" if without_tracing else "traced"
@ -934,7 +934,7 @@ class Function(object):
For example, for For example, for
```python ```python
@tf.function(experimental_compile=True) @tf.function(jit_compile=True)
def f(x): def f(x):
return x + 1 return x + 1
@ -962,14 +962,13 @@ class Function(object):
Raises: Raises:
ValueError: If an invalid `stage` is selected or if applied to a function ValueError: If an invalid `stage` is selected or if applied to a function
which is not compiled (`experimental_compile=True` is not set). which is not compiled (`jit_compile=True` is not set).
TypeError: When called with input in graph mode. TypeError: When called with input in graph mode.
""" """
context.ensure_initialized() context.ensure_initialized()
if not self._experimental_compile: if not self._jit_compile:
raise ValueError( raise ValueError("Compiler IR can only be returned for functions marked "
"Compiler IR can only be returned for functions marked with " "with 'jit_compile=True'")
"experimental_compile=True")
concrete_fn = self.get_concrete_function(*args, **kwargs) concrete_fn = self.get_concrete_function(*args, **kwargs)
fn_name = concrete_fn.name fn_name = concrete_fn.name
@ -1285,9 +1284,13 @@ class Function(object):
@tf_export("function") @tf_export("function")
@deprecation.deprecated_args(None,
"experimental_compile is deprecated, use "
"jit_compile instead", "experimental_compile")
def function(func=None, def function(func=None,
input_signature=None, input_signature=None,
autograph=True, autograph=True,
jit_compile=None,
experimental_implements=None, experimental_implements=None,
experimental_autograph_options=None, experimental_autograph_options=None,
experimental_relax_shapes=False, experimental_relax_shapes=False,
@ -1497,6 +1500,20 @@ def function(func=None,
graph. Data-dependent control flow requires `autograph=True`. For more graph. Data-dependent control flow requires `autograph=True`. For more
information, see the [tf.function and AutoGraph guide]( information, see the [tf.function and AutoGraph guide](
https://www.tensorflow.org/guide/function). https://www.tensorflow.org/guide/function).
jit_compile: If `True`, compiles the function using
[XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
such as fusion, and attempts to emit more efficient code. This may
drastically improve the performance. If set to `True`,
the whole function needs to be compilable by XLA, or an
`errors.InvalidArgumentError` is thrown.
If `None` (default), compiles the function with XLA when running on TPU
and goes through the regular function execution path when running on
other devices.
If `False`, executes the function without XLA compilation. Set this value
to `False` when directly running a multi-device function on TPUs (e.g. two
TPU cores, one TPU core and its host CPU).
Not all functions are compilable, see a list of
[sharp corners](https://tensorflow.org/xla/known_issues).
experimental_implements: If provided, contains a name of a "known" function experimental_implements: If provided, contains a name of a "known" function
this implements. For example "mycompany.my_recurrent_cell". this implements. For example "mycompany.my_recurrent_cell".
This is stored as an attribute in inference function, This is stored as an attribute in inference function,
@ -1519,22 +1536,7 @@ def function(func=None,
`tf.autograph.experimental.Feature` values. `tf.autograph.experimental.Feature` values.
experimental_relax_shapes: When True, `tf.function` may generate fewer, experimental_relax_shapes: When True, `tf.function` may generate fewer,
graphs that are less specialized on input shapes. graphs that are less specialized on input shapes.
experimental_compile: If `True`, compiles the function using XLA experimental_compile: Deprecated alias to 'jit_compile'.
(see https://tensorflow.org/xla). XLA performs compiler optimizations,
such as fusion, and attempts to emit more efficient code. This may
drastically improve the performance. If set to `True`,
the whole function needs to be compilable by XLA, or an
`errors.InvalidArgumentError` is thrown.
If `None` (default), compiles the function with XLA when running on TPU
and goes through the regular function execution path when running on
other devices.
If `False`, executes the function in a regular way (graph rewrite
passes are applied, kernels are dispatched one-by-one by the TensorFlow
executor). Set this value to `False` when directly running a
multi-device function on TPUs (e.g. two TPU cores, one TPU core and its
host CPU).
Not all functions are compilable, see
https://tensorflow.org/xla/known_issues for a list of sharp corners.
experimental_follow_type_hints: When True, the function may use type experimental_follow_type_hints: When True, the function may use type
annotations from `func` to optimize the tracing performance. For example, annotations from `func` to optimize the tracing performance. For example,
arguments annotated with `tf.Tensor` will automatically be converted arguments annotated with `tf.Tensor` will automatically be converted
@ -1547,8 +1549,8 @@ def function(func=None,
`func` argument, returns a callable equivalent to the case above. `func` argument, returns a callable equivalent to the case above.
Raises: Raises:
ValueError when attempting to use experimental_compile, but XLA support is ValueError when attempting to use jit_compile=True, but XLA support is not
not enabled. linked.
""" """
# TODO(mdan): Link to `tf.types` section once published. # TODO(mdan): Link to `tf.types` section once published.
if input_signature is not None: if input_signature is not None:
@ -1571,7 +1573,14 @@ def function(func=None,
autograph=autograph, autograph=autograph,
experimental_autograph_options=experimental_autograph_options, experimental_autograph_options=experimental_autograph_options,
experimental_relax_shapes=experimental_relax_shapes, experimental_relax_shapes=experimental_relax_shapes,
experimental_compile=experimental_compile,
# TODO(b/171825496): Update once `experimental_compile` is removed
# entirely in favor of 'jit_compile'.
jit_compile=deprecation.deprecated_argument_lookup(
"jit_compile",
jit_compile,
"experimental_compile",
experimental_compile),
experimental_implements=experimental_implements, experimental_implements=experimental_implements,
experimental_follow_type_hints=experimental_follow_type_hints)) experimental_follow_type_hints=experimental_follow_type_hints))

View File

@ -683,7 +683,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
experimental_implements=implements, experimental_implements=implements,
experimental_autograph_options=autograph_options, experimental_autograph_options=autograph_options,
experimental_relax_shapes=relax_shapes, experimental_relax_shapes=relax_shapes,
experimental_compile=compile_) jit_compile=compile_)
if override_function: if override_function:
cloned_py_function = lambda x: x + 1 cloned_py_function = lambda x: x + 1
@ -699,7 +699,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertEqual(implements, cloned._implements) self.assertEqual(implements, cloned._implements)
self.assertEqual(autograph_options, cloned._experimental_autograph_options) self.assertEqual(autograph_options, cloned._experimental_autograph_options)
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes) self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
self.assertEqual(compile_, cloned._experimental_compile) self.assertEqual(compile_, cloned._jit_compile)
# This test does not run with XLA JIT support linked in so we can only check # This test does not run with XLA JIT support linked in so we can only check
# the output of the function if compile is disabled. # the output of the function if compile is disabled.

View File

@ -27,20 +27,20 @@ from tensorflow.python.platform import test
class DefFunctionCpuOnlyTest(test.TestCase, parameterized.TestCase): class DefFunctionCpuOnlyTest(test.TestCase, parameterized.TestCase):
"""Test that experimental_compile=True correctly throws an exception if XLA is not available. """Test that jit_compile=True correctly throws an exception if XLA is not available.
This test should only be run without `--config=cuda`, as that implicitly links This test should only be run without `--config=cuda`, as that implicitly links
in XLA JIT. in XLA JIT.
""" """
def testExperimentalCompileRaisesExceptionWhenXlaIsUnsupported(self): def testJitCompileRaisesExceptionWhenXlaIsUnsupported(self):
if test.is_built_with_rocm() or test_util.is_xla_enabled(): if test.is_built_with_rocm() or test_util.is_xla_enabled():
return return
with self.assertRaisesRegex(errors.UnimplementedError, with self.assertRaisesRegex(errors.UnimplementedError,
'check target linkage'): 'check target linkage'):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def fn(x): def fn(x):
return x + x return x + x

View File

@ -45,11 +45,11 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=False) @def_function.function(jit_compile=False)
def outer(a, b, c): def outer(a, b, c):
return a * inner(b, c) + c return a * inner(b, c) + c
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def inner(b, c): def inner(b, c):
return b + c * b return b + c * b
@ -71,8 +71,8 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x, a): def fn(x, a):
return x + a return x + a
func = def_function.function(fn, experimental_compile=False) func = def_function.function(fn, jit_compile=False)
xla_func = def_function.function(fn, experimental_compile=True) xla_func = def_function.function(fn, jit_compile=True)
inputs = constant_op.constant([1, 2, 2, 3, 3]) inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertAllClose([2, 3, 3, 4, 4], func(inputs, 1)) self.assertAllClose([2, 3, 3, 4, 4], func(inputs, 1))
@ -81,7 +81,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testBasicInt32(self): def testBasicInt32(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def fn(x, a): def fn(x, a):
return x + a return x + a
@ -94,7 +94,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x, a): def fn(x, a):
return 2 * x + a return 2 * x + a
xla_func = def_function.function(fn, experimental_compile=True) xla_func = def_function.function(fn, jit_compile=True)
with backprop.GradientTape() as tape: with backprop.GradientTape() as tape:
inputs = constant_op.constant([1., 2., 2., 3., 3.]) inputs = constant_op.constant([1., 2., 2., 3., 3.])
@ -112,19 +112,19 @@ class DefFunctionTest(xla_test.XLATestCase):
self.assertTrue(backward.function_def.attr['_XlaMustCompile']) self.assertTrue(backward.function_def.attr['_XlaMustCompile'])
self.assertTrue(forward.definition.attr['_XlaMustCompile']) self.assertTrue(forward.definition.attr['_XlaMustCompile'])
# Calling function with experimental_compile=True from # Calling function with jit_compile=True from
# experimental_compile=False should compile the inner func. # jit_compile=False should compile the inner func.
def testNestedCall(self): def testNestedCall(self):
if 'tpu' in self.device.lower(): if 'tpu' in self.device.lower():
self.skipTest('b/162800687: Inner function runs on host') self.skipTest('b/162800687: Inner function runs on host')
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def fn(x, a): def fn(x, a):
return x + a return x + a
@def_function.function(experimental_compile=False) @def_function.function(jit_compile=False)
def fn2(x, a): def fn2(x, a):
return fn(x, a) return fn(x, a)
@ -139,12 +139,12 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x): def fn(x):
return array_ops.unique(x).y return array_ops.unique(x).y
xla_func = def_function.function(fn, experimental_compile=True) xla_func = def_function.function(fn, jit_compile=True)
def fn2(x): def fn2(x):
return xla_func(x) return xla_func(x)
func = def_function.function(fn2, experimental_compile=False) func = def_function.function(fn2, jit_compile=False)
inputs = constant_op.constant([1, 2, 2, 3, 3]) inputs = constant_op.constant([1, 2, 2, 3, 3])
with self.assertRaisesRegex(errors.InvalidArgumentError, with self.assertRaisesRegex(errors.InvalidArgumentError,
'not compilable'): 'not compilable'):
@ -158,8 +158,8 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x): def fn(x):
return array_ops.unique(x).y # Unique is not supported by XLA return array_ops.unique(x).y # Unique is not supported by XLA
func = def_function.function(fn, experimental_compile=False) func = def_function.function(fn, jit_compile=False)
xla_func = def_function.function(fn, experimental_compile=True) xla_func = def_function.function(fn, jit_compile=True)
inputs = constant_op.constant([1, 2, 2, 3, 3]) inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertAllClose([1, 2, 3], func(inputs)) self.assertAllClose([1, 2, 3], func(inputs))
@ -174,8 +174,8 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x): def fn(x):
return v * x return v * x
func = def_function.function(fn, experimental_compile=False) func = def_function.function(fn, jit_compile=False)
xla_func = def_function.function(fn, experimental_compile=True) xla_func = def_function.function(fn, jit_compile=True)
def run_and_check(test_func): def run_and_check(test_func):
x = constant_op.constant(3.0) x = constant_op.constant(3.0)
@ -195,7 +195,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(x): def f(x):
assert control_flow_util.GraphOrParentsInXlaContext( assert control_flow_util.GraphOrParentsInXlaContext(
ops.get_default_graph()) ops.get_default_graph())
@ -210,7 +210,7 @@ class DefFunctionTest(xla_test.XLATestCase):
body, (constant_op.constant(0), constant_op.constant(3.)), body, (constant_op.constant(0), constant_op.constant(3.)),
maximum_iterations=10)[1] maximum_iterations=10)[1]
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def g(x): def g(x):
x = ops.convert_to_tensor(x) x = ops.convert_to_tensor(x)
with backprop.GradientTape() as tape: with backprop.GradientTape() as tape:
@ -232,7 +232,7 @@ class DefFunctionTest(xla_test.XLATestCase):
class C(object): class C(object):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f1(self, x, a): def f1(self, x, a):
return x + a return x + a
@ -248,7 +248,7 @@ class DefFunctionTest(xla_test.XLATestCase):
class C(object): class C(object):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f1(self, x): def f1(self, x):
return array_ops.unique(x).y return array_ops.unique(x).y
@ -264,11 +264,11 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(): def f():
return constant_op.constant([0, 2, 1], dtype=dtypes.int32) return constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def g(a, b): def g(a, b):
return array_ops.transpose(a, b) return array_ops.transpose(a, b)
@ -283,11 +283,11 @@ class DefFunctionTest(xla_test.XLATestCase):
def testArgMinMax(self): def testArgMinMax(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def argmax(x): def argmax(x):
return math_ops.argmax(x) return math_ops.argmax(x)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def argmin(x): def argmin(x):
return math_ops.argmin(x) return math_ops.argmin(x)
@ -300,7 +300,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testErrorMessagePassingTensorArray(self): def testErrorMessagePassingTensorArray(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(x): def f(x):
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=1, element_shape=[]) dtype=dtypes.float32, size=1, element_shape=[])
@ -328,7 +328,7 @@ class DefFunctionTest(xla_test.XLATestCase):
ta = ta.write(1, 3 * x) ta = ta.write(1, 3 * x)
return ta.concat() return ta.concat()
compiled_f = def_function.function(experimental_compile=True)(f) compiled_f = def_function.function(jit_compile=True)(f)
inputs = constant_op.constant([3.14, 2.68, 7.69]) inputs = constant_op.constant([3.14, 2.68, 7.69])
@ -348,7 +348,7 @@ class DefFunctionTest(xla_test.XLATestCase):
ta = ta.write(1, 3 * x) ta = ta.write(1, 3 * x)
return ta.concat() return ta.concat()
compiled_f = def_function.function(experimental_compile=True)(f) compiled_f = def_function.function(jit_compile=True)(f)
inputs = constant_op.constant([[3.14, 21.1], [2.68, 22.2], [7.69, 23.3]]) inputs = constant_op.constant([[3.14, 21.1], [2.68, 22.2], [7.69, 23.3]])
self.assertAllClose(f(inputs), compiled_f(inputs)) self.assertAllClose(f(inputs), compiled_f(inputs))
@ -365,7 +365,7 @@ class DefFunctionTest(xla_test.XLATestCase):
ta = ta.write(1, 3 * x) ta = ta.write(1, 3 * x)
return ta.concat() return ta.concat()
compiled_f = def_function.function(experimental_compile=True)(f) compiled_f = def_function.function(jit_compile=True)(f)
inputs = constant_op.constant([3.14]) inputs = constant_op.constant([3.14])
self.assertAllClose(f(inputs), compiled_f(inputs)) self.assertAllClose(f(inputs), compiled_f(inputs))
@ -388,7 +388,7 @@ class DefFunctionTest(xla_test.XLATestCase):
y = f(x) y = f(x)
return tape.gradient(y, x) return tape.gradient(y, x)
compiled_g = def_function.function(experimental_compile=True)(g) compiled_g = def_function.function(jit_compile=True)(g)
self.assertAllClose([5.0, 5.0, 5.0], g()) self.assertAllClose([5.0, 5.0, 5.0], g())
self.assertAllClose(compiled_g(), g()) self.assertAllClose(compiled_g(), g())
@ -398,7 +398,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testTensorListConcatGradNestedCompile(self): def testTensorListConcatGradNestedCompile(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(x): def f(x):
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=2, element_shape=[3]) dtype=dtypes.float32, size=2, element_shape=[3])
@ -406,7 +406,7 @@ class DefFunctionTest(xla_test.XLATestCase):
ta = ta.write(1, 3 * x) ta = ta.write(1, 3 * x)
return ta.concat() return ta.concat()
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def g(): def g():
x = constant_op.constant([3.14, 2.68, 7.69]) x = constant_op.constant([3.14, 2.68, 7.69])
with backprop.GradientTape() as tape: with backprop.GradientTape() as tape:
@ -423,7 +423,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(x): def f(x):
return math_ops.cumsum(x) return math_ops.cumsum(x)
@ -434,7 +434,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
inner_retracings = 0 inner_retracings = 0
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def inner(a, b): def inner(a, b):
nonlocal inner_retracings nonlocal inner_retracings
inner_retracings += 1 inner_retracings += 1
@ -455,7 +455,7 @@ class DefFunctionTest(xla_test.XLATestCase):
on_gpu = 'gpu' in self.device.lower() on_gpu = 'gpu' in self.device.lower()
v = variables.Variable([3.1, 3.2]) v = variables.Variable([3.1, 3.2])
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def update_var(a, b): def update_var(a, b):
v.assign_add(a * b) v.assign_add(a * b)
@ -476,7 +476,7 @@ class DefFunctionTest(xla_test.XLATestCase):
class C(object): class C(object):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def update_var(self, a, b): def update_var(self, a, b):
if not hasattr(self, 'v'): if not hasattr(self, 'v'):
self.v = variables.Variable(3.1) self.v = variables.Variable(3.1)
@ -497,7 +497,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
v = variables.Variable(3.1) v = variables.Variable(3.1)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def update_var(a, b): def update_var(a, b):
v.assign_add(a * b) v.assign_add(a * b)
return a * b + v return a * b + v
@ -509,7 +509,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testReturnIdentity(self): def testReturnIdentity(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(a, b): def f(a, b):
return (a, b) return (a, b)
@ -532,7 +532,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(a, b): def f(a, b):
return array_ops.transpose(a, b) return array_ops.transpose(a, b)
@ -549,7 +549,7 @@ class DefFunctionTest(xla_test.XLATestCase):
v = variables.Variable([3.1, 3.2]) v = variables.Variable([3.1, 3.2])
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(a, b): def f(a, b):
v.assign_add(a * b) v.assign_add(a * b)
@ -568,17 +568,17 @@ class DefFunctionTest(xla_test.XLATestCase):
a = random_ops.random_normal([10, 10]) a = random_ops.random_normal([10, 10])
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
'marked with experimental_compile'): 'marked with \'jit_compile'):
f.experimental_get_compiler_ir(a)() f.experimental_get_compiler_ir(a)()
def testGetCompilerIrNested(self): def testGetCompilerIrNested(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def fn(x, a): def fn(x, a):
return x + a return x + a
@def_function.function(experimental_compile=False) @def_function.function(jit_compile=False)
def fn2(x, a): def fn2(x, a):
fn.experimental_get_compiler_ir(x, a)() fn.experimental_get_compiler_ir(x, a)()
return fn(x, a) return fn(x, a)
@ -592,7 +592,7 @@ class DefFunctionTest(xla_test.XLATestCase):
v = variables.Variable([0.1, 0.1]) v = variables.Variable([0.1, 0.1])
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(a, b): def f(a, b):
return (a + b) * v return (a + b) * v
@ -605,7 +605,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testGetCompilerIrDot(self): def testGetCompilerIrDot(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(a, b): def f(a, b):
return a + b return a + b
@ -620,7 +620,7 @@ class DefFunctionTest(xla_test.XLATestCase):
if 'gpu' not in self.device.lower(): if 'gpu' not in self.device.lower():
self.skipTest('Testing get_compiler_ir on GPUs without placement') self.skipTest('Testing get_compiler_ir on GPUs without placement')
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(a, b): def f(a, b):
return a + b return a + b
@ -634,7 +634,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testGetCompilerIrNonTensors(self): def testGetCompilerIrNonTensors(self):
with ops.device('device:{}:0'.format(self.device)): with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(l): def f(l):
return l[0] + l[1] return l[0] + l[1]
@ -649,7 +649,7 @@ class DefFunctionTest(xla_test.XLATestCase):
s = random_ops.random_uniform([2], 1, 10, dtypes.int32) s = random_ops.random_uniform([2], 1, 10, dtypes.int32)
l = random_ops.random_normal([s[0] * s[1]]) l = random_ops.random_normal([s[0] * s[1]])
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(l): def f(l):
return array_ops.reshape(l, s) return array_ops.reshape(l, s)

View File

@ -2349,7 +2349,7 @@ class FunctionSpec(object):
input_signature, input_signature,
is_pure=False, is_pure=False,
experimental_follow_type_hints=False, experimental_follow_type_hints=False,
experimental_compile=None): jit_compile=None):
"""Create a FunctionSpec instance given a python function and signature. """Create a FunctionSpec instance given a python function and signature.
Args: Args:
@ -2358,7 +2358,7 @@ class FunctionSpec(object):
is_pure: if True all input arguments (including variables and constants) is_pure: if True all input arguments (including variables and constants)
will be converted to tensors and no variable changes allowed. will be converted to tensors and no variable changes allowed.
experimental_follow_type_hints: see `tf.function` experimental_follow_type_hints: see `tf.function`
experimental_compile: see `tf.function` jit_compile: see `tf.function`
Returns: Returns:
instance of FunctionSpec instance of FunctionSpec
@ -2440,7 +2440,7 @@ class FunctionSpec(object):
is_method, is_method,
input_signature, input_signature,
is_pure=is_pure, is_pure=is_pure,
experimental_compile=experimental_compile, jit_compile=jit_compile,
experimental_follow_type_hints=experimental_follow_type_hints, experimental_follow_type_hints=experimental_follow_type_hints,
name=name) name=name)
@ -2451,7 +2451,7 @@ class FunctionSpec(object):
is_pure=False, is_pure=False,
experimental_follow_type_hints=False, experimental_follow_type_hints=False,
name=None, name=None,
experimental_compile=None): jit_compile=None):
"""Constructs a FunctionSpec describing a python function. """Constructs a FunctionSpec describing a python function.
Args: Args:
@ -2462,12 +2462,12 @@ class FunctionSpec(object):
will be converted to tensors and no variable changes allowed. will be converted to tensors and no variable changes allowed.
experimental_follow_type_hints: see `tf.function`. experimental_follow_type_hints: see `tf.function`.
name: Name of the function name: Name of the function
experimental_compile: see `tf.function`. jit_compile: see `tf.function`.
""" """
self._fullargspec = fullargspec self._fullargspec = fullargspec
self._is_method = is_method self._is_method = is_method
self._is_pure = is_pure self._is_pure = is_pure
self._experimental_compile = experimental_compile self._jit_compile = jit_compile
self._experimental_follow_type_hints = experimental_follow_type_hints self._experimental_follow_type_hints = experimental_follow_type_hints
# TODO(edloper): Include name when serializing for SavedModel? # TODO(edloper): Include name when serializing for SavedModel?
@ -2538,8 +2538,8 @@ class FunctionSpec(object):
return self._is_pure return self._is_pure
@property @property
def experimental_compile(self): def jit_compile(self):
return self._experimental_compile return self._jit_compile
@property @property
def arg_names(self): def arg_names(self):
@ -2904,7 +2904,7 @@ class Function(object):
autograph_options=None, autograph_options=None,
experimental_relax_shapes=False, experimental_relax_shapes=False,
capture_by_value=None, capture_by_value=None,
experimental_compile=None, jit_compile=None,
experimental_follow_type_hints=False): experimental_follow_type_hints=False):
"""Initializes a `Function`. """Initializes a `Function`.
@ -2927,8 +2927,8 @@ class Function(object):
capture_by_value: Experimental. Whether to capture resource variables by capture_by_value: Experimental. Whether to capture resource variables by
value or reference. If None, will inherit from a parent context or value or reference. If None, will inherit from a parent context or
default to False. default to False.
experimental_compile: Force-compile the function with XLA, cf. jit_compile: Force-compile the function with XLA, cf.
def_function.Function doc on experimental_compile. def_function.Function doc on jit_compile.
experimental_follow_type_hints: See the documentation for `tf.function`. experimental_follow_type_hints: See the documentation for `tf.function`.
Raises: Raises:
@ -2959,7 +2959,7 @@ class Function(object):
# `Function`, used to make sure defun-decorated methods create different # `Function`, used to make sure defun-decorated methods create different
# functions for each instance. # functions for each instance.
self._descriptor_cache = weakref.WeakKeyDictionary() self._descriptor_cache = weakref.WeakKeyDictionary()
self._experimental_compile = experimental_compile self._jit_compile = jit_compile
self._experimental_follow_type_hints = experimental_follow_type_hints self._experimental_follow_type_hints = experimental_follow_type_hints
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
@ -3774,7 +3774,7 @@ def defun_with_attributes(func=None,
attributes=None, attributes=None,
autograph=True, autograph=True,
experimental_autograph_options=None, experimental_autograph_options=None,
experimental_compile=None, jit_compile=None,
experimental_relax_shapes=False, experimental_relax_shapes=False,
experimental_follow_type_hints=False): experimental_follow_type_hints=False):
"""Compiles a Python function into a callable TensorFlow graph. """Compiles a Python function into a callable TensorFlow graph.
@ -3796,7 +3796,7 @@ def defun_with_attributes(func=None,
autograph: same as defun()'s autograph. autograph: same as defun()'s autograph.
experimental_autograph_options: same as defun()'s experimental_autograph_options: same as defun()'s
experimental_autograph_options. experimental_autograph_options.
experimental_compile: same as defun()'s experimental_compile. jit_compile: same as defun()'s jit_compile.
experimental_relax_shapes: same as defun()'s experimental_relax_shapes experimental_relax_shapes: same as defun()'s experimental_relax_shapes
experimental_follow_type_hints: see `tf.function`. experimental_follow_type_hints: see `tf.function`.
@ -3825,7 +3825,7 @@ def defun_with_attributes(func=None,
attributes=attributes, attributes=attributes,
autograph=autograph, autograph=autograph,
autograph_options=experimental_autograph_options, autograph_options=experimental_autograph_options,
experimental_compile=experimental_compile, jit_compile=jit_compile,
experimental_relax_shapes=experimental_relax_shapes, experimental_relax_shapes=experimental_relax_shapes,
experimental_follow_type_hints=experimental_follow_type_hints)) experimental_follow_type_hints=experimental_follow_type_hints))
@ -3925,7 +3925,7 @@ def class_method_to_instance_method(original_function, instance):
autograph=original_function._autograph, autograph=original_function._autograph,
input_signature=original_function.input_signature, input_signature=original_function.input_signature,
experimental_relax_shapes=original_function._experimental_relax_shapes, experimental_relax_shapes=original_function._experimental_relax_shapes,
experimental_compile=original_function._experimental_compile) jit_compile=original_function._jit_compile)
# pylint: enable=protected-access # pylint: enable=protected-access
# And we wrap the function with tf_decorator so inspection works correctly # And we wrap the function with tf_decorator so inspection works correctly

View File

@ -92,7 +92,7 @@ class KerasLayerBenchmarks(six.with_metaclass(
layer = layer_cls(**layer_args) layer = layer_cls(**layer_args)
x = tf.ones(input_shape) x = tf.ones(input_shape)
layer.call = tf.function( layer.call = tf.function(
layer.call, experimental_compile=True) layer.call, jit_compile=True)
fn = functools.partial(layer, x) fn = functools.partial(layer, x)
name = _get_benchmark_name(self._get_name()) name = _get_benchmark_name(self._get_name())
@ -156,7 +156,7 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass(
layer = layer_cls(**layer_args) layer = layer_cls(**layer_args)
x = tf.ones(input_shape) x = tf.ones(input_shape)
layer.call = tf.function( layer.call = tf.function(
layer.call, experimental_compile=True) layer.call, jit_compile=True)
fn = functools.partial(_layer_call_backward, layer, x) fn = functools.partial(_layer_call_backward, layer, x)
name = _get_benchmark_name(self._get_name()) name = _get_benchmark_name(self._get_name())

View File

@ -419,7 +419,7 @@ class KerasModelsXLATest(test.TestCase, parameterized.TestCase):
@ds_combinations.generate( @ds_combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.tpu_strategies, mode=["eager"])) distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
def test_tf_function_experimental_compile(self, distribution): def test_tf_function_jit_compile(self, distribution):
dataset = _get_dataset() dataset = _get_dataset()
input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
@ -433,7 +433,7 @@ class KerasModelsXLATest(test.TestCase, parameterized.TestCase):
self.kernel = self.add_variable( self.kernel = self.add_variable(
"kernel", shape=[int(input_shape[-1]), self.num_outputs]) "kernel", shape=[int(input_shape[-1]), self.num_outputs])
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def call(self, inputs): def call(self, inputs):
return math_ops.matmul(inputs, self.kernel) return math_ops.matmul(inputs, self.kernel)

View File

@ -416,7 +416,7 @@ class NonFusedAdam(optimizer_v2.OptimizerV2):
weights = weights[:len(params)] weights = weights[:len(params)]
super(NonFusedAdam, self).set_weights(weights) super(NonFusedAdam, self).set_weights(weights)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def _resource_apply_dense(self, grad, var, apply_state=None): def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
@ -437,7 +437,7 @@ class NonFusedAdam(optimizer_v2.OptimizerV2):
var.assign_sub( var.assign_sub(
(m * alpha) / (math_ops.sqrt(v) - coefficients['epsilon'])) (m * alpha) / (math_ops.sqrt(v) - coefficients['epsilon']))
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or coefficients = ((apply_state or {}).get((var_device, var_dtype)) or

View File

@ -56,7 +56,7 @@ class CollectiveOpXlaTest(test.TestCase):
tensor_val = [i + 1.] * tensor_size tensor_val = [i + 1.] * tensor_size
constant = constant_op.constant(tensor_val) constant = constant_op.constant(tensor_val)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(x): def f(x):
return 2 * x + 1 return 2 * x + 1

View File

@ -1288,7 +1288,7 @@ class ExecuteFnForDeviceTest(test_util.TensorFlowTestCase):
def gpu_fn(x): def gpu_fn(x):
return x * x return x * x
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def flexible_defun(a): def flexible_defun(a):
branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)} branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)}
return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a)) return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a))

View File

@ -76,12 +76,12 @@ class PForTest(PForTestCase):
vectorized_compute, inputs=[array_ops.ones((10, 5, 3))]) vectorized_compute, inputs=[array_ops.ones((10, 5, 3))])
self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
def test_function_experimental_compile(self): def test_function_jit_compile(self):
def compute(x): def compute(x):
return math_ops.reduce_mean(x, axis=0, keepdims=True) return math_ops.reduce_mean(x, axis=0, keepdims=True)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def vectorized_compute(x): def vectorized_compute(x):
return pfor_control_flow_ops.vectorized_map(compute, x) return pfor_control_flow_ops.vectorized_map(compute, x)
@ -112,7 +112,7 @@ class PForTest(PForTestCase):
def test_reduce_mean(self): def test_reduce_mean(self):
x = random_ops.random_uniform([8, 3]) x = random_ops.random_uniform([8, 3])
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def f(): def f():
def loop_fn(i, pfor_config): def loop_fn(i, pfor_config):
@ -172,7 +172,7 @@ class WhileV2Test(PForTestCase):
# TODO(agarwal): The following may complain about uncompilable nodes. Hence # TODO(agarwal): The following may complain about uncompilable nodes. Hence
# these are currently not enabled for all tests. # these are currently not enabled for all tests.
if force_xla: if force_xla:
out_exp_compile_f = def_function.function(experimental_compile=True)(f)() out_exp_compile_f = def_function.function(jit_compile=True)(f)()
self.run_and_assert_equal(out, out_exp_compile_f) self.run_and_assert_equal(out, out_exp_compile_f)
out_xla_compile_f = xla.compile(f, inputs=[]) out_xla_compile_f = xla.compile(f, inputs=[])
self.run_and_assert_equal(out, out_xla_compile_f) self.run_and_assert_equal(out, out_xla_compile_f)

View File

@ -592,7 +592,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
random.set_global_generator(None) random.set_global_generator(None)
@def_function.function(experimental_compile=True) @def_function.function(jit_compile=True)
def make_seed(): def make_seed():
generator = random.get_global_generator() generator = random.get_global_generator()
state = array_ops.identity(generator.state, name="state") state = array_ops.identity(generator.state, name="state")

View File

@ -143,17 +143,17 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
annotations=typeless_fullargspec.annotations) annotations=typeless_fullargspec.annotations)
input_signature = coder.decode_proto(function_spec_proto.input_signature) input_signature = coder.decode_proto(function_spec_proto.input_signature)
# See `tf.function` and the ExperimentalCompile proto for details. # See `tf.function` and the JitCompile proto for details.
experimental_compile = { jit_compile = {
saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.DEFAULT: None, saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None,
saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.ON: True, saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True,
saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.OFF: False, saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False,
}.get(function_spec_proto.experimental_compile) }.get(function_spec_proto.jit_compile)
return function_lib.FunctionSpec(fullargspec=fullargspec, return function_lib.FunctionSpec(fullargspec=fullargspec,
is_method=False, is_method=False,
input_signature=input_signature, input_signature=input_signature,
experimental_compile=experimental_compile) jit_compile=jit_compile)
# TODO(allenl): The fact that we can't derive ConcreteFunction calling # TODO(allenl): The fact that we can't derive ConcreteFunction calling
@ -191,7 +191,7 @@ class RestoredFunction(def_function.Function):
# TODO(mdan): We may enable autograph once exceptions are supported. # TODO(mdan): We may enable autograph once exceptions are supported.
super(RestoredFunction, self).__init__( super(RestoredFunction, self).__init__(
python_function, name, autograph=False, python_function, name, autograph=False,
experimental_compile=function_spec.experimental_compile) jit_compile=function_spec.jit_compile)
self.concrete_functions = concrete_functions self.concrete_functions = concrete_functions
self._function_spec = function_spec self._function_spec = function_spec

View File

@ -47,12 +47,12 @@ def _serialize_function_spec(function_spec, coder):
proto.input_signature.CopyFrom( proto.input_signature.CopyFrom(
coder.encode_structure(function_spec.input_signature)) coder.encode_structure(function_spec.input_signature))
# See `tf.function` and the ExperimentalCompile proto for details. # See `tf.function` and the JitCompile proto for details.
proto.experimental_compile = { proto.jit_compile = {
None: saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.DEFAULT, None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT,
True: saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.ON, True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON,
False: saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.OFF, False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF,
}.get(function_spec.experimental_compile) }.get(function_spec.jit_compile)
return proto return proto

View File

@ -931,13 +931,13 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual([2, 4, 6], self.assertAllEqual([2, 4, 6],
imported.f(constant_op.constant([1, 2, 3])).numpy()) imported.f(constant_op.constant([1, 2, 3])).numpy())
def test_experimental_compile(self, cycles): def test_jit_compile(self, cycles):
# It'd be nice to use parameterize here, but the library does not support # It'd be nice to use parameterize here, but the library does not support
# having parameterized test methods inside already-parameterized classes. # having parameterized test methods inside already-parameterized classes.
for experimental_compile in (None, True, False): for jit_compile in (None, True, False):
@def_function.function(experimental_compile=experimental_compile) @def_function.function(jit_compile=jit_compile)
def f(x): def f(x):
return x + 1. return x + 1.
@ -948,7 +948,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
imported = cycle(root, cycles) imported = cycle(root, cycles)
self.assertEqual(imported.f._experimental_compile, experimental_compile) self.assertEqual(imported.f._jit_compile, jit_compile)
def test_get_concrete_function(self, cycles): def test_get_concrete_function(self, cycles):

View File

@ -292,7 +292,7 @@ class MultiDeviceSaver(object):
# latest values of options like experimental_io_device. # latest values of options like experimental_io_device.
if context.executing_eagerly() and len(self._single_device_savers) > 1: if context.executing_eagerly() and len(self._single_device_savers) > 1:
# Explicitly place the identity op on the first device. # Explicitly place the identity op on the first device.
@def_function.function(experimental_compile=False) @def_function.function(jit_compile=False)
def tf_function_save(): def tf_function_save():
save_fn() save_fn()
tf_function_save() tf_function_save()
@ -328,7 +328,7 @@ class MultiDeviceSaver(object):
# latest values of options like experimental_io_device. # latest values of options like experimental_io_device.
if context.executing_eagerly() and len(self._single_device_savers) > 1: if context.executing_eagerly() and len(self._single_device_savers) > 1:
first_device, _ = list(self._single_device_savers.items())[0] first_device, _ = list(self._single_device_savers.items())[0]
@def_function.function(experimental_compile=False) @def_function.function(jit_compile=False)
def tf_function_restore(): def tf_function_restore():
restore_ops = restore_fn() restore_ops = restore_fn()
restore_tensors = {} restore_tensors = {}

View File

@ -1342,7 +1342,7 @@ tf_module {
} }
member_method { member_method {
name: "function" name: "function"
argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'False\', \'None\', \'None\'], " argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'jit_compile\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "gather" name: "gather"

View File

@ -674,7 +674,7 @@ tf_module {
} }
member_method { member_method {
name: "function" name: "function"
argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'False\', \'None\', \'None\'], " argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'jit_compile\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "gather" name: "gather"