[TF2XLA] [NFC] Clarify config.optimizer.set_jit and config.optimizer.get_jit semantics
Clarify the setting value vs. `@tf.function(jit_compile=True)` PiperOrigin-RevId: 354636613 Change-Id: I7afd41be1764a17c1cb68710c1f31a2554295754
This commit is contained in:
parent
3f5b1c389d
commit
fcb2fc1ec3
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.util import _pywrap_tensor_float_32_execution
|
from tensorflow.python.util import _pywrap_tensor_float_32_execution
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
@ -145,31 +147,42 @@ def set_inter_op_parallelism_threads(num_threads):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export('config.optimizer.get_jit')
|
@tf_export('config.optimizer.get_jit')
|
||||||
def get_optimizer_jit():
|
def get_optimizer_jit() -> str:
|
||||||
"""Get if JIT compilation is enabled.
|
"""Returns JIT compilation configuration for code inside `tf.function`.
|
||||||
|
|
||||||
Note that optimizations are only applied to code that is compiled into a
|
Possible return values:
|
||||||
graph. In eager mode, which is the TF2 API default, that means only code that
|
-`"autoclustering"` if
|
||||||
is defined under a tf.function decorator.
|
[autoclustering](https://www.tensorflow.org/xla#auto-clustering) is enabled
|
||||||
|
- `""` when no default compilation is applied.
|
||||||
Returns:
|
|
||||||
If JIT compilation is enabled.
|
|
||||||
"""
|
"""
|
||||||
return context.context().optimizer_jit
|
if context.context().optimizer_jit:
|
||||||
|
return 'autoclustering'
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
@tf_export('config.optimizer.set_jit')
|
@tf_export('config.optimizer.set_jit')
|
||||||
def set_optimizer_jit(enabled):
|
@deprecation.deprecated_arg_values(
|
||||||
"""Set if JIT compilation is enabled.
|
None,
|
||||||
|
'`True` setting is deprecated, use `autoclustering` instead.',
|
||||||
|
warn_once=True,
|
||||||
|
jit_config=True)
|
||||||
|
def set_optimizer_jit(enabled: Union[bool, str]):
|
||||||
|
"""Configure JIT compilation.
|
||||||
|
|
||||||
Note that optimizations are only applied to code that is compiled into a
|
Note: compilation is only applied to code that is compiled into a
|
||||||
graph. In eager mode, which is the TF2 API default, that means only code that
|
graph (in TF2 that's only a code inside `tf.function`).
|
||||||
is defined under a tf.function decorator.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
enabled: Whether to enable JIT compilation.
|
enabled: JIT compilation configuration.
|
||||||
|
Possible values:
|
||||||
|
- `"autoclustering"` (`True` is a deprecated alias): perform
|
||||||
|
[autoclustering](https://www.tensorflow.org/xla#auto-clustering)
|
||||||
|
(automatically identify and compile clusters of nodes) on all graphs using
|
||||||
|
[XLA](https://www.tensorflow.org/xla).
|
||||||
|
- `False`: do not automatically compile any graphs.
|
||||||
"""
|
"""
|
||||||
context.context().optimizer_jit = enabled
|
autoclustering_enabled = enabled in (True, 'autoclustering')
|
||||||
|
context.context().optimizer_jit = autoclustering_enabled
|
||||||
|
|
||||||
|
|
||||||
@tf_export('config.optimizer.get_experimental_options')
|
@tf_export('config.optimizer.get_experimental_options')
|
||||||
|
@ -251,7 +251,7 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
|||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testJit(self):
|
def testJit(self):
|
||||||
self.assertEqual(config.get_optimizer_jit(), False)
|
self.assertEqual(config.get_optimizer_jit(), '')
|
||||||
|
|
||||||
# the following function should cause Op fusion to occur. However, there is
|
# the following function should cause Op fusion to occur. However, there is
|
||||||
# unfortunately no straightforward way to ensure this. We will just have to
|
# unfortunately no straightforward way to ensure this. We will just have to
|
||||||
@ -267,17 +267,13 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.evaluate(fun(a, b))
|
self.evaluate(fun(a, b))
|
||||||
|
|
||||||
config.set_optimizer_jit(True)
|
config.set_optimizer_jit('autoclustering')
|
||||||
self.assertEqual(config.get_optimizer_jit(), True)
|
self.assertEqual(config.get_optimizer_jit(), 'autoclustering')
|
||||||
self.assertEqual(config.get_optimizer_jit(),
|
|
||||||
context.context().optimizer_jit)
|
|
||||||
|
|
||||||
self.evaluate(fun(a, b))
|
self.evaluate(fun(a, b))
|
||||||
|
|
||||||
config.set_optimizer_jit(False)
|
config.set_optimizer_jit('')
|
||||||
self.assertEqual(config.get_optimizer_jit(), False)
|
self.assertEqual(config.get_optimizer_jit(), '')
|
||||||
self.assertEqual(config.get_optimizer_jit(),
|
|
||||||
context.context().optimizer_jit)
|
|
||||||
|
|
||||||
self.evaluate(fun(a, b))
|
self.evaluate(fun(a, b))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user