Update tf_mlir_enable_mlir_bridge support unspecified
The existing tf_mlir_enable_mlir_bridge flag allows models to selectively enable or disable the model via TF_XLA_FLAGS. If the flag is not set, it defaults to false. In order to slowly and safely rollout the mlir_bridge, we will need to distinguish between unspecified and forcibly disabled. If the flag is unspecified, we can selectively choose when the bridge is enabled. This will allow us to slowly ramp up the number of models that use the new bridge. This patch continues to support the existing TF_XLA_FLAG interface (tf_mlir_enable_mlir_bridge can be set to true or false) but internally, TensorFlow can now distinguish between false (forcibly disabled) and unset (unspecified). PiperOrigin-RevId: 337523318 Change-Id: I8ebb49da104663e12e5c1fa6399a1bf79239a44f
This commit is contained in:
parent
c68a4daebf
commit
0c4416e3c2
tensorflow
compiler
core/util
python
@ -283,6 +283,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -167,8 +167,8 @@ void AllocateAndParseFlags() {
|
||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||
jitter_flags->jitter_amount = 1e-5;
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge = false;
|
||||
bool enable_mlir_bridge = false;
|
||||
bool enable_mlir_bridge_flag_updated = false;
|
||||
|
||||
auto setter_for_jitter_tensor_names = [](string sequence) {
|
||||
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
||||
@ -217,12 +217,24 @@ void AllocateAndParseFlags() {
|
||||
"The amount of jitter to introduce. This amount is added to each "
|
||||
"element in the tensors named in `tensor_names."),
|
||||
|
||||
Flag("tf_mlir_enable_mlir_bridge",
|
||||
&mlir_flags->tf_mlir_enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
|
||||
Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
|
||||
&enable_mlir_bridge_flag_updated)});
|
||||
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
if (!enable_mlir_bridge_flag_updated) {
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge =
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
|
||||
} else if (enable_mlir_bridge) {
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge =
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
} else {
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge =
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -135,7 +136,7 @@ struct IntroduceFloatingPointJitterPassFlags {
|
||||
|
||||
// Flags for common MLIR configurations.
|
||||
struct MlirCommonFlags {
|
||||
bool tf_mlir_enable_mlir_bridge;
|
||||
ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
|
||||
};
|
||||
|
||||
// Return a pointer to the DumpGraphFlags struct;
|
||||
|
@ -89,7 +89,8 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
|
||||
// Only check for compilability if the MLIR bridge is not enabled.
|
||||
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge !=
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
|
@ -31,7 +31,9 @@ class MlirBridgePass : public MlirOptimizationPass {
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_bridge() ||
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::
|
||||
MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
}
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
@ -48,7 +50,9 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_bridge() ||
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::
|
||||
MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
}
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
|
@ -754,13 +754,15 @@ Status XlaCompiler::CompileFunction(
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
#ifdef LIBTPU_ON_GCE
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
|
||||
VLOG(1) << "MLIR is not supported in this environment.";
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(options, function_id, std::move(graph), args, result));
|
||||
#else
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
|
||||
VLOG(1) << "Using MLIR bridge";
|
||||
GraphDebugInfo debug_info;
|
||||
std::vector<std::string> control_rets;
|
||||
|
@ -132,51 +132,61 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||
|
||||
} // namespace
|
||||
|
||||
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
|
||||
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text,
|
||||
bool* dst_updated)
|
||||
: name_(name),
|
||||
type_(TYPE_INT32),
|
||||
int32_hook_([dst](int32 value) {
|
||||
int32_hook_([dst, dst_updated](int32 value) {
|
||||
*dst = value;
|
||||
if (dst_updated) *dst_updated = true;
|
||||
return true;
|
||||
}),
|
||||
int32_default_for_display_(*dst),
|
||||
usage_text_(usage_text) {}
|
||||
|
||||
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
|
||||
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text,
|
||||
bool* dst_updated)
|
||||
: name_(name),
|
||||
type_(TYPE_INT64),
|
||||
int64_hook_([dst](int64 value) {
|
||||
int64_hook_([dst, dst_updated](int64 value) {
|
||||
*dst = value;
|
||||
if (dst_updated) *dst_updated = true;
|
||||
return true;
|
||||
}),
|
||||
int64_default_for_display_(*dst),
|
||||
usage_text_(usage_text) {}
|
||||
|
||||
Flag::Flag(const char* name, float* dst, const string& usage_text)
|
||||
Flag::Flag(const char* name, float* dst, const string& usage_text,
|
||||
bool* dst_updated)
|
||||
: name_(name),
|
||||
type_(TYPE_FLOAT),
|
||||
float_hook_([dst](float value) {
|
||||
float_hook_([dst, dst_updated](float value) {
|
||||
*dst = value;
|
||||
if (dst_updated) *dst_updated = true;
|
||||
return true;
|
||||
}),
|
||||
float_default_for_display_(*dst),
|
||||
usage_text_(usage_text) {}
|
||||
|
||||
Flag::Flag(const char* name, bool* dst, const string& usage_text)
|
||||
Flag::Flag(const char* name, bool* dst, const string& usage_text,
|
||||
bool* dst_updated)
|
||||
: name_(name),
|
||||
type_(TYPE_BOOL),
|
||||
bool_hook_([dst](bool value) {
|
||||
bool_hook_([dst, dst_updated](bool value) {
|
||||
*dst = value;
|
||||
if (dst_updated) *dst_updated = true;
|
||||
return true;
|
||||
}),
|
||||
bool_default_for_display_(*dst),
|
||||
usage_text_(usage_text) {}
|
||||
|
||||
Flag::Flag(const char* name, string* dst, const string& usage_text)
|
||||
Flag::Flag(const char* name, string* dst, const string& usage_text,
|
||||
bool* dst_updated)
|
||||
: name_(name),
|
||||
type_(TYPE_STRING),
|
||||
string_hook_([dst](string value) {
|
||||
string_hook_([dst, dst_updated](string value) {
|
||||
*dst = std::move(value);
|
||||
if (dst_updated) *dst_updated = true;
|
||||
return true;
|
||||
}),
|
||||
string_default_for_display_(*dst),
|
||||
|
@ -62,11 +62,16 @@ namespace tensorflow {
|
||||
// text, and a pointer to the corresponding variable.
|
||||
class Flag {
|
||||
public:
|
||||
Flag(const char* name, int32* dst, const string& usage_text);
|
||||
Flag(const char* name, int64* dst, const string& usage_text);
|
||||
Flag(const char* name, bool* dst, const string& usage_text);
|
||||
Flag(const char* name, string* dst, const string& usage_text);
|
||||
Flag(const char* name, float* dst, const string& usage_text);
|
||||
Flag(const char* name, int32* dst, const string& usage_text,
|
||||
bool* dst_updated = nullptr);
|
||||
Flag(const char* name, int64* dst, const string& usage_text,
|
||||
bool* dst_updated = nullptr);
|
||||
Flag(const char* name, bool* dst, const string& usage_text,
|
||||
bool* dst_updated = nullptr);
|
||||
Flag(const char* name, string* dst, const string& usage_text,
|
||||
bool* dst_updated = nullptr);
|
||||
Flag(const char* name, float* dst, const string& usage_text,
|
||||
bool* dst_updated = nullptr);
|
||||
|
||||
// These constructors invoke a hook on a match instead of writing to a
|
||||
// specific memory location. The hook may return false to signal a malformed
|
||||
@ -85,6 +90,8 @@ class Flag {
|
||||
Flag(const char* name, std::function<bool(string)> string_hook,
|
||||
string default_value_for_display, const string& usage_text);
|
||||
|
||||
bool is_default_initialized() const { return default_initialized_; }
|
||||
|
||||
private:
|
||||
friend class Flags;
|
||||
|
||||
@ -115,6 +122,7 @@ class Flag {
|
||||
string string_default_for_display_;
|
||||
|
||||
string usage_text_;
|
||||
bool default_initialized_ = true;
|
||||
};
|
||||
|
||||
class Flags {
|
||||
|
@ -108,7 +108,7 @@ except Exception: # pylint: disable=broad-except
|
||||
# Uses the same mechanism as above to selectively enable/disable MLIR
|
||||
# compilation.
|
||||
def is_mlir_bridge_enabled():
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
@ -2022,8 +2022,13 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
# disable it here.
|
||||
pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
|
||||
|
||||
# Check if the mlir bridge has been explicitly enabled or disabled. If
|
||||
# is_mlir_bridge_enabled() returns None, the user did not explictly enable
|
||||
# or disable the bridge so do not update enable_mlir_bridge.
|
||||
if is_mlir_bridge_enabled():
|
||||
context.context().enable_mlir_bridge = True
|
||||
elif is_mlir_bridge_enabled() is not None:
|
||||
context.context().enable_mlir_bridge = False
|
||||
|
||||
self._threads = []
|
||||
self._tempdir = None
|
||||
|
@ -580,10 +580,15 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
|
||||
// MLIR Logic
|
||||
m.def("TF_IsMlirBridgeEnabled", [] {
|
||||
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
});
|
||||
m.def("TF_EnableMlirBridge", [](bool enabled) {
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
|
||||
enabled
|
||||
? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
|
||||
: tensorflow::ConfigProto::Experimental::
|
||||
MLIR_BRIDGE_ROLLOUT_DISABLED;
|
||||
});
|
||||
m.def("TF_EnableXlaDevices", [] {
|
||||
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
|
||||
|
Loading…
Reference in New Issue
Block a user