Yunxing Dai fb9f77525b [XLA] Add option to print out hlo pass hash.
PiperOrigin-RevId: 297693781
Change-Id: I5a804056643f7c97bd51da6a7e14fe2e8bc77153
2020-02-27 14:54:06 -08:00

170 lines
6.0 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include <functional>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
template <typename HloT>
Status HloPassPipeline::RunInvariantCheckers(
HloT* hlo, absl::string_view after_pass_name) {
for (auto& invariant_checker : invariant_checkers_) {
VLOG(1) << " Invariant checker " << invariant_checker->name();
StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
VLOG(1) << " Invariant checker done " << invariant_checker->name();
if (!changed_status.ok()) {
VLOG(2) << "Failed invariant check:";
XLA_VLOG_LINES(2, hlo->ToString());
return Status(changed_status.status().code(),
absl::StrCat(changed_status.status().error_message(),
"\n\nFailed after ", after_pass_name));
}
TF_RET_CHECK(!changed_status.ValueOrDie())
<< "invariant checkers must not change the graph";
}
return Status::OK();
}
template <typename HloT>
StatusOr<bool> HloPassPipeline::RunPassesInternal(
HloT* hlo, absl::Span<HloPassInterface* const> passes) {
string last_pass_name = "pipeline-start";
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
bool changed = false;
for (HloPassInterface* pass : passes) {
absl::string_view pass_name = pass->name();
VLOG(1) << " HLO pass " << pass_name;
VLOG(2) << " Module hash " << hlo->Hash();
MaybeDumpHlo(*hlo,
/*after_pass_name=*/last_pass_name,
/*before_pass_name=*/pass_name);
if (!pass->IsPassPipeline()) {
compilation_stats_->StartPass(pass_name);
}
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
changed |= pass_changed;
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
last_pass_name = string(pass_name);
if (!pass->IsPassPipeline()) {
compilation_stats_->EndPass(pass_name);
}
}
MaybeDumpHlo(*hlo,
/*after_pass_name=*/last_pass_name,
/*before_pass_name=*/"pipeline-end");
return changed;
}
std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
const DebugOptions& debug_options) {
if (debug_options.xla_disable_all_hlo_passes()) {
VLOG(1) << "*All* passes disabled by --xla_disable_all_hlo_passes.";
return {};
}
absl::flat_hash_set<string> disabled_pass_names(
debug_options.xla_disable_hlo_passes().begin(),
debug_options.xla_disable_hlo_passes().end());
absl::flat_hash_set<string> enabled_pass_names(
debug_options.xla_enable_hlo_passes_only().begin(),
debug_options.xla_enable_hlo_passes_only().end());
if (!disabled_pass_names.empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
<< absl::StrJoin(disabled_pass_names, ", ");
}
if (!enabled_pass_names.empty()) {
VLOG(1) << "Passes enabled by --xla_enable_hlo_passes_only: "
<< absl::StrJoin(enabled_pass_names, ", ");
}
CHECK(disabled_pass_names.empty() || enabled_pass_names.empty());
std::vector<HloPassInterface*> enabled_passes;
if (!enabled_pass_names.empty()) {
for (auto& pass : passes_) {
if (enabled_pass_names.contains(pass->name())) {
enabled_passes.push_back(pass.get());
}
}
} else {
for (auto& pass : passes_) {
if (!disabled_pass_names.contains(pass->name())) {
enabled_passes.push_back(pass.get());
}
}
}
return enabled_passes;
}
void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
absl::string_view after_pass_name,
absl::string_view before_pass_name) {
DumpHloModuleBetweenPassesIfEnabled(name(), before_pass_name, after_pass_name,
module);
}
void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
absl::string_view after_pass_name,
absl::string_view before_pass_name) {
for (const HloModule* module : module_group.modules()) {
MaybeDumpHlo(*module, after_pass_name, before_pass_name);
}
}
StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
run_called_ = true;
VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
<< name();
return RunPassesInternal(module,
GetEnabledPasses(module->config().debug_options()));
}
StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
run_called_ = true;
VLOG(1) << "Running HLO pass pipeline on module group "
<< module_group->name() << ": " << name();
if (module_group->modules().empty()) {
VLOG(1) << "Module group is empty. Nothing to do.";
return false;
}
return RunPassesInternal(
module_group,
GetEnabledPasses(module_group->module(0).config().debug_options()));
}
} // namespace xla