Internal change
PiperOrigin-RevId: 335680049 Change-Id: I91e6edc767caf596d3cf1a28c075cc87388043e2
This commit is contained in:
parent
9340214eef
commit
c5d4acd09a
@ -283,7 +283,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -167,6 +167,9 @@ void AllocateAndParseFlags() {
|
||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||
jitter_flags->jitter_amount = 1e-5;
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge = false;
|
||||
|
||||
auto setter_for_jitter_tensor_names = [](string sequence) {
|
||||
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
||||
return true;
|
||||
@ -212,28 +215,14 @@ void AllocateAndParseFlags() {
|
||||
Flag("tf_introduce_floating_point_jitter_amount",
|
||||
&jitter_flags->jitter_amount,
|
||||
"The amount of jitter to introduce. This amount is added to each "
|
||||
"element in the tensors named in `tensor_names.")});
|
||||
"element in the tensors named in `tensor_names."),
|
||||
|
||||
bool enable_mlir_bridge = false;
|
||||
flag_list->emplace_back(
|
||||
"tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.");
|
||||
const Flag& enable_mlir_bridge_flag = flag_list->back();
|
||||
Flag("tf_mlir_enable_mlir_bridge",
|
||||
&mlir_flags->tf_mlir_enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
|
||||
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
if (enable_mlir_bridge_flag.is_default_initialized()) {
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge =
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
|
||||
} else if (enable_mlir_bridge) {
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge =
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
} else {
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge =
|
||||
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -136,7 +135,7 @@ struct IntroduceFloatingPointJitterPassFlags {
|
||||
|
||||
// Flags for common MLIR configurations.
|
||||
struct MlirCommonFlags {
|
||||
ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
|
||||
bool tf_mlir_enable_mlir_bridge;
|
||||
};
|
||||
|
||||
// Return a pointer to the DumpGraphFlags struct;
|
||||
|
@ -89,8 +89,7 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
|
||||
// Only check for compilability if the MLIR bridge is not enabled.
|
||||
if (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge !=
|
||||
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
|
||||
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
|
@ -47,9 +47,7 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
|
||||
|
||||
bool IsEnabled(const ConfigProto& config_proto) const override {
|
||||
return config_proto.experimental().enable_mlir_bridge() ||
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::
|
||||
MLIR_BRIDGE_ROLLOUT_ENABLED;
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
}
|
||||
|
||||
// This should be used as a thin mapper around mlir::ModulePass::runOnModule
|
||||
|
@ -734,15 +734,13 @@ Status XlaCompiler::CompileFunction(
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
#ifdef LIBTPU_ON_GCE
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||
VLOG(1) << "MLIR is not supported in this environment.";
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(options, function_id, std::move(graph), args, result));
|
||||
#else
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||
VLOG(1) << "Using MLIR bridge";
|
||||
GraphDebugInfo debug_info;
|
||||
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
|
||||
|
@ -135,9 +135,8 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
|
||||
: name_(name),
|
||||
type_(TYPE_INT32),
|
||||
int32_hook_([this, dst](int32 value) {
|
||||
int32_hook_([dst](int32 value) {
|
||||
*dst = value;
|
||||
this->default_initialized_ = false;
|
||||
return true;
|
||||
}),
|
||||
int32_default_for_display_(*dst),
|
||||
@ -146,9 +145,8 @@ Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
|
||||
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
|
||||
: name_(name),
|
||||
type_(TYPE_INT64),
|
||||
int64_hook_([this, dst](int64 value) {
|
||||
int64_hook_([dst](int64 value) {
|
||||
*dst = value;
|
||||
this->default_initialized_ = false;
|
||||
return true;
|
||||
}),
|
||||
int64_default_for_display_(*dst),
|
||||
@ -157,9 +155,8 @@ Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
|
||||
Flag::Flag(const char* name, float* dst, const string& usage_text)
|
||||
: name_(name),
|
||||
type_(TYPE_FLOAT),
|
||||
float_hook_([this, dst](float value) {
|
||||
float_hook_([dst](float value) {
|
||||
*dst = value;
|
||||
this->default_initialized_ = false;
|
||||
return true;
|
||||
}),
|
||||
float_default_for_display_(*dst),
|
||||
@ -168,9 +165,8 @@ Flag::Flag(const char* name, float* dst, const string& usage_text)
|
||||
Flag::Flag(const char* name, bool* dst, const string& usage_text)
|
||||
: name_(name),
|
||||
type_(TYPE_BOOL),
|
||||
bool_hook_([this, dst](bool value) {
|
||||
bool_hook_([dst](bool value) {
|
||||
*dst = value;
|
||||
this->default_initialized_ = false;
|
||||
return true;
|
||||
}),
|
||||
bool_default_for_display_(*dst),
|
||||
@ -179,9 +175,8 @@ Flag::Flag(const char* name, bool* dst, const string& usage_text)
|
||||
Flag::Flag(const char* name, string* dst, const string& usage_text)
|
||||
: name_(name),
|
||||
type_(TYPE_STRING),
|
||||
string_hook_([this, dst](string value) {
|
||||
string_hook_([dst](string value) {
|
||||
*dst = std::move(value);
|
||||
this->default_initialized_ = false;
|
||||
return true;
|
||||
}),
|
||||
string_default_for_display_(*dst),
|
||||
|
@ -85,8 +85,6 @@ class Flag {
|
||||
Flag(const char* name, std::function<bool(string)> string_hook,
|
||||
string default_value_for_display, const string& usage_text);
|
||||
|
||||
bool is_default_initialized() const { return default_initialized_; }
|
||||
|
||||
private:
|
||||
friend class Flags;
|
||||
|
||||
@ -117,7 +115,6 @@ class Flag {
|
||||
string string_default_for_display_;
|
||||
|
||||
string usage_text_;
|
||||
bool default_initialized_ = true;
|
||||
};
|
||||
|
||||
class Flags {
|
||||
|
@ -580,15 +580,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
|
||||
// MLIR Logic
|
||||
m.def("TF_IsMlirBridgeEnabled", [] {
|
||||
return (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
|
||||
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED);
|
||||
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
|
||||
});
|
||||
m.def("TF_EnableMlirBridge", [](bool enabled) {
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
|
||||
enabled
|
||||
? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
|
||||
: tensorflow::ConfigProto::Experimental::
|
||||
MLIR_BRIDGE_ROLLOUT_DISABLED;
|
||||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
|
||||
});
|
||||
m.def("TF_EnableXlaDevices", [] {
|
||||
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
|
||||
|
Loading…
x
Reference in New Issue
Block a user