Rationale:

- Just `compile` is a Python built-in, and is already overloaded by the Keras
   `compile` method.
 - `jit` is closer to other ML frameworks.
 - Technically speaking, `jit=True` is not self-descriptive (it does not
   specify what is it doing just-in-time)
 - Moreover, `tf.function` by itself without compilation could be described as
   a JIT.
 - Also `jit` by itself is less grep'able.
 - Thus `@tf.function(jit_compile=True)` is the preferred spelling.

PiperOrigin-RevId: 340503501
Change-Id: I7bffe60aca69be6640390f6e6c4af40c6c4dbfda
This commit is contained in:
George Karpenkov 2020-11-03 12:41:28 -08:00 committed by TensorFlower Gardener
parent 6561045a9b
commit 9df9d06e27
32 changed files with 178 additions and 166 deletions

View File

@ -36,6 +36,9 @@
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
to control how external state should be handled during dataset
serialization or iterator checkpointing.
* XLA compilation:
* `tf.function(experimental_compile=True)` has become a stable API,
renamed `tf.function(jit_compile=True)`.
* `tf.lite`:
* NNAPI

View File

@ -398,7 +398,7 @@ static void ShowXlaDeviceDeprecationWarning(
absl::call_once(once, [] {
LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"@tf.function(jit_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
});

View File

@ -1481,7 +1481,7 @@ tf_xla_py_test(
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
],
use_xla_device = False, # Uses tf.function(experimental_compile=True)
use_xla_device = False, # Uses tf.function(jit_compile=True)
deps = [
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",

View File

@ -32,7 +32,7 @@ class CaseTest(xla_test.XLATestCase):
def testCaseBasic(self):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def switch_case_test(branch_index):
def f1():
@ -58,7 +58,7 @@ class CaseTest(xla_test.XLATestCase):
def testBranchIsPruned(self):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def switch_case_test():
branch_index = array_ops.constant(0)

View File

@ -693,7 +693,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
return x, y
wholly_compiled_f = def_function.function(f)
op_by_op_f = def_function.function(f, experimental_compile=False)
op_by_op_f = def_function.function(f, jit_compile=False)
x = array_ops.identity([0.0, 2.0], name='data')

View File

@ -45,22 +45,22 @@ flags.DEFINE_bool('vary_seed', False,
NUM_SAMPLES = int(1e3)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _igamma(a, x):
return math_ops.igamma(a, x)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _igammac(a, x):
return math_ops.igammac(a, x)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _polygamma(n, x):
return math_ops.polygamma(n, x)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _zeta(a, q):
return math_ops.zeta(a, q)
@ -72,7 +72,7 @@ def implicit_reparameterization_grad(a, x):
return -gen_math_ops.igamma_grad_a(a, x) / prob
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _log1p(x):
return math_ops.log1p(x)

View File

@ -51,7 +51,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
scenarios (e.g. TPU). The new version of stateless_random_* requires the
intermediate tensor `alg` to be compile-time constant, so we need to check
that this requirement is met. We use xla.compile instead of tf.function's
experimental_compile because the latter doesn't throw an error even if the
jit_compile because the latter doesn't throw an error even if the
compile-time-constant constraint is not met.
"""
if config.list_logical_devices('TPU'):

View File

@ -78,7 +78,7 @@ For example, the following TensorFlow function which performs the MNIST training
is compiled with XLA:
```
@tf.function(experimental_compile=True)
@tf.function(jit_compile=True)
def train_mnist(images, labels):
images, labels = cast(images, labels)
@ -92,7 +92,7 @@ def train_mnist(images, labels):
optimizer.apply_gradients(zip(grads, layer_variables))
```
The `experimental_compile` API has _must-compile_ semantics: either the entire
The `jit_compile` API has _must-compile_ semantics: either the entire
function is compiled with XLA, or an `errors.InvalidArgumentError` exception is
thrown. XLA can not currently compile functions where dimensions are not
_inferrable_: that is, if it's not possible to infer the dimensions of all
@ -108,7 +108,7 @@ def not_compilable(x):
Shapes can vary across the runs though:
```
@tf.function(experimental_compile=True)
@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
return a + b

View File

@ -164,7 +164,7 @@
"source": [
"# Define the training function\n",
"\n",
"In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside `tf.function` with `experimental_compile=True`."
"In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside `tf.function` with `jit_compile=True`."
]
},
{
@ -177,7 +177,7 @@
},
"outputs": [],
"source": [
"@tf.function(experimental_compile=True)\n",
"@tf.function(jit_compile=True)\n",
"def train_mnist(images, labels):\n",
" images, labels = cast(images, labels)\n",
"\n",

View File

@ -179,12 +179,12 @@ message FunctionSpec {
// field, so we instead map to an enum.
//
// See `tf.function` for details.
enum ExperimentalCompile {
enum JitCompile {
DEFAULT = 0;
ON = 1;
OFF = 2;
}
ExperimentalCompile experimental_compile = 6;
JitCompile jit_compile = 6;
reserved 3, 4;
}

View File

@ -95,8 +95,8 @@ cuda_py_test(
)
cuda_py_test(
name = "experimental_compile_test",
srcs = ["experimental_compile_test.py"],
name = "jit_compile_test",
srcs = ["jit_compile_test.py"],
python_version = "PY3",
tags = [
"no_mac",

View File

@ -27,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ExperimentalCompileTest(test.TestCase):
class JitCompileTest(test.TestCase):
def testBasic(self):
with ops.Graph().as_default() as g:
@ -35,7 +35,7 @@ class ExperimentalCompileTest(test.TestCase):
def fn(x, a):
return x + a
xla_func = def_function.function(fn, experimental_compile=True)
xla_func = def_function.function(fn, jit_compile=True)
inputs = array_ops.placeholder(dtypes.float32, [5])
# XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm():
@ -55,7 +55,7 @@ class ExperimentalCompileTest(test.TestCase):
return 2 * x + a
with ops.Graph().as_default() as g:
xla_func = def_function.function(fn, experimental_compile=True)
xla_func = def_function.function(fn, jit_compile=True)
with backprop.GradientTape() as tape:
inputs = array_ops.placeholder(dtypes.float32, [5])
tape.watch(inputs)
@ -79,7 +79,7 @@ class ExperimentalCompileTest(test.TestCase):
def fn(x, a):
return x + a
xla_func = def_function.function(fn, experimental_compile=True)
xla_func = def_function.function(fn, jit_compile=True)
inputs = array_ops.placeholder(dtypes.int32, [5])
# XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm():
@ -98,7 +98,7 @@ class ExperimentalCompileTest(test.TestCase):
def fn(x):
return array_ops.unique(x).y # Unique is not supported by XLA
xla_func = def_function.function(fn, experimental_compile=True)
xla_func = def_function.function(fn, jit_compile=True)
inputs = array_ops.placeholder(dtypes.float32, [5])
x = xla_func(inputs)
# XLA support is not yet enabled for TF ROCm

View File

@ -67,7 +67,7 @@ _UNSUPPORTED_OPS = set([
@tf_export('xla.experimental.compile')
@deprecated(
None, 'xla.experimental.compile is deprecated. Consider using '
'tf.function(experimental_compile=True)',
'tf.function(jit_compile=True)',
warn_once=True)
def compile(computation, inputs=None): # pylint: disable=redefined-builtin
"""Builds an operator that compiles and runs `computation` with XLA.

View File

@ -917,7 +917,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
x = constant_op.constant([[1, 0.], [0., 0.]])
if defunc:
reduce_func = def_function.function(
math_ops.reduce_logsumexp, experimental_compile=xla_compile)
math_ops.reduce_logsumexp, jit_compile=xla_compile)
func = lambda: reduce_func(x)
else:
func = lambda: math_ops.reduce_logsumexp(x)

View File

@ -465,10 +465,10 @@ class Function(object):
name,
input_signature=None,
autograph=True,
jit_compile=None,
experimental_implements=None,
experimental_autograph_options=None,
experimental_relax_shapes=False,
experimental_compile=None,
experimental_follow_type_hints=None):
"""Initializes a `Function`.
@ -477,10 +477,10 @@ class Function(object):
name: the name given to it.
input_signature: See the documentation for `tf.function`.
autograph: See the documentation for `tf.function`.
jit_compile: See the documentation for `tf.function`.
experimental_implements: See the documentation for `tf.function`.
experimental_autograph_options: See the documentation for `tf.function`.
experimental_relax_shapes: See the documentation for `tf.function`.
experimental_compile: See the documentation for `tf.function`.
experimental_follow_type_hints: See the documentation for `tf.function`.
Raises:
@ -492,7 +492,7 @@ class Function(object):
self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
python_function,
input_signature,
experimental_compile=experimental_compile,
jit_compile=jit_compile,
experimental_follow_type_hints=experimental_follow_type_hints,
)
self._implements = experimental_implements
@ -503,7 +503,7 @@ class Function(object):
self._autograph = autograph
self._experimental_autograph_options = experimental_autograph_options
self._experimental_relax_shapes = experimental_relax_shapes
self._experimental_compile = experimental_compile
self._jit_compile = jit_compile
if experimental_follow_type_hints is None:
experimental_follow_type_hints = False
self._experimental_follow_type_hints = experimental_follow_type_hints
@ -558,7 +558,7 @@ class Function(object):
"""Creates a defun wrapped inside a variable creator scope."""
weak_wrapped_fn = None
compile_with_xla = self._experimental_compile
compile_with_xla = self._jit_compile
def wrapped_fn(*args, **kwds):
"""Wraps `self._python_function` in a variable creator scope."""
@ -629,9 +629,9 @@ class Function(object):
if share is not None:
attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share
if self._experimental_compile is not None:
attributes.update(_XlaMustCompile=bool(self._experimental_compile))
if self._experimental_compile:
if self._jit_compile is not None:
attributes.update(_XlaMustCompile=bool(self._jit_compile))
if self._jit_compile:
attributes.update(_noinline=True)
if not attributes:
attributes = None
@ -640,8 +640,8 @@ class Function(object):
input_signature=self.input_signature,
attributes=attributes,
autograph=self._autograph,
jit_compile=self._jit_compile,
experimental_autograph_options=self._experimental_autograph_options,
experimental_compile=self._experimental_compile,
experimental_follow_type_hints=self._experimental_follow_type_hints,
experimental_relax_shapes=self._experimental_relax_shapes)
@ -698,10 +698,10 @@ class Function(object):
name=self._name,
input_signature=self._input_signature,
autograph=self._autograph,
jit_compile=self._jit_compile,
experimental_implements=self._implements,
experimental_autograph_options=self._experimental_autograph_options,
experimental_relax_shapes=self._experimental_relax_shapes,
experimental_compile=self._experimental_compile,
experimental_follow_type_hints=self._experimental_follow_type_hints)
if self._shared_rendezvous:
@ -782,7 +782,7 @@ class Function(object):
tracing_count = self.experimental_get_tracing_count()
with trace.Trace(self._name) as tm:
result = self._call(*args, **kwds)
compiler = "xla" if self._experimental_compile else "nonXla"
compiler = "xla" if self._jit_compile else "nonXla"
new_tracing_count = self.experimental_get_tracing_count()
without_tracing = (tracing_count == new_tracing_count)
execution_mode = "notTraced" if without_tracing else "traced"
@ -934,7 +934,7 @@ class Function(object):
For example, for
```python
@tf.function(experimental_compile=True)
@tf.function(jit_compile=True)
def f(x):
return x + 1
@ -962,14 +962,13 @@ class Function(object):
Raises:
ValueError: If an invalid `stage` is selected or if applied to a function
which is not compiled (`experimental_compile=True` is not set).
which is not compiled (`jit_compile=True` is not set).
TypeError: When called with input in graph mode.
"""
context.ensure_initialized()
if not self._experimental_compile:
raise ValueError(
"Compiler IR can only be returned for functions marked with "
"experimental_compile=True")
if not self._jit_compile:
raise ValueError("Compiler IR can only be returned for functions marked "
"with 'jit_compile=True'")
concrete_fn = self.get_concrete_function(*args, **kwargs)
fn_name = concrete_fn.name
@ -1285,9 +1284,13 @@ class Function(object):
@tf_export("function")
@deprecation.deprecated_args(None,
"experimental_compile is deprecated, use "
"jit_compile instead", "experimental_compile")
def function(func=None,
input_signature=None,
autograph=True,
jit_compile=None,
experimental_implements=None,
experimental_autograph_options=None,
experimental_relax_shapes=False,
@ -1497,6 +1500,20 @@ def function(func=None,
graph. Data-dependent control flow requires `autograph=True`. For more
information, see the [tf.function and AutoGraph guide](
https://www.tensorflow.org/guide/function).
jit_compile: If `True`, compiles the function using
[XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
such as fusion, and attempts to emit more efficient code. This may
drastically improve the performance. If set to `True`,
the whole function needs to be compilable by XLA, or an
`errors.InvalidArgumentError` is thrown.
If `None` (default), compiles the function with XLA when running on TPU
and goes through the regular function execution path when running on
other devices.
If `False`, executes the function without XLA compilation. Set this value
to `False` when directly running a multi-device function on TPUs (e.g. two
TPU cores, one TPU core and its host CPU).
Not all functions are compilable, see a list of
[sharp corners](https://tensorflow.org/xla/known_issues).
experimental_implements: If provided, contains a name of a "known" function
this implements. For example "mycompany.my_recurrent_cell".
This is stored as an attribute in inference function,
@ -1519,22 +1536,7 @@ def function(func=None,
`tf.autograph.experimental.Feature` values.
experimental_relax_shapes: When True, `tf.function` may generate fewer,
graphs that are less specialized on input shapes.
experimental_compile: If `True`, compiles the function using XLA
(see https://tensorflow.org/xla). XLA performs compiler optimizations,
such as fusion, and attempts to emit more efficient code. This may
drastically improve the performance. If set to `True`,
the whole function needs to be compilable by XLA, or an
`errors.InvalidArgumentError` is thrown.
If `None` (default), compiles the function with XLA when running on TPU
and goes through the regular function execution path when running on
other devices.
If `False`, executes the function in a regular way (graph rewrite
passes are applied, kernels are dispatched one-by-one by the TensorFlow
executor). Set this value to `False` when directly running a
multi-device function on TPUs (e.g. two TPU cores, one TPU core and its
host CPU).
Not all functions are compilable, see
https://tensorflow.org/xla/known_issues for a list of sharp corners.
experimental_compile: Deprecated alias to 'jit_compile'.
experimental_follow_type_hints: When True, the function may use type
annotations from `func` to optimize the tracing performance. For example,
arguments annotated with `tf.Tensor` will automatically be converted
@ -1547,8 +1549,8 @@ def function(func=None,
`func` argument, returns a callable equivalent to the case above.
Raises:
ValueError when attempting to use experimental_compile, but XLA support is
not enabled.
ValueError when attempting to use jit_compile=True, but XLA support is not
linked.
"""
# TODO(mdan): Link to `tf.types` section once published.
if input_signature is not None:
@ -1571,7 +1573,14 @@ def function(func=None,
autograph=autograph,
experimental_autograph_options=experimental_autograph_options,
experimental_relax_shapes=experimental_relax_shapes,
experimental_compile=experimental_compile,
# TODO(b/171825496): Update once `experimental_compile` is removed
# entirely in favor of 'jit_compile'.
jit_compile=deprecation.deprecated_argument_lookup(
"jit_compile",
jit_compile,
"experimental_compile",
experimental_compile),
experimental_implements=experimental_implements,
experimental_follow_type_hints=experimental_follow_type_hints))

View File

@ -683,7 +683,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
experimental_implements=implements,
experimental_autograph_options=autograph_options,
experimental_relax_shapes=relax_shapes,
experimental_compile=compile_)
jit_compile=compile_)
if override_function:
cloned_py_function = lambda x: x + 1
@ -699,7 +699,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertEqual(implements, cloned._implements)
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
self.assertEqual(compile_, cloned._experimental_compile)
self.assertEqual(compile_, cloned._jit_compile)
# This test does not run with XLA JIT support linked in so we can only check
# the output of the function if compile is disabled.

View File

@ -27,20 +27,20 @@ from tensorflow.python.platform import test
class DefFunctionCpuOnlyTest(test.TestCase, parameterized.TestCase):
"""Test that experimental_compile=True correctly throws an exception if XLA is not available.
"""Test that jit_compile=True correctly throws an exception if XLA is not available.
This test should only be run without `--config=cuda`, as that implicitly links
in XLA JIT.
"""
def testExperimentalCompileRaisesExceptionWhenXlaIsUnsupported(self):
def testJitCompileRaisesExceptionWhenXlaIsUnsupported(self):
if test.is_built_with_rocm() or test_util.is_xla_enabled():
return
with self.assertRaisesRegex(errors.UnimplementedError,
'check target linkage'):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def fn(x):
return x + x

View File

@ -45,11 +45,11 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=False)
@def_function.function(jit_compile=False)
def outer(a, b, c):
return a * inner(b, c) + c
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def inner(b, c):
return b + c * b
@ -71,8 +71,8 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x, a):
return x + a
func = def_function.function(fn, experimental_compile=False)
xla_func = def_function.function(fn, experimental_compile=True)
func = def_function.function(fn, jit_compile=False)
xla_func = def_function.function(fn, jit_compile=True)
inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertAllClose([2, 3, 3, 4, 4], func(inputs, 1))
@ -81,7 +81,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testBasicInt32(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def fn(x, a):
return x + a
@ -94,7 +94,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x, a):
return 2 * x + a
xla_func = def_function.function(fn, experimental_compile=True)
xla_func = def_function.function(fn, jit_compile=True)
with backprop.GradientTape() as tape:
inputs = constant_op.constant([1., 2., 2., 3., 3.])
@ -112,19 +112,19 @@ class DefFunctionTest(xla_test.XLATestCase):
self.assertTrue(backward.function_def.attr['_XlaMustCompile'])
self.assertTrue(forward.definition.attr['_XlaMustCompile'])
# Calling function with experimental_compile=True from
# experimental_compile=False should compile the inner func.
# Calling function with jit_compile=True from
# jit_compile=False should compile the inner func.
def testNestedCall(self):
if 'tpu' in self.device.lower():
self.skipTest('b/162800687: Inner function runs on host')
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def fn(x, a):
return x + a
@def_function.function(experimental_compile=False)
@def_function.function(jit_compile=False)
def fn2(x, a):
return fn(x, a)
@ -139,12 +139,12 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x):
return array_ops.unique(x).y
xla_func = def_function.function(fn, experimental_compile=True)
xla_func = def_function.function(fn, jit_compile=True)
def fn2(x):
return xla_func(x)
func = def_function.function(fn2, experimental_compile=False)
func = def_function.function(fn2, jit_compile=False)
inputs = constant_op.constant([1, 2, 2, 3, 3])
with self.assertRaisesRegex(errors.InvalidArgumentError,
'not compilable'):
@ -158,8 +158,8 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x):
return array_ops.unique(x).y # Unique is not supported by XLA
func = def_function.function(fn, experimental_compile=False)
xla_func = def_function.function(fn, experimental_compile=True)
func = def_function.function(fn, jit_compile=False)
xla_func = def_function.function(fn, jit_compile=True)
inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertAllClose([1, 2, 3], func(inputs))
@ -174,8 +174,8 @@ class DefFunctionTest(xla_test.XLATestCase):
def fn(x):
return v * x
func = def_function.function(fn, experimental_compile=False)
xla_func = def_function.function(fn, experimental_compile=True)
func = def_function.function(fn, jit_compile=False)
xla_func = def_function.function(fn, jit_compile=True)
def run_and_check(test_func):
x = constant_op.constant(3.0)
@ -195,7 +195,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(x):
assert control_flow_util.GraphOrParentsInXlaContext(
ops.get_default_graph())
@ -210,7 +210,7 @@ class DefFunctionTest(xla_test.XLATestCase):
body, (constant_op.constant(0), constant_op.constant(3.)),
maximum_iterations=10)[1]
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def g(x):
x = ops.convert_to_tensor(x)
with backprop.GradientTape() as tape:
@ -232,7 +232,7 @@ class DefFunctionTest(xla_test.XLATestCase):
class C(object):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f1(self, x, a):
return x + a
@ -248,7 +248,7 @@ class DefFunctionTest(xla_test.XLATestCase):
class C(object):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f1(self, x):
return array_ops.unique(x).y
@ -264,11 +264,11 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f():
return constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def g(a, b):
return array_ops.transpose(a, b)
@ -283,11 +283,11 @@ class DefFunctionTest(xla_test.XLATestCase):
def testArgMinMax(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def argmax(x):
return math_ops.argmax(x)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def argmin(x):
return math_ops.argmin(x)
@ -300,7 +300,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testErrorMessagePassingTensorArray(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(x):
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=1, element_shape=[])
@ -328,7 +328,7 @@ class DefFunctionTest(xla_test.XLATestCase):
ta = ta.write(1, 3 * x)
return ta.concat()
compiled_f = def_function.function(experimental_compile=True)(f)
compiled_f = def_function.function(jit_compile=True)(f)
inputs = constant_op.constant([3.14, 2.68, 7.69])
@ -348,7 +348,7 @@ class DefFunctionTest(xla_test.XLATestCase):
ta = ta.write(1, 3 * x)
return ta.concat()
compiled_f = def_function.function(experimental_compile=True)(f)
compiled_f = def_function.function(jit_compile=True)(f)
inputs = constant_op.constant([[3.14, 21.1], [2.68, 22.2], [7.69, 23.3]])
self.assertAllClose(f(inputs), compiled_f(inputs))
@ -365,7 +365,7 @@ class DefFunctionTest(xla_test.XLATestCase):
ta = ta.write(1, 3 * x)
return ta.concat()
compiled_f = def_function.function(experimental_compile=True)(f)
compiled_f = def_function.function(jit_compile=True)(f)
inputs = constant_op.constant([3.14])
self.assertAllClose(f(inputs), compiled_f(inputs))
@ -388,7 +388,7 @@ class DefFunctionTest(xla_test.XLATestCase):
y = f(x)
return tape.gradient(y, x)
compiled_g = def_function.function(experimental_compile=True)(g)
compiled_g = def_function.function(jit_compile=True)(g)
self.assertAllClose([5.0, 5.0, 5.0], g())
self.assertAllClose(compiled_g(), g())
@ -398,7 +398,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testTensorListConcatGradNestedCompile(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(x):
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, size=2, element_shape=[3])
@ -406,7 +406,7 @@ class DefFunctionTest(xla_test.XLATestCase):
ta = ta.write(1, 3 * x)
return ta.concat()
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def g():
x = constant_op.constant([3.14, 2.68, 7.69])
with backprop.GradientTape() as tape:
@ -423,7 +423,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(x):
return math_ops.cumsum(x)
@ -434,7 +434,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)):
inner_retracings = 0
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def inner(a, b):
nonlocal inner_retracings
inner_retracings += 1
@ -455,7 +455,7 @@ class DefFunctionTest(xla_test.XLATestCase):
on_gpu = 'gpu' in self.device.lower()
v = variables.Variable([3.1, 3.2])
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def update_var(a, b):
v.assign_add(a * b)
@ -476,7 +476,7 @@ class DefFunctionTest(xla_test.XLATestCase):
class C(object):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def update_var(self, a, b):
if not hasattr(self, 'v'):
self.v = variables.Variable(3.1)
@ -497,7 +497,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)):
v = variables.Variable(3.1)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def update_var(a, b):
v.assign_add(a * b)
return a * b + v
@ -509,7 +509,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testReturnIdentity(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(a, b):
return (a, b)
@ -532,7 +532,7 @@ class DefFunctionTest(xla_test.XLATestCase):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(a, b):
return array_ops.transpose(a, b)
@ -549,7 +549,7 @@ class DefFunctionTest(xla_test.XLATestCase):
v = variables.Variable([3.1, 3.2])
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(a, b):
v.assign_add(a * b)
@ -568,17 +568,17 @@ class DefFunctionTest(xla_test.XLATestCase):
a = random_ops.random_normal([10, 10])
with self.assertRaisesRegex(ValueError,
'marked with experimental_compile'):
'marked with \'jit_compile'):
f.experimental_get_compiler_ir(a)()
def testGetCompilerIrNested(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def fn(x, a):
return x + a
@def_function.function(experimental_compile=False)
@def_function.function(jit_compile=False)
def fn2(x, a):
fn.experimental_get_compiler_ir(x, a)()
return fn(x, a)
@ -592,7 +592,7 @@ class DefFunctionTest(xla_test.XLATestCase):
v = variables.Variable([0.1, 0.1])
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(a, b):
return (a + b) * v
@ -605,7 +605,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testGetCompilerIrDot(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(a, b):
return a + b
@ -620,7 +620,7 @@ class DefFunctionTest(xla_test.XLATestCase):
if 'gpu' not in self.device.lower():
self.skipTest('Testing get_compiler_ir on GPUs without placement')
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(a, b):
return a + b
@ -634,7 +634,7 @@ class DefFunctionTest(xla_test.XLATestCase):
def testGetCompilerIrNonTensors(self):
with ops.device('device:{}:0'.format(self.device)):
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(l):
return l[0] + l[1]
@ -649,7 +649,7 @@ class DefFunctionTest(xla_test.XLATestCase):
s = random_ops.random_uniform([2], 1, 10, dtypes.int32)
l = random_ops.random_normal([s[0] * s[1]])
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(l):
return array_ops.reshape(l, s)

View File

@ -2349,7 +2349,7 @@ class FunctionSpec(object):
input_signature,
is_pure=False,
experimental_follow_type_hints=False,
experimental_compile=None):
jit_compile=None):
"""Create a FunctionSpec instance given a python function and signature.
Args:
@ -2358,7 +2358,7 @@ class FunctionSpec(object):
is_pure: if True all input arguments (including variables and constants)
will be converted to tensors and no variable changes allowed.
experimental_follow_type_hints: see `tf.function`
experimental_compile: see `tf.function`
jit_compile: see `tf.function`
Returns:
instance of FunctionSpec
@ -2440,7 +2440,7 @@ class FunctionSpec(object):
is_method,
input_signature,
is_pure=is_pure,
experimental_compile=experimental_compile,
jit_compile=jit_compile,
experimental_follow_type_hints=experimental_follow_type_hints,
name=name)
@ -2451,7 +2451,7 @@ class FunctionSpec(object):
is_pure=False,
experimental_follow_type_hints=False,
name=None,
experimental_compile=None):
jit_compile=None):
"""Constructs a FunctionSpec describing a python function.
Args:
@ -2462,12 +2462,12 @@ class FunctionSpec(object):
will be converted to tensors and no variable changes allowed.
experimental_follow_type_hints: see `tf.function`.
name: Name of the function
experimental_compile: see `tf.function`.
jit_compile: see `tf.function`.
"""
self._fullargspec = fullargspec
self._is_method = is_method
self._is_pure = is_pure
self._experimental_compile = experimental_compile
self._jit_compile = jit_compile
self._experimental_follow_type_hints = experimental_follow_type_hints
# TODO(edloper): Include name when serializing for SavedModel?
@ -2538,8 +2538,8 @@ class FunctionSpec(object):
return self._is_pure
@property
def experimental_compile(self):
return self._experimental_compile
def jit_compile(self):
return self._jit_compile
@property
def arg_names(self):
@ -2904,7 +2904,7 @@ class Function(object):
autograph_options=None,
experimental_relax_shapes=False,
capture_by_value=None,
experimental_compile=None,
jit_compile=None,
experimental_follow_type_hints=False):
"""Initializes a `Function`.
@ -2927,8 +2927,8 @@ class Function(object):
capture_by_value: Experimental. Whether to capture resource variables by
value or reference. If None, will inherit from a parent context or
default to False.
experimental_compile: Force-compile the function with XLA, cf.
def_function.Function doc on experimental_compile.
jit_compile: Force-compile the function with XLA, cf.
def_function.Function doc on jit_compile.
experimental_follow_type_hints: See the documentation for `tf.function`.
Raises:
@ -2959,7 +2959,7 @@ class Function(object):
# `Function`, used to make sure defun-decorated methods create different
# functions for each instance.
self._descriptor_cache = weakref.WeakKeyDictionary()
self._experimental_compile = experimental_compile
self._jit_compile = jit_compile
self._experimental_follow_type_hints = experimental_follow_type_hints
def __call__(self, *args, **kwargs):
@ -3774,7 +3774,7 @@ def defun_with_attributes(func=None,
attributes=None,
autograph=True,
experimental_autograph_options=None,
experimental_compile=None,
jit_compile=None,
experimental_relax_shapes=False,
experimental_follow_type_hints=False):
"""Compiles a Python function into a callable TensorFlow graph.
@ -3796,7 +3796,7 @@ def defun_with_attributes(func=None,
autograph: same as defun()'s autograph.
experimental_autograph_options: same as defun()'s
experimental_autograph_options.
experimental_compile: same as defun()'s experimental_compile.
jit_compile: same as defun()'s jit_compile.
experimental_relax_shapes: same as defun()'s experimental_relax_shapes
experimental_follow_type_hints: see `tf.function`.
@ -3825,7 +3825,7 @@ def defun_with_attributes(func=None,
attributes=attributes,
autograph=autograph,
autograph_options=experimental_autograph_options,
experimental_compile=experimental_compile,
jit_compile=jit_compile,
experimental_relax_shapes=experimental_relax_shapes,
experimental_follow_type_hints=experimental_follow_type_hints))
@ -3925,7 +3925,7 @@ def class_method_to_instance_method(original_function, instance):
autograph=original_function._autograph,
input_signature=original_function.input_signature,
experimental_relax_shapes=original_function._experimental_relax_shapes,
experimental_compile=original_function._experimental_compile)
jit_compile=original_function._jit_compile)
# pylint: enable=protected-access
# And we wrap the function with tf_decorator so inspection works correctly

View File

@ -92,7 +92,7 @@ class KerasLayerBenchmarks(six.with_metaclass(
layer = layer_cls(**layer_args)
x = tf.ones(input_shape)
layer.call = tf.function(
layer.call, experimental_compile=True)
layer.call, jit_compile=True)
fn = functools.partial(layer, x)
name = _get_benchmark_name(self._get_name())
@ -156,7 +156,7 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass(
layer = layer_cls(**layer_args)
x = tf.ones(input_shape)
layer.call = tf.function(
layer.call, experimental_compile=True)
layer.call, jit_compile=True)
fn = functools.partial(_layer_call_backward, layer, x)
name = _get_benchmark_name(self._get_name())

View File

@ -419,7 +419,7 @@ class KerasModelsXLATest(test.TestCase, parameterized.TestCase):
@ds_combinations.generate(
combinations.combine(
distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
def test_tf_function_experimental_compile(self, distribution):
def test_tf_function_jit_compile(self, distribution):
dataset = _get_dataset()
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
@ -433,7 +433,7 @@ class KerasModelsXLATest(test.TestCase, parameterized.TestCase):
self.kernel = self.add_variable(
"kernel", shape=[int(input_shape[-1]), self.num_outputs])
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def call(self, inputs):
return math_ops.matmul(inputs, self.kernel)

View File

@ -416,7 +416,7 @@ class NonFusedAdam(optimizer_v2.OptimizerV2):
weights = weights[:len(params)]
super(NonFusedAdam, self).set_weights(weights)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
@ -437,7 +437,7 @@ class NonFusedAdam(optimizer_v2.OptimizerV2):
var.assign_sub(
(m * alpha) / (math_ops.sqrt(v) - coefficients['epsilon']))
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or

View File

@ -56,7 +56,7 @@ class CollectiveOpXlaTest(test.TestCase):
tensor_val = [i + 1.] * tensor_size
constant = constant_op.constant(tensor_val)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f(x):
return 2 * x + 1

View File

@ -1288,7 +1288,7 @@ class ExecuteFnForDeviceTest(test_util.TensorFlowTestCase):
def gpu_fn(x):
return x * x
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def flexible_defun(a):
branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)}
return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a))

View File

@ -76,12 +76,12 @@ class PForTest(PForTestCase):
vectorized_compute, inputs=[array_ops.ones((10, 5, 3))])
self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
def test_function_experimental_compile(self):
def test_function_jit_compile(self):
def compute(x):
return math_ops.reduce_mean(x, axis=0, keepdims=True)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def vectorized_compute(x):
return pfor_control_flow_ops.vectorized_map(compute, x)
@ -112,7 +112,7 @@ class PForTest(PForTestCase):
def test_reduce_mean(self):
x = random_ops.random_uniform([8, 3])
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def f():
def loop_fn(i, pfor_config):
@ -172,7 +172,7 @@ class WhileV2Test(PForTestCase):
# TODO(agarwal): The following may complain about uncompilable nodes. Hence
# these are currently not enabled for all tests.
if force_xla:
out_exp_compile_f = def_function.function(experimental_compile=True)(f)()
out_exp_compile_f = def_function.function(jit_compile=True)(f)()
self.run_and_assert_equal(out, out_exp_compile_f)
out_xla_compile_f = xla.compile(f, inputs=[])
self.run_and_assert_equal(out, out_xla_compile_f)

View File

@ -592,7 +592,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
random.set_global_generator(None)
@def_function.function(experimental_compile=True)
@def_function.function(jit_compile=True)
def make_seed():
generator = random.get_global_generator()
state = array_ops.identity(generator.state, name="state")

View File

@ -143,17 +143,17 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
annotations=typeless_fullargspec.annotations)
input_signature = coder.decode_proto(function_spec_proto.input_signature)
# See `tf.function` and the ExperimentalCompile proto for details.
experimental_compile = {
saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.DEFAULT: None,
saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.ON: True,
saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.OFF: False,
}.get(function_spec_proto.experimental_compile)
# See `tf.function` and the JitCompile proto for details.
jit_compile = {
saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None,
saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True,
saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False,
}.get(function_spec_proto.jit_compile)
return function_lib.FunctionSpec(fullargspec=fullargspec,
is_method=False,
input_signature=input_signature,
experimental_compile=experimental_compile)
jit_compile=jit_compile)
# TODO(allenl): The fact that we can't derive ConcreteFunction calling
@ -191,7 +191,7 @@ class RestoredFunction(def_function.Function):
# TODO(mdan): We may enable autograph once exceptions are supported.
super(RestoredFunction, self).__init__(
python_function, name, autograph=False,
experimental_compile=function_spec.experimental_compile)
jit_compile=function_spec.jit_compile)
self.concrete_functions = concrete_functions
self._function_spec = function_spec

View File

@ -47,12 +47,12 @@ def _serialize_function_spec(function_spec, coder):
proto.input_signature.CopyFrom(
coder.encode_structure(function_spec.input_signature))
# See `tf.function` and the ExperimentalCompile proto for details.
proto.experimental_compile = {
None: saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.DEFAULT,
True: saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.ON,
False: saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.OFF,
}.get(function_spec.experimental_compile)
# See `tf.function` and the JitCompile proto for details.
proto.jit_compile = {
None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT,
True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON,
False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF,
}.get(function_spec.jit_compile)
return proto

View File

@ -931,13 +931,13 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual([2, 4, 6],
imported.f(constant_op.constant([1, 2, 3])).numpy())
def test_experimental_compile(self, cycles):
def test_jit_compile(self, cycles):
# It'd be nice to use parameterize here, but the library does not support
# having parameterized test methods inside already-parameterized classes.
for experimental_compile in (None, True, False):
for jit_compile in (None, True, False):
@def_function.function(experimental_compile=experimental_compile)
@def_function.function(jit_compile=jit_compile)
def f(x):
return x + 1.
@ -948,7 +948,7 @@ class LoadTest(test.TestCase, parameterized.TestCase):
imported = cycle(root, cycles)
self.assertEqual(imported.f._experimental_compile, experimental_compile)
self.assertEqual(imported.f._jit_compile, jit_compile)
def test_get_concrete_function(self, cycles):

View File

@ -292,7 +292,7 @@ class MultiDeviceSaver(object):
# latest values of options like experimental_io_device.
if context.executing_eagerly() and len(self._single_device_savers) > 1:
# Explicitly place the identity op on the first device.
@def_function.function(experimental_compile=False)
@def_function.function(jit_compile=False)
def tf_function_save():
save_fn()
tf_function_save()
@ -328,7 +328,7 @@ class MultiDeviceSaver(object):
# latest values of options like experimental_io_device.
if context.executing_eagerly() and len(self._single_device_savers) > 1:
first_device, _ = list(self._single_device_savers.items())[0]
@def_function.function(experimental_compile=False)
@def_function.function(jit_compile=False)
def tf_function_restore():
restore_ops = restore_fn()
restore_tensors = {}

View File

@ -1342,7 +1342,7 @@ tf_module {
}
member_method {
name: "function"
argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'jit_compile\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
}
member_method {
name: "gather"

View File

@ -674,7 +674,7 @@ tf_module {
}
member_method {
name: "function"
argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'jit_compile\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], "
}
member_method {
name: "gather"