diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 998b014264d..cfd0cdef07c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -343,6 +343,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -356,7 +357,9 @@ cc_library( "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 15bd5340503..9112b8d021b 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -292,6 +292,37 @@ MlirCommonFlags* GetMlirCommonFlags() { return mlir_flags; } +ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState( + absl::optional<const ConfigProto> config_proto) { + // TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs + // can enable/disable the graph via tf_mlir_enable_mlir_bridge. + auto tf_mlir_enable_mlir_bridge = + GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; + if (tf_mlir_enable_mlir_bridge != + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) { + return tf_mlir_enable_mlir_bridge; + } + + // If a ConfigProto was not passed in, we can assume the caller is + // checking if TF2 graph should have the bridge enabled / disabled. In that + // case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to + // return UNSPECIFIED here. + if (!config_proto.has_value()) { + return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; + } + + // TF1 graphs that do override Session's ConfigProto and set + // ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not + // update tf_mlir_enable_mlir_bridge so check their values. + + // ConfigProto's enable_mlir_bridge defaults to false so only respect it + // when it is true. + if (config_proto.value().experimental().enable_mlir_bridge()) { + return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + } + return config_proto.value().experimental().mlir_bridge_rollout(); +} + void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) { absl::call_once(flags_init, &AllocateAndParseFlags); AppendMarkForCompilationPassFlagsInternal(flag_list); diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index b54dcf942c7..ef4d89b2b56 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -18,6 +18,7 @@ limitations under the License. #include <vector> +#include "absl/types/optional.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/command_line_flags.h" @@ -156,6 +157,11 @@ GetIntroduceFloatingPointJitterPassFlags(); MlirCommonFlags* GetMlirCommonFlags(); +// Returns the effective MLIR bridge rollout state based on the flags and the +// optional configuration. +ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState( + absl::optional<const ConfigProto> config_proto); + // Appends the flag definitions associated with // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. //