From c5d4acd09aedb8bb4289e4cc66661da5e235c8e0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 6 Oct 2020 11:16:14 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 335680049 Change-Id: I91e6edc767caf596d3cf1a28c075cc87388043e2 --- tensorflow/compiler/jit/BUILD | 1 - tensorflow/compiler/jit/flags.cc | 25 ++++++------------- tensorflow/compiler/jit/flags.h | 3 +-- tensorflow/compiler/jit/xla_kernel_creator.cc | 3 +-- tensorflow/compiler/tf2xla/mlir_bridge_pass.h | 4 +-- tensorflow/compiler/tf2xla/xla_compiler.cc | 6 ++--- tensorflow/core/util/command_line_flags.cc | 15 ++++------- tensorflow/core/util/command_line_flags.h | 3 --- tensorflow/python/tfe_wrapper.cc | 9 ++----- 9 files changed, 19 insertions(+), 50 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 76b411b554a..da3db1789b5 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -283,7 +283,6 @@ cc_library( "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 01e43b00c86..ee7daf092da 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -167,6 +167,9 @@ void AllocateAndParseFlags() { jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags->jitter_amount = 1e-5; + mlir_flags = new MlirCommonFlags; + mlir_flags->tf_mlir_enable_mlir_bridge = false; + auto setter_for_jitter_tensor_names = [](string sequence) { jitter_flags->tensor_names = absl::StrSplit(sequence, ','); return true; @@ -212,28 +215,14 @@ void AllocateAndParseFlags() { Flag("tf_introduce_floating_point_jitter_amount", &jitter_flags->jitter_amount, "The amount of jitter to introduce. This amount is added to each " - "element in the tensors named in `tensor_names.")}); + "element in the tensors named in `tensor_names."), - bool enable_mlir_bridge = false; - flag_list->emplace_back( - "tf_mlir_enable_mlir_bridge", &enable_mlir_bridge, - "Enables experimental MLIR-Based TensorFlow Compiler Bridge."); - const Flag& enable_mlir_bridge_flag = flag_list->back(); + Flag("tf_mlir_enable_mlir_bridge", + &mlir_flags->tf_mlir_enable_mlir_bridge, + "Enables experimental MLIR-Based TensorFlow Compiler Bridge.")}); AppendMarkForCompilationPassFlagsInternal(flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); - - mlir_flags = new MlirCommonFlags; - if (enable_mlir_bridge_flag.is_default_initialized()) { - mlir_flags->tf_mlir_enable_mlir_bridge = - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; - } else if (enable_mlir_bridge) { - mlir_flags->tf_mlir_enable_mlir_bridge = - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; - } else { - mlir_flags->tf_mlir_enable_mlir_bridge = - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; - } } } // namespace diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index a0860da7b04..5612b3b5864 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { @@ -136,7 +135,7 @@ struct IntroduceFloatingPointJitterPassFlags { // Flags for common MLIR configurations. struct MlirCommonFlags { - ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge; + bool tf_mlir_enable_mlir_bridge; }; // Return a pointer to the DumpGraphFlags struct; diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 0ccdacbfd02..d4a69da4898 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -89,8 +89,7 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, XlaOpRegistry::RegisterCompilationKernels(); // Only check for compilability if the MLIR bridge is not enabled. - if (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge != - tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { + if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { std::vector diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index 8efe2d6a872..f7541e634d4 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -47,9 +47,7 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { bool IsEnabled(const ConfigProto& config_proto) const override { return config_proto.experimental().enable_mlir_bridge() || - tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == - tensorflow::ConfigProto::Experimental:: - MLIR_BRIDGE_ROLLOUT_ENABLED; + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; } // This should be used as a thin mapper around mlir::ModulePass::runOnModule diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 549d631eff1..c62b8286bbe 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -734,15 +734,13 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "===================================================="; #ifdef LIBTPU_ON_GCE - if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == - tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { VLOG(1) << "MLIR is not supported in this environment."; } TF_RETURN_IF_ERROR( CompileGraph(options, function_id, std::move(graph), args, result)); #else - if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == - tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; TF_RETURN_IF_ERROR(CompileGraphToXlaHlo( diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 83bb300ae40..00a9cbaa3d8 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -135,9 +135,8 @@ bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text) : name_(name), type_(TYPE_INT32), - int32_hook_([this, dst](int32 value) { + int32_hook_([dst](int32 value) { *dst = value; - this->default_initialized_ = false; return true; }), int32_default_for_display_(*dst), @@ -146,9 +145,8 @@ Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text) Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text) : name_(name), type_(TYPE_INT64), - int64_hook_([this, dst](int64 value) { + int64_hook_([dst](int64 value) { *dst = value; - this->default_initialized_ = false; return true; }), int64_default_for_display_(*dst), @@ -157,9 +155,8 @@ Flag::Flag(const char* name, tensorflow::int64* dst, const string& usage_text) Flag::Flag(const char* name, float* dst, const string& usage_text) : name_(name), type_(TYPE_FLOAT), - float_hook_([this, dst](float value) { + float_hook_([dst](float value) { *dst = value; - this->default_initialized_ = false; return true; }), float_default_for_display_(*dst), @@ -168,9 +165,8 @@ Flag::Flag(const char* name, float* dst, const string& usage_text) Flag::Flag(const char* name, bool* dst, const string& usage_text) : name_(name), type_(TYPE_BOOL), - bool_hook_([this, dst](bool value) { + bool_hook_([dst](bool value) { *dst = value; - this->default_initialized_ = false; return true; }), bool_default_for_display_(*dst), @@ -179,9 +175,8 @@ Flag::Flag(const char* name, bool* dst, const string& usage_text) Flag::Flag(const char* name, string* dst, const string& usage_text) : name_(name), type_(TYPE_STRING), - string_hook_([this, dst](string value) { + string_hook_([dst](string value) { *dst = std::move(value); - this->default_initialized_ = false; return true; }), string_default_for_display_(*dst), diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h index 9011ee126a9..928ae8a4e94 100644 --- a/tensorflow/core/util/command_line_flags.h +++ b/tensorflow/core/util/command_line_flags.h @@ -85,8 +85,6 @@ class Flag { Flag(const char* name, std::function string_hook, string default_value_for_display, const string& usage_text); - bool is_default_initialized() const { return default_initialized_; } - private: friend class Flags; @@ -117,7 +115,6 @@ class Flag { string string_default_for_display_; string usage_text_; - bool default_initialized_ = true; }; class Flags { diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 1dd046f7667..36165deeaad 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -580,15 +580,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) { // MLIR Logic m.def("TF_IsMlirBridgeEnabled", [] { - return (tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge == - tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED); + return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; }); m.def("TF_EnableMlirBridge", [](bool enabled) { - tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = - enabled - ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED - : tensorflow::ConfigProto::Experimental:: - MLIR_BRIDGE_ROLLOUT_DISABLED; + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled; }); m.def("TF_EnableXlaDevices", [] { tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;