From b9d5e144193f587020c1bac5d7505af88baa24d9 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Wed, 7 Jun 2017 09:10:01 -0700 Subject: [PATCH] [XLA] Start collecting flags for debug options in a single place. ClientLibraryTestBase will now parse command-line flags for debug options automatically, permitting subclasses to override certain options by using mutable_debug_options. main() still has to call AppendDebugOptionsFlags() explicitly before running the TF flag parser. In the mean-time, this CL leaves flag handling to the current "legacy" approach. However, this is part of a larger plan to move *all* debugging flags for XLA into the DebugOptions message and expose them as flags from a single place. The other flags (which are not controlling debugging options) will have to be propagated more explicitly. PiperOrigin-RevId: 158276294 --- tensorflow/compiler/aot/BUILD | 2 +- tensorflow/compiler/aot/tfcompile_main.cc | 4 +- tensorflow/compiler/xla/legacy_flags/BUILD | 26 +++--- .../xla/legacy_flags/debug_options_flags.cc | 84 +++++++++++++++++++ .../xla/legacy_flags/debug_options_flags.h | 38 +++++++++ .../legacy_flags/hlo_pass_pipeline_flags.cc | 62 -------------- .../legacy_flags/hlo_pass_pipeline_flags.h | 48 ----------- tensorflow/compiler/xla/service/BUILD | 1 - .../compiler/xla/service/hlo_pass_pipeline.cc | 27 ++---- tensorflow/compiler/xla/tests/BUILD | 4 +- .../xla/tests/client_library_test_base.cc | 7 +- .../xla/tests/client_library_test_base.h | 6 +- .../xla/tests/compute_constant_test.cc | 1 - tensorflow/compiler/xla/tests/convert_test.cc | 8 +- tensorflow/compiler/xla/tests/map_test.cc | 16 ++-- .../xla/tests/vector_ops_simple_test.cc | 8 +- 16 files changed, 174 insertions(+), 168 deletions(-) create mode 100644 tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc create mode 100644 tensorflow/compiler/xla/legacy_flags/debug_options_flags.h delete mode 100644 tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc delete mode 100644 tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 5e368749a09..71c6b17d51b 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -128,8 +128,8 @@ cc_library( "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", - "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/compiler/xla/legacy_flags:util_flags", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 63ec261e01d..6fed46b4329 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -28,8 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" #include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" #include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/legacy_flags/util_flags.h" @@ -142,7 +142,7 @@ int main(int argc, char** argv) { xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list); - xla::legacy_flags::AppendHloPassPipelineFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendLlvmUtilFlags(&flag_list); xla::legacy_flags::AppendServiceFlags(&flag_list); xla::legacy_flags::AppendUtilFlags(&flag_list); diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 017cb5bb0ed..b124e2d4251 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -65,6 +65,20 @@ cc_library( ], ) +cc_library( + name = "debug_options_flags", + srcs = ["debug_options_flags.cc"], + hdrs = ["debug_options_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "cpu_compiler_flags", srcs = ["cpu_compiler_flags.cc"], @@ -160,18 +174,6 @@ cc_library( ], ) -cc_library( - name = "hlo_pass_pipeline_flags", - srcs = ["hlo_pass_pipeline_flags.cc"], - hdrs = ["hlo_pass_pipeline_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "alias_analysis_flags", srcs = ["alias_analysis_flags.cc"], diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc new file mode 100644 index 00000000000..0211462cb1a --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" + +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace legacy_flags { + +struct DebugOptionsFlags { + string xla_generate_hlo_graph; + + string xla_disable_hlo_passes; +}; + +namespace { + +DebugOptionsFlags* flag_values; +std::vector* flag_objects; +std::once_flag flags_init; + +// Allocates flag_values and flag_objects; this function must not be called more +// than once - its call done via call_once. +void AllocateFlags() { + flag_values = new DebugOptionsFlags; + flag_values->xla_generate_hlo_graph = ""; + flag_values->xla_disable_hlo_passes = ""; + + flag_objects = new std::vector( + {tensorflow::Flag( + "xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph, + "HLO modules matching this regex will be dumped to a .dot file " + "throughout various stages in compilation."), + + tensorflow::Flag( + "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes, + "Comma-separated list of HLO passes to be disabled. These names " + "must " + "exactly match the passes' names; no whitespace around commas.")}); + ParseFlagsFromEnv(*flag_objects); +} + +} // namespace + +void AppendDebugOptionsFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateFlags); + flag_list->insert(flag_list->end(), flag_objects->begin(), + flag_objects->end()); +} + +xla::DebugOptions GetDebugOptionsFromFlags() { + std::call_once(flags_init, &AllocateFlags); + + DebugOptions options; + + options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph); + + std::vector disabled_passes = + tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ','); + for (const auto& passname : disabled_passes) { + options.add_xla_disable_hlo_passes(passname); + } + + return options; +} + +} // namespace legacy_flags +} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h new file mode 100644 index 00000000000..d0ef8e66ab0 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ + +#include + +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace legacy_flags { + +// Appends flag definitions for debug options to flag_list. +void AppendDebugOptionsFlags(std::vector* flag_list); + +// Fetches a DebugOptions proto message from flags provided to the program. +// Flags must be registered with the flags parser using AppendDebugOptionsFlags +// first. +xla::DebugOptions GetDebugOptionsFromFlags(); + +} // namespace legacy_flags +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc deleted file mode 100644 index edc04d51a70..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's hlo_pass_pipeline module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static HloPassPipelineFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new HloPassPipelineFlags; - flags->xla_disable_hlo_passes = ""; - flag_list = new std::vector({ - tensorflow::Flag("xla_disable_hlo_passes", &flags->xla_disable_hlo_passes, - "Comma-separated list of HLO passes to disable."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline -// module. -void AppendHloPassPipelineFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the HloPassPipelineFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloPassPipelineFlags* GetHloPassPipelineFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h deleted file mode 100644 index 520759bbf0d..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_ - -// Legacy flags for XLA's hlo_pass_pipeline module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's hlo_pass_pipeline -// module. -void AppendHloPassPipelineFlags(std::vector* flag_list); - -// The values of flags associated with XLA's hlo_pass_pipeline module. -typedef struct { - // Comma-separated list of HLO passes to disable. - string xla_disable_hlo_passes; -} HloPassPipelineFlags; - -// Return a pointer to the HloPassPipelineFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -HloPassPipelineFlags* GetHloPassPipelineFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_ diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index aa1349a3507..7cb3c95ffa9 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1446,7 +1446,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 4e258c2a88c..afc4d3733c8 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -44,23 +43,13 @@ StatusOr HloPassPipeline::Run(HloModule* module) { VLOG(1) << "Running HLO pass pipeline " << name(); - legacy_flags::HloPassPipelineFlags* flags = - legacy_flags::GetHloPassPipelineFlags(); - std::unique_ptr> disabled_passes; - if (!flags->xla_disable_hlo_passes.empty()) { - std::vector passes_vec = - tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ','); - disabled_passes = MakeUnique>( - passes_vec.begin(), passes_vec.end()); - } else { - auto repeated_field = - module->config().debug_options().xla_disable_hlo_passes(); - disabled_passes = MakeUnique>( - repeated_field.begin(), repeated_field.end()); - } - if (!disabled_passes->empty()) { + auto repeated_field = + module->config().debug_options().xla_disable_hlo_passes(); + tensorflow::gtl::FlatSet disabled_passes(repeated_field.begin(), + repeated_field.end()); + if (!disabled_passes.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << tensorflow::str_util::Join(*disabled_passes, ", "); + << tensorflow::str_util::Join(disabled_passes, ", "); } auto run_invariant_checkers = [this, module]() -> Status { @@ -75,8 +64,8 @@ StatusOr HloPassPipeline::Run(HloModule* module) { bool changed = false; string message; for (auto& pass : passes_) { - if (!disabled_passes->empty() && - disabled_passes->count(pass->name().ToString()) > 0) { + if (!disabled_passes.empty() && + disabled_passes.count(pass->name().ToString()) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index eb31e8cdbff..1971868a386 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -150,7 +150,7 @@ cc_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -1153,6 +1153,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1212,7 +1213,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 2d052e7a4d2..03552d7bbf0 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -45,7 +45,10 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) { } // namespace ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) - : client_(GetOrCreateLocalClientOrDie(platform)) {} + : client_(GetOrCreateLocalClientOrDie(platform)) { + *(execution_options_.mutable_debug_options()) = + legacy_flags::GetDebugOptionsFromFlags(); +} string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 9f0d6272f48..e6fc0f457a3 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -57,8 +57,10 @@ class ClientLibraryTestBase : public ::testing::Test { void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } - void SetDebugOptions(const DebugOptions& debug_options) { - *(execution_options_.mutable_debug_options()) = debug_options; + // Provides mutable access to the execution DebugOptions field; this lets + // tests tweak the options that will be used to compile/run the graph. + DebugOptions* mutable_debug_options() { + return execution_options_.mutable_debug_options(); } // TODO(b/25566808): Add helper that populates a literal from a testdata file. diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 25b645557e9..72a8d47ac9d 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 7b2f201d1b0..f6178608c89 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -37,10 +38,8 @@ class ConvertTest : public ClientLibraryTestBase { public: explicit ConvertTest(perftools::gputools::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { - DebugOptions debug_options; - debug_options.add_xla_disable_hlo_passes("algsimp"); - debug_options.add_xla_disable_hlo_passes("inline"); - SetDebugOptions(debug_options); + mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("algsimp"); + mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("inline"); } }; @@ -199,6 +198,7 @@ TEST_F(ConvertTest, ConvertReshape) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index e263400929b..6c82460c7c4 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -42,10 +43,8 @@ class MapTest : public ClientLibraryTestBase { public: explicit MapTest(perftools::gputools::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { - DebugOptions debug_options; - debug_options.add_xla_disable_hlo_passes("algsimp"); - debug_options.add_xla_disable_hlo_passes("inline"); - SetDebugOptions(debug_options); + mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("algsimp"); + mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("inline"); } // Creates a function that adds its scalar argument with the constant 1.0. @@ -103,8 +102,8 @@ class MapTest : public ClientLibraryTestBase { // Creates a function that adds its scalar argument with the constant 1.0 and // then multiplies by the original element. // - // /---------------\ - // / \ + // /------------------| + // / | // x {R0F32} ----> (add) ----> (mul) // / // 1.0f ---------/ @@ -150,8 +149,8 @@ class MapTest : public ClientLibraryTestBase { // Creates a function that adds three scalar arguments // - // x {R0F32} ----\ - // \ + // x {R0F32} -------| + // | // y {R0F32} ----> (add) ---> (add) // / // z {R0F32} ---------------/ @@ -624,6 +623,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index c380c046ce8..a41c2797bf6 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -42,10 +43,8 @@ class VecOpsSimpleTest : public ClientLibraryTestBase { public: explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { - DebugOptions debug_options; - debug_options.add_xla_disable_hlo_passes("algsimp"); - debug_options.add_xla_disable_hlo_passes("inline"); - SetDebugOptions(debug_options); + mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("algsimp"); + mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("inline"); } ErrorSpec error_spec_{0.0001}; @@ -443,6 +442,7 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) {