[XLA] Add xla_disable_hlo_passes to DebugOptions

Also add a SetDebugOptions method to ClientLibraryTestBas; this lets us set
debug options in tests by calling it.

As an example, this CL removes the current way of passing
xla_disable_hlo_passes programmatically in tests - it used to employ a special
constructor parameter which is no longer required.

PiperOrigin-RevId: 158169006
This commit is contained in:
Eli Bendersky 2017-06-06 11:48:39 -07:00 committed by TensorFlower Gardener
parent 2b3535c649
commit cabc5c35c2
8 changed files with 49 additions and 32 deletions

View File

@ -46,12 +46,21 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
std::vector<string> tmp =
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
tensorflow::gtl::FlatSet<string> disabled_passes(tmp.begin(), tmp.end());
if (!disabled_passes.empty()) {
std::unique_ptr<tensorflow::gtl::FlatSet<string>> disabled_passes;
if (!flags->xla_disable_hlo_passes.empty()) {
std::vector<string> passes_vec =
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
passes_vec.begin(), passes_vec.end());
} else {
auto repeated_field =
module->config().debug_options().xla_disable_hlo_passes();
disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
repeated_field.begin(), repeated_field.end());
}
if (!disabled_passes->empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
<< tensorflow::str_util::Join(disabled_passes, ", ");
<< tensorflow::str_util::Join(*disabled_passes, ", ");
}
auto run_invariant_checkers = [this, module]() -> Status {
@ -66,8 +75,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
bool changed = false;
string message;
for (auto& pass : passes_) {
if (!disabled_passes.empty() &&
disabled_passes.count(pass->name().ToString()) > 0) {
if (!disabled_passes->empty() &&
disabled_passes->count(pass->name().ToString()) > 0) {
VLOG(1) << " Skipping HLO pass " << pass->name()
<< ", disabled by --xla_disable_hlo_passes";
continue;

View File

@ -44,15 +44,8 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
}
} // namespace
ClientLibraryTestBase::ClientLibraryTestBase(
se::Platform* platform,
tensorflow::gtl::ArraySlice<string> disabled_pass_names)
: client_(GetOrCreateLocalClientOrDie(platform)) {
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
flags->xla_disable_hlo_passes =
tensorflow::str_util::Join(disabled_pass_names, ",");
}
ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
: client_(GetOrCreateLocalClientOrDie(platform)) {}
string ClientLibraryTestBase::TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();

View File

@ -46,8 +46,7 @@ namespace xla {
class ClientLibraryTestBase : public ::testing::Test {
protected:
explicit ClientLibraryTestBase(
perftools::gputools::Platform* platform = nullptr,
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {});
perftools::gputools::Platform* platform = nullptr);
// Returns the name of the test currently being run.
string TestName() const;
@ -58,6 +57,12 @@ class ClientLibraryTestBase : public ::testing::Test {
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
void SetDebugOptions(const DebugOptions& debug_options) {
*(execution_options_.mutable_debug_options()) = debug_options;
}
// TODO(b/25566808): Add helper that populates a literal from a testdata file.
// Convenience methods for building and running a computation from a builder.
StatusOr<std::unique_ptr<GlobalData>> Execute(
ComputationBuilder* builder,

View File

@ -46,14 +46,8 @@ ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
class ComputeConstantTest : public ::testing::Test {
public:
explicit ComputeConstantTest(
perftools::gputools::Platform* platform = nullptr,
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {})
: platform_(platform) {
legacy_flags::HloPassPipelineFlags* flags =
legacy_flags::GetHloPassPipelineFlags();
flags->xla_disable_hlo_passes =
tensorflow::str_util::Join(disabled_pass_names, ",");
}
perftools::gputools::Platform* platform = nullptr)
: platform_(platform) {}
string TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();

View File

@ -36,8 +36,12 @@ namespace {
class ConvertTest : public ClientLibraryTestBase {
public:
explicit ConvertTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform,
/*disabled_pass_names=*/{"algsimp", "inline"}) {}
: ClientLibraryTestBase(platform) {
DebugOptions debug_options;
debug_options.add_xla_disable_hlo_passes("algsimp");
debug_options.add_xla_disable_hlo_passes("inline");
SetDebugOptions(debug_options);
}
};
TEST_F(ConvertTest, ConvertR1S32ToR1S32) {

View File

@ -41,8 +41,12 @@ namespace {
class MapTest : public ClientLibraryTestBase {
public:
explicit MapTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform,
/*disabled_pass_names=*/{"algsimp", "inline"}) {}
: ClientLibraryTestBase(platform) {
DebugOptions debug_options;
debug_options.add_xla_disable_hlo_passes("algsimp");
debug_options.add_xla_disable_hlo_passes("inline");
SetDebugOptions(debug_options);
}
// Creates a function that adds its scalar argument with the constant 1.0.
//

View File

@ -41,8 +41,12 @@ namespace {
class VecOpsSimpleTest : public ClientLibraryTestBase {
public:
explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform,
/*disabled_pass_names=*/{"algsimp", "inline"}) {}
: ClientLibraryTestBase(platform) {
DebugOptions debug_options;
debug_options.add_xla_disable_hlo_passes("algsimp");
debug_options.add_xla_disable_hlo_passes("inline");
SetDebugOptions(debug_options);
}
ErrorSpec error_spec_{0.0001};
};

View File

@ -27,6 +27,10 @@ message DebugOptions {
// various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to
// dump *all* HLO modules.
string xla_generate_hlo_graph = 1;
// List of HLO passes to disable. These names must exactly match the pass
// names as specified by the HloPassInterface::name() method.
repeated string xla_disable_hlo_passes = 2;
}
// These settings control how XLA compiles and/or runs code. Not all settings