Update tf_mlir_enable_mlir_bridge support unspecified

The existing tf_mlir_enable_mlir_bridge flag allows models to
selectively enable or disable the model via TF_XLA_FLAGS. If the
flag is not set, it defaults to false.

In order to slowly and safely rollout the mlir_bridge, we will
need to distinguish between unspecified and forcibly disabled.
If the flag is unspecified, we can selectively choose when the
bridge is enabled. This will allow us to slowly ramp up the
number of models that use the new bridge.

This patch continues to support the existing TF_XLA_FLAG
interface (tf_mlir_enable_mlir_bridge can be set to true or false)
but internally, TensorFlow can now distinguish between false
(forcibly disabled) and unset (unspecified).

PiperOrigin-RevId: 337523318
Change-Id: I8ebb49da104663e12e5c1fa6399a1bf79239a44f
This commit is contained in:
Marissa Ikonomidis 2020-10-16 09:46:18 -07:00 committed by TensorFlower Gardener
parent c68a4daebf
commit 0c4416e3c2
10 changed files with 78 additions and 29 deletions

View File

@ -283,6 +283,7 @@ cc_library(
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],

View File

@ -167,8 +167,8 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
bool enable_mlir_bridge = false;
bool enable_mlir_bridge_flag_updated = false;
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
@ -217,12 +217,24 @@ void AllocateAndParseFlags() {
"The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
&enable_mlir_bridge_flag_updated)});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
mlir_flags = new MlirCommonFlags;
if (!enable_mlir_bridge_flag_updated) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
} else if (enable_mlir_bridge) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
} else {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
}
}
} // namespace

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
@ -135,7 +136,7 @@ struct IntroduceFloatingPointJitterPassFlags {
// Flags for common MLIR configurations.
struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge;
ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct;

View File

@ -89,7 +89,8 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled.
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge !=
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>

View File

@ -31,7 +31,9 @@ class MlirBridgePass : public MlirOptimizationPass {
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge() ||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_ENABLED;
}
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
@ -48,7 +50,9 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge() ||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_ENABLED;
}
// This should be used as a thin mapper around mlir::ModulePass::runOnModule

View File

@ -754,13 +754,15 @@ Status XlaCompiler::CompileFunction(
VLOG(1) << "====================================================";
#ifdef LIBTPU_ON_GCE
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
VLOG(1) << "MLIR is not supported in this environment.";
}
TF_RETURN_IF_ERROR(
CompileGraph(options, function_id, std::move(graph), args, result));
#else
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
std::vector<std::string> control_rets;

View File

@ -132,51 +132,61 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
} // namespace
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text,
bool* dst_updated)
: name_(name),
type_(TYPE_INT32),
int32_hook_([dst](int32 value) {
int32_hook_([dst, dst_updated](int32 value) {
*dst = value;
if (dst_updated) *dst_updated = true;
return true;
}),
int32_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text,
bool* dst_updated)
: name_(name),
type_(TYPE_INT64),
int64_hook_([dst](int64 value) {
int64_hook_([dst, dst_updated](int64 value) {
*dst = value;
if (dst_updated) *dst_updated = true;
return true;
}),
int64_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, float* dst, const string& usage_text)
Flag::Flag(const char* name, float* dst, const string& usage_text,
bool* dst_updated)
: name_(name),
type_(TYPE_FLOAT),
float_hook_([dst](float value) {
float_hook_([dst, dst_updated](float value) {
*dst = value;
if (dst_updated) *dst_updated = true;
return true;
}),
float_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, bool* dst, const string& usage_text)
Flag::Flag(const char* name, bool* dst, const string& usage_text,
bool* dst_updated)
: name_(name),
type_(TYPE_BOOL),
bool_hook_([dst](bool value) {
bool_hook_([dst, dst_updated](bool value) {
*dst = value;
if (dst_updated) *dst_updated = true;
return true;
}),
bool_default_for_display_(*dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, string* dst, const string& usage_text)
Flag::Flag(const char* name, string* dst, const string& usage_text,
bool* dst_updated)
: name_(name),
type_(TYPE_STRING),
string_hook_([dst](string value) {
string_hook_([dst, dst_updated](string value) {
*dst = std::move(value);
if (dst_updated) *dst_updated = true;
return true;
}),
string_default_for_display_(*dst),

View File

@ -62,11 +62,16 @@ namespace tensorflow {
// text, and a pointer to the corresponding variable.
class Flag {
public:
Flag(const char* name, int32* dst, const string& usage_text);
Flag(const char* name, int64* dst, const string& usage_text);
Flag(const char* name, bool* dst, const string& usage_text);
Flag(const char* name, string* dst, const string& usage_text);
Flag(const char* name, float* dst, const string& usage_text);
Flag(const char* name, int32* dst, const string& usage_text,
bool* dst_updated = nullptr);
Flag(const char* name, int64* dst, const string& usage_text,
bool* dst_updated = nullptr);
Flag(const char* name, bool* dst, const string& usage_text,
bool* dst_updated = nullptr);
Flag(const char* name, string* dst, const string& usage_text,
bool* dst_updated = nullptr);
Flag(const char* name, float* dst, const string& usage_text,
bool* dst_updated = nullptr);
// These constructors invoke a hook on a match instead of writing to a
// specific memory location. The hook may return false to signal a malformed
@ -85,6 +90,8 @@ class Flag {
Flag(const char* name, std::function<bool(string)> string_hook,
string default_value_for_display, const string& usage_text);
bool is_default_initialized() const { return default_initialized_; }
private:
friend class Flags;
@ -115,6 +122,7 @@ class Flag {
string string_default_for_display_;
string usage_text_;
bool default_initialized_ = true;
};
class Flags {

View File

@ -108,7 +108,7 @@ except Exception: # pylint: disable=broad-except
# Uses the same mechanism as above to selectively enable/disable MLIR
# compilation.
def is_mlir_bridge_enabled():
return False
return None
try:
@ -2022,8 +2022,13 @@ class TensorFlowTestCase(googletest.TestCase):
# disable it here.
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
# Check if the mlir bridge has been explicitly enabled or disabled. If
# is_mlir_bridge_enabled() returns None, the user did not explictly enable
# or disable the bridge so do not update enable_mlir_bridge.
if is_mlir_bridge_enabled():
context.context().enable_mlir_bridge = True
elif is_mlir_bridge_enabled() is not None:
context.context().enable_mlir_bridge = False
self._threads = []
self._tempdir = None

View File

@ -580,10 +580,15 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
// MLIR Logic
m.def("TF_IsMlirBridgeEnabled", [] {
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
});
m.def("TF_EnableMlirBridge", [](bool enabled) {
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
enabled
? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
: tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_DISABLED;
});
m.def("TF_EnableXlaDevices", [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;