Define helper function for effective MlirBridgeRollout enum

This decides the effective rollout based on ConfigProto and flags. Planning to use this from _TpuCompileMlir op to fallback to the old bridge if safe mode is enabled.

PiperOrigin-RevId: 360405556
Change-Id: I885e0a7cab3555946a8f2d39e9640ec28d9d653f
This commit is contained in:
Smit Hinsu 2021-03-02 05:23:34 -08:00 committed by TensorFlower Gardener
parent e305de9662
commit bf6f03809e
3 changed files with 40 additions and 0 deletions
tensorflow/compiler/jit

View File

@ -343,6 +343,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)
@ -356,7 +357,9 @@ cc_library(
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/protobuf:for_core_protos_cc",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -292,6 +292,37 @@ MlirCommonFlags* GetMlirCommonFlags() {
return mlir_flags;
}
ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
absl::optional<const ConfigProto> config_proto) {
// TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
// can enable/disable the graph via tf_mlir_enable_mlir_bridge.
auto tf_mlir_enable_mlir_bridge =
GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
if (tf_mlir_enable_mlir_bridge !=
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
return tf_mlir_enable_mlir_bridge;
}
// If a ConfigProto was not passed in, we can assume the caller is
// checking if TF2 graph should have the bridge enabled / disabled. In that
// case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
// return UNSPECIFIED here.
if (!config_proto.has_value()) {
return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
}
// TF1 graphs that do override Session's ConfigProto and set
// ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
// update tf_mlir_enable_mlir_bridge so check their values.
// ConfigProto's enable_mlir_bridge defaults to false so only respect it
// when it is true.
if (config_proto.value().experimental().enable_mlir_bridge()) {
return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
}
return config_proto.value().experimental().mlir_bridge_rollout();
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/command_line_flags.h"
@ -156,6 +157,11 @@ GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Returns the effective MLIR bridge rollout state based on the flags and the
// optional configuration.
ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
absl::optional<const ConfigProto> config_proto);
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
//