[XLA] Add xla_disable_hlo_passes to DebugOptions
Also add a SetDebugOptions method to ClientLibraryTestBas; this lets us set debug options in tests by calling it. As an example, this CL removes the current way of passing xla_disable_hlo_passes programmatically in tests - it used to employ a special constructor parameter which is no longer required. PiperOrigin-RevId: 158169006
This commit is contained in:
parent
2b3535c649
commit
cabc5c35c2
@ -46,12 +46,21 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
|
||||
legacy_flags::HloPassPipelineFlags* flags =
|
||||
legacy_flags::GetHloPassPipelineFlags();
|
||||
std::vector<string> tmp =
|
||||
std::unique_ptr<tensorflow::gtl::FlatSet<string>> disabled_passes;
|
||||
if (!flags->xla_disable_hlo_passes.empty()) {
|
||||
std::vector<string> passes_vec =
|
||||
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
|
||||
tensorflow::gtl::FlatSet<string> disabled_passes(tmp.begin(), tmp.end());
|
||||
if (!disabled_passes.empty()) {
|
||||
disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
|
||||
passes_vec.begin(), passes_vec.end());
|
||||
} else {
|
||||
auto repeated_field =
|
||||
module->config().debug_options().xla_disable_hlo_passes();
|
||||
disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
|
||||
repeated_field.begin(), repeated_field.end());
|
||||
}
|
||||
if (!disabled_passes->empty()) {
|
||||
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
|
||||
<< tensorflow::str_util::Join(disabled_passes, ", ");
|
||||
<< tensorflow::str_util::Join(*disabled_passes, ", ");
|
||||
}
|
||||
|
||||
auto run_invariant_checkers = [this, module]() -> Status {
|
||||
@ -66,8 +75,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
string message;
|
||||
for (auto& pass : passes_) {
|
||||
if (!disabled_passes.empty() &&
|
||||
disabled_passes.count(pass->name().ToString()) > 0) {
|
||||
if (!disabled_passes->empty() &&
|
||||
disabled_passes->count(pass->name().ToString()) > 0) {
|
||||
VLOG(1) << " Skipping HLO pass " << pass->name()
|
||||
<< ", disabled by --xla_disable_hlo_passes";
|
||||
continue;
|
||||
|
@ -44,15 +44,8 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ClientLibraryTestBase::ClientLibraryTestBase(
|
||||
se::Platform* platform,
|
||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names)
|
||||
: client_(GetOrCreateLocalClientOrDie(platform)) {
|
||||
legacy_flags::HloPassPipelineFlags* flags =
|
||||
legacy_flags::GetHloPassPipelineFlags();
|
||||
flags->xla_disable_hlo_passes =
|
||||
tensorflow::str_util::Join(disabled_pass_names, ",");
|
||||
}
|
||||
ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
|
||||
: client_(GetOrCreateLocalClientOrDie(platform)) {}
|
||||
|
||||
string ClientLibraryTestBase::TestName() const {
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
|
@ -46,8 +46,7 @@ namespace xla {
|
||||
class ClientLibraryTestBase : public ::testing::Test {
|
||||
protected:
|
||||
explicit ClientLibraryTestBase(
|
||||
perftools::gputools::Platform* platform = nullptr,
|
||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {});
|
||||
perftools::gputools::Platform* platform = nullptr);
|
||||
|
||||
// Returns the name of the test currently being run.
|
||||
string TestName() const;
|
||||
@ -58,6 +57,12 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
|
||||
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
|
||||
|
||||
void SetDebugOptions(const DebugOptions& debug_options) {
|
||||
*(execution_options_.mutable_debug_options()) = debug_options;
|
||||
}
|
||||
|
||||
// TODO(b/25566808): Add helper that populates a literal from a testdata file.
|
||||
|
||||
// Convenience methods for building and running a computation from a builder.
|
||||
StatusOr<std::unique_ptr<GlobalData>> Execute(
|
||||
ComputationBuilder* builder,
|
||||
|
@ -46,14 +46,8 @@ ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly};
|
||||
class ComputeConstantTest : public ::testing::Test {
|
||||
public:
|
||||
explicit ComputeConstantTest(
|
||||
perftools::gputools::Platform* platform = nullptr,
|
||||
tensorflow::gtl::ArraySlice<string> disabled_pass_names = {})
|
||||
: platform_(platform) {
|
||||
legacy_flags::HloPassPipelineFlags* flags =
|
||||
legacy_flags::GetHloPassPipelineFlags();
|
||||
flags->xla_disable_hlo_passes =
|
||||
tensorflow::str_util::Join(disabled_pass_names, ",");
|
||||
}
|
||||
perftools::gputools::Platform* platform = nullptr)
|
||||
: platform_(platform) {}
|
||||
|
||||
string TestName() const {
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
|
@ -36,8 +36,12 @@ namespace {
|
||||
class ConvertTest : public ClientLibraryTestBase {
|
||||
public:
|
||||
explicit ConvertTest(perftools::gputools::Platform* platform = nullptr)
|
||||
: ClientLibraryTestBase(platform,
|
||||
/*disabled_pass_names=*/{"algsimp", "inline"}) {}
|
||||
: ClientLibraryTestBase(platform) {
|
||||
DebugOptions debug_options;
|
||||
debug_options.add_xla_disable_hlo_passes("algsimp");
|
||||
debug_options.add_xla_disable_hlo_passes("inline");
|
||||
SetDebugOptions(debug_options);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
|
||||
|
@ -41,8 +41,12 @@ namespace {
|
||||
class MapTest : public ClientLibraryTestBase {
|
||||
public:
|
||||
explicit MapTest(perftools::gputools::Platform* platform = nullptr)
|
||||
: ClientLibraryTestBase(platform,
|
||||
/*disabled_pass_names=*/{"algsimp", "inline"}) {}
|
||||
: ClientLibraryTestBase(platform) {
|
||||
DebugOptions debug_options;
|
||||
debug_options.add_xla_disable_hlo_passes("algsimp");
|
||||
debug_options.add_xla_disable_hlo_passes("inline");
|
||||
SetDebugOptions(debug_options);
|
||||
}
|
||||
|
||||
// Creates a function that adds its scalar argument with the constant 1.0.
|
||||
//
|
||||
|
@ -41,8 +41,12 @@ namespace {
|
||||
class VecOpsSimpleTest : public ClientLibraryTestBase {
|
||||
public:
|
||||
explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr)
|
||||
: ClientLibraryTestBase(platform,
|
||||
/*disabled_pass_names=*/{"algsimp", "inline"}) {}
|
||||
: ClientLibraryTestBase(platform) {
|
||||
DebugOptions debug_options;
|
||||
debug_options.add_xla_disable_hlo_passes("algsimp");
|
||||
debug_options.add_xla_disable_hlo_passes("inline");
|
||||
SetDebugOptions(debug_options);
|
||||
}
|
||||
|
||||
ErrorSpec error_spec_{0.0001};
|
||||
};
|
||||
|
@ -27,6 +27,10 @@ message DebugOptions {
|
||||
// various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to
|
||||
// dump *all* HLO modules.
|
||||
string xla_generate_hlo_graph = 1;
|
||||
|
||||
// List of HLO passes to disable. These names must exactly match the pass
|
||||
// names as specified by the HloPassInterface::name() method.
|
||||
repeated string xla_disable_hlo_passes = 2;
|
||||
}
|
||||
|
||||
// These settings control how XLA compiles and/or runs code. Not all settings
|
||||
|
Loading…
Reference in New Issue
Block a user