Define helper function for effective MlirBridgeRollout enum
This decides the effective rollout based on ConfigProto and flags. Planning to use this from _TpuCompileMlir op to fallback to the old bridge if safe mode is enabled. PiperOrigin-RevId: 360405556 Change-Id: I885e0a7cab3555946a8f2d39e9640ec28d9d653f
This commit is contained in:
parent
e305de9662
commit
bf6f03809e
@ -343,6 +343,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
@ -356,7 +357,9 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/protobuf:for_core_protos_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -292,6 +292,37 @@ MlirCommonFlags* GetMlirCommonFlags() {
|
||||
return mlir_flags;
|
||||
}
|
||||
|
||||
ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
|
||||
absl::optional<const ConfigProto> config_proto) {
|
||||
// TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
|
||||
// can enable/disable the graph via tf_mlir_enable_mlir_bridge.
|
||||
auto tf_mlir_enable_mlir_bridge =
|
||||
GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
if (tf_mlir_enable_mlir_bridge !=
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
|
||||
return tf_mlir_enable_mlir_bridge;
|
||||
}
|
||||
|
||||
// If a ConfigProto was not passed in, we can assume the caller is
|
||||
// checking if TF2 graph should have the bridge enabled / disabled. In that
|
||||
// case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
|
||||
// return UNSPECIFIED here.
|
||||
if (!config_proto.has_value()) {
|
||||
return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
|
||||
}
|
||||
|
||||
// TF1 graphs that do override Session's ConfigProto and set
|
||||
// ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
|
||||
// update tf_mlir_enable_mlir_bridge so check their values.
|
||||
|
||||
// ConfigProto's enable_mlir_bridge defaults to false so only respect it
|
||||
// when it is true.
|
||||
if (config_proto.value().experimental().enable_mlir_bridge()) {
|
||||
return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
}
|
||||
return config_proto.value().experimental().mlir_bridge_rollout();
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
@ -156,6 +157,11 @@ GetIntroduceFloatingPointJitterPassFlags();
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags();
|
||||
|
||||
// Returns the effective MLIR bridge rollout state based on the flags and the
|
||||
// optional configuration.
|
||||
ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
|
||||
absl::optional<const ConfigProto> config_proto);
|
||||
|
||||
// Appends the flag definitions associated with
|
||||
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
||||
//
|
||||
|
Loading…
Reference in New Issue
Block a user