From 1914c410b274055078cfb97c1f8db5d98c5f9146 Mon Sep 17 00:00:00 2001
From: Marissa Ikonomidis <marissaw@google.com>
Date: Tue, 20 Oct 2020 13:20:19 -0700
Subject: [PATCH] Update context and tfe_wrapper to support mlir_bridge_rollout

Update eager/context.py and tfe_wrapper to support returning
the real value of mlir_bridge_rollout (enabled/disabled/unspecified)
instead of a bool. This gives users a clearer signal of whether
or not the mlir bridge is being used. At the moment, the mlir
bridge is only enabled when mlir_bridge_rollout is set to
enabled but this will change in the future.

PiperOrigin-RevId: 338124102
Change-Id: I5c93cbdd2815a698e6b41244db8eed716f4988e6
---
 tensorflow/compiler/tf2xla/mlir_bridge_pass.h | 8 +++++++-
 tensorflow/python/eager/context.py            | 7 ++++++-
 tensorflow/python/framework/config_test.py    | 9 +++++++++
 tensorflow/python/tfe_wrapper.cc              | 6 ++++--
 4 files changed, 26 insertions(+), 4 deletions(-)

diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
index bbddeb6a967..2f08a80e975 100644
--- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
+++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h
@@ -31,6 +31,9 @@ class MlirBridgePass : public MlirOptimizationPass {
 
   bool IsEnabled(const ConfigProto& config_proto) const override {
     return config_proto.experimental().enable_mlir_bridge() ||
+           config_proto.experimental().mlir_bridge_rollout() ==
+               tensorflow::ConfigProto::Experimental::
+                   MLIR_BRIDGE_ROLLOUT_ENABLED ||
            tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
                tensorflow::ConfigProto::Experimental::
                    MLIR_BRIDGE_ROLLOUT_ENABLED;
@@ -50,7 +53,10 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
 
   bool IsEnabled(const ConfigProto& config_proto) const override {
     return config_proto.experimental().enable_mlir_bridge() ||
-           GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
+           config_proto.experimental().mlir_bridge_rollout() ==
+               tensorflow::ConfigProto::Experimental::
+                   MLIR_BRIDGE_ROLLOUT_ENABLED ||
+           tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
                tensorflow::ConfigProto::Experimental::
                    MLIR_BRIDGE_ROLLOUT_ENABLED;
   }
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index a15e37a9151..026dbce321d 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -948,7 +948,12 @@ class Context(object):
     if self._log_device_placement is not None:
       config.log_device_placement = self._log_device_placement
 
-    config.experimental.enable_mlir_bridge = pywrap_tfe.TF_IsMlirBridgeEnabled()
+    is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled()
+    config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled
+    if (is_mlir_bridge_enabled ==
+        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED):
+      config.experimental.enable_mlir_bridge = True
+
     if self._enable_mlir_graph_optimization is not None:
       config.experimental.enable_mlir_graph_optimization = (
           self._enable_mlir_graph_optimization)
diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py
index a20af802824..7dd26425037 100644
--- a/tensorflow/python/framework/config_test.py
+++ b/tensorflow/python/framework/config_test.py
@@ -214,14 +214,23 @@ class ConfigTest(test.TestCase, parameterized.TestCase):
   def testEnableMlirBridge(self):
     # Default value of enable_mlir_bridge is false.
     self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
+    self.assertEqual(
+        context.context().config.experimental.mlir_bridge_rollout,
+        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_UNSPECIFIED)
 
     # Tests enabling mlir bridge.
     config.enable_mlir_bridge()
     self.assertTrue(context.context().config.experimental.enable_mlir_bridge)
+    self.assertEqual(
+        context.context().config.experimental.mlir_bridge_rollout,
+        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED)
 
     # Tests disabling mlir bridge.
     config.disable_mlir_bridge()
     self.assertFalse(context.context().config.experimental.enable_mlir_bridge)
+    self.assertEqual(
+        context.context().config.experimental.mlir_bridge_rollout,
+        config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_DISABLED)
 
   @reset_eager
   def testEnableMlirGraphOptimization(self):
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index e8e66bcef28..980695f28bb 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -580,8 +580,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
 
   // MLIR Logic
   m.def("TF_IsMlirBridgeEnabled", [] {
-    return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
-           tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
+    // Since python protobuf enums are integers, cast to an integer before
+    // returning the enum to python.
+    return static_cast<int32_t>(
+        tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
   });
   m.def("TF_EnableMlirBridge", [](bool enabled) {
     tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =