[XLA] Add metadata for HLO modules and passes.

When the xla_dump_module_metadata flag is set, one HloModuleMetadata will be dumped per module.

PiperOrigin-RevId: 342091720
Change-Id: I278e96d2a0fcf7d94054058ff9fb27bf80b71e82
This commit is contained in:
A. Unique TensorFlower 2020-11-12 11:28:53 -08:00 committed by TensorFlower Gardener
parent 9aec83d8e6
commit 083b03e2f9
16 changed files with 746 additions and 68 deletions

View File

@ -42,6 +42,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_dump_hlo_as_html(false);
opts.set_xla_dump_include_timestamp(true);
opts.set_xla_dump_max_hlo_modules(-1);
opts.set_xla_dump_module_metadata(false);
#ifdef INTEL_MKL
opts.set_xla_cpu_use_mkl_dnn(true);
#endif // INTEL_MKL
@ -230,7 +231,6 @@ static void AllocateFlags() {
};
flag_objects = new std::vector<tensorflow::Flag>();
flag_objects->reserve(55);
// Don't use an initializer list for initializing the vector; this would
// create a temporary copy, and exceeds the stack space when compiling with
// certain configurations.
@ -513,6 +513,12 @@ static void AllocateFlags() {
flag_values->xla_dump_max_hlo_modules(),
"Max number of hlo module dumps in a directory. Set to < 0 for "
"unbounded."));
flag_objects->push_back(tensorflow::Flag(
"xla_dump_module_metadata",
bool_setter_for(&DebugOptions::set_xla_dump_module_metadata),
flag_values->xla_dump_module_metadata(),
"Dumps HloModuleMetadata as text protos to the directory specified "
"by --xla_dump_to."));
flag_objects->push_back(tensorflow::Flag(
"xla_hlo_graph_addresses",
bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),

View File

@ -412,6 +412,7 @@ cc_library(
"hlo_instruction.cc",
"hlo_instructions.cc",
"hlo_module.cc",
"hlo_module_metadata.cc",
"hlo_opcode.cc",
"hlo_schedule.cc",
"hlo_sharding.cc",
@ -428,6 +429,7 @@ cc_library(
"hlo_instruction.h",
"hlo_instructions.h",
"hlo_module.h",
"hlo_module_metadata.h",
"hlo_opcode.h",
"hlo_schedule.h",
"hlo_sharding.h",
@ -3047,6 +3049,18 @@ tf_cc_test(
],
)
tf_cc_test(
name = "hlo_module_metadata_test",
srcs = ["hlo_module_metadata_test.cc"],
deps = [
":hlo",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/stream_executor/lib",
],
)
cc_library(
name = "buffer_value",
srcs = ["buffer_value.cc"],

View File

@ -45,7 +45,8 @@ struct CanonicalDebugOptions {
dump_as_url(opts.xla_dump_hlo_as_url()),
dump_snapshots(opts.xla_dump_hlo_snapshots()),
dump_include_timestamp(opts.xla_dump_include_timestamp()),
dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()) {
dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()),
dump_module_metadata(opts.xla_dump_module_metadata()) {
// This constructor examines the values in `opts` and turns on other flags
// based on what we think is the user's intent. To reduce confusion about
// what was a user-specified value versus an extrapolated value, within this
@ -137,18 +138,20 @@ struct CanonicalDebugOptions {
bool dump_snapshots;
bool dump_include_timestamp;
int64 dump_max_hlo_modules;
bool dump_module_metadata;
};
void DumpToFileInDirImpl(string_view filename, string_view contents,
const CanonicalDebugOptions& opts) {
absl::optional<std::string> DumpToFileInDirImpl(
string_view filename, string_view contents,
const CanonicalDebugOptions& opts) {
if (opts.dumping_to_stdout()) {
LOG(ERROR) << "Refusing to write " << filename
<< " to stdout. Pass --xla_dump_to=<path> to write to a file.";
return;
return absl::nullopt;
}
if (opts.dump_to.empty()) {
return;
return absl::nullopt;
}
const string& dir = opts.dump_to;
@ -164,7 +167,7 @@ void DumpToFileInDirImpl(string_view filename, string_view contents,
if (!status.ok() && !env->IsDirectory(dir).ok()) {
LOG(ERROR) << "Could not create directory " << dir
<< " for dumping XLA debug data: " << status;
return;
return absl::nullopt;
}
}
@ -181,7 +184,7 @@ void DumpToFileInDirImpl(string_view filename, string_view contents,
LOG(ERROR) << "Have already dumped " << matches.size()
<< " modules, more than the limit of "
<< opts.dump_max_hlo_modules;
return;
return absl::nullopt;
}
}
@ -192,33 +195,42 @@ void DumpToFileInDirImpl(string_view filename, string_view contents,
LOG(ERROR) << "Could not write XLA debug data to " << file_path << ": "
<< status;
}
return file_path;
}
void DumpToFileInDirOrStdoutImpl(string_view filename, string_view contents,
const CanonicalDebugOptions& opts) {
absl::optional<std::string> DumpToFileInDirOrStdoutImpl(
string_view filename, string_view contents,
const CanonicalDebugOptions& opts) {
// Dump to stdout if that's called for.
if (opts.dumping_to_stdout()) {
std::cout << "*** Begin " << filename << " ***\n"
<< contents << "\n*** End " << filename << " ***" << std::endl;
return;
return absl::nullopt;
}
// Otherwise, dump to a file.
DumpToFileInDirImpl(filename, contents, opts);
return DumpToFileInDirImpl(filename, contents, opts);
}
void DumpHloModuleImpl(const HloModule& module,
const BufferAssignment* buffer_assn,
const HloExecutionProfile* profile, string_view prefix,
string_view suffix, const CanonicalDebugOptions& opts) {
// Returns full file paths of all dumps of the module.
std::vector<std::string> DumpHloModuleImpl(const HloModule& module,
const BufferAssignment* buffer_assn,
const HloExecutionProfile* profile,
string_view prefix,
string_view suffix,
const CanonicalDebugOptions& opts) {
string filename = FilenameFor(module, prefix, suffix);
std::vector<absl::optional<std::string>> file_paths;
if (opts.dump_as_text) {
DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"), module.ToString(),
opts);
file_paths.push_back(DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"),
module.ToString(), opts));
if (buffer_assn) {
DumpToFileInDirOrStdoutImpl(StrCat(filename, "-buffer-assignment.txt"),
buffer_assn->ToString(), opts);
file_paths.push_back(DumpToFileInDirOrStdoutImpl(
StrCat(filename, "-buffer-assignment.txt"), buffer_assn->ToString(),
opts));
}
}
@ -229,7 +241,8 @@ void DumpHloModuleImpl(const HloModule& module,
if (!tensorflow::SerializeToStringDeterministic(module_proto, &pb)) {
pb = "Failed to serialize HLO module proto.";
}
DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts);
file_paths.push_back(
DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts));
}
auto render_graph = [&](RenderedGraphFormat format) {
@ -244,13 +257,15 @@ void DumpHloModuleImpl(const HloModule& module,
};
if (opts.dump_as_dot) {
DumpToFileInDirImpl(StrFormat("%s.dot", filename),
render_graph(RenderedGraphFormat::kDot), opts);
file_paths.push_back(
DumpToFileInDirImpl(StrFormat("%s.dot", filename),
render_graph(RenderedGraphFormat::kDot), opts));
}
if (opts.dump_as_html) {
DumpToFileInDirImpl(StrFormat("%s.html", filename),
render_graph(RenderedGraphFormat::kHtml), opts);
file_paths.push_back(
DumpToFileInDirImpl(StrFormat("%s.html", filename),
render_graph(RenderedGraphFormat::kHtml), opts));
}
// Special case for rendering graphs as URLs. We'll dump them to a file
@ -259,9 +274,35 @@ void DumpHloModuleImpl(const HloModule& module,
string url = render_graph(RenderedGraphFormat::kUrl);
std::cout << filename << " --> " << url << std::endl;
if (!opts.dumping_to_stdout()) {
DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts);
file_paths.push_back(
DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts));
}
}
std::vector<std::string> dumped_file_paths;
for (const absl::optional<std::string>& path : file_paths) {
if (path.has_value()) {
dumped_file_paths.push_back(*path);
}
}
return dumped_file_paths;
}
void DumpHloModuleMetadata(const HloModuleMetadataProto& metadata,
const CanonicalDebugOptions& opts,
absl::flat_hash_set<int64>* dumped_module_ids) {
// Return if metadata for this module has already been dumped.
if (!dumped_module_ids->insert(metadata.canonical_module_id()).second) {
return;
}
std::string filename = absl::StrFormat("module_%04d.metadata.textproto",
metadata.canonical_module_id());
std::string content;
if (tensorflow::protobuf::TextFormat::PrintToString(metadata, &content)) {
DumpToFileInDirImpl(filename, content, opts);
} else {
LOG(ERROR) << "Failed to convert HloModuleMetadataProto to text.";
}
}
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
@ -381,18 +422,17 @@ bool DumpingToStdout(const DebugOptions& opts) {
return CanonicalDebugOptions(opts).dumping_to_stdout();
}
void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name,
string_view before_pass_name,
string_view after_pass_name,
const HloModule& module) {
std::vector<std::string> DumpHloModuleBetweenPassesIfEnabled(
string_view pipeline_name, string_view before_pass_name,
string_view after_pass_name, const HloModule& module) {
CanonicalDebugOptions opts(module.config().debug_options());
if (!opts.should_dump_module(module.name())) {
return;
return {};
}
if (!opts.should_dump_pass(before_pass_name) &&
!opts.should_dump_pass(after_pass_name)) {
return;
return {};
}
int64 step_number = StepNumberForModule(module);
@ -401,8 +441,8 @@ void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name,
string filename_suffix =
StrFormat("%04d.%s.after_%s.before_%s", step_number, pipeline_name,
after_pass_name, before_pass_name);
DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr,
timestamp, filename_suffix, opts);
return DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr,
timestamp, filename_suffix, opts);
}
void DumpHloModuleDuringPassIfEnabled(string_view pass_name,
@ -488,4 +528,21 @@ void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot,
DumpToFileInDirImpl(filename, pb, canonical_opts);
}
void DumpHloModuleMetadataIfEnabled(const std::vector<HloModule*>& modules) {
absl::flat_hash_set<int64> dumped_module_ids;
for (const HloModule* module : modules) {
CanonicalDebugOptions opts(module->config().debug_options());
if (!opts.dump_module_metadata) {
continue;
}
DumpHloModuleMetadata(module->metadata().proto(), opts, &dumped_module_ids);
const absl::optional<HloModuleMetadataProto>& prepartitioning_metadata =
module->metadata().prepartitioning_metadata();
if (prepartitioning_metadata.has_value()) {
DumpHloModuleMetadata(*prepartitioning_metadata, opts,
&dumped_module_ids);
}
}
}
} // namespace xla

View File

@ -75,11 +75,11 @@ void DumpHloModuleIfEnabled(const HloModule& module,
absl::string_view name);
// Dumps the given HLO module after running one HLO pass and before running
// another, if that's enabled.
void DumpHloModuleBetweenPassesIfEnabled(absl::string_view pipeline_name,
absl::string_view before_pass_name,
absl::string_view after_pass_name,
const HloModule& module);
// another, if that's enabled. Returns the full file paths of all dumps of the
// module, or an empty vector if nothing was dumped.
std::vector<std::string> DumpHloModuleBetweenPassesIfEnabled(
absl::string_view pipeline_name, absl::string_view before_pass_name,
absl::string_view after_pass_name, const HloModule& module);
// Dumps the given HLO module during the given HLO pass, if that's enabled.
//
@ -100,6 +100,8 @@ void DumpHloSnapshotIfEnabled(const HloModule& module,
void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot,
const DebugOptions& opts);
void DumpHloModuleMetadataIfEnabled(const std::vector<HloModule*>& modules);
// Returns true if we should dump data for an HloModule. This is useful if you
// want to check if DumpToFileInDir{,OrStdout} will do anything before
// generating an expensive string.

View File

@ -528,3 +528,66 @@ message HloSnapshot {
// The name of the platform used to run the graph.
string execution_platform = 4;
}
// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering
// with filename module_####.metadata.textproto, where #### is
// canonical_module_id.
message HloModuleMetadataProto {
// Uniquely identifies an HloModuleMetadata. Equal to the first unique_id
// of the module (a module may go through multiple unique_ids). If a module
// is partitioned into multiple modules, those modules will each have a new
// HloModuleMetadata with a different canonical_module_id.
int64 canonical_module_id = 1;
// Name of the module group that the module is part of.
string module_group_name = 2;
// The canonical module id of the module that this one is partitioned from,
// if applicable.
int64 original_module_id = 3;
// The canonical module ids of the modules that this one is partitioned into,
// if applicable.
repeated int64 partitioned_module_ids = 4;
// Metadata for the HLO passes that are run on the module.
repeated HloPassMetadata pass_metadata = 5;
}
// Metadata for one run of an HLO pass on a module. Provides more information
// when processing debug dumps of HloProtos about the order of HLO passes and
// various other stats like duration. `pass_id` may also be used to identify a
// particular run of a pass in debug info that propagates through stages of
// compilation.
message HloPassMetadata {
// For a given module, pass_id uniquely identifies a run of an HLO pass on
// that module. Note that a pass_id may not always refer to the same pass
// because the order of passes during compilation may change. For finding
// metadata for a particular pass, pass_name and pipeline_name would be more
// reliable, although note that they may not be unique.
int64 pass_id = 1;
string pass_name = 2;
string pipeline_name = 3;
// Filenames of the dumps of the module after this pass ran. Module may be
// dumped in multiple formats, and the order of formats in this field will
// stay consistent across passes.
repeated string dump_filenames = 4;
// Return value of pass.Run(). True if this pass changed the module, or, in
// the case where the module was run through this pass as part of a module
// group, true if this pass changed any module in the same module group.
bool module_changed = 5;
// The unique_id of the module that this pass is run on. May be different from
// the canonical_module_id of the HloModuleMetadata that this HloPassMetadata
// is inside.
int64 module_id = 6;
// If the module went through this pass as part of a module group, this is
// set as the ids of all the modules in the module group. Empty otherwise.
repeated int64 module_group_module_ids = 7;
int64 start_timestamp_usec = 8;
int64 end_timestamp_usec = 9;
}

View File

@ -44,7 +44,10 @@ namespace xla {
HloModule::HloModule(const string& name, HloModuleConfig config)
: name_(NameUniquer::GetSanitizedName(name)),
config_(std::move(config)),
unique_id_(next_unique_module_id_++) {}
unique_id_(next_unique_module_id_++),
metadata_(tensorflow::Env::Default()) {
metadata_.set_canonical_module_id(unique_id_);
}
Status HloModule::set_schedule(HloSchedule schedule) {
TF_RET_CHECK(schedule.module() == this);

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
@ -363,6 +364,16 @@ class HloModule {
return cross_program_prefetches_;
}
const HloModuleMetadata& metadata() const { return metadata_; }
HloModuleMetadata* metadata() { return &metadata_; }
// Moves (not copies) metadata from this HloModule to `module`. To be used
// in cases like HloModuleGroup::ReplaceModule when metadata should be
// transferred out of a module before it's destroyed.
void MoveMetadataToModule(HloModule* module) {
module->metadata_ = std::move(metadata_);
}
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
@ -413,6 +424,9 @@ class HloModule {
// Arguments to be prefetched across programs.
std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_;
// Metadata for this module, such as its canonical id and the HLO passes run.
HloModuleMetadata metadata_;
};
} // namespace xla

View File

@ -96,12 +96,14 @@ uint64 HloModuleGroup::Hash() const {
}
void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) {
module->metadata()->set_module_group_name(name());
modules_.push_back(std::move(module));
module_ptrs_.push_back(modules_.back().get());
}
void HloModuleGroup::ReplaceModule(int index,
std::unique_ptr<HloModule> module) {
modules_.at(index)->MoveMetadataToModule(module.get());
modules_.at(index) = std::move(module);
module_ptrs_.at(index) = modules_.at(index).get();
}

View File

@ -27,6 +27,8 @@ namespace xla {
namespace {
namespace op = ::xla::testing::opcode_matchers;
using ::testing::Property;
using ::testing::StrEq;
class HloModuleGroupTest : public HloTestBase {
protected:
@ -202,6 +204,31 @@ ENTRY entry {
}
}
// Test that metadata is transferred when a module is replaced.
TEST_F(HloModuleGroupTest, ReplaceModuleMetadata) {
auto old_module = CreateNewVerifiedModule();
int old_module_id = old_module->unique_id();
old_module->metadata()->RecordPassStart();
TF_EXPECT_OK(old_module->metadata()->set_current_pass_name("fake pass"));
HloModuleGroup group(std::move(old_module));
EXPECT_EQ(group.module(0).metadata()->proto().module_group_name(),
group.name());
auto new_module = CreateNewVerifiedModule();
group.ReplaceModule(0, std::move(new_module));
EXPECT_NE(group.module(0).unique_id(), old_module_id);
const HloModuleMetadataProto& module_metadata =
group.module(0).metadata()->proto();
EXPECT_EQ(module_metadata.canonical_module_id(), old_module_id);
const HloPassMetadata& pass_metadata =
*module_metadata.pass_metadata().rbegin();
EXPECT_THAT(pass_metadata,
Property(&HloPassMetadata::pass_name, StrEq("fake pass")));
}
} // namespace
} // namespace xla

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
#include <algorithm>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/platform/env.h"
namespace xla {
StatusOr<HloPassMetadata*> HloModuleMetadata::GetCurrentHloPassMetadata() {
if (running_passes_.empty()) {
return NotFound(
"HloPassMetadata for currently running pass not found, either because "
"the pass did not call RecordPassStart or because a pass is "
"creating/switching modules without using "
"HloModuleGroup::ReplaceModule.");
}
return running_passes_.back();
}
Status HloModuleMetadata::MutateCurrentHloPassMetadata(
const std::function<void(HloPassMetadata*)>& mutator) {
TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
GetCurrentHloPassMetadata());
mutator(pass_metadata);
return Status::OK();
}
void HloModuleMetadata::RecordPassStart() {
HloPassMetadata* pass_metadata = module_metadata_.add_pass_metadata();
pass_metadata->set_pass_id(next_pass_id_++);
pass_metadata->set_start_timestamp_usec(env_->NowMicros());
running_passes_.push_back(pass_metadata);
}
Status HloModuleMetadata::RecordPassEnd() {
TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
GetCurrentHloPassMetadata());
pass_metadata->set_end_timestamp_usec(env_->NowMicros());
running_passes_.pop_back();
return Status::OK();
}
void HloModuleMetadata::set_prepartitioning_metadata(
const HloModuleMetadata& prepartitioning_metadata) {
module_metadata_.set_original_module_id(
prepartitioning_metadata.proto().canonical_module_id());
prepartitioning_metadata_ = prepartitioning_metadata.proto();
prepartitioning_metadata_->clear_pass_metadata();
// Because HloPassMetadata represents the completion of a pass, metadata for
// all currently running passes need to be moved over to the new module.
absl::flat_hash_set<HloPassMetadata*> running_passes(
prepartitioning_metadata.running_passes_.begin(),
prepartitioning_metadata.running_passes_.end());
for (const HloPassMetadata& pass_metadata :
prepartitioning_metadata.proto().pass_metadata()) {
if (running_passes.contains(&pass_metadata)) {
HloPassMetadata* added_pass_metadata =
module_metadata_.add_pass_metadata();
*added_pass_metadata = pass_metadata;
running_passes_.push_back(added_pass_metadata);
next_pass_id_ =
std::max(next_pass_id_,
static_cast<int64>(added_pass_metadata->pass_id()) + 1);
} else {
*prepartitioning_metadata_->add_pass_metadata() = pass_metadata;
}
}
}
} // namespace xla

View File

@ -0,0 +1,126 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
#include <functional>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/env.h"
namespace xla {
// Wrapper class for HloModuleMetadataProto to avoid allowing callers to mutate
// arbitrary fields. Specifically, callers cannot set timestamps or ids or
// set the fields of any pass not currently running.
class HloModuleMetadata {
public:
explicit HloModuleMetadata(tensorflow::Env* env) : env_(env) {}
const HloModuleMetadataProto& proto() const { return module_metadata_; }
// Creates a new HloPassMetadata. All calls to RecordPassStart should be
// matched by a later call to RecordPassEnd.
void RecordPassStart();
// Marks the currently running pass as finished. Returns NotFound if metadata
// for the currently running pass cannot be found.
Status RecordPassEnd();
const absl::optional<HloModuleMetadataProto>& prepartitioning_metadata()
const {
return prepartitioning_metadata_;
}
void set_prepartitioning_metadata(
const HloModuleMetadata& prepartitioning_metadata);
// Setters for HloModuleMetadataProto.
void set_module_group_name(const std::string& name) {
module_metadata_.set_module_group_name(name);
}
void set_canonical_module_id(int64 id) {
module_metadata_.set_canonical_module_id(id);
}
void add_partitioned_module_id(int64 id) {
module_metadata_.add_partitioned_module_ids(id);
}
// Setters for the current HloPassMetadata.
Status set_current_pass_name(const std::string& pass_name) {
return MutateCurrentHloPassMetadata(
[&pass_name](HloPassMetadata* pass_metadata) {
pass_metadata->set_pass_name(pass_name);
});
}
Status set_current_pass_pipeline_name(const std::string& pipeline_name) {
return MutateCurrentHloPassMetadata(
[&pipeline_name](HloPassMetadata* pass_metadata) {
pass_metadata->set_pipeline_name(pipeline_name);
});
}
Status add_current_pass_dump_filename(const std::string& dump_filename) {
return MutateCurrentHloPassMetadata(
[&dump_filename](HloPassMetadata* pass_metadata) {
pass_metadata->add_dump_filenames(dump_filename);
});
}
Status set_current_pass_module_changed(bool module_changed) {
return MutateCurrentHloPassMetadata(
[&module_changed](HloPassMetadata* pass_metadata) {
pass_metadata->set_module_changed(module_changed);
});
}
Status set_current_pass_module_id(int64 module_id) {
return MutateCurrentHloPassMetadata(
[&module_id](HloPassMetadata* pass_metadata) {
pass_metadata->set_module_id(module_id);
});
}
Status add_current_pass_module_group_module_id(int64 module_id) {
return MutateCurrentHloPassMetadata(
[&module_id](HloPassMetadata* pass_metadata) {
pass_metadata->add_module_group_module_ids(module_id);
});
}
private:
// Gets mutable metadata for the currently running pass. If passes are nested,
// finds the deepest one still running. Returns NotFound if metadata for the
// currently running pass cannot be found.
StatusOr<HloPassMetadata*> GetCurrentHloPassMetadata();
Status MutateCurrentHloPassMetadata(
const std::function<void(HloPassMetadata*)>& mutator);
HloModuleMetadataProto module_metadata_;
tensorflow::Env* env_;
int64 next_pass_id_ = 1;
// Stack of metadata for passes that are currently running. Size > 1 iff
// passes are nested.
std::vector<HloPassMetadata*> running_passes_;
// Metadata from before the module was partitioned, if applicable.
absl::optional<HloModuleMetadataProto> prepartitioning_metadata_ =
absl::nullopt;
};
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_

View File

@ -0,0 +1,145 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
namespace {
using ::testing::ElementsAre;
using ::testing::Property;
using ::testing::StrEq;
class TestEnv : public tensorflow::EnvWrapper {
public:
TestEnv() : EnvWrapper(Env::Default()) {}
uint64 NowMicros() const override { return current_micros_; }
void SetCurrentMicros(uint64 micros) { current_micros_ = micros; }
private:
uint64 current_micros_ = 1;
};
TEST(HloModuleMetadata, RecordsPassStart) {
TestEnv env;
HloModuleMetadata module_metadata(&env);
env.SetCurrentMicros(1234);
module_metadata.RecordPassStart();
EXPECT_THAT(
module_metadata.proto().pass_metadata(),
ElementsAre(Property(&HloPassMetadata::start_timestamp_usec, 1234)));
}
TEST(HloModuleMetadata, RecordsPassEnd) {
TestEnv env;
HloModuleMetadata module_metadata(&env);
module_metadata.RecordPassStart();
env.SetCurrentMicros(4321);
EXPECT_IS_OK(module_metadata.RecordPassEnd());
EXPECT_THAT(
module_metadata.proto().pass_metadata(),
ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 4321)));
}
TEST(HloModuleMetadata, RecordsPassEndInNestedMetadata) {
TestEnv env;
HloModuleMetadata module_metadata(&env);
module_metadata.RecordPassStart();
module_metadata.RecordPassStart();
env.SetCurrentMicros(111);
EXPECT_IS_OK(module_metadata.RecordPassEnd());
EXPECT_THAT(module_metadata.proto().pass_metadata(),
ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 0),
Property(&HloPassMetadata::end_timestamp_usec, 111)));
env.SetCurrentMicros(222);
EXPECT_IS_OK(module_metadata.RecordPassEnd());
EXPECT_THAT(module_metadata.proto().pass_metadata(),
ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 222),
Property(&HloPassMetadata::end_timestamp_usec, 111)));
}
TEST(HloModuleMetadata, RecordPassEndReturnsNotFound) {
HloModuleMetadata module_metadata(tensorflow::Env::Default());
EXPECT_EQ(module_metadata.RecordPassEnd().code(),
tensorflow::error::NOT_FOUND);
module_metadata.RecordPassStart();
EXPECT_IS_OK(module_metadata.RecordPassEnd());
EXPECT_EQ(module_metadata.RecordPassEnd().code(),
tensorflow::error::NOT_FOUND);
}
TEST(HloModuleMetadata, SetsHloPassMetadataFields) {
HloModuleMetadata module_metadata(tensorflow::Env::Default());
module_metadata.RecordPassStart();
EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
EXPECT_THAT(
module_metadata.proto().pass_metadata(),
ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
}
TEST(HloModuleMetadata, SetsHloPassMetadataFieldsInNestedMetadata) {
HloModuleMetadata module_metadata(tensorflow::Env::Default());
module_metadata.RecordPassStart();
module_metadata.RecordPassStart();
EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
EXPECT_THAT(
module_metadata.proto().pass_metadata(),
ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("")),
Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
}
TEST(HloModuleMetadata, SetterReturnsNotFound) {
HloModuleMetadata module_metadata(tensorflow::Env::Default());
EXPECT_EQ(module_metadata.set_current_pass_name("fake name").code(),
tensorflow::error::NOT_FOUND);
}
TEST(HloModuleMetadata, CopiesRunningPrepartitioningPasses) {
HloModuleMetadata old_module_metadata(tensorflow::Env::Default());
old_module_metadata.RecordPassStart();
EXPECT_IS_OK(old_module_metadata.set_current_pass_name("outer pass"));
old_module_metadata.RecordPassStart();
EXPECT_IS_OK(old_module_metadata.set_current_pass_name("finished pass"));
EXPECT_IS_OK(old_module_metadata.RecordPassEnd());
old_module_metadata.RecordPassStart();
EXPECT_IS_OK(old_module_metadata.set_current_pass_name("inner pass"));
HloModuleMetadata new_module_metadata(tensorflow::Env::Default());
new_module_metadata.set_prepartitioning_metadata(old_module_metadata);
// Passes that are still running go in the new module.
EXPECT_THAT(
new_module_metadata.proto().pass_metadata(),
ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("outer pass")),
Property(&HloPassMetadata::pass_name, StrEq("inner pass"))));
// Passes that finished go in the prepartitioning metadata.
EXPECT_THAT(new_module_metadata.prepartitioning_metadata()->pass_metadata(),
ElementsAre(Property(&HloPassMetadata::pass_name,
StrEq("finished pass"))));
}
} // namespace
} // namespace xla

View File

@ -31,6 +31,72 @@ limitations under the License.
namespace xla {
namespace {
void RecordPassStartMetadata(HloModule& module, const std::string& pass_name,
const std::string& pipeline_name) {
module.metadata()->RecordPassStart();
// An HloPassMetadata was just created so Status should always be OK.
TF_CHECK_OK(module.metadata()->set_current_pass_name(pass_name));
TF_CHECK_OK(module.metadata()->set_current_pass_pipeline_name(pipeline_name));
}
void RecordPassStartMetadata(HloModuleGroup& module_group,
const std::string& pass_name,
const std::string& pipeline_name) {
for (HloModule* module : module_group.modules()) {
RecordPassStartMetadata(*module, pass_name, pipeline_name);
}
}
Status AttemptRecordPassEndMetadata(HloModule& module,
const std::string& pass_name,
bool module_changed) {
// Module id is set here instead of RecordPassStartMetadata because it may
// change in the middle of the pass, and we want the final id.
TF_RETURN_IF_ERROR(
module.metadata()->set_current_pass_module_id(module.unique_id()));
TF_RETURN_IF_ERROR(
module.metadata()->set_current_pass_module_changed(module_changed));
TF_RETURN_IF_ERROR(module.metadata()->RecordPassEnd());
return Status::OK();
}
void RecordPassEndMetadata(HloModule& module, const std::string& pass_name,
bool module_changed) {
Status status =
AttemptRecordPassEndMetadata(module, pass_name, module_changed);
if (!status.ok()) {
LOG(FATAL) << status.error_message();
}
}
Status AttemptRecordPassEndMetadata(HloModuleGroup& module_group,
const std::string& pass_name,
bool module_changed) {
for (HloModule* module : module_group.modules()) {
for (HloModule* other_module : module_group.modules()) {
TF_RETURN_IF_ERROR(
module->metadata()->add_current_pass_module_group_module_id(
other_module->unique_id()));
}
TF_RETURN_IF_ERROR(
AttemptRecordPassEndMetadata(*module, pass_name, module_changed));
}
return Status::OK();
}
void RecordPassEndMetadata(HloModuleGroup& module_group,
const std::string& pass_name, bool module_changed) {
Status status =
AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
if (!status.ok()) {
LOG(FATAL) << status.error_message();
}
}
} // namespace
template <typename HloT>
Status HloPassPipeline::RunInvariantCheckers(
HloT* hlo, absl::string_view after_pass_name) {
@ -54,34 +120,48 @@ Status HloPassPipeline::RunInvariantCheckers(
template <typename HloT>
StatusOr<bool> HloPassPipeline::RunPassesInternal(
HloT* hlo, absl::Span<HloPassInterface* const> passes) {
string last_pass_name = "pipeline-start";
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
static constexpr absl::string_view kPipelineStart = "pipeline-start";
static constexpr absl::string_view kPipelineEnd = "pipeline-end";
std::string pipeline_name = std::string(name());
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
MaybeDumpHloAndSaveFilenames(*hlo,
/*after_pass_name=*/kPipelineStart,
/*before_pass_name=*/passes.empty()
? kPipelineEnd
: passes.front()->name());
RecordPassEndMetadata(*hlo, std::string(kPipelineStart),
/*module_changed=*/false);
bool changed = false;
for (HloPassInterface* pass : passes) {
for (int i = 0; i < passes.size(); i++) {
HloPassInterface* pass = passes[i];
XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name()));
absl::string_view pass_name = pass->name();
std::string pass_name = std::string(pass->name());
VLOG(1) << " HLO pass " << pass_name;
VLOG(2) << " Module hash " << hlo->Hash();
MaybeDumpHlo(*hlo,
/*after_pass_name=*/last_pass_name,
/*before_pass_name=*/pass_name);
if (!pass->IsPassPipeline()) {
compilation_stats_->StartPass(pass_name);
}
RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
MaybeDumpHloAndSaveFilenames(*hlo,
/*after_pass_name=*/pass_name,
/*before_pass_name=*/i + 1 >= passes.size()
? kPipelineEnd
: passes[i + 1]->name());
RecordPassEndMetadata(*hlo, pass_name, pass_changed);
changed |= pass_changed;
if (pass_changed) {
VLOG(3) << " Pass caused changes" << pass->name();
}
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
last_pass_name = string(pass_name);
if (!pass->IsPassPipeline()) {
compilation_stats_->EndPass(pass_name);
}
}
MaybeDumpHlo(*hlo,
/*after_pass_name=*/last_pass_name,
/*before_pass_name=*/"pipeline-end");
return changed;
}
@ -129,18 +209,23 @@ std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
return enabled_passes;
}
void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
absl::string_view after_pass_name,
absl::string_view before_pass_name) {
DumpHloModuleBetweenPassesIfEnabled(name(), before_pass_name, after_pass_name,
module);
void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
HloModule& module, absl::string_view after_pass_name,
absl::string_view before_pass_name) {
for (const std::string& filename : DumpHloModuleBetweenPassesIfEnabled(
name(), before_pass_name, after_pass_name, module)) {
Status status = module.metadata()->add_current_pass_dump_filename(filename);
if (!status.ok()) {
LOG(FATAL) << status.error_message();
}
}
}
void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
absl::string_view after_pass_name,
absl::string_view before_pass_name) {
for (const HloModule* module : module_group.modules()) {
MaybeDumpHlo(*module, after_pass_name, before_pass_name);
void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
HloModuleGroup& module_group, absl::string_view after_pass_name,
absl::string_view before_pass_name) {
for (HloModule* module : module_group.modules()) {
MaybeDumpHloAndSaveFilenames(*module, after_pass_name, before_pass_name);
}
}

View File

@ -90,12 +90,14 @@ class HloPassPipeline : public HloPassInterface {
const DebugOptions& debug_options);
// Maybe dumps the given module or module group depending on flag values
// contained in DebugOptions of module config.
void MaybeDumpHlo(const HloModuleGroup& module_group,
absl::string_view after_pass_name,
absl::string_view before_pass_name);
void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
absl::string_view before_pass_name);
// contained in DebugOptions of module config. If it is dumped, saves the
// filenames of the dumps into module metadata.
void MaybeDumpHloAndSaveFilenames(HloModuleGroup& module_group,
absl::string_view after_pass_name,
absl::string_view before_pass_name);
void MaybeDumpHloAndSaveFilenames(HloModule& module,
absl::string_view after_pass_name,
absl::string_view before_pass_name);
// Runs the invariant checker on the given HLO. HloT can be either HloModule
// or HloModuleGroup.

View File

@ -26,6 +26,10 @@ limitations under the License.
namespace xla {
namespace {
using ::testing::ElementsAre;
using ::testing::SizeIs;
using ::testing::StrEq;
class HloPassPipelineTest : public HloTestBase {
protected:
StatusOr<HloModuleGroup> ParseModuleGroup(
@ -255,5 +259,43 @@ ENTRY main {
::testing::HasSubstr("Module group pass cannot be run on a module"));
}
// Test that metadata is set when a module group goes through a pass pipeline.
TEST_F(HloPassPipelineTest, SetHloModuleMetadata) {
HloModuleGroup module_group(TestName());
module_group.push_back(CreateNewVerifiedModule());
module_group.push_back(CreateNewVerifiedModule());
HloPassPipeline pipeline(TestName());
pipeline.AddPass<BazToQuxModuleGroupPass>();
pipeline.AddPass<FooToBarModulePass>();
TF_ASSERT_OK(pipeline.RunOnModuleGroup(&module_group).status());
ASSERT_THAT(module_group.modules(), SizeIs(2));
std::vector<std::string> pass_names = {"pipeline-start", "baz2qux",
"foo2bar"};
std::string pipeline_name = std::string(pipeline.name());
for (const HloModule* module : module_group.modules()) {
const HloModuleMetadataProto& metadata = module->metadata().proto();
EXPECT_EQ(metadata.canonical_module_id(), module->unique_id());
EXPECT_EQ(metadata.module_group_name(), module_group.name());
ASSERT_THAT(metadata.pass_metadata(), SizeIs(3));
for (int pass = 0; pass < metadata.pass_metadata().size(); pass++) {
const HloPassMetadata& pass_metadata = metadata.pass_metadata(pass);
EXPECT_NE(pass_metadata.pass_id(), 0);
EXPECT_THAT(pass_metadata.pass_name(), StrEq(pass_names[pass]));
EXPECT_THAT(pass_metadata.pipeline_name(), StrEq(pipeline_name));
EXPECT_FALSE(pass_metadata.module_changed());
EXPECT_EQ(pass_metadata.module_id(), module->unique_id());
EXPECT_THAT(pass_metadata.module_group_module_ids(),
ElementsAre(module_group.module(0).unique_id(),
module_group.module(1).unique_id()));
EXPECT_GT(pass_metadata.start_timestamp_usec(), 0);
EXPECT_LT(pass_metadata.start_timestamp_usec(),
pass_metadata.end_timestamp_usec());
}
}
}
} // namespace
} // namespace xla

View File

@ -261,6 +261,9 @@ message DebugOptions {
// Max number of hlo module dumps in a directory. Set to < 0 for unbounded.
int32 xla_dump_max_hlo_modules = 132;
// Dump HloModuleMetadata as a text proto for each HLO module.
bool xla_dump_module_metadata = 144;
//
// END flags controlling dumping HLO modules.
//
@ -297,7 +300,7 @@ message DebugOptions {
// Enable detailed logging into vlog.
bool xla_detailed_logging = 143;
// Next id: 144
// Next id: 145
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.