Add experimental.enable_mlir_bridge to tf.config module

PiperOrigin-RevId: 272092574
This commit is contained in:
Yanan Cao 2019-09-30 16:32:39 -07:00 committed by TensorFlower Gardener
parent 19156f6e40
commit 1be7d54642
5 changed files with 66 additions and 0 deletions

View File

@ -413,6 +413,7 @@ class Context(object):
self._inter_op_parallelism_threads = None
self._soft_device_placement = None
self._log_device_placement = None
self._enable_mlir_bridge = None
self._optimizer_experimental_options = {}
_python_eager_context_create_counter.get_cell().increase_by(1)
@ -807,6 +808,9 @@ class Context(object):
if self._log_device_placement is not None:
config.log_device_placement = self._log_device_placement
if self._enable_mlir_bridge is not None:
config.experimental.enable_mlir_bridge = self._enable_mlir_bridge
def rewriter_toggle(option):
toggle = self._optimizer_experimental_options.get(option, None)
if toggle is None:
@ -1258,6 +1262,16 @@ class Context(object):
self._virtual_device_map[dev] = virtual_devices
@property
def enable_mlir_bridge(self):
return self._enable_mlir_bridge
@enable_mlir_bridge.setter
def enable_mlir_bridge(self, enabled):
self._enable_mlir_bridge = enabled
self._thread_local_data.function_call_options = None
@property
def optimizer_jit(self):
level = self.config.graph_options.optimizer_options.global_jit_level

View File

@ -571,3 +571,26 @@ def set_virtual_device_configuration(device, virtual_devices):
RuntimeError: Runtime is already initialized.
"""
context.context().set_virtual_device_configuration(device, virtual_devices)
@tf_export('config.experimental.enable_mlir_bridge')
def enable_mlir_bridge():
"""Enables experimental MLIR-Based TensorFlow Compiler Bridge.
DO NOT USE, DEV AND TESTING ONLY AT THE MOMENT.
NOTE: MLIR-Based TensorFlow Compiler is under active development and has
missing features, please refrain from using. This API exists for development
and testing only.
TensorFlow Compiler Bridge (TF Bridge) is responsible for translating parts
of TensorFlow graph into a form that can be accepted as an input by a backend
compiler such as XLA.
"""
context.context().enable_mlir_bridge = True
@tf_export('config.experimental.disable_mlir_bridge')
def disable_mlir_bridge():
"""Disables experimental MLIR-Based TensorFlow Compiler Bridge."""
context.context().enable_mlir_bridge = False

View File

@ -209,6 +209,19 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
# exception.
context.set_log_device_placement(False)
@reset_eager
def testEnableMlirBridge(self):
# Default value of enable_mlir_bridge is false.
self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
# Tests enabling mlir bridge.
config.enable_mlir_bridge()
self.assertTrue(context.context().config.experimental.enable_mlir_bridge)
# Tests disabling mlir bridge.
config.disable_mlir_bridge()
self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
@test_util.run_gpu_only
@reset_eager
def testJit(self):

View File

@ -4,6 +4,14 @@ tf_module {
name: "VirtualDeviceConfiguration"
mtype: "<type \'type\'>"
}
member_method {
name: "disable_mlir_bridge"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "enable_mlir_bridge"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_device_policy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"

View File

@ -4,6 +4,14 @@ tf_module {
name: "VirtualDeviceConfiguration"
mtype: "<type \'type\'>"
}
member_method {
name: "disable_mlir_bridge"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "enable_mlir_bridge"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_device_policy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"