From 1914c410b274055078cfb97c1f8db5d98c5f9146 Mon Sep 17 00:00:00 2001 From: Marissa Ikonomidis <marissaw@google.com> Date: Tue, 20 Oct 2020 13:20:19 -0700 Subject: [PATCH] Update context and tfe_wrapper to support mlir_bridge_rollout Update eager/context.py and tfe_wrapper to support returning the real value of mlir_bridge_rollout (enabled/disabled/unspecified) instead of a bool. This gives users a clearer signal of whether or not the mlir bridge is being used. At the moment, the mlir bridge is only enabled when mlir_bridge_rollout is set to enabled but this will change in the future. PiperOrigin-RevId: 338124102 Change-Id: I5c93cbdd2815a698e6b41244db8eed716f4988e6 --- tensorflow/compiler/tf2xla/mlir_bridge_pass.h | 8 +++++++- tensorflow/python/eager/context.py | 7 ++++++- tensorflow/python/framework/config_test.py | 9 +++++++++ tensorflow/python/tfe_wrapper.cc | 6 ++++-- 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index bbddeb6a967..2f08a80e975 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -31,6 +31,9 @@ class MlirBridgePass : public MlirOptimizationPass { bool IsEnabled(const ConfigProto& config_proto) const override { return config_proto.experimental().enable_mlir_bridge() || + config_proto.experimental().mlir_bridge_rollout() == + tensorflow::ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_ENABLED || tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == tensorflow::ConfigProto::Experimental:: MLIR_BRIDGE_ROLLOUT_ENABLED; @@ -50,7 +53,10 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { bool IsEnabled(const ConfigProto& config_proto) const override { return config_proto.experimental().enable_mlir_bridge() || - GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == + config_proto.experimental().mlir_bridge_rollout() == + tensorflow::ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_ENABLED || + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == tensorflow::ConfigProto::Experimental:: MLIR_BRIDGE_ROLLOUT_ENABLED; } diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index a15e37a9151..026dbce321d 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -948,7 +948,12 @@ class Context(object): if self._log_device_placement is not None: config.log_device_placement = self._log_device_placement - config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled() + is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled() + config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled + if (is_mlir_bridge_enabled == + config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED): + config.experimental.enable_mlir_bridge = True + if self._enable_mlir_graph_optimization is not None: config.experimental.enable_mlir_graph_optimization = ( self._enable_mlir_graph_optimization) diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index a20af802824..7dd26425037 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -214,14 +214,23 @@ class ConfigTest(test.TestCase, parameterized.TestCase): def testEnableMlirBridge(self): # Default value of enable_mlir_bridge is false. self.assertFalse(context.context().config.experimental.enable_mlir_bridge) + self.assertEqual( + context.context().config.experimental.mlir_bridge_rollout, + config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) # Tests enabling mlir bridge. config.enable_mlir_bridge() self.assertTrue(context.context().config.experimental.enable_mlir_bridge) + self.assertEqual( + context.context().config.experimental.mlir_bridge_rollout, + config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED) # Tests disabling mlir bridge. config.disable_mlir_bridge() self.assertFalse(context.context().config.experimental.enable_mlir_bridge) + self.assertEqual( + context.context().config.experimental.mlir_bridge_rollout, + config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_DISABLED) @reset_eager def testEnableMlirGraphOptimization(self): diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index e8e66bcef28..980695f28bb 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -580,8 +580,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) { // MLIR Logic m.def("TF_IsMlirBridgeEnabled", [] { - return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == - tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + // Since python protobuf enums are integers, cast to an integer before + // returning the enum to python. + return static_cast<int32_t>( + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge); }); m.def("TF_EnableMlirBridge", [](bool enabled) { tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =