[TF2XLA] [NFC] Clarify config.optimizer.set_jit and config.optimizer.get_jit semantics

Clarify the setting value vs. `@tf.function(jit_compile=True)`

PiperOrigin-RevId: 354636613
Change-Id: I7afd41be1764a17c1cb68710c1f31a2554295754
This commit is contained in:
George Karpenkov 2021-01-29 16:45:46 -08:00 committed by TensorFlower Gardener
parent 3f5b1c389d
commit fcb2fc1ec3
2 changed files with 34 additions and 25 deletions
tensorflow/python/framework

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Union
from tensorflow.python.eager import context
from tensorflow.python.util import _pywrap_tensor_float_32_execution
from tensorflow.python.util import deprecation
@ -145,31 +147,42 @@ def set_inter_op_parallelism_threads(num_threads):
@tf_export('config.optimizer.get_jit')
def get_optimizer_jit():
"""Get if JIT compilation is enabled.
def get_optimizer_jit() -> str:
"""Returns JIT compilation configuration for code inside `tf.function`.
Note that optimizations are only applied to code that is compiled into a
graph. In eager mode, which is the TF2 API default, that means only code that
is defined under a tf.function decorator.
Returns:
If JIT compilation is enabled.
Possible return values:
-`"autoclustering"` if
[autoclustering](https://www.tensorflow.org/xla#auto-clustering) is enabled
- `""` when no default compilation is applied.
"""
return context.context().optimizer_jit
if context.context().optimizer_jit:
return 'autoclustering'
return ''
@tf_export('config.optimizer.set_jit')
def set_optimizer_jit(enabled):
"""Set if JIT compilation is enabled.
@deprecation.deprecated_arg_values(
None,
'`True` setting is deprecated, use `autoclustering` instead.',
warn_once=True,
jit_config=True)
def set_optimizer_jit(enabled: Union[bool, str]):
"""Configure JIT compilation.
Note that optimizations are only applied to code that is compiled into a
graph. In eager mode, which is the TF2 API default, that means only code that
is defined under a tf.function decorator.
Note: compilation is only applied to code that is compiled into a
graph (in TF2 that's only a code inside `tf.function`).
Args:
enabled: Whether to enable JIT compilation.
enabled: JIT compilation configuration.
Possible values:
- `"autoclustering"` (`True` is a deprecated alias): perform
[autoclustering](https://www.tensorflow.org/xla#auto-clustering)
(automatically identify and compile clusters of nodes) on all graphs using
[XLA](https://www.tensorflow.org/xla).
- `False`: do not automatically compile any graphs.
"""
context.context().optimizer_jit = enabled
autoclustering_enabled = enabled in (True, 'autoclustering')
context.context().optimizer_jit = autoclustering_enabled
@tf_export('config.optimizer.get_experimental_options')

View File

@ -251,7 +251,7 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
@test_util.run_gpu_only
@reset_eager
def testJit(self):
self.assertEqual(config.get_optimizer_jit(), False)
self.assertEqual(config.get_optimizer_jit(), '')
# the following function should cause Op fusion to occur. However, there is
# unfortunately no straightforward way to ensure this. We will just have to
@ -267,17 +267,13 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
self.evaluate(fun(a, b))
config.set_optimizer_jit(True)
self.assertEqual(config.get_optimizer_jit(), True)
self.assertEqual(config.get_optimizer_jit(),
context.context().optimizer_jit)
config.set_optimizer_jit('autoclustering')
self.assertEqual(config.get_optimizer_jit(), 'autoclustering')
self.evaluate(fun(a, b))
config.set_optimizer_jit(False)
self.assertEqual(config.get_optimizer_jit(), False)
self.assertEqual(config.get_optimizer_jit(),
context.context().optimizer_jit)
config.set_optimizer_jit('')
self.assertEqual(config.get_optimizer_jit(), '')
self.evaluate(fun(a, b))