Add experimental.enable_mlir_bridge to tf.config module
PiperOrigin-RevId: 272092574
This commit is contained in:
parent
19156f6e40
commit
1be7d54642
tensorflow
python
tools/api/golden
@ -413,6 +413,7 @@ class Context(object):
|
||||
self._inter_op_parallelism_threads = None
|
||||
self._soft_device_placement = None
|
||||
self._log_device_placement = None
|
||||
self._enable_mlir_bridge = None
|
||||
self._optimizer_experimental_options = {}
|
||||
|
||||
_python_eager_context_create_counter.get_cell().increase_by(1)
|
||||
@ -807,6 +808,9 @@ class Context(object):
|
||||
if self._log_device_placement is not None:
|
||||
config.log_device_placement = self._log_device_placement
|
||||
|
||||
if self._enable_mlir_bridge is not None:
|
||||
config.experimental.enable_mlir_bridge = self._enable_mlir_bridge
|
||||
|
||||
def rewriter_toggle(option):
|
||||
toggle = self._optimizer_experimental_options.get(option, None)
|
||||
if toggle is None:
|
||||
@ -1258,6 +1262,16 @@ class Context(object):
|
||||
|
||||
self._virtual_device_map[dev] = virtual_devices
|
||||
|
||||
@property
|
||||
def enable_mlir_bridge(self):
|
||||
return self._enable_mlir_bridge
|
||||
|
||||
@enable_mlir_bridge.setter
|
||||
def enable_mlir_bridge(self, enabled):
|
||||
self._enable_mlir_bridge = enabled
|
||||
|
||||
self._thread_local_data.function_call_options = None
|
||||
|
||||
@property
|
||||
def optimizer_jit(self):
|
||||
level = self.config.graph_options.optimizer_options.global_jit_level
|
||||
|
@ -571,3 +571,26 @@ def set_virtual_device_configuration(device, virtual_devices):
|
||||
RuntimeError: Runtime is already initialized.
|
||||
"""
|
||||
context.context().set_virtual_device_configuration(device, virtual_devices)
|
||||
|
||||
|
||||
@tf_export('config.experimental.enable_mlir_bridge')
|
||||
def enable_mlir_bridge():
|
||||
"""Enables experimental MLIR-Based TensorFlow Compiler Bridge.
|
||||
|
||||
DO NOT USE, DEV AND TESTING ONLY AT THE MOMENT.
|
||||
|
||||
NOTE: MLIR-Based TensorFlow Compiler is under active development and has
|
||||
missing features, please refrain from using. This API exists for development
|
||||
and testing only.
|
||||
|
||||
TensorFlow Compiler Bridge (TF Bridge) is responsible for translating parts
|
||||
of TensorFlow graph into a form that can be accepted as an input by a backend
|
||||
compiler such as XLA.
|
||||
"""
|
||||
context.context().enable_mlir_bridge = True
|
||||
|
||||
|
||||
@tf_export('config.experimental.disable_mlir_bridge')
|
||||
def disable_mlir_bridge():
|
||||
"""Disables experimental MLIR-Based TensorFlow Compiler Bridge."""
|
||||
context.context().enable_mlir_bridge = False
|
||||
|
@ -209,6 +209,19 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
|
||||
# exception.
|
||||
context.set_log_device_placement(False)
|
||||
|
||||
@reset_eager
|
||||
def testEnableMlirBridge(self):
|
||||
# Default value of enable_mlir_bridge is false.
|
||||
self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
|
||||
|
||||
# Tests enabling mlir bridge.
|
||||
config.enable_mlir_bridge()
|
||||
self.assertTrue(context.context().config.experimental.enable_mlir_bridge)
|
||||
|
||||
# Tests disabling mlir bridge.
|
||||
config.disable_mlir_bridge()
|
||||
self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@reset_eager
|
||||
def testJit(self):
|
||||
|
@ -4,6 +4,14 @@ tf_module {
|
||||
name: "VirtualDeviceConfiguration"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "disable_mlir_bridge"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_mlir_bridge"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_device_policy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -4,6 +4,14 @@ tf_module {
|
||||
name: "VirtualDeviceConfiguration"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "disable_mlir_bridge"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_mlir_bridge"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_device_policy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user