[XLA] Add metadata for HLO modules and passes.
When the xla_dump_module_metadata flag is set, one HloModuleMetadata will be dumped per module. PiperOrigin-RevId: 342091720 Change-Id: I278e96d2a0fcf7d94054058ff9fb27bf80b71e82
This commit is contained in:
parent
9aec83d8e6
commit
083b03e2f9
@ -42,6 +42,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
|
||||
opts.set_xla_dump_hlo_as_html(false);
|
||||
opts.set_xla_dump_include_timestamp(true);
|
||||
opts.set_xla_dump_max_hlo_modules(-1);
|
||||
opts.set_xla_dump_module_metadata(false);
|
||||
#ifdef INTEL_MKL
|
||||
opts.set_xla_cpu_use_mkl_dnn(true);
|
||||
#endif // INTEL_MKL
|
||||
@ -230,7 +231,6 @@ static void AllocateFlags() {
|
||||
};
|
||||
|
||||
flag_objects = new std::vector<tensorflow::Flag>();
|
||||
flag_objects->reserve(55);
|
||||
// Don't use an initializer list for initializing the vector; this would
|
||||
// create a temporary copy, and exceeds the stack space when compiling with
|
||||
// certain configurations.
|
||||
@ -513,6 +513,12 @@ static void AllocateFlags() {
|
||||
flag_values->xla_dump_max_hlo_modules(),
|
||||
"Max number of hlo module dumps in a directory. Set to < 0 for "
|
||||
"unbounded."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_dump_module_metadata",
|
||||
bool_setter_for(&DebugOptions::set_xla_dump_module_metadata),
|
||||
flag_values->xla_dump_module_metadata(),
|
||||
"Dumps HloModuleMetadata as text protos to the directory specified "
|
||||
"by --xla_dump_to."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_hlo_graph_addresses",
|
||||
bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),
|
||||
|
@ -412,6 +412,7 @@ cc_library(
|
||||
"hlo_instruction.cc",
|
||||
"hlo_instructions.cc",
|
||||
"hlo_module.cc",
|
||||
"hlo_module_metadata.cc",
|
||||
"hlo_opcode.cc",
|
||||
"hlo_schedule.cc",
|
||||
"hlo_sharding.cc",
|
||||
@ -428,6 +429,7 @@ cc_library(
|
||||
"hlo_instruction.h",
|
||||
"hlo_instructions.h",
|
||||
"hlo_module.h",
|
||||
"hlo_module_metadata.h",
|
||||
"hlo_opcode.h",
|
||||
"hlo_schedule.h",
|
||||
"hlo_sharding.h",
|
||||
@ -3047,6 +3049,18 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_module_metadata_test",
|
||||
srcs = ["hlo_module_metadata_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "buffer_value",
|
||||
srcs = ["buffer_value.cc"],
|
||||
|
@ -45,7 +45,8 @@ struct CanonicalDebugOptions {
|
||||
dump_as_url(opts.xla_dump_hlo_as_url()),
|
||||
dump_snapshots(opts.xla_dump_hlo_snapshots()),
|
||||
dump_include_timestamp(opts.xla_dump_include_timestamp()),
|
||||
dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()) {
|
||||
dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()),
|
||||
dump_module_metadata(opts.xla_dump_module_metadata()) {
|
||||
// This constructor examines the values in `opts` and turns on other flags
|
||||
// based on what we think is the user's intent. To reduce confusion about
|
||||
// what was a user-specified value versus an extrapolated value, within this
|
||||
@ -137,18 +138,20 @@ struct CanonicalDebugOptions {
|
||||
bool dump_snapshots;
|
||||
bool dump_include_timestamp;
|
||||
int64 dump_max_hlo_modules;
|
||||
bool dump_module_metadata;
|
||||
};
|
||||
|
||||
void DumpToFileInDirImpl(string_view filename, string_view contents,
|
||||
const CanonicalDebugOptions& opts) {
|
||||
absl::optional<std::string> DumpToFileInDirImpl(
|
||||
string_view filename, string_view contents,
|
||||
const CanonicalDebugOptions& opts) {
|
||||
if (opts.dumping_to_stdout()) {
|
||||
LOG(ERROR) << "Refusing to write " << filename
|
||||
<< " to stdout. Pass --xla_dump_to=<path> to write to a file.";
|
||||
return;
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
if (opts.dump_to.empty()) {
|
||||
return;
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
const string& dir = opts.dump_to;
|
||||
@ -164,7 +167,7 @@ void DumpToFileInDirImpl(string_view filename, string_view contents,
|
||||
if (!status.ok() && !env->IsDirectory(dir).ok()) {
|
||||
LOG(ERROR) << "Could not create directory " << dir
|
||||
<< " for dumping XLA debug data: " << status;
|
||||
return;
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,7 +184,7 @@ void DumpToFileInDirImpl(string_view filename, string_view contents,
|
||||
LOG(ERROR) << "Have already dumped " << matches.size()
|
||||
<< " modules, more than the limit of "
|
||||
<< opts.dump_max_hlo_modules;
|
||||
return;
|
||||
return absl::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,33 +195,42 @@ void DumpToFileInDirImpl(string_view filename, string_view contents,
|
||||
LOG(ERROR) << "Could not write XLA debug data to " << file_path << ": "
|
||||
<< status;
|
||||
}
|
||||
|
||||
return file_path;
|
||||
}
|
||||
|
||||
void DumpToFileInDirOrStdoutImpl(string_view filename, string_view contents,
|
||||
const CanonicalDebugOptions& opts) {
|
||||
absl::optional<std::string> DumpToFileInDirOrStdoutImpl(
|
||||
string_view filename, string_view contents,
|
||||
const CanonicalDebugOptions& opts) {
|
||||
// Dump to stdout if that's called for.
|
||||
if (opts.dumping_to_stdout()) {
|
||||
std::cout << "*** Begin " << filename << " ***\n"
|
||||
<< contents << "\n*** End " << filename << " ***" << std::endl;
|
||||
return;
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
// Otherwise, dump to a file.
|
||||
DumpToFileInDirImpl(filename, contents, opts);
|
||||
return DumpToFileInDirImpl(filename, contents, opts);
|
||||
}
|
||||
|
||||
void DumpHloModuleImpl(const HloModule& module,
|
||||
const BufferAssignment* buffer_assn,
|
||||
const HloExecutionProfile* profile, string_view prefix,
|
||||
string_view suffix, const CanonicalDebugOptions& opts) {
|
||||
// Returns full file paths of all dumps of the module.
|
||||
std::vector<std::string> DumpHloModuleImpl(const HloModule& module,
|
||||
const BufferAssignment* buffer_assn,
|
||||
const HloExecutionProfile* profile,
|
||||
string_view prefix,
|
||||
string_view suffix,
|
||||
const CanonicalDebugOptions& opts) {
|
||||
string filename = FilenameFor(module, prefix, suffix);
|
||||
|
||||
std::vector<absl::optional<std::string>> file_paths;
|
||||
|
||||
if (opts.dump_as_text) {
|
||||
DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"), module.ToString(),
|
||||
opts);
|
||||
file_paths.push_back(DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"),
|
||||
module.ToString(), opts));
|
||||
if (buffer_assn) {
|
||||
DumpToFileInDirOrStdoutImpl(StrCat(filename, "-buffer-assignment.txt"),
|
||||
buffer_assn->ToString(), opts);
|
||||
file_paths.push_back(DumpToFileInDirOrStdoutImpl(
|
||||
StrCat(filename, "-buffer-assignment.txt"), buffer_assn->ToString(),
|
||||
opts));
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,7 +241,8 @@ void DumpHloModuleImpl(const HloModule& module,
|
||||
if (!tensorflow::SerializeToStringDeterministic(module_proto, &pb)) {
|
||||
pb = "Failed to serialize HLO module proto.";
|
||||
}
|
||||
DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts);
|
||||
file_paths.push_back(
|
||||
DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts));
|
||||
}
|
||||
|
||||
auto render_graph = [&](RenderedGraphFormat format) {
|
||||
@ -244,13 +257,15 @@ void DumpHloModuleImpl(const HloModule& module,
|
||||
};
|
||||
|
||||
if (opts.dump_as_dot) {
|
||||
DumpToFileInDirImpl(StrFormat("%s.dot", filename),
|
||||
render_graph(RenderedGraphFormat::kDot), opts);
|
||||
file_paths.push_back(
|
||||
DumpToFileInDirImpl(StrFormat("%s.dot", filename),
|
||||
render_graph(RenderedGraphFormat::kDot), opts));
|
||||
}
|
||||
|
||||
if (opts.dump_as_html) {
|
||||
DumpToFileInDirImpl(StrFormat("%s.html", filename),
|
||||
render_graph(RenderedGraphFormat::kHtml), opts);
|
||||
file_paths.push_back(
|
||||
DumpToFileInDirImpl(StrFormat("%s.html", filename),
|
||||
render_graph(RenderedGraphFormat::kHtml), opts));
|
||||
}
|
||||
|
||||
// Special case for rendering graphs as URLs. We'll dump them to a file
|
||||
@ -259,9 +274,35 @@ void DumpHloModuleImpl(const HloModule& module,
|
||||
string url = render_graph(RenderedGraphFormat::kUrl);
|
||||
std::cout << filename << " --> " << url << std::endl;
|
||||
if (!opts.dumping_to_stdout()) {
|
||||
DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts);
|
||||
file_paths.push_back(
|
||||
DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> dumped_file_paths;
|
||||
for (const absl::optional<std::string>& path : file_paths) {
|
||||
if (path.has_value()) {
|
||||
dumped_file_paths.push_back(*path);
|
||||
}
|
||||
}
|
||||
return dumped_file_paths;
|
||||
}
|
||||
|
||||
void DumpHloModuleMetadata(const HloModuleMetadataProto& metadata,
|
||||
const CanonicalDebugOptions& opts,
|
||||
absl::flat_hash_set<int64>* dumped_module_ids) {
|
||||
// Return if metadata for this module has already been dumped.
|
||||
if (!dumped_module_ids->insert(metadata.canonical_module_id()).second) {
|
||||
return;
|
||||
}
|
||||
std::string filename = absl::StrFormat("module_%04d.metadata.textproto",
|
||||
metadata.canonical_module_id());
|
||||
std::string content;
|
||||
if (tensorflow::protobuf::TextFormat::PrintToString(metadata, &content)) {
|
||||
DumpToFileInDirImpl(filename, content, opts);
|
||||
} else {
|
||||
LOG(ERROR) << "Failed to convert HloModuleMetadataProto to text.";
|
||||
}
|
||||
}
|
||||
|
||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||
@ -381,18 +422,17 @@ bool DumpingToStdout(const DebugOptions& opts) {
|
||||
return CanonicalDebugOptions(opts).dumping_to_stdout();
|
||||
}
|
||||
|
||||
void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name,
|
||||
string_view before_pass_name,
|
||||
string_view after_pass_name,
|
||||
const HloModule& module) {
|
||||
std::vector<std::string> DumpHloModuleBetweenPassesIfEnabled(
|
||||
string_view pipeline_name, string_view before_pass_name,
|
||||
string_view after_pass_name, const HloModule& module) {
|
||||
CanonicalDebugOptions opts(module.config().debug_options());
|
||||
if (!opts.should_dump_module(module.name())) {
|
||||
return;
|
||||
return {};
|
||||
}
|
||||
|
||||
if (!opts.should_dump_pass(before_pass_name) &&
|
||||
!opts.should_dump_pass(after_pass_name)) {
|
||||
return;
|
||||
return {};
|
||||
}
|
||||
|
||||
int64 step_number = StepNumberForModule(module);
|
||||
@ -401,8 +441,8 @@ void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name,
|
||||
string filename_suffix =
|
||||
StrFormat("%04d.%s.after_%s.before_%s", step_number, pipeline_name,
|
||||
after_pass_name, before_pass_name);
|
||||
DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr,
|
||||
timestamp, filename_suffix, opts);
|
||||
return DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr,
|
||||
timestamp, filename_suffix, opts);
|
||||
}
|
||||
|
||||
void DumpHloModuleDuringPassIfEnabled(string_view pass_name,
|
||||
@ -488,4 +528,21 @@ void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot,
|
||||
DumpToFileInDirImpl(filename, pb, canonical_opts);
|
||||
}
|
||||
|
||||
void DumpHloModuleMetadataIfEnabled(const std::vector<HloModule*>& modules) {
|
||||
absl::flat_hash_set<int64> dumped_module_ids;
|
||||
for (const HloModule* module : modules) {
|
||||
CanonicalDebugOptions opts(module->config().debug_options());
|
||||
if (!opts.dump_module_metadata) {
|
||||
continue;
|
||||
}
|
||||
DumpHloModuleMetadata(module->metadata().proto(), opts, &dumped_module_ids);
|
||||
const absl::optional<HloModuleMetadataProto>& prepartitioning_metadata =
|
||||
module->metadata().prepartitioning_metadata();
|
||||
if (prepartitioning_metadata.has_value()) {
|
||||
DumpHloModuleMetadata(*prepartitioning_metadata, opts,
|
||||
&dumped_module_ids);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -75,11 +75,11 @@ void DumpHloModuleIfEnabled(const HloModule& module,
|
||||
absl::string_view name);
|
||||
|
||||
// Dumps the given HLO module after running one HLO pass and before running
|
||||
// another, if that's enabled.
|
||||
void DumpHloModuleBetweenPassesIfEnabled(absl::string_view pipeline_name,
|
||||
absl::string_view before_pass_name,
|
||||
absl::string_view after_pass_name,
|
||||
const HloModule& module);
|
||||
// another, if that's enabled. Returns the full file paths of all dumps of the
|
||||
// module, or an empty vector if nothing was dumped.
|
||||
std::vector<std::string> DumpHloModuleBetweenPassesIfEnabled(
|
||||
absl::string_view pipeline_name, absl::string_view before_pass_name,
|
||||
absl::string_view after_pass_name, const HloModule& module);
|
||||
|
||||
// Dumps the given HLO module during the given HLO pass, if that's enabled.
|
||||
//
|
||||
@ -100,6 +100,8 @@ void DumpHloSnapshotIfEnabled(const HloModule& module,
|
||||
void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot,
|
||||
const DebugOptions& opts);
|
||||
|
||||
void DumpHloModuleMetadataIfEnabled(const std::vector<HloModule*>& modules);
|
||||
|
||||
// Returns true if we should dump data for an HloModule. This is useful if you
|
||||
// want to check if DumpToFileInDir{,OrStdout} will do anything before
|
||||
// generating an expensive string.
|
||||
|
@ -528,3 +528,66 @@ message HloSnapshot {
|
||||
// The name of the platform used to run the graph.
|
||||
string execution_platform = 4;
|
||||
}
|
||||
|
||||
// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering
|
||||
// with filename module_####.metadata.textproto, where #### is
|
||||
// canonical_module_id.
|
||||
message HloModuleMetadataProto {
|
||||
// Uniquely identifies an HloModuleMetadata. Equal to the first unique_id
|
||||
// of the module (a module may go through multiple unique_ids). If a module
|
||||
// is partitioned into multiple modules, those modules will each have a new
|
||||
// HloModuleMetadata with a different canonical_module_id.
|
||||
int64 canonical_module_id = 1;
|
||||
|
||||
// Name of the module group that the module is part of.
|
||||
string module_group_name = 2;
|
||||
|
||||
// The canonical module id of the module that this one is partitioned from,
|
||||
// if applicable.
|
||||
int64 original_module_id = 3;
|
||||
|
||||
// The canonical module ids of the modules that this one is partitioned into,
|
||||
// if applicable.
|
||||
repeated int64 partitioned_module_ids = 4;
|
||||
|
||||
// Metadata for the HLO passes that are run on the module.
|
||||
repeated HloPassMetadata pass_metadata = 5;
|
||||
}
|
||||
|
||||
// Metadata for one run of an HLO pass on a module. Provides more information
|
||||
// when processing debug dumps of HloProtos about the order of HLO passes and
|
||||
// various other stats like duration. `pass_id` may also be used to identify a
|
||||
// particular run of a pass in debug info that propagates through stages of
|
||||
// compilation.
|
||||
message HloPassMetadata {
|
||||
// For a given module, pass_id uniquely identifies a run of an HLO pass on
|
||||
// that module. Note that a pass_id may not always refer to the same pass
|
||||
// because the order of passes during compilation may change. For finding
|
||||
// metadata for a particular pass, pass_name and pipeline_name would be more
|
||||
// reliable, although note that they may not be unique.
|
||||
int64 pass_id = 1;
|
||||
string pass_name = 2;
|
||||
string pipeline_name = 3;
|
||||
|
||||
// Filenames of the dumps of the module after this pass ran. Module may be
|
||||
// dumped in multiple formats, and the order of formats in this field will
|
||||
// stay consistent across passes.
|
||||
repeated string dump_filenames = 4;
|
||||
|
||||
// Return value of pass.Run(). True if this pass changed the module, or, in
|
||||
// the case where the module was run through this pass as part of a module
|
||||
// group, true if this pass changed any module in the same module group.
|
||||
bool module_changed = 5;
|
||||
|
||||
// The unique_id of the module that this pass is run on. May be different from
|
||||
// the canonical_module_id of the HloModuleMetadata that this HloPassMetadata
|
||||
// is inside.
|
||||
int64 module_id = 6;
|
||||
|
||||
// If the module went through this pass as part of a module group, this is
|
||||
// set as the ids of all the modules in the module group. Empty otherwise.
|
||||
repeated int64 module_group_module_ids = 7;
|
||||
|
||||
int64 start_timestamp_usec = 8;
|
||||
int64 end_timestamp_usec = 9;
|
||||
}
|
||||
|
@ -44,7 +44,10 @@ namespace xla {
|
||||
HloModule::HloModule(const string& name, HloModuleConfig config)
|
||||
: name_(NameUniquer::GetSanitizedName(name)),
|
||||
config_(std::move(config)),
|
||||
unique_id_(next_unique_module_id_++) {}
|
||||
unique_id_(next_unique_module_id_++),
|
||||
metadata_(tensorflow::Env::Default()) {
|
||||
metadata_.set_canonical_module_id(unique_id_);
|
||||
}
|
||||
|
||||
Status HloModule::set_schedule(HloSchedule schedule) {
|
||||
TF_RET_CHECK(schedule.module() == this);
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -363,6 +364,16 @@ class HloModule {
|
||||
return cross_program_prefetches_;
|
||||
}
|
||||
|
||||
const HloModuleMetadata& metadata() const { return metadata_; }
|
||||
HloModuleMetadata* metadata() { return &metadata_; }
|
||||
|
||||
// Moves (not copies) metadata from this HloModule to `module`. To be used
|
||||
// in cases like HloModuleGroup::ReplaceModule when metadata should be
|
||||
// transferred out of a module before it's destroyed.
|
||||
void MoveMetadataToModule(HloModule* module) {
|
||||
module->metadata_ = std::move(metadata_);
|
||||
}
|
||||
|
||||
private:
|
||||
HloComputation* AddComputationInternal(
|
||||
std::unique_ptr<HloComputation> computation, bool is_entry,
|
||||
@ -413,6 +424,9 @@ class HloModule {
|
||||
|
||||
// Arguments to be prefetched across programs.
|
||||
std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_;
|
||||
|
||||
// Metadata for this module, such as its canonical id and the HLO passes run.
|
||||
HloModuleMetadata metadata_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -96,12 +96,14 @@ uint64 HloModuleGroup::Hash() const {
|
||||
}
|
||||
|
||||
void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) {
|
||||
module->metadata()->set_module_group_name(name());
|
||||
modules_.push_back(std::move(module));
|
||||
module_ptrs_.push_back(modules_.back().get());
|
||||
}
|
||||
|
||||
void HloModuleGroup::ReplaceModule(int index,
|
||||
std::unique_ptr<HloModule> module) {
|
||||
modules_.at(index)->MoveMetadataToModule(module.get());
|
||||
modules_.at(index) = std::move(module);
|
||||
module_ptrs_.at(index) = modules_.at(index).get();
|
||||
}
|
||||
|
@ -27,6 +27,8 @@ namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
using ::testing::Property;
|
||||
using ::testing::StrEq;
|
||||
|
||||
class HloModuleGroupTest : public HloTestBase {
|
||||
protected:
|
||||
@ -202,6 +204,31 @@ ENTRY entry {
|
||||
}
|
||||
}
|
||||
|
||||
// Test that metadata is transferred when a module is replaced.
|
||||
TEST_F(HloModuleGroupTest, ReplaceModuleMetadata) {
|
||||
auto old_module = CreateNewVerifiedModule();
|
||||
int old_module_id = old_module->unique_id();
|
||||
old_module->metadata()->RecordPassStart();
|
||||
TF_EXPECT_OK(old_module->metadata()->set_current_pass_name("fake pass"));
|
||||
|
||||
HloModuleGroup group(std::move(old_module));
|
||||
EXPECT_EQ(group.module(0).metadata()->proto().module_group_name(),
|
||||
group.name());
|
||||
|
||||
auto new_module = CreateNewVerifiedModule();
|
||||
group.ReplaceModule(0, std::move(new_module));
|
||||
|
||||
EXPECT_NE(group.module(0).unique_id(), old_module_id);
|
||||
const HloModuleMetadataProto& module_metadata =
|
||||
group.module(0).metadata()->proto();
|
||||
EXPECT_EQ(module_metadata.canonical_module_id(), old_module_id);
|
||||
|
||||
const HloPassMetadata& pass_metadata =
|
||||
*module_metadata.pass_metadata().rbegin();
|
||||
EXPECT_THAT(pass_metadata,
|
||||
Property(&HloPassMetadata::pass_name, StrEq("fake pass")));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
87
tensorflow/compiler/xla/service/hlo_module_metadata.cc
Normal file
87
tensorflow/compiler/xla/service/hlo_module_metadata.cc
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<HloPassMetadata*> HloModuleMetadata::GetCurrentHloPassMetadata() {
|
||||
if (running_passes_.empty()) {
|
||||
return NotFound(
|
||||
"HloPassMetadata for currently running pass not found, either because "
|
||||
"the pass did not call RecordPassStart or because a pass is "
|
||||
"creating/switching modules without using "
|
||||
"HloModuleGroup::ReplaceModule.");
|
||||
}
|
||||
return running_passes_.back();
|
||||
}
|
||||
|
||||
Status HloModuleMetadata::MutateCurrentHloPassMetadata(
|
||||
const std::function<void(HloPassMetadata*)>& mutator) {
|
||||
TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
|
||||
GetCurrentHloPassMetadata());
|
||||
mutator(pass_metadata);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void HloModuleMetadata::RecordPassStart() {
|
||||
HloPassMetadata* pass_metadata = module_metadata_.add_pass_metadata();
|
||||
pass_metadata->set_pass_id(next_pass_id_++);
|
||||
pass_metadata->set_start_timestamp_usec(env_->NowMicros());
|
||||
running_passes_.push_back(pass_metadata);
|
||||
}
|
||||
|
||||
Status HloModuleMetadata::RecordPassEnd() {
|
||||
TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
|
||||
GetCurrentHloPassMetadata());
|
||||
pass_metadata->set_end_timestamp_usec(env_->NowMicros());
|
||||
running_passes_.pop_back();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void HloModuleMetadata::set_prepartitioning_metadata(
|
||||
const HloModuleMetadata& prepartitioning_metadata) {
|
||||
module_metadata_.set_original_module_id(
|
||||
prepartitioning_metadata.proto().canonical_module_id());
|
||||
prepartitioning_metadata_ = prepartitioning_metadata.proto();
|
||||
prepartitioning_metadata_->clear_pass_metadata();
|
||||
|
||||
// Because HloPassMetadata represents the completion of a pass, metadata for
|
||||
// all currently running passes need to be moved over to the new module.
|
||||
absl::flat_hash_set<HloPassMetadata*> running_passes(
|
||||
prepartitioning_metadata.running_passes_.begin(),
|
||||
prepartitioning_metadata.running_passes_.end());
|
||||
for (const HloPassMetadata& pass_metadata :
|
||||
prepartitioning_metadata.proto().pass_metadata()) {
|
||||
if (running_passes.contains(&pass_metadata)) {
|
||||
HloPassMetadata* added_pass_metadata =
|
||||
module_metadata_.add_pass_metadata();
|
||||
*added_pass_metadata = pass_metadata;
|
||||
running_passes_.push_back(added_pass_metadata);
|
||||
next_pass_id_ =
|
||||
std::max(next_pass_id_,
|
||||
static_cast<int64>(added_pass_metadata->pass_id()) + 1);
|
||||
} else {
|
||||
*prepartitioning_metadata_->add_pass_metadata() = pass_metadata;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
126
tensorflow/compiler/xla/service/hlo_module_metadata.h
Normal file
126
tensorflow/compiler/xla/service/hlo_module_metadata.h
Normal file
@ -0,0 +1,126 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Wrapper class for HloModuleMetadataProto to avoid allowing callers to mutate
|
||||
// arbitrary fields. Specifically, callers cannot set timestamps or ids or
|
||||
// set the fields of any pass not currently running.
|
||||
class HloModuleMetadata {
|
||||
public:
|
||||
explicit HloModuleMetadata(tensorflow::Env* env) : env_(env) {}
|
||||
|
||||
const HloModuleMetadataProto& proto() const { return module_metadata_; }
|
||||
|
||||
// Creates a new HloPassMetadata. All calls to RecordPassStart should be
|
||||
// matched by a later call to RecordPassEnd.
|
||||
void RecordPassStart();
|
||||
|
||||
// Marks the currently running pass as finished. Returns NotFound if metadata
|
||||
// for the currently running pass cannot be found.
|
||||
Status RecordPassEnd();
|
||||
|
||||
const absl::optional<HloModuleMetadataProto>& prepartitioning_metadata()
|
||||
const {
|
||||
return prepartitioning_metadata_;
|
||||
}
|
||||
void set_prepartitioning_metadata(
|
||||
const HloModuleMetadata& prepartitioning_metadata);
|
||||
|
||||
// Setters for HloModuleMetadataProto.
|
||||
void set_module_group_name(const std::string& name) {
|
||||
module_metadata_.set_module_group_name(name);
|
||||
}
|
||||
void set_canonical_module_id(int64 id) {
|
||||
module_metadata_.set_canonical_module_id(id);
|
||||
}
|
||||
void add_partitioned_module_id(int64 id) {
|
||||
module_metadata_.add_partitioned_module_ids(id);
|
||||
}
|
||||
|
||||
// Setters for the current HloPassMetadata.
|
||||
Status set_current_pass_name(const std::string& pass_name) {
|
||||
return MutateCurrentHloPassMetadata(
|
||||
[&pass_name](HloPassMetadata* pass_metadata) {
|
||||
pass_metadata->set_pass_name(pass_name);
|
||||
});
|
||||
}
|
||||
Status set_current_pass_pipeline_name(const std::string& pipeline_name) {
|
||||
return MutateCurrentHloPassMetadata(
|
||||
[&pipeline_name](HloPassMetadata* pass_metadata) {
|
||||
pass_metadata->set_pipeline_name(pipeline_name);
|
||||
});
|
||||
}
|
||||
Status add_current_pass_dump_filename(const std::string& dump_filename) {
|
||||
return MutateCurrentHloPassMetadata(
|
||||
[&dump_filename](HloPassMetadata* pass_metadata) {
|
||||
pass_metadata->add_dump_filenames(dump_filename);
|
||||
});
|
||||
}
|
||||
Status set_current_pass_module_changed(bool module_changed) {
|
||||
return MutateCurrentHloPassMetadata(
|
||||
[&module_changed](HloPassMetadata* pass_metadata) {
|
||||
pass_metadata->set_module_changed(module_changed);
|
||||
});
|
||||
}
|
||||
Status set_current_pass_module_id(int64 module_id) {
|
||||
return MutateCurrentHloPassMetadata(
|
||||
[&module_id](HloPassMetadata* pass_metadata) {
|
||||
pass_metadata->set_module_id(module_id);
|
||||
});
|
||||
}
|
||||
Status add_current_pass_module_group_module_id(int64 module_id) {
|
||||
return MutateCurrentHloPassMetadata(
|
||||
[&module_id](HloPassMetadata* pass_metadata) {
|
||||
pass_metadata->add_module_group_module_ids(module_id);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
// Gets mutable metadata for the currently running pass. If passes are nested,
|
||||
// finds the deepest one still running. Returns NotFound if metadata for the
|
||||
// currently running pass cannot be found.
|
||||
StatusOr<HloPassMetadata*> GetCurrentHloPassMetadata();
|
||||
|
||||
Status MutateCurrentHloPassMetadata(
|
||||
const std::function<void(HloPassMetadata*)>& mutator);
|
||||
|
||||
HloModuleMetadataProto module_metadata_;
|
||||
tensorflow::Env* env_;
|
||||
int64 next_pass_id_ = 1;
|
||||
|
||||
// Stack of metadata for passes that are currently running. Size > 1 iff
|
||||
// passes are nested.
|
||||
std::vector<HloPassMetadata*> running_passes_;
|
||||
|
||||
// Metadata from before the module was partitioned, if applicable.
|
||||
absl::optional<HloModuleMetadataProto> prepartitioning_metadata_ =
|
||||
absl::nullopt;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
|
145
tensorflow/compiler/xla/service/hlo_module_metadata_test.cc
Normal file
145
tensorflow/compiler/xla/service/hlo_module_metadata_test.cc
Normal file
@ -0,0 +1,145 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::Property;
|
||||
using ::testing::StrEq;
|
||||
|
||||
class TestEnv : public tensorflow::EnvWrapper {
|
||||
public:
|
||||
TestEnv() : EnvWrapper(Env::Default()) {}
|
||||
|
||||
uint64 NowMicros() const override { return current_micros_; }
|
||||
|
||||
void SetCurrentMicros(uint64 micros) { current_micros_ = micros; }
|
||||
|
||||
private:
|
||||
uint64 current_micros_ = 1;
|
||||
};
|
||||
|
||||
TEST(HloModuleMetadata, RecordsPassStart) {
|
||||
TestEnv env;
|
||||
HloModuleMetadata module_metadata(&env);
|
||||
env.SetCurrentMicros(1234);
|
||||
module_metadata.RecordPassStart();
|
||||
EXPECT_THAT(
|
||||
module_metadata.proto().pass_metadata(),
|
||||
ElementsAre(Property(&HloPassMetadata::start_timestamp_usec, 1234)));
|
||||
}
|
||||
|
||||
TEST(HloModuleMetadata, RecordsPassEnd) {
|
||||
TestEnv env;
|
||||
HloModuleMetadata module_metadata(&env);
|
||||
module_metadata.RecordPassStart();
|
||||
env.SetCurrentMicros(4321);
|
||||
EXPECT_IS_OK(module_metadata.RecordPassEnd());
|
||||
EXPECT_THAT(
|
||||
module_metadata.proto().pass_metadata(),
|
||||
ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 4321)));
|
||||
}
|
||||
|
||||
TEST(HloModuleMetadata, RecordsPassEndInNestedMetadata) {
|
||||
TestEnv env;
|
||||
HloModuleMetadata module_metadata(&env);
|
||||
module_metadata.RecordPassStart();
|
||||
module_metadata.RecordPassStart();
|
||||
env.SetCurrentMicros(111);
|
||||
EXPECT_IS_OK(module_metadata.RecordPassEnd());
|
||||
EXPECT_THAT(module_metadata.proto().pass_metadata(),
|
||||
ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 0),
|
||||
Property(&HloPassMetadata::end_timestamp_usec, 111)));
|
||||
|
||||
env.SetCurrentMicros(222);
|
||||
EXPECT_IS_OK(module_metadata.RecordPassEnd());
|
||||
EXPECT_THAT(module_metadata.proto().pass_metadata(),
|
||||
ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 222),
|
||||
Property(&HloPassMetadata::end_timestamp_usec, 111)));
|
||||
}
|
||||
|
||||
TEST(HloModuleMetadata, RecordPassEndReturnsNotFound) {
|
||||
HloModuleMetadata module_metadata(tensorflow::Env::Default());
|
||||
EXPECT_EQ(module_metadata.RecordPassEnd().code(),
|
||||
tensorflow::error::NOT_FOUND);
|
||||
|
||||
module_metadata.RecordPassStart();
|
||||
EXPECT_IS_OK(module_metadata.RecordPassEnd());
|
||||
EXPECT_EQ(module_metadata.RecordPassEnd().code(),
|
||||
tensorflow::error::NOT_FOUND);
|
||||
}
|
||||
|
||||
TEST(HloModuleMetadata, SetsHloPassMetadataFields) {
|
||||
HloModuleMetadata module_metadata(tensorflow::Env::Default());
|
||||
module_metadata.RecordPassStart();
|
||||
EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
|
||||
EXPECT_THAT(
|
||||
module_metadata.proto().pass_metadata(),
|
||||
ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
|
||||
}
|
||||
|
||||
TEST(HloModuleMetadata, SetsHloPassMetadataFieldsInNestedMetadata) {
|
||||
HloModuleMetadata module_metadata(tensorflow::Env::Default());
|
||||
module_metadata.RecordPassStart();
|
||||
module_metadata.RecordPassStart();
|
||||
EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
|
||||
EXPECT_THAT(
|
||||
module_metadata.proto().pass_metadata(),
|
||||
ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("")),
|
||||
Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
|
||||
}
|
||||
|
||||
TEST(HloModuleMetadata, SetterReturnsNotFound) {
|
||||
HloModuleMetadata module_metadata(tensorflow::Env::Default());
|
||||
EXPECT_EQ(module_metadata.set_current_pass_name("fake name").code(),
|
||||
tensorflow::error::NOT_FOUND);
|
||||
}
|
||||
|
||||
TEST(HloModuleMetadata, CopiesRunningPrepartitioningPasses) {
|
||||
HloModuleMetadata old_module_metadata(tensorflow::Env::Default());
|
||||
old_module_metadata.RecordPassStart();
|
||||
EXPECT_IS_OK(old_module_metadata.set_current_pass_name("outer pass"));
|
||||
|
||||
old_module_metadata.RecordPassStart();
|
||||
EXPECT_IS_OK(old_module_metadata.set_current_pass_name("finished pass"));
|
||||
EXPECT_IS_OK(old_module_metadata.RecordPassEnd());
|
||||
|
||||
old_module_metadata.RecordPassStart();
|
||||
EXPECT_IS_OK(old_module_metadata.set_current_pass_name("inner pass"));
|
||||
|
||||
HloModuleMetadata new_module_metadata(tensorflow::Env::Default());
|
||||
new_module_metadata.set_prepartitioning_metadata(old_module_metadata);
|
||||
|
||||
// Passes that are still running go in the new module.
|
||||
EXPECT_THAT(
|
||||
new_module_metadata.proto().pass_metadata(),
|
||||
ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("outer pass")),
|
||||
Property(&HloPassMetadata::pass_name, StrEq("inner pass"))));
|
||||
|
||||
// Passes that finished go in the prepartitioning metadata.
|
||||
EXPECT_THAT(new_module_metadata.prepartitioning_metadata()->pass_metadata(),
|
||||
ElementsAre(Property(&HloPassMetadata::pass_name,
|
||||
StrEq("finished pass"))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -31,6 +31,72 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
void RecordPassStartMetadata(HloModule& module, const std::string& pass_name,
|
||||
const std::string& pipeline_name) {
|
||||
module.metadata()->RecordPassStart();
|
||||
// An HloPassMetadata was just created so Status should always be OK.
|
||||
TF_CHECK_OK(module.metadata()->set_current_pass_name(pass_name));
|
||||
TF_CHECK_OK(module.metadata()->set_current_pass_pipeline_name(pipeline_name));
|
||||
}
|
||||
|
||||
void RecordPassStartMetadata(HloModuleGroup& module_group,
|
||||
const std::string& pass_name,
|
||||
const std::string& pipeline_name) {
|
||||
for (HloModule* module : module_group.modules()) {
|
||||
RecordPassStartMetadata(*module, pass_name, pipeline_name);
|
||||
}
|
||||
}
|
||||
|
||||
Status AttemptRecordPassEndMetadata(HloModule& module,
|
||||
const std::string& pass_name,
|
||||
bool module_changed) {
|
||||
// Module id is set here instead of RecordPassStartMetadata because it may
|
||||
// change in the middle of the pass, and we want the final id.
|
||||
TF_RETURN_IF_ERROR(
|
||||
module.metadata()->set_current_pass_module_id(module.unique_id()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
module.metadata()->set_current_pass_module_changed(module_changed));
|
||||
TF_RETURN_IF_ERROR(module.metadata()->RecordPassEnd());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecordPassEndMetadata(HloModule& module, const std::string& pass_name,
|
||||
bool module_changed) {
|
||||
Status status =
|
||||
AttemptRecordPassEndMetadata(module, pass_name, module_changed);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << status.error_message();
|
||||
}
|
||||
}
|
||||
|
||||
Status AttemptRecordPassEndMetadata(HloModuleGroup& module_group,
|
||||
const std::string& pass_name,
|
||||
bool module_changed) {
|
||||
for (HloModule* module : module_group.modules()) {
|
||||
for (HloModule* other_module : module_group.modules()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
module->metadata()->add_current_pass_module_group_module_id(
|
||||
other_module->unique_id()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
AttemptRecordPassEndMetadata(*module, pass_name, module_changed));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RecordPassEndMetadata(HloModuleGroup& module_group,
|
||||
const std::string& pass_name, bool module_changed) {
|
||||
Status status =
|
||||
AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << status.error_message();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename HloT>
|
||||
Status HloPassPipeline::RunInvariantCheckers(
|
||||
HloT* hlo, absl::string_view after_pass_name) {
|
||||
@ -54,34 +120,48 @@ Status HloPassPipeline::RunInvariantCheckers(
|
||||
template <typename HloT>
|
||||
StatusOr<bool> HloPassPipeline::RunPassesInternal(
|
||||
HloT* hlo, absl::Span<HloPassInterface* const> passes) {
|
||||
string last_pass_name = "pipeline-start";
|
||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
|
||||
static constexpr absl::string_view kPipelineStart = "pipeline-start";
|
||||
static constexpr absl::string_view kPipelineEnd = "pipeline-end";
|
||||
std::string pipeline_name = std::string(name());
|
||||
|
||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
|
||||
|
||||
RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
|
||||
MaybeDumpHloAndSaveFilenames(*hlo,
|
||||
/*after_pass_name=*/kPipelineStart,
|
||||
/*before_pass_name=*/passes.empty()
|
||||
? kPipelineEnd
|
||||
: passes.front()->name());
|
||||
RecordPassEndMetadata(*hlo, std::string(kPipelineStart),
|
||||
/*module_changed=*/false);
|
||||
|
||||
bool changed = false;
|
||||
for (HloPassInterface* pass : passes) {
|
||||
for (int i = 0; i < passes.size(); i++) {
|
||||
HloPassInterface* pass = passes[i];
|
||||
XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name()));
|
||||
absl::string_view pass_name = pass->name();
|
||||
std::string pass_name = std::string(pass->name());
|
||||
VLOG(1) << " HLO pass " << pass_name;
|
||||
VLOG(2) << " Module hash " << hlo->Hash();
|
||||
MaybeDumpHlo(*hlo,
|
||||
/*after_pass_name=*/last_pass_name,
|
||||
/*before_pass_name=*/pass_name);
|
||||
if (!pass->IsPassPipeline()) {
|
||||
compilation_stats_->StartPass(pass_name);
|
||||
}
|
||||
RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
|
||||
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
|
||||
MaybeDumpHloAndSaveFilenames(*hlo,
|
||||
/*after_pass_name=*/pass_name,
|
||||
/*before_pass_name=*/i + 1 >= passes.size()
|
||||
? kPipelineEnd
|
||||
: passes[i + 1]->name());
|
||||
RecordPassEndMetadata(*hlo, pass_name, pass_changed);
|
||||
changed |= pass_changed;
|
||||
if (pass_changed) {
|
||||
VLOG(3) << " Pass caused changes" << pass->name();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
|
||||
last_pass_name = string(pass_name);
|
||||
if (!pass->IsPassPipeline()) {
|
||||
compilation_stats_->EndPass(pass_name);
|
||||
}
|
||||
}
|
||||
MaybeDumpHlo(*hlo,
|
||||
/*after_pass_name=*/last_pass_name,
|
||||
/*before_pass_name=*/"pipeline-end");
|
||||
return changed;
|
||||
}
|
||||
|
||||
@ -129,18 +209,23 @@ std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
|
||||
return enabled_passes;
|
||||
}
|
||||
|
||||
void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
|
||||
absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name) {
|
||||
DumpHloModuleBetweenPassesIfEnabled(name(), before_pass_name, after_pass_name,
|
||||
module);
|
||||
void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
|
||||
HloModule& module, absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name) {
|
||||
for (const std::string& filename : DumpHloModuleBetweenPassesIfEnabled(
|
||||
name(), before_pass_name, after_pass_name, module)) {
|
||||
Status status = module.metadata()->add_current_pass_dump_filename(filename);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << status.error_message();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
|
||||
absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name) {
|
||||
for (const HloModule* module : module_group.modules()) {
|
||||
MaybeDumpHlo(*module, after_pass_name, before_pass_name);
|
||||
void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
|
||||
HloModuleGroup& module_group, absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name) {
|
||||
for (HloModule* module : module_group.modules()) {
|
||||
MaybeDumpHloAndSaveFilenames(*module, after_pass_name, before_pass_name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,12 +90,14 @@ class HloPassPipeline : public HloPassInterface {
|
||||
const DebugOptions& debug_options);
|
||||
|
||||
// Maybe dumps the given module or module group depending on flag values
|
||||
// contained in DebugOptions of module config.
|
||||
void MaybeDumpHlo(const HloModuleGroup& module_group,
|
||||
absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name);
|
||||
void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name);
|
||||
// contained in DebugOptions of module config. If it is dumped, saves the
|
||||
// filenames of the dumps into module metadata.
|
||||
void MaybeDumpHloAndSaveFilenames(HloModuleGroup& module_group,
|
||||
absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name);
|
||||
void MaybeDumpHloAndSaveFilenames(HloModule& module,
|
||||
absl::string_view after_pass_name,
|
||||
absl::string_view before_pass_name);
|
||||
|
||||
// Runs the invariant checker on the given HLO. HloT can be either HloModule
|
||||
// or HloModuleGroup.
|
||||
|
@ -26,6 +26,10 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::SizeIs;
|
||||
using ::testing::StrEq;
|
||||
|
||||
class HloPassPipelineTest : public HloTestBase {
|
||||
protected:
|
||||
StatusOr<HloModuleGroup> ParseModuleGroup(
|
||||
@ -255,5 +259,43 @@ ENTRY main {
|
||||
::testing::HasSubstr("Module group pass cannot be run on a module"));
|
||||
}
|
||||
|
||||
// Test that metadata is set when a module group goes through a pass pipeline.
|
||||
TEST_F(HloPassPipelineTest, SetHloModuleMetadata) {
|
||||
HloModuleGroup module_group(TestName());
|
||||
module_group.push_back(CreateNewVerifiedModule());
|
||||
module_group.push_back(CreateNewVerifiedModule());
|
||||
|
||||
HloPassPipeline pipeline(TestName());
|
||||
pipeline.AddPass<BazToQuxModuleGroupPass>();
|
||||
pipeline.AddPass<FooToBarModulePass>();
|
||||
TF_ASSERT_OK(pipeline.RunOnModuleGroup(&module_group).status());
|
||||
ASSERT_THAT(module_group.modules(), SizeIs(2));
|
||||
|
||||
std::vector<std::string> pass_names = {"pipeline-start", "baz2qux",
|
||||
"foo2bar"};
|
||||
std::string pipeline_name = std::string(pipeline.name());
|
||||
for (const HloModule* module : module_group.modules()) {
|
||||
const HloModuleMetadataProto& metadata = module->metadata().proto();
|
||||
EXPECT_EQ(metadata.canonical_module_id(), module->unique_id());
|
||||
EXPECT_EQ(metadata.module_group_name(), module_group.name());
|
||||
|
||||
ASSERT_THAT(metadata.pass_metadata(), SizeIs(3));
|
||||
for (int pass = 0; pass < metadata.pass_metadata().size(); pass++) {
|
||||
const HloPassMetadata& pass_metadata = metadata.pass_metadata(pass);
|
||||
EXPECT_NE(pass_metadata.pass_id(), 0);
|
||||
EXPECT_THAT(pass_metadata.pass_name(), StrEq(pass_names[pass]));
|
||||
EXPECT_THAT(pass_metadata.pipeline_name(), StrEq(pipeline_name));
|
||||
EXPECT_FALSE(pass_metadata.module_changed());
|
||||
EXPECT_EQ(pass_metadata.module_id(), module->unique_id());
|
||||
EXPECT_THAT(pass_metadata.module_group_module_ids(),
|
||||
ElementsAre(module_group.module(0).unique_id(),
|
||||
module_group.module(1).unique_id()));
|
||||
EXPECT_GT(pass_metadata.start_timestamp_usec(), 0);
|
||||
EXPECT_LT(pass_metadata.start_timestamp_usec(),
|
||||
pass_metadata.end_timestamp_usec());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -261,6 +261,9 @@ message DebugOptions {
|
||||
// Max number of hlo module dumps in a directory. Set to < 0 for unbounded.
|
||||
int32 xla_dump_max_hlo_modules = 132;
|
||||
|
||||
// Dump HloModuleMetadata as a text proto for each HLO module.
|
||||
bool xla_dump_module_metadata = 144;
|
||||
|
||||
//
|
||||
// END flags controlling dumping HLO modules.
|
||||
//
|
||||
@ -297,7 +300,7 @@ message DebugOptions {
|
||||
// Enable detailed logging into vlog.
|
||||
bool xla_detailed_logging = 143;
|
||||
|
||||
// Next id: 144
|
||||
// Next id: 145
|
||||
|
||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||
// interpretation of these values is left to the backend.
|
||||
|
Loading…
x
Reference in New Issue
Block a user