[TF2XLA] [NFC] Clarify config.optimizer.set_jit and config.optimizer.get_jit semantics
Clarify the setting value vs. `@tf.function(jit_compile=True)` PiperOrigin-RevId: 354636613 Change-Id: I7afd41be1764a17c1cb68710c1f31a2554295754
This commit is contained in:
parent
3f5b1c389d
commit
fcb2fc1ec3
tensorflow/python/framework
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from typing import Union
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.util import _pywrap_tensor_float_32_execution
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -145,31 +147,42 @@ def set_inter_op_parallelism_threads(num_threads):
|
||||
|
||||
|
||||
@tf_export('config.optimizer.get_jit')
|
||||
def get_optimizer_jit():
|
||||
"""Get if JIT compilation is enabled.
|
||||
def get_optimizer_jit() -> str:
|
||||
"""Returns JIT compilation configuration for code inside `tf.function`.
|
||||
|
||||
Note that optimizations are only applied to code that is compiled into a
|
||||
graph. In eager mode, which is the TF2 API default, that means only code that
|
||||
is defined under a tf.function decorator.
|
||||
|
||||
Returns:
|
||||
If JIT compilation is enabled.
|
||||
Possible return values:
|
||||
-`"autoclustering"` if
|
||||
[autoclustering](https://www.tensorflow.org/xla#auto-clustering) is enabled
|
||||
- `""` when no default compilation is applied.
|
||||
"""
|
||||
return context.context().optimizer_jit
|
||||
if context.context().optimizer_jit:
|
||||
return 'autoclustering'
|
||||
return ''
|
||||
|
||||
|
||||
@tf_export('config.optimizer.set_jit')
|
||||
def set_optimizer_jit(enabled):
|
||||
"""Set if JIT compilation is enabled.
|
||||
@deprecation.deprecated_arg_values(
|
||||
None,
|
||||
'`True` setting is deprecated, use `autoclustering` instead.',
|
||||
warn_once=True,
|
||||
jit_config=True)
|
||||
def set_optimizer_jit(enabled: Union[bool, str]):
|
||||
"""Configure JIT compilation.
|
||||
|
||||
Note that optimizations are only applied to code that is compiled into a
|
||||
graph. In eager mode, which is the TF2 API default, that means only code that
|
||||
is defined under a tf.function decorator.
|
||||
Note: compilation is only applied to code that is compiled into a
|
||||
graph (in TF2 that's only a code inside `tf.function`).
|
||||
|
||||
Args:
|
||||
enabled: Whether to enable JIT compilation.
|
||||
enabled: JIT compilation configuration.
|
||||
Possible values:
|
||||
- `"autoclustering"` (`True` is a deprecated alias): perform
|
||||
[autoclustering](https://www.tensorflow.org/xla#auto-clustering)
|
||||
(automatically identify and compile clusters of nodes) on all graphs using
|
||||
[XLA](https://www.tensorflow.org/xla).
|
||||
- `False`: do not automatically compile any graphs.
|
||||
"""
|
||||
context.context().optimizer_jit = enabled
|
||||
autoclustering_enabled = enabled in (True, 'autoclustering')
|
||||
context.context().optimizer_jit = autoclustering_enabled
|
||||
|
||||
|
||||
@tf_export('config.optimizer.get_experimental_options')
|
||||
|
@ -251,7 +251,7 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_gpu_only
|
||||
@reset_eager
|
||||
def testJit(self):
|
||||
self.assertEqual(config.get_optimizer_jit(), False)
|
||||
self.assertEqual(config.get_optimizer_jit(), '')
|
||||
|
||||
# the following function should cause Op fusion to occur. However, there is
|
||||
# unfortunately no straightforward way to ensure this. We will just have to
|
||||
@ -267,17 +267,13 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.evaluate(fun(a, b))
|
||||
|
||||
config.set_optimizer_jit(True)
|
||||
self.assertEqual(config.get_optimizer_jit(), True)
|
||||
self.assertEqual(config.get_optimizer_jit(),
|
||||
context.context().optimizer_jit)
|
||||
config.set_optimizer_jit('autoclustering')
|
||||
self.assertEqual(config.get_optimizer_jit(), 'autoclustering')
|
||||
|
||||
self.evaluate(fun(a, b))
|
||||
|
||||
config.set_optimizer_jit(False)
|
||||
self.assertEqual(config.get_optimizer_jit(), False)
|
||||
self.assertEqual(config.get_optimizer_jit(),
|
||||
context.context().optimizer_jit)
|
||||
config.set_optimizer_jit('')
|
||||
self.assertEqual(config.get_optimizer_jit(), '')
|
||||
|
||||
self.evaluate(fun(a, b))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user