From 9df9d06e27882504b75e16e4c444ee22405340b3 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Tue, 3 Nov 2020 12:41:28 -0800 Subject: [PATCH] Rationale: - Just `compile` is a Python built-in, and is already overloaded by the Keras `compile` method. - `jit` is closer to other ML frameworks. - Technically speaking, `jit=True` is not self-descriptive (it does not specify what is it doing just-in-time) - Moreover, `tf.function` by itself without compilation could be described as a JIT. - Also `jit` by itself is less grep'able. - Thus `@tf.function(jit_compile=True)` is the preferred spelling. PiperOrigin-RevId: 340503501 Change-Id: I7bffe60aca69be6640390f6e6c4af40c6c4dbfda --- RELEASE.md | 3 + tensorflow/compiler/jit/xla_device.cc | 2 +- tensorflow/compiler/tests/BUILD | 2 +- tensorflow/compiler/tests/case_test.py | 4 +- tensorflow/compiler/tests/eager_test.py | 2 +- .../compiler/tests/special_math_test.py | 10 +- .../tests/stateless_random_ops_test.py | 2 +- tensorflow/compiler/xla/g3doc/index.md | 6 +- .../xla/g3doc/tutorials/compile.ipynb | 4 +- .../core/protobuf/saved_object_graph.proto | 4 +- tensorflow/python/compiler/xla/BUILD | 4 +- ...al_compile_test.py => jit_compile_test.py} | 10 +- tensorflow/python/compiler/xla/xla.py | 2 +- tensorflow/python/eager/benchmarks_test.py | 2 +- tensorflow/python/eager/def_function.py | 81 +++++++++------- tensorflow/python/eager/def_function_test.py | 4 +- .../eager/def_function_test_cpu_only.py | 6 +- .../python/eager/def_function_xla_jit_test.py | 94 +++++++++---------- tensorflow/python/eager/function.py | 32 +++---- .../layer_benchmarks/layer_benchmarks_test.py | 4 +- .../custom_training_loop_models_test.py | 4 +- tensorflow/python/keras/optimizer_v2/adam.py | 4 +- .../python/ops/collective_ops_xla_test.py | 2 +- .../python/ops/control_flow_ops_test.py | 2 +- .../parallel_for/xla_control_flow_ops_test.py | 8 +- .../python/ops/stateful_random_ops_test.py | 2 +- .../saved_model/function_deserialization.py | 16 ++-- .../saved_model/function_serialization.py | 12 +-- tensorflow/python/saved_model/load_test.py | 8 +- .../training/saving/functional_saver.py | 4 +- .../tools/api/golden/v1/tensorflow.pbtxt | 2 +- .../tools/api/golden/v2/tensorflow.pbtxt | 2 +- 32 files changed, 178 insertions(+), 166 deletions(-) rename tensorflow/python/compiler/xla/{experimental_compile_test.py => jit_compile_test.py} (92%) diff --git a/RELEASE.md b/RELEASE.md index 962cc87ae28..de634638a3a 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -36,6 +36,9 @@ * Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used to control how external state should be handled during dataset serialization or iterator checkpointing. +* XLA compilation: + * `tf.function(experimental_compile=True)` has become a stable API, + renamed `tf.function(jit_compile=True)`. * `tf.lite`: * NNAPI diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 089d22dca03..f0e236de511 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -398,7 +398,7 @@ static void ShowXlaDeviceDeprecationWarning( absl::call_once(once, [] { LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be " "removed in subsequent releases. Instead, use either " - "@tf.function(experimental_compile=True) for must-compile " + "@tf.function(jit_compile=True) for must-compile " "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 " "for auto-clustering best-effort compilation."; }); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4eedf8c8f72..93a7c7c318a 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1481,7 +1481,7 @@ tf_xla_py_test( tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], - use_xla_device = False, # Uses tf.function(experimental_compile=True) + use_xla_device = False, # Uses tf.function(jit_compile=True) deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", diff --git a/tensorflow/compiler/tests/case_test.py b/tensorflow/compiler/tests/case_test.py index 4da9c4fac7a..1be0f08e236 100644 --- a/tensorflow/compiler/tests/case_test.py +++ b/tensorflow/compiler/tests/case_test.py @@ -32,7 +32,7 @@ class CaseTest(xla_test.XLATestCase): def testCaseBasic(self): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def switch_case_test(branch_index): def f1(): @@ -58,7 +58,7 @@ class CaseTest(xla_test.XLATestCase): def testBranchIsPruned(self): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def switch_case_test(): branch_index = array_ops.constant(0) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index eef9d24766d..1a61e58dffc 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -693,7 +693,7 @@ class EagerFunctionTest(xla_test.XLATestCase): return x, y wholly_compiled_f = def_function.function(f) - op_by_op_f = def_function.function(f, experimental_compile=False) + op_by_op_f = def_function.function(f, jit_compile=False) x = array_ops.identity([0.0, 2.0], name='data') diff --git a/tensorflow/compiler/tests/special_math_test.py b/tensorflow/compiler/tests/special_math_test.py index 5e7f8763743..c9ec628e3fa 100644 --- a/tensorflow/compiler/tests/special_math_test.py +++ b/tensorflow/compiler/tests/special_math_test.py @@ -45,22 +45,22 @@ flags.DEFINE_bool('vary_seed', False, NUM_SAMPLES = int(1e3) -@def_function.function(experimental_compile=True) +@def_function.function(jit_compile=True) def _igamma(a, x): return math_ops.igamma(a, x) -@def_function.function(experimental_compile=True) +@def_function.function(jit_compile=True) def _igammac(a, x): return math_ops.igammac(a, x) -@def_function.function(experimental_compile=True) +@def_function.function(jit_compile=True) def _polygamma(n, x): return math_ops.polygamma(n, x) -@def_function.function(experimental_compile=True) +@def_function.function(jit_compile=True) def _zeta(a, q): return math_ops.zeta(a, q) @@ -72,7 +72,7 @@ def implicit_reparameterization_grad(a, x): return -gen_math_ops.igamma_grad_a(a, x) / prob -@def_function.function(experimental_compile=True) +@def_function.function(jit_compile=True) def _log1p(x): return math_ops.log1p(x) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 23e827f18e8..c8f75e87dc5 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -51,7 +51,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): scenarios (e.g. TPU). The new version of stateless_random_* requires the intermediate tensor `alg` to be compile-time constant, so we need to check that this requirement is met. We use xla.compile instead of tf.function's - experimental_compile because the latter doesn't throw an error even if the + jit_compile because the latter doesn't throw an error even if the compile-time-constant constraint is not met. """ if config.list_logical_devices('TPU'): diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/index.md index 45abd9b4c92..d749e8baadd 100644 --- a/tensorflow/compiler/xla/g3doc/index.md +++ b/tensorflow/compiler/xla/g3doc/index.md @@ -78,7 +78,7 @@ For example, the following TensorFlow function which performs the MNIST training is compiled with XLA: ``` -@tf.function(experimental_compile=True) +@tf.function(jit_compile=True) def train_mnist(images, labels): images, labels = cast(images, labels) @@ -92,7 +92,7 @@ def train_mnist(images, labels): optimizer.apply_gradients(zip(grads, layer_variables)) ``` -The `experimental_compile` API has _must-compile_ semantics: either the entire +The `jit_compile` API has _must-compile_ semantics: either the entire function is compiled with XLA, or an `errors.InvalidArgumentError` exception is thrown. XLA can not currently compile functions where dimensions are not _inferrable_: that is, if it's not possible to infer the dimensions of all @@ -108,7 +108,7 @@ def not_compilable(x): Shapes can vary across the runs though: ``` -@tf.function(experimental_compile=True) +@tf.function(jit_compile=True) def recompiled_on_launch(a, b): return a + b diff --git a/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb index 59523a549d8..aec3ee7e658 100644 --- a/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb +++ b/tensorflow/compiler/xla/g3doc/tutorials/compile.ipynb @@ -164,7 +164,7 @@ "source": [ "# Define the training function\n", "\n", - "In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside `tf.function` with `experimental_compile=True`." + "In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside `tf.function` with `jit_compile=True`." ] }, { @@ -177,7 +177,7 @@ }, "outputs": [], "source": [ - "@tf.function(experimental_compile=True)\n", + "@tf.function(jit_compile=True)\n", "def train_mnist(images, labels):\n", " images, labels = cast(images, labels)\n", "\n", diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto index 8df58683ead..643e32b15b5 100644 --- a/tensorflow/core/protobuf/saved_object_graph.proto +++ b/tensorflow/core/protobuf/saved_object_graph.proto @@ -179,12 +179,12 @@ message FunctionSpec { // field, so we instead map to an enum. // // See `tf.function` for details. - enum ExperimentalCompile { + enum JitCompile { DEFAULT = 0; ON = 1; OFF = 2; } - ExperimentalCompile experimental_compile = 6; + JitCompile jit_compile = 6; reserved 3, 4; } diff --git a/tensorflow/python/compiler/xla/BUILD b/tensorflow/python/compiler/xla/BUILD index 79c18571f9a..a5014064cd8 100644 --- a/tensorflow/python/compiler/xla/BUILD +++ b/tensorflow/python/compiler/xla/BUILD @@ -95,8 +95,8 @@ cuda_py_test( ) cuda_py_test( - name = "experimental_compile_test", - srcs = ["experimental_compile_test.py"], + name = "jit_compile_test", + srcs = ["jit_compile_test.py"], python_version = "PY3", tags = [ "no_mac", diff --git a/tensorflow/python/compiler/xla/experimental_compile_test.py b/tensorflow/python/compiler/xla/jit_compile_test.py similarity index 92% rename from tensorflow/python/compiler/xla/experimental_compile_test.py rename to tensorflow/python/compiler/xla/jit_compile_test.py index 963a92d4384..7b715736489 100644 --- a/tensorflow/python/compiler/xla/experimental_compile_test.py +++ b/tensorflow/python/compiler/xla/jit_compile_test.py @@ -27,7 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ExperimentalCompileTest(test.TestCase): +class JitCompileTest(test.TestCase): def testBasic(self): with ops.Graph().as_default() as g: @@ -35,7 +35,7 @@ class ExperimentalCompileTest(test.TestCase): def fn(x, a): return x + a - xla_func = def_function.function(fn, experimental_compile=True) + xla_func = def_function.function(fn, jit_compile=True) inputs = array_ops.placeholder(dtypes.float32, [5]) # XLA support is not yet enabled for TF ROCm if not test.is_built_with_rocm(): @@ -55,7 +55,7 @@ class ExperimentalCompileTest(test.TestCase): return 2 * x + a with ops.Graph().as_default() as g: - xla_func = def_function.function(fn, experimental_compile=True) + xla_func = def_function.function(fn, jit_compile=True) with backprop.GradientTape() as tape: inputs = array_ops.placeholder(dtypes.float32, [5]) tape.watch(inputs) @@ -79,7 +79,7 @@ class ExperimentalCompileTest(test.TestCase): def fn(x, a): return x + a - xla_func = def_function.function(fn, experimental_compile=True) + xla_func = def_function.function(fn, jit_compile=True) inputs = array_ops.placeholder(dtypes.int32, [5]) # XLA support is not yet enabled for TF ROCm if not test.is_built_with_rocm(): @@ -98,7 +98,7 @@ class ExperimentalCompileTest(test.TestCase): def fn(x): return array_ops.unique(x).y # Unique is not supported by XLA - xla_func = def_function.function(fn, experimental_compile=True) + xla_func = def_function.function(fn, jit_compile=True) inputs = array_ops.placeholder(dtypes.float32, [5]) x = xla_func(inputs) # XLA support is not yet enabled for TF ROCm diff --git a/tensorflow/python/compiler/xla/xla.py b/tensorflow/python/compiler/xla/xla.py index 59b70f2a217..5aa4bfb10c4 100644 --- a/tensorflow/python/compiler/xla/xla.py +++ b/tensorflow/python/compiler/xla/xla.py @@ -67,7 +67,7 @@ _UNSUPPORTED_OPS = set([ @tf_export('xla.experimental.compile') @deprecated( None, 'xla.experimental.compile is deprecated. Consider using ' - 'tf.function(experimental_compile=True)', + 'tf.function(jit_compile=True)', warn_once=True) def compile(computation, inputs=None): # pylint: disable=redefined-builtin """Builds an operator that compiles and runs `computation` with XLA. diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index c567fcb762c..3287a1548ac 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -917,7 +917,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): x = constant_op.constant([[1, 0.], [0., 0.]]) if defunc: reduce_func = def_function.function( - math_ops.reduce_logsumexp, experimental_compile=xla_compile) + math_ops.reduce_logsumexp, jit_compile=xla_compile) func = lambda: reduce_func(x) else: func = lambda: math_ops.reduce_logsumexp(x) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 907c257d605..98913c53fa3 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -465,10 +465,10 @@ class Function(object): name, input_signature=None, autograph=True, + jit_compile=None, experimental_implements=None, experimental_autograph_options=None, experimental_relax_shapes=False, - experimental_compile=None, experimental_follow_type_hints=None): """Initializes a `Function`. @@ -477,10 +477,10 @@ class Function(object): name: the name given to it. input_signature: See the documentation for `tf.function`. autograph: See the documentation for `tf.function`. + jit_compile: See the documentation for `tf.function`. experimental_implements: See the documentation for `tf.function`. experimental_autograph_options: See the documentation for `tf.function`. experimental_relax_shapes: See the documentation for `tf.function`. - experimental_compile: See the documentation for `tf.function`. experimental_follow_type_hints: See the documentation for `tf.function`. Raises: @@ -492,7 +492,7 @@ class Function(object): self._function_spec = function_lib.FunctionSpec.from_function_and_signature( python_function, input_signature, - experimental_compile=experimental_compile, + jit_compile=jit_compile, experimental_follow_type_hints=experimental_follow_type_hints, ) self._implements = experimental_implements @@ -503,7 +503,7 @@ class Function(object): self._autograph = autograph self._experimental_autograph_options = experimental_autograph_options self._experimental_relax_shapes = experimental_relax_shapes - self._experimental_compile = experimental_compile + self._jit_compile = jit_compile if experimental_follow_type_hints is None: experimental_follow_type_hints = False self._experimental_follow_type_hints = experimental_follow_type_hints @@ -558,7 +558,7 @@ class Function(object): """Creates a defun wrapped inside a variable creator scope.""" weak_wrapped_fn = None - compile_with_xla = self._experimental_compile + compile_with_xla = self._jit_compile def wrapped_fn(*args, **kwds): """Wraps `self._python_function` in a variable creator scope.""" @@ -629,9 +629,9 @@ class Function(object): if share is not None: attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share - if self._experimental_compile is not None: - attributes.update(_XlaMustCompile=bool(self._experimental_compile)) - if self._experimental_compile: + if self._jit_compile is not None: + attributes.update(_XlaMustCompile=bool(self._jit_compile)) + if self._jit_compile: attributes.update(_noinline=True) if not attributes: attributes = None @@ -640,8 +640,8 @@ class Function(object): input_signature=self.input_signature, attributes=attributes, autograph=self._autograph, + jit_compile=self._jit_compile, experimental_autograph_options=self._experimental_autograph_options, - experimental_compile=self._experimental_compile, experimental_follow_type_hints=self._experimental_follow_type_hints, experimental_relax_shapes=self._experimental_relax_shapes) @@ -698,10 +698,10 @@ class Function(object): name=self._name, input_signature=self._input_signature, autograph=self._autograph, + jit_compile=self._jit_compile, experimental_implements=self._implements, experimental_autograph_options=self._experimental_autograph_options, experimental_relax_shapes=self._experimental_relax_shapes, - experimental_compile=self._experimental_compile, experimental_follow_type_hints=self._experimental_follow_type_hints) if self._shared_rendezvous: @@ -782,7 +782,7 @@ class Function(object): tracing_count = self.experimental_get_tracing_count() with trace.Trace(self._name) as tm: result = self._call(*args, **kwds) - compiler = "xla" if self._experimental_compile else "nonXla" + compiler = "xla" if self._jit_compile else "nonXla" new_tracing_count = self.experimental_get_tracing_count() without_tracing = (tracing_count == new_tracing_count) execution_mode = "notTraced" if without_tracing else "traced" @@ -934,7 +934,7 @@ class Function(object): For example, for ```python - @tf.function(experimental_compile=True) + @tf.function(jit_compile=True) def f(x): return x + 1 @@ -962,14 +962,13 @@ class Function(object): Raises: ValueError: If an invalid `stage` is selected or if applied to a function - which is not compiled (`experimental_compile=True` is not set). + which is not compiled (`jit_compile=True` is not set). TypeError: When called with input in graph mode. """ context.ensure_initialized() - if not self._experimental_compile: - raise ValueError( - "Compiler IR can only be returned for functions marked with " - "experimental_compile=True") + if not self._jit_compile: + raise ValueError("Compiler IR can only be returned for functions marked " + "with 'jit_compile=True'") concrete_fn = self.get_concrete_function(*args, **kwargs) fn_name = concrete_fn.name @@ -1285,9 +1284,13 @@ class Function(object): @tf_export("function") +@deprecation.deprecated_args(None, + "experimental_compile is deprecated, use " + "jit_compile instead", "experimental_compile") def function(func=None, input_signature=None, autograph=True, + jit_compile=None, experimental_implements=None, experimental_autograph_options=None, experimental_relax_shapes=False, @@ -1497,6 +1500,20 @@ def function(func=None, graph. Data-dependent control flow requires `autograph=True`. For more information, see the [tf.function and AutoGraph guide]( https://www.tensorflow.org/guide/function). + jit_compile: If `True`, compiles the function using + [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations, + such as fusion, and attempts to emit more efficient code. This may + drastically improve the performance. If set to `True`, + the whole function needs to be compilable by XLA, or an + `errors.InvalidArgumentError` is thrown. + If `None` (default), compiles the function with XLA when running on TPU + and goes through the regular function execution path when running on + other devices. + If `False`, executes the function without XLA compilation. Set this value + to `False` when directly running a multi-device function on TPUs (e.g. two + TPU cores, one TPU core and its host CPU). + Not all functions are compilable, see a list of + [sharp corners](https://tensorflow.org/xla/known_issues). experimental_implements: If provided, contains a name of a "known" function this implements. For example "mycompany.my_recurrent_cell". This is stored as an attribute in inference function, @@ -1519,22 +1536,7 @@ def function(func=None, `tf.autograph.experimental.Feature` values. experimental_relax_shapes: When True, `tf.function` may generate fewer, graphs that are less specialized on input shapes. - experimental_compile: If `True`, compiles the function using XLA - (see https://tensorflow.org/xla). XLA performs compiler optimizations, - such as fusion, and attempts to emit more efficient code. This may - drastically improve the performance. If set to `True`, - the whole function needs to be compilable by XLA, or an - `errors.InvalidArgumentError` is thrown. - If `None` (default), compiles the function with XLA when running on TPU - and goes through the regular function execution path when running on - other devices. - If `False`, executes the function in a regular way (graph rewrite - passes are applied, kernels are dispatched one-by-one by the TensorFlow - executor). Set this value to `False` when directly running a - multi-device function on TPUs (e.g. two TPU cores, one TPU core and its - host CPU). - Not all functions are compilable, see - https://tensorflow.org/xla/known_issues for a list of sharp corners. + experimental_compile: Deprecated alias to 'jit_compile'. experimental_follow_type_hints: When True, the function may use type annotations from `func` to optimize the tracing performance. For example, arguments annotated with `tf.Tensor` will automatically be converted @@ -1547,8 +1549,8 @@ def function(func=None, `func` argument, returns a callable equivalent to the case above. Raises: - ValueError when attempting to use experimental_compile, but XLA support is - not enabled. + ValueError when attempting to use jit_compile=True, but XLA support is not + linked. """ # TODO(mdan): Link to `tf.types` section once published. if input_signature is not None: @@ -1571,7 +1573,14 @@ def function(func=None, autograph=autograph, experimental_autograph_options=experimental_autograph_options, experimental_relax_shapes=experimental_relax_shapes, - experimental_compile=experimental_compile, + + # TODO(b/171825496): Update once `experimental_compile` is removed + # entirely in favor of 'jit_compile'. + jit_compile=deprecation.deprecated_argument_lookup( + "jit_compile", + jit_compile, + "experimental_compile", + experimental_compile), experimental_implements=experimental_implements, experimental_follow_type_hints=experimental_follow_type_hints)) diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 42af94c6cb1..a30885392b4 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -683,7 +683,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): experimental_implements=implements, experimental_autograph_options=autograph_options, experimental_relax_shapes=relax_shapes, - experimental_compile=compile_) + jit_compile=compile_) if override_function: cloned_py_function = lambda x: x + 1 @@ -699,7 +699,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertEqual(implements, cloned._implements) self.assertEqual(autograph_options, cloned._experimental_autograph_options) self.assertEqual(relax_shapes, cloned._experimental_relax_shapes) - self.assertEqual(compile_, cloned._experimental_compile) + self.assertEqual(compile_, cloned._jit_compile) # This test does not run with XLA JIT support linked in so we can only check # the output of the function if compile is disabled. diff --git a/tensorflow/python/eager/def_function_test_cpu_only.py b/tensorflow/python/eager/def_function_test_cpu_only.py index 8f54845fa41..f6cdf0a209c 100644 --- a/tensorflow/python/eager/def_function_test_cpu_only.py +++ b/tensorflow/python/eager/def_function_test_cpu_only.py @@ -27,20 +27,20 @@ from tensorflow.python.platform import test class DefFunctionCpuOnlyTest(test.TestCase, parameterized.TestCase): - """Test that experimental_compile=True correctly throws an exception if XLA is not available. + """Test that jit_compile=True correctly throws an exception if XLA is not available. This test should only be run without `--config=cuda`, as that implicitly links in XLA JIT. """ - def testExperimentalCompileRaisesExceptionWhenXlaIsUnsupported(self): + def testJitCompileRaisesExceptionWhenXlaIsUnsupported(self): if test.is_built_with_rocm() or test_util.is_xla_enabled(): return with self.assertRaisesRegex(errors.UnimplementedError, 'check target linkage'): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def fn(x): return x + x diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index ed1085d8b54..5820bec31be 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -45,11 +45,11 @@ class DefFunctionTest(xla_test.XLATestCase): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=False) + @def_function.function(jit_compile=False) def outer(a, b, c): return a * inner(b, c) + c - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def inner(b, c): return b + c * b @@ -71,8 +71,8 @@ class DefFunctionTest(xla_test.XLATestCase): def fn(x, a): return x + a - func = def_function.function(fn, experimental_compile=False) - xla_func = def_function.function(fn, experimental_compile=True) + func = def_function.function(fn, jit_compile=False) + xla_func = def_function.function(fn, jit_compile=True) inputs = constant_op.constant([1, 2, 2, 3, 3]) self.assertAllClose([2, 3, 3, 4, 4], func(inputs, 1)) @@ -81,7 +81,7 @@ class DefFunctionTest(xla_test.XLATestCase): def testBasicInt32(self): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def fn(x, a): return x + a @@ -94,7 +94,7 @@ class DefFunctionTest(xla_test.XLATestCase): def fn(x, a): return 2 * x + a - xla_func = def_function.function(fn, experimental_compile=True) + xla_func = def_function.function(fn, jit_compile=True) with backprop.GradientTape() as tape: inputs = constant_op.constant([1., 2., 2., 3., 3.]) @@ -112,19 +112,19 @@ class DefFunctionTest(xla_test.XLATestCase): self.assertTrue(backward.function_def.attr['_XlaMustCompile']) self.assertTrue(forward.definition.attr['_XlaMustCompile']) - # Calling function with experimental_compile=True from - # experimental_compile=False should compile the inner func. + # Calling function with jit_compile=True from + # jit_compile=False should compile the inner func. def testNestedCall(self): if 'tpu' in self.device.lower(): self.skipTest('b/162800687: Inner function runs on host') with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def fn(x, a): return x + a - @def_function.function(experimental_compile=False) + @def_function.function(jit_compile=False) def fn2(x, a): return fn(x, a) @@ -139,12 +139,12 @@ class DefFunctionTest(xla_test.XLATestCase): def fn(x): return array_ops.unique(x).y - xla_func = def_function.function(fn, experimental_compile=True) + xla_func = def_function.function(fn, jit_compile=True) def fn2(x): return xla_func(x) - func = def_function.function(fn2, experimental_compile=False) + func = def_function.function(fn2, jit_compile=False) inputs = constant_op.constant([1, 2, 2, 3, 3]) with self.assertRaisesRegex(errors.InvalidArgumentError, 'not compilable'): @@ -158,8 +158,8 @@ class DefFunctionTest(xla_test.XLATestCase): def fn(x): return array_ops.unique(x).y # Unique is not supported by XLA - func = def_function.function(fn, experimental_compile=False) - xla_func = def_function.function(fn, experimental_compile=True) + func = def_function.function(fn, jit_compile=False) + xla_func = def_function.function(fn, jit_compile=True) inputs = constant_op.constant([1, 2, 2, 3, 3]) self.assertAllClose([1, 2, 3], func(inputs)) @@ -174,8 +174,8 @@ class DefFunctionTest(xla_test.XLATestCase): def fn(x): return v * x - func = def_function.function(fn, experimental_compile=False) - xla_func = def_function.function(fn, experimental_compile=True) + func = def_function.function(fn, jit_compile=False) + xla_func = def_function.function(fn, jit_compile=True) def run_and_check(test_func): x = constant_op.constant(3.0) @@ -195,7 +195,7 @@ class DefFunctionTest(xla_test.XLATestCase): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(x): assert control_flow_util.GraphOrParentsInXlaContext( ops.get_default_graph()) @@ -210,7 +210,7 @@ class DefFunctionTest(xla_test.XLATestCase): body, (constant_op.constant(0), constant_op.constant(3.)), maximum_iterations=10)[1] - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def g(x): x = ops.convert_to_tensor(x) with backprop.GradientTape() as tape: @@ -232,7 +232,7 @@ class DefFunctionTest(xla_test.XLATestCase): class C(object): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f1(self, x, a): return x + a @@ -248,7 +248,7 @@ class DefFunctionTest(xla_test.XLATestCase): class C(object): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f1(self, x): return array_ops.unique(x).y @@ -264,11 +264,11 @@ class DefFunctionTest(xla_test.XLATestCase): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(): return constant_op.constant([0, 2, 1], dtype=dtypes.int32) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def g(a, b): return array_ops.transpose(a, b) @@ -283,11 +283,11 @@ class DefFunctionTest(xla_test.XLATestCase): def testArgMinMax(self): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def argmax(x): return math_ops.argmax(x) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def argmin(x): return math_ops.argmin(x) @@ -300,7 +300,7 @@ class DefFunctionTest(xla_test.XLATestCase): def testErrorMessagePassingTensorArray(self): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(x): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=1, element_shape=[]) @@ -328,7 +328,7 @@ class DefFunctionTest(xla_test.XLATestCase): ta = ta.write(1, 3 * x) return ta.concat() - compiled_f = def_function.function(experimental_compile=True)(f) + compiled_f = def_function.function(jit_compile=True)(f) inputs = constant_op.constant([3.14, 2.68, 7.69]) @@ -348,7 +348,7 @@ class DefFunctionTest(xla_test.XLATestCase): ta = ta.write(1, 3 * x) return ta.concat() - compiled_f = def_function.function(experimental_compile=True)(f) + compiled_f = def_function.function(jit_compile=True)(f) inputs = constant_op.constant([[3.14, 21.1], [2.68, 22.2], [7.69, 23.3]]) self.assertAllClose(f(inputs), compiled_f(inputs)) @@ -365,7 +365,7 @@ class DefFunctionTest(xla_test.XLATestCase): ta = ta.write(1, 3 * x) return ta.concat() - compiled_f = def_function.function(experimental_compile=True)(f) + compiled_f = def_function.function(jit_compile=True)(f) inputs = constant_op.constant([3.14]) self.assertAllClose(f(inputs), compiled_f(inputs)) @@ -388,7 +388,7 @@ class DefFunctionTest(xla_test.XLATestCase): y = f(x) return tape.gradient(y, x) - compiled_g = def_function.function(experimental_compile=True)(g) + compiled_g = def_function.function(jit_compile=True)(g) self.assertAllClose([5.0, 5.0, 5.0], g()) self.assertAllClose(compiled_g(), g()) @@ -398,7 +398,7 @@ class DefFunctionTest(xla_test.XLATestCase): def testTensorListConcatGradNestedCompile(self): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(x): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=2, element_shape=[3]) @@ -406,7 +406,7 @@ class DefFunctionTest(xla_test.XLATestCase): ta = ta.write(1, 3 * x) return ta.concat() - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def g(): x = constant_op.constant([3.14, 2.68, 7.69]) with backprop.GradientTape() as tape: @@ -423,7 +423,7 @@ class DefFunctionTest(xla_test.XLATestCase): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(x): return math_ops.cumsum(x) @@ -434,7 +434,7 @@ class DefFunctionTest(xla_test.XLATestCase): with ops.device('device:{}:0'.format(self.device)): inner_retracings = 0 - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def inner(a, b): nonlocal inner_retracings inner_retracings += 1 @@ -455,7 +455,7 @@ class DefFunctionTest(xla_test.XLATestCase): on_gpu = 'gpu' in self.device.lower() v = variables.Variable([3.1, 3.2]) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def update_var(a, b): v.assign_add(a * b) @@ -476,7 +476,7 @@ class DefFunctionTest(xla_test.XLATestCase): class C(object): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def update_var(self, a, b): if not hasattr(self, 'v'): self.v = variables.Variable(3.1) @@ -497,7 +497,7 @@ class DefFunctionTest(xla_test.XLATestCase): with ops.device('device:{}:0'.format(self.device)): v = variables.Variable(3.1) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def update_var(a, b): v.assign_add(a * b) return a * b + v @@ -509,7 +509,7 @@ class DefFunctionTest(xla_test.XLATestCase): def testReturnIdentity(self): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(a, b): return (a, b) @@ -532,7 +532,7 @@ class DefFunctionTest(xla_test.XLATestCase): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(a, b): return array_ops.transpose(a, b) @@ -549,7 +549,7 @@ class DefFunctionTest(xla_test.XLATestCase): v = variables.Variable([3.1, 3.2]) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(a, b): v.assign_add(a * b) @@ -568,17 +568,17 @@ class DefFunctionTest(xla_test.XLATestCase): a = random_ops.random_normal([10, 10]) with self.assertRaisesRegex(ValueError, - 'marked with experimental_compile'): + 'marked with \'jit_compile'): f.experimental_get_compiler_ir(a)() def testGetCompilerIrNested(self): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def fn(x, a): return x + a - @def_function.function(experimental_compile=False) + @def_function.function(jit_compile=False) def fn2(x, a): fn.experimental_get_compiler_ir(x, a)() return fn(x, a) @@ -592,7 +592,7 @@ class DefFunctionTest(xla_test.XLATestCase): v = variables.Variable([0.1, 0.1]) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(a, b): return (a + b) * v @@ -605,7 +605,7 @@ class DefFunctionTest(xla_test.XLATestCase): def testGetCompilerIrDot(self): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(a, b): return a + b @@ -620,7 +620,7 @@ class DefFunctionTest(xla_test.XLATestCase): if 'gpu' not in self.device.lower(): self.skipTest('Testing get_compiler_ir on GPUs without placement') - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(a, b): return a + b @@ -634,7 +634,7 @@ class DefFunctionTest(xla_test.XLATestCase): def testGetCompilerIrNonTensors(self): with ops.device('device:{}:0'.format(self.device)): - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(l): return l[0] + l[1] @@ -649,7 +649,7 @@ class DefFunctionTest(xla_test.XLATestCase): s = random_ops.random_uniform([2], 1, 10, dtypes.int32) l = random_ops.random_normal([s[0] * s[1]]) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(l): return array_ops.reshape(l, s) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index d92440e9594..09cf404828d 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2349,7 +2349,7 @@ class FunctionSpec(object): input_signature, is_pure=False, experimental_follow_type_hints=False, - experimental_compile=None): + jit_compile=None): """Create a FunctionSpec instance given a python function and signature. Args: @@ -2358,7 +2358,7 @@ class FunctionSpec(object): is_pure: if True all input arguments (including variables and constants) will be converted to tensors and no variable changes allowed. experimental_follow_type_hints: see `tf.function` - experimental_compile: see `tf.function` + jit_compile: see `tf.function` Returns: instance of FunctionSpec @@ -2440,7 +2440,7 @@ class FunctionSpec(object): is_method, input_signature, is_pure=is_pure, - experimental_compile=experimental_compile, + jit_compile=jit_compile, experimental_follow_type_hints=experimental_follow_type_hints, name=name) @@ -2451,7 +2451,7 @@ class FunctionSpec(object): is_pure=False, experimental_follow_type_hints=False, name=None, - experimental_compile=None): + jit_compile=None): """Constructs a FunctionSpec describing a python function. Args: @@ -2462,12 +2462,12 @@ class FunctionSpec(object): will be converted to tensors and no variable changes allowed. experimental_follow_type_hints: see `tf.function`. name: Name of the function - experimental_compile: see `tf.function`. + jit_compile: see `tf.function`. """ self._fullargspec = fullargspec self._is_method = is_method self._is_pure = is_pure - self._experimental_compile = experimental_compile + self._jit_compile = jit_compile self._experimental_follow_type_hints = experimental_follow_type_hints # TODO(edloper): Include name when serializing for SavedModel? @@ -2538,8 +2538,8 @@ class FunctionSpec(object): return self._is_pure @property - def experimental_compile(self): - return self._experimental_compile + def jit_compile(self): + return self._jit_compile @property def arg_names(self): @@ -2904,7 +2904,7 @@ class Function(object): autograph_options=None, experimental_relax_shapes=False, capture_by_value=None, - experimental_compile=None, + jit_compile=None, experimental_follow_type_hints=False): """Initializes a `Function`. @@ -2927,8 +2927,8 @@ class Function(object): capture_by_value: Experimental. Whether to capture resource variables by value or reference. If None, will inherit from a parent context or default to False. - experimental_compile: Force-compile the function with XLA, cf. - def_function.Function doc on experimental_compile. + jit_compile: Force-compile the function with XLA, cf. + def_function.Function doc on jit_compile. experimental_follow_type_hints: See the documentation for `tf.function`. Raises: @@ -2959,7 +2959,7 @@ class Function(object): # `Function`, used to make sure defun-decorated methods create different # functions for each instance. self._descriptor_cache = weakref.WeakKeyDictionary() - self._experimental_compile = experimental_compile + self._jit_compile = jit_compile self._experimental_follow_type_hints = experimental_follow_type_hints def __call__(self, *args, **kwargs): @@ -3774,7 +3774,7 @@ def defun_with_attributes(func=None, attributes=None, autograph=True, experimental_autograph_options=None, - experimental_compile=None, + jit_compile=None, experimental_relax_shapes=False, experimental_follow_type_hints=False): """Compiles a Python function into a callable TensorFlow graph. @@ -3796,7 +3796,7 @@ def defun_with_attributes(func=None, autograph: same as defun()'s autograph. experimental_autograph_options: same as defun()'s experimental_autograph_options. - experimental_compile: same as defun()'s experimental_compile. + jit_compile: same as defun()'s jit_compile. experimental_relax_shapes: same as defun()'s experimental_relax_shapes experimental_follow_type_hints: see `tf.function`. @@ -3825,7 +3825,7 @@ def defun_with_attributes(func=None, attributes=attributes, autograph=autograph, autograph_options=experimental_autograph_options, - experimental_compile=experimental_compile, + jit_compile=jit_compile, experimental_relax_shapes=experimental_relax_shapes, experimental_follow_type_hints=experimental_follow_type_hints)) @@ -3925,7 +3925,7 @@ def class_method_to_instance_method(original_function, instance): autograph=original_function._autograph, input_signature=original_function.input_signature, experimental_relax_shapes=original_function._experimental_relax_shapes, - experimental_compile=original_function._experimental_compile) + jit_compile=original_function._jit_compile) # pylint: enable=protected-access # And we wrap the function with tf_decorator so inspection works correctly diff --git a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py index 59ae60257dd..ebc343158c0 100644 --- a/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py +++ b/tensorflow/python/keras/benchmarks/layer_benchmarks/layer_benchmarks_test.py @@ -92,7 +92,7 @@ class KerasLayerBenchmarks(six.with_metaclass( layer = layer_cls(**layer_args) x = tf.ones(input_shape) layer.call = tf.function( - layer.call, experimental_compile=True) + layer.call, jit_compile=True) fn = functools.partial(layer, x) name = _get_benchmark_name(self._get_name()) @@ -156,7 +156,7 @@ class KerasLayerBenchmarksBackwardXLA(six.with_metaclass( layer = layer_cls(**layer_args) x = tf.ones(input_shape) layer.call = tf.function( - layer.call, experimental_compile=True) + layer.call, jit_compile=True) fn = functools.partial(_layer_call_backward, layer, x) name = _get_benchmark_name(self._get_name()) diff --git a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py index e66b174b3aa..4fad158e5ae 100644 --- a/tensorflow/python/keras/distribute/custom_training_loop_models_test.py +++ b/tensorflow/python/keras/distribute/custom_training_loop_models_test.py @@ -419,7 +419,7 @@ class KerasModelsXLATest(test.TestCase, parameterized.TestCase): @ds_combinations.generate( combinations.combine( distribution=strategy_combinations.tpu_strategies, mode=["eager"])) - def test_tf_function_experimental_compile(self, distribution): + def test_tf_function_jit_compile(self, distribution): dataset = _get_dataset() input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) @@ -433,7 +433,7 @@ class KerasModelsXLATest(test.TestCase, parameterized.TestCase): self.kernel = self.add_variable( "kernel", shape=[int(input_shape[-1]), self.num_outputs]) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def call(self, inputs): return math_ops.matmul(inputs, self.kernel) diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py index e4896fd167e..a1d8b703c3f 100644 --- a/tensorflow/python/keras/optimizer_v2/adam.py +++ b/tensorflow/python/keras/optimizer_v2/adam.py @@ -416,7 +416,7 @@ class NonFusedAdam(optimizer_v2.OptimizerV2): weights = weights[:len(params)] super(NonFusedAdam, self).set_weights(weights) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def _resource_apply_dense(self, grad, var, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = ((apply_state or {}).get((var_device, var_dtype)) or @@ -437,7 +437,7 @@ class NonFusedAdam(optimizer_v2.OptimizerV2): var.assign_sub( (m * alpha) / (math_ops.sqrt(v) - coefficients['epsilon'])) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def _resource_apply_sparse(self, grad, var, indices, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = ((apply_state or {}).get((var_device, var_dtype)) or diff --git a/tensorflow/python/ops/collective_ops_xla_test.py b/tensorflow/python/ops/collective_ops_xla_test.py index c7550c854e0..bdfe816bd0b 100644 --- a/tensorflow/python/ops/collective_ops_xla_test.py +++ b/tensorflow/python/ops/collective_ops_xla_test.py @@ -56,7 +56,7 @@ class CollectiveOpXlaTest(test.TestCase): tensor_val = [i + 1.] * tensor_size constant = constant_op.constant(tensor_val) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(x): return 2 * x + 1 diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index c3e33231465..0b6bbbd7223 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -1288,7 +1288,7 @@ class ExecuteFnForDeviceTest(test_util.TensorFlowTestCase): def gpu_fn(x): return x * x - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def flexible_defun(a): branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)} return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a)) diff --git a/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py index 188df3f9b87..b709356af50 100644 --- a/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py @@ -76,12 +76,12 @@ class PForTest(PForTestCase): vectorized_compute, inputs=[array_ops.ones((10, 5, 3))]) self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) - def test_function_experimental_compile(self): + def test_function_jit_compile(self): def compute(x): return math_ops.reduce_mean(x, axis=0, keepdims=True) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def vectorized_compute(x): return pfor_control_flow_ops.vectorized_map(compute, x) @@ -112,7 +112,7 @@ class PForTest(PForTestCase): def test_reduce_mean(self): x = random_ops.random_uniform([8, 3]) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def f(): def loop_fn(i, pfor_config): @@ -172,7 +172,7 @@ class WhileV2Test(PForTestCase): # TODO(agarwal): The following may complain about uncompilable nodes. Hence # these are currently not enabled for all tests. if force_xla: - out_exp_compile_f = def_function.function(experimental_compile=True)(f)() + out_exp_compile_f = def_function.function(jit_compile=True)(f)() self.run_and_assert_equal(out, out_exp_compile_f) out_xla_compile_f = xla.compile(f, inputs=[]) self.run_and_assert_equal(out, out_xla_compile_f) diff --git a/tensorflow/python/ops/stateful_random_ops_test.py b/tensorflow/python/ops/stateful_random_ops_test.py index 756ead401b4..cc82e5058e0 100644 --- a/tensorflow/python/ops/stateful_random_ops_test.py +++ b/tensorflow/python/ops/stateful_random_ops_test.py @@ -592,7 +592,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase): random.set_global_generator(None) - @def_function.function(experimental_compile=True) + @def_function.function(jit_compile=True) def make_seed(): generator = random.get_global_generator() state = array_ops.identity(generator.state, name="state") diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index 7e3e3d3b32b..368fbd105f6 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -143,17 +143,17 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder): annotations=typeless_fullargspec.annotations) input_signature = coder.decode_proto(function_spec_proto.input_signature) - # See `tf.function` and the ExperimentalCompile proto for details. - experimental_compile = { - saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.DEFAULT: None, - saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.ON: True, - saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.OFF: False, - }.get(function_spec_proto.experimental_compile) + # See `tf.function` and the JitCompile proto for details. + jit_compile = { + saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None, + saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True, + saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False, + }.get(function_spec_proto.jit_compile) return function_lib.FunctionSpec(fullargspec=fullargspec, is_method=False, input_signature=input_signature, - experimental_compile=experimental_compile) + jit_compile=jit_compile) # TODO(allenl): The fact that we can't derive ConcreteFunction calling @@ -191,7 +191,7 @@ class RestoredFunction(def_function.Function): # TODO(mdan): We may enable autograph once exceptions are supported. super(RestoredFunction, self).__init__( python_function, name, autograph=False, - experimental_compile=function_spec.experimental_compile) + jit_compile=function_spec.jit_compile) self.concrete_functions = concrete_functions self._function_spec = function_spec diff --git a/tensorflow/python/saved_model/function_serialization.py b/tensorflow/python/saved_model/function_serialization.py index ad18e8f5d2a..aaaead7c3fe 100644 --- a/tensorflow/python/saved_model/function_serialization.py +++ b/tensorflow/python/saved_model/function_serialization.py @@ -47,12 +47,12 @@ def _serialize_function_spec(function_spec, coder): proto.input_signature.CopyFrom( coder.encode_structure(function_spec.input_signature)) - # See `tf.function` and the ExperimentalCompile proto for details. - proto.experimental_compile = { - None: saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.DEFAULT, - True: saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.ON, - False: saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.OFF, - }.get(function_spec.experimental_compile) + # See `tf.function` and the JitCompile proto for details. + proto.jit_compile = { + None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT, + True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON, + False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF, + }.get(function_spec.jit_compile) return proto diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index d58e6c1239e..6d8d7520c2f 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -931,13 +931,13 @@ class LoadTest(test.TestCase, parameterized.TestCase): self.assertAllEqual([2, 4, 6], imported.f(constant_op.constant([1, 2, 3])).numpy()) - def test_experimental_compile(self, cycles): + def test_jit_compile(self, cycles): # It'd be nice to use parameterize here, but the library does not support # having parameterized test methods inside already-parameterized classes. - for experimental_compile in (None, True, False): + for jit_compile in (None, True, False): - @def_function.function(experimental_compile=experimental_compile) + @def_function.function(jit_compile=jit_compile) def f(x): return x + 1. @@ -948,7 +948,7 @@ class LoadTest(test.TestCase, parameterized.TestCase): imported = cycle(root, cycles) - self.assertEqual(imported.f._experimental_compile, experimental_compile) + self.assertEqual(imported.f._jit_compile, jit_compile) def test_get_concrete_function(self, cycles): diff --git a/tensorflow/python/training/saving/functional_saver.py b/tensorflow/python/training/saving/functional_saver.py index 62b8b72ce2a..9511fdbaa05 100644 --- a/tensorflow/python/training/saving/functional_saver.py +++ b/tensorflow/python/training/saving/functional_saver.py @@ -292,7 +292,7 @@ class MultiDeviceSaver(object): # latest values of options like experimental_io_device. if context.executing_eagerly() and len(self._single_device_savers) > 1: # Explicitly place the identity op on the first device. - @def_function.function(experimental_compile=False) + @def_function.function(jit_compile=False) def tf_function_save(): save_fn() tf_function_save() @@ -328,7 +328,7 @@ class MultiDeviceSaver(object): # latest values of options like experimental_io_device. if context.executing_eagerly() and len(self._single_device_savers) > 1: first_device, _ = list(self._single_device_savers.items())[0] - @def_function.function(experimental_compile=False) + @def_function.function(jit_compile=False) def tf_function_restore(): restore_ops = restore_fn() restore_tensors = {} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index abd33957365..610e6b5275f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1342,7 +1342,7 @@ tf_module { } member_method { name: "function" - argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'jit_compile\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], " } member_method { name: "gather" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 0aa5e8924a3..fa80942edad 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -674,7 +674,7 @@ tf_module { } member_method { name: "function" - argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'False\', \'None\', \'None\'], " + argspec: "args=[\'func\', \'input_signature\', \'autograph\', \'jit_compile\', \'experimental_implements\', \'experimental_autograph_options\', \'experimental_relax_shapes\', \'experimental_compile\', \'experimental_follow_type_hints\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\'], " } member_method { name: "gather"