diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 998b014264d..cfd0cdef07c 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -343,6 +343,7 @@ cc_library(
         "//tensorflow/core:protos_all_cc",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
@@ -356,7 +357,9 @@ cc_library(
         "//tensorflow/compiler/xla:parse_flags_from_env",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
+        "//tensorflow/core/protobuf:for_core_protos_cc",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:optional",
     ],
 )
 
diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc
index 15bd5340503..9112b8d021b 100644
--- a/tensorflow/compiler/jit/flags.cc
+++ b/tensorflow/compiler/jit/flags.cc
@@ -292,6 +292,37 @@ MlirCommonFlags* GetMlirCommonFlags() {
   return mlir_flags;
 }
 
+ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
+    absl::optional<const ConfigProto> config_proto) {
+  // TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
+  // can enable/disable the graph via tf_mlir_enable_mlir_bridge.
+  auto tf_mlir_enable_mlir_bridge =
+      GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
+  if (tf_mlir_enable_mlir_bridge !=
+      ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
+    return tf_mlir_enable_mlir_bridge;
+  }
+
+  // If a ConfigProto was not passed in, we can assume the caller is
+  // checking if TF2 graph should have the bridge enabled / disabled. In that
+  // case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
+  // return UNSPECIFIED here.
+  if (!config_proto.has_value()) {
+    return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
+  }
+
+  // TF1 graphs that do override Session's ConfigProto and set
+  // ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
+  // update tf_mlir_enable_mlir_bridge so check their values.
+
+  // ConfigProto's enable_mlir_bridge defaults to false so only respect it
+  // when it is true.
+  if (config_proto.value().experimental().enable_mlir_bridge()) {
+    return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
+  }
+  return config_proto.value().experimental().mlir_bridge_rollout();
+}
+
 void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
   absl::call_once(flags_init, &AllocateAndParseFlags);
   AppendMarkForCompilationPassFlagsInternal(flag_list);
diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h
index b54dcf942c7..ef4d89b2b56 100644
--- a/tensorflow/compiler/jit/flags.h
+++ b/tensorflow/compiler/jit/flags.h
@@ -18,6 +18,7 @@ limitations under the License.
 
 #include <vector>
 
+#include "absl/types/optional.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/protobuf/config.pb.h"
 #include "tensorflow/core/util/command_line_flags.h"
@@ -156,6 +157,11 @@ GetIntroduceFloatingPointJitterPassFlags();
 
 MlirCommonFlags* GetMlirCommonFlags();
 
+// Returns the effective MLIR bridge rollout state based on the flags and the
+// optional configuration.
+ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
+    absl::optional<const ConfigProto> config_proto);
+
 // Appends the flag definitions associated with
 // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
 //