Internal change

PiperOrigin-RevId: 335680049
Change-Id: I91e6edc767caf596d3cf1a28c075cc87388043e2
This commit is contained in:
A. Unique TensorFlower 2020-10-06 11:16:14 -07:00 committed by TensorFlower Gardener
parent 9340214eef
commit c5d4acd09a
9 changed files with 19 additions and 50 deletions

View File

@ -283,7 +283,6 @@ cc_library(
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],

View File

@ -167,6 +167,9 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true;
@ -212,28 +215,14 @@ void AllocateAndParseFlags() {
Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")});
"element in the tensors named in `tensor_names."),
bool enable_mlir_bridge = false;
flag_list->emplace_back(
"tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.");
const Flag& enable_mlir_bridge_flag = flag_list->back();
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
mlir_flags = new MlirCommonFlags;
if (enable_mlir_bridge_flag.is_default_initialized()) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
} else if (enable_mlir_bridge) {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
} else {
mlir_flags->tf_mlir_enable_mlir_bridge =
ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
}
}
} // namespace

View File

@ -19,7 +19,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
@ -136,7 +135,7 @@ struct IntroduceFloatingPointJitterPassFlags {
// Flags for common MLIR configurations.
struct MlirCommonFlags {
ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct;

View File

@ -89,8 +89,7 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled.
if (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge !=
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>

View File

@ -47,9 +47,7 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass {
bool IsEnabled(const ConfigProto& config_proto) const override {
return config_proto.experimental().enable_mlir_bridge() ||
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_ENABLED;
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
}
// This should be used as a thin mapper around mlir::ModulePass::runOnModule

View File

@ -734,15 +734,13 @@ Status XlaCompiler::CompileFunction(
VLOG(1) << "====================================================";
#ifdef LIBTPU_ON_GCE
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
VLOG(1) << "MLIR is not supported in this environment.";
}
TF_RETURN_IF_ERROR(
CompileGraph(options, function_id, std::move(graph), args, result));
#else
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) {
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(

View File

@ -135,9 +135,8 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
: name_(name),
type_(TYPE_INT32),
int32_hook_([this, dst](int32 value) {
int32_hook_([dst](int32 value) {
*dst = value;
this->default_initialized_ = false;
return true;
}),
int32_default_for_display_(*dst),
@ -146,9 +145,8 @@ Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
: name_(name),
type_(TYPE_INT64),
int64_hook_([this, dst](int64 value) {
int64_hook_([dst](int64 value) {
*dst = value;
this->default_initialized_ = false;
return true;
}),
int64_default_for_display_(*dst),
@ -157,9 +155,8 @@ Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text)
Flag::Flag(const char* name, float* dst, const string& usage_text)
: name_(name),
type_(TYPE_FLOAT),
float_hook_([this, dst](float value) {
float_hook_([dst](float value) {
*dst = value;
this->default_initialized_ = false;
return true;
}),
float_default_for_display_(*dst),
@ -168,9 +165,8 @@ Flag::Flag(const char* name, float* dst, const string& usage_text)
Flag::Flag(const char* name, bool* dst, const string& usage_text)
: name_(name),
type_(TYPE_BOOL),
bool_hook_([this, dst](bool value) {
bool_hook_([dst](bool value) {
*dst = value;
this->default_initialized_ = false;
return true;
}),
bool_default_for_display_(*dst),
@ -179,9 +175,8 @@ Flag::Flag(const char* name, bool* dst, const string& usage_text)
Flag::Flag(const char* name, string* dst, const string& usage_text)
: name_(name),
type_(TYPE_STRING),
string_hook_([this, dst](string value) {
string_hook_([dst](string value) {
*dst = std::move(value);
this->default_initialized_ = false;
return true;
}),
string_default_for_display_(*dst),

View File

@ -85,8 +85,6 @@ class Flag {
Flag(const char* name, std::function<bool(string)> string_hook,
string default_value_for_display, const string& usage_text);
bool is_default_initialized() const { return default_initialized_; }
private:
friend class Flags;
@ -117,7 +115,6 @@ class Flag {
string string_default_for_display_;
string usage_text_;
bool default_initialized_ = true;
};
class Flags {

View File

@ -580,15 +580,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
// MLIR Logic
m.def("TF_IsMlirBridgeEnabled", [] {
return (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge ==
tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED);
return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
});
m.def("TF_EnableMlirBridge", [](bool enabled) {
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
enabled
? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
: tensorflow::ConfigProto::Experimental::
MLIR_BRIDGE_ROLLOUT_DISABLED;
tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled;
});
m.def("TF_EnableXlaDevices", [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;