diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 310383270db..57a19359a48 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -413,6 +413,7 @@ class Context(object): self._inter_op_parallelism_threads = None self._soft_device_placement = None self._log_device_placement = None + self._enable_mlir_bridge = None self._optimizer_experimental_options = {} _python_eager_context_create_counter.get_cell().increase_by(1) @@ -807,6 +808,9 @@ class Context(object): if self._log_device_placement is not None: config.log_device_placement = self._log_device_placement + if self._enable_mlir_bridge is not None: + config.experimental.enable_mlir_bridge = self._enable_mlir_bridge + def rewriter_toggle(option): toggle = self._optimizer_experimental_options.get(option, None) if toggle is None: @@ -1258,6 +1262,16 @@ class Context(object): self._virtual_device_map[dev] = virtual_devices + @property + def enable_mlir_bridge(self): + return self._enable_mlir_bridge + + @enable_mlir_bridge.setter + def enable_mlir_bridge(self, enabled): + self._enable_mlir_bridge = enabled + + self._thread_local_data.function_call_options = None + @property def optimizer_jit(self): level = self.config.graph_options.optimizer_options.global_jit_level diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index 44f687beda6..2f17c1f3fbe 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -571,3 +571,26 @@ def set_virtual_device_configuration(device, virtual_devices): RuntimeError: Runtime is already initialized. """ context.context().set_virtual_device_configuration(device, virtual_devices) + + +@tf_export('config.experimental.enable_mlir_bridge') +def enable_mlir_bridge(): + """Enables experimental MLIR-Based TensorFlow Compiler Bridge. + + DO NOT USE, DEV AND TESTING ONLY AT THE MOMENT. + + NOTE: MLIR-Based TensorFlow Compiler is under active development and has + missing features, please refrain from using. This API exists for development + and testing only. + + TensorFlow Compiler Bridge (TF Bridge) is responsible for translating parts + of TensorFlow graph into a form that can be accepted as an input by a backend + compiler such as XLA. + """ + context.context().enable_mlir_bridge = True + + +@tf_export('config.experimental.disable_mlir_bridge') +def disable_mlir_bridge(): + """Disables experimental MLIR-Based TensorFlow Compiler Bridge.""" + context.context().enable_mlir_bridge = False diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index 5d3b300190e..24250f90100 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -209,6 +209,19 @@ class ConfigTest(test.TestCase, parameterized.TestCase): # exception. context.set_log_device_placement(False) + @reset_eager + def testEnableMlirBridge(self): + # Default value of enable_mlir_bridge is false. + self.assertFalse(context.context().config.experimental.enable_mlir_bridge) + + # Tests enabling mlir bridge. + config.enable_mlir_bridge() + self.assertTrue(context.context().config.experimental.enable_mlir_bridge) + + # Tests disabling mlir bridge. + config.disable_mlir_bridge() + self.assertFalse(context.context().config.experimental.enable_mlir_bridge) + @test_util.run_gpu_only @reset_eager def testJit(self): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt index e567a662a34..b5cfaadcccc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt @@ -4,6 +4,14 @@ tf_module { name: "VirtualDeviceConfiguration" mtype: "" } + member_method { + name: "disable_mlir_bridge" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "enable_mlir_bridge" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_device_policy" argspec: "args=[], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt index e567a662a34..b5cfaadcccc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt @@ -4,6 +4,14 @@ tf_module { name: "VirtualDeviceConfiguration" mtype: "" } + member_method { + name: "disable_mlir_bridge" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "enable_mlir_bridge" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } member_method { name: "get_device_policy" argspec: "args=[], varargs=None, keywords=None, defaults=None"