Support safe mode in the mlir bridge

Add plumbing to support enabling the mlir bridge on a per graph
basis based on the analysis of the features in the graph. If the
mlir bridge can support all of the features, run the mlir bridge.

PiperOrigin-RevId: 347701287
Change-Id: I2cb26194d0f858f59474952aa29db09ae67692cc
This commit is contained in:
Marissa Ikonomidis 2020-12-15 15:09:18 -08:00 committed by TensorFlower Gardener
parent 3c374ed73b
commit b14587ad60

View File

@ -177,6 +177,7 @@ void AllocateAndParseFlags() {
// bridge, on a per-graph basis).
bool enable_mlir_bridge = false;
bool enable_mlir_bridge_is_explicit = false;
bool mlir_bridge_safe_mode = false;
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
@ -227,7 +228,13 @@ void AllocateAndParseFlags() {
Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
&enable_mlir_bridge_is_explicit)});
&enable_mlir_bridge_is_explicit),
Flag(
"tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode,
"When tf_mlir_enable_mlir_bridge is true, this field can enable "
"the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, "
"it only runs for graphs that use features MLIR bridge currently "
"supports.")});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -238,7 +245,9 @@ void AllocateAndParseFlags() {
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
} else if (enable_mlir_bridge) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
(mlir_bridge_safe_mode)
? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED
: ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
} else {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;