[XLA] Add metadata for HLO modules and passes.
When the xla_dump_module_metadata flag is set, one HloModuleMetadata will be dumped per module. PiperOrigin-RevId: 342091720 Change-Id: I278e96d2a0fcf7d94054058ff9fb27bf80b71e82
This commit is contained in:
parent
9aec83d8e6
commit
083b03e2f9
@ -42,6 +42,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
|
|||||||
opts.set_xla_dump_hlo_as_html(false);
|
opts.set_xla_dump_hlo_as_html(false);
|
||||||
opts.set_xla_dump_include_timestamp(true);
|
opts.set_xla_dump_include_timestamp(true);
|
||||||
opts.set_xla_dump_max_hlo_modules(-1);
|
opts.set_xla_dump_max_hlo_modules(-1);
|
||||||
|
opts.set_xla_dump_module_metadata(false);
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
opts.set_xla_cpu_use_mkl_dnn(true);
|
opts.set_xla_cpu_use_mkl_dnn(true);
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
@ -230,7 +231,6 @@ static void AllocateFlags() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
flag_objects = new std::vector<tensorflow::Flag>();
|
flag_objects = new std::vector<tensorflow::Flag>();
|
||||||
flag_objects->reserve(55);
|
|
||||||
// Don't use an initializer list for initializing the vector; this would
|
// Don't use an initializer list for initializing the vector; this would
|
||||||
// create a temporary copy, and exceeds the stack space when compiling with
|
// create a temporary copy, and exceeds the stack space when compiling with
|
||||||
// certain configurations.
|
// certain configurations.
|
||||||
@ -513,6 +513,12 @@ static void AllocateFlags() {
|
|||||||
flag_values->xla_dump_max_hlo_modules(),
|
flag_values->xla_dump_max_hlo_modules(),
|
||||||
"Max number of hlo module dumps in a directory. Set to < 0 for "
|
"Max number of hlo module dumps in a directory. Set to < 0 for "
|
||||||
"unbounded."));
|
"unbounded."));
|
||||||
|
flag_objects->push_back(tensorflow::Flag(
|
||||||
|
"xla_dump_module_metadata",
|
||||||
|
bool_setter_for(&DebugOptions::set_xla_dump_module_metadata),
|
||||||
|
flag_values->xla_dump_module_metadata(),
|
||||||
|
"Dumps HloModuleMetadata as text protos to the directory specified "
|
||||||
|
"by --xla_dump_to."));
|
||||||
flag_objects->push_back(tensorflow::Flag(
|
flag_objects->push_back(tensorflow::Flag(
|
||||||
"xla_hlo_graph_addresses",
|
"xla_hlo_graph_addresses",
|
||||||
bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),
|
bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),
|
||||||
|
|||||||
@ -412,6 +412,7 @@ cc_library(
|
|||||||
"hlo_instruction.cc",
|
"hlo_instruction.cc",
|
||||||
"hlo_instructions.cc",
|
"hlo_instructions.cc",
|
||||||
"hlo_module.cc",
|
"hlo_module.cc",
|
||||||
|
"hlo_module_metadata.cc",
|
||||||
"hlo_opcode.cc",
|
"hlo_opcode.cc",
|
||||||
"hlo_schedule.cc",
|
"hlo_schedule.cc",
|
||||||
"hlo_sharding.cc",
|
"hlo_sharding.cc",
|
||||||
@ -428,6 +429,7 @@ cc_library(
|
|||||||
"hlo_instruction.h",
|
"hlo_instruction.h",
|
||||||
"hlo_instructions.h",
|
"hlo_instructions.h",
|
||||||
"hlo_module.h",
|
"hlo_module.h",
|
||||||
|
"hlo_module_metadata.h",
|
||||||
"hlo_opcode.h",
|
"hlo_opcode.h",
|
||||||
"hlo_schedule.h",
|
"hlo_schedule.h",
|
||||||
"hlo_sharding.h",
|
"hlo_sharding.h",
|
||||||
@ -3047,6 +3049,18 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "hlo_module_metadata_test",
|
||||||
|
srcs = ["hlo_module_metadata_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "buffer_value",
|
name = "buffer_value",
|
||||||
srcs = ["buffer_value.cc"],
|
srcs = ["buffer_value.cc"],
|
||||||
|
|||||||
@ -45,7 +45,8 @@ struct CanonicalDebugOptions {
|
|||||||
dump_as_url(opts.xla_dump_hlo_as_url()),
|
dump_as_url(opts.xla_dump_hlo_as_url()),
|
||||||
dump_snapshots(opts.xla_dump_hlo_snapshots()),
|
dump_snapshots(opts.xla_dump_hlo_snapshots()),
|
||||||
dump_include_timestamp(opts.xla_dump_include_timestamp()),
|
dump_include_timestamp(opts.xla_dump_include_timestamp()),
|
||||||
dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()) {
|
dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()),
|
||||||
|
dump_module_metadata(opts.xla_dump_module_metadata()) {
|
||||||
// This constructor examines the values in `opts` and turns on other flags
|
// This constructor examines the values in `opts` and turns on other flags
|
||||||
// based on what we think is the user's intent. To reduce confusion about
|
// based on what we think is the user's intent. To reduce confusion about
|
||||||
// what was a user-specified value versus an extrapolated value, within this
|
// what was a user-specified value versus an extrapolated value, within this
|
||||||
@ -137,18 +138,20 @@ struct CanonicalDebugOptions {
|
|||||||
bool dump_snapshots;
|
bool dump_snapshots;
|
||||||
bool dump_include_timestamp;
|
bool dump_include_timestamp;
|
||||||
int64 dump_max_hlo_modules;
|
int64 dump_max_hlo_modules;
|
||||||
|
bool dump_module_metadata;
|
||||||
};
|
};
|
||||||
|
|
||||||
void DumpToFileInDirImpl(string_view filename, string_view contents,
|
absl::optional<std::string> DumpToFileInDirImpl(
|
||||||
|
string_view filename, string_view contents,
|
||||||
const CanonicalDebugOptions& opts) {
|
const CanonicalDebugOptions& opts) {
|
||||||
if (opts.dumping_to_stdout()) {
|
if (opts.dumping_to_stdout()) {
|
||||||
LOG(ERROR) << "Refusing to write " << filename
|
LOG(ERROR) << "Refusing to write " << filename
|
||||||
<< " to stdout. Pass --xla_dump_to=<path> to write to a file.";
|
<< " to stdout. Pass --xla_dump_to=<path> to write to a file.";
|
||||||
return;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (opts.dump_to.empty()) {
|
if (opts.dump_to.empty()) {
|
||||||
return;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
const string& dir = opts.dump_to;
|
const string& dir = opts.dump_to;
|
||||||
@ -164,7 +167,7 @@ void DumpToFileInDirImpl(string_view filename, string_view contents,
|
|||||||
if (!status.ok() && !env->IsDirectory(dir).ok()) {
|
if (!status.ok() && !env->IsDirectory(dir).ok()) {
|
||||||
LOG(ERROR) << "Could not create directory " << dir
|
LOG(ERROR) << "Could not create directory " << dir
|
||||||
<< " for dumping XLA debug data: " << status;
|
<< " for dumping XLA debug data: " << status;
|
||||||
return;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,7 +184,7 @@ void DumpToFileInDirImpl(string_view filename, string_view contents,
|
|||||||
LOG(ERROR) << "Have already dumped " << matches.size()
|
LOG(ERROR) << "Have already dumped " << matches.size()
|
||||||
<< " modules, more than the limit of "
|
<< " modules, more than the limit of "
|
||||||
<< opts.dump_max_hlo_modules;
|
<< opts.dump_max_hlo_modules;
|
||||||
return;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,33 +195,42 @@ void DumpToFileInDirImpl(string_view filename, string_view contents,
|
|||||||
LOG(ERROR) << "Could not write XLA debug data to " << file_path << ": "
|
LOG(ERROR) << "Could not write XLA debug data to " << file_path << ": "
|
||||||
<< status;
|
<< status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return file_path;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DumpToFileInDirOrStdoutImpl(string_view filename, string_view contents,
|
absl::optional<std::string> DumpToFileInDirOrStdoutImpl(
|
||||||
|
string_view filename, string_view contents,
|
||||||
const CanonicalDebugOptions& opts) {
|
const CanonicalDebugOptions& opts) {
|
||||||
// Dump to stdout if that's called for.
|
// Dump to stdout if that's called for.
|
||||||
if (opts.dumping_to_stdout()) {
|
if (opts.dumping_to_stdout()) {
|
||||||
std::cout << "*** Begin " << filename << " ***\n"
|
std::cout << "*** Begin " << filename << " ***\n"
|
||||||
<< contents << "\n*** End " << filename << " ***" << std::endl;
|
<< contents << "\n*** End " << filename << " ***" << std::endl;
|
||||||
return;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, dump to a file.
|
// Otherwise, dump to a file.
|
||||||
DumpToFileInDirImpl(filename, contents, opts);
|
return DumpToFileInDirImpl(filename, contents, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DumpHloModuleImpl(const HloModule& module,
|
// Returns full file paths of all dumps of the module.
|
||||||
|
std::vector<std::string> DumpHloModuleImpl(const HloModule& module,
|
||||||
const BufferAssignment* buffer_assn,
|
const BufferAssignment* buffer_assn,
|
||||||
const HloExecutionProfile* profile, string_view prefix,
|
const HloExecutionProfile* profile,
|
||||||
string_view suffix, const CanonicalDebugOptions& opts) {
|
string_view prefix,
|
||||||
|
string_view suffix,
|
||||||
|
const CanonicalDebugOptions& opts) {
|
||||||
string filename = FilenameFor(module, prefix, suffix);
|
string filename = FilenameFor(module, prefix, suffix);
|
||||||
|
|
||||||
|
std::vector<absl::optional<std::string>> file_paths;
|
||||||
|
|
||||||
if (opts.dump_as_text) {
|
if (opts.dump_as_text) {
|
||||||
DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"), module.ToString(),
|
file_paths.push_back(DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"),
|
||||||
opts);
|
module.ToString(), opts));
|
||||||
if (buffer_assn) {
|
if (buffer_assn) {
|
||||||
DumpToFileInDirOrStdoutImpl(StrCat(filename, "-buffer-assignment.txt"),
|
file_paths.push_back(DumpToFileInDirOrStdoutImpl(
|
||||||
buffer_assn->ToString(), opts);
|
StrCat(filename, "-buffer-assignment.txt"), buffer_assn->ToString(),
|
||||||
|
opts));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -229,7 +241,8 @@ void DumpHloModuleImpl(const HloModule& module,
|
|||||||
if (!tensorflow::SerializeToStringDeterministic(module_proto, &pb)) {
|
if (!tensorflow::SerializeToStringDeterministic(module_proto, &pb)) {
|
||||||
pb = "Failed to serialize HLO module proto.";
|
pb = "Failed to serialize HLO module proto.";
|
||||||
}
|
}
|
||||||
DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts);
|
file_paths.push_back(
|
||||||
|
DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto render_graph = [&](RenderedGraphFormat format) {
|
auto render_graph = [&](RenderedGraphFormat format) {
|
||||||
@ -244,13 +257,15 @@ void DumpHloModuleImpl(const HloModule& module,
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (opts.dump_as_dot) {
|
if (opts.dump_as_dot) {
|
||||||
|
file_paths.push_back(
|
||||||
DumpToFileInDirImpl(StrFormat("%s.dot", filename),
|
DumpToFileInDirImpl(StrFormat("%s.dot", filename),
|
||||||
render_graph(RenderedGraphFormat::kDot), opts);
|
render_graph(RenderedGraphFormat::kDot), opts));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (opts.dump_as_html) {
|
if (opts.dump_as_html) {
|
||||||
|
file_paths.push_back(
|
||||||
DumpToFileInDirImpl(StrFormat("%s.html", filename),
|
DumpToFileInDirImpl(StrFormat("%s.html", filename),
|
||||||
render_graph(RenderedGraphFormat::kHtml), opts);
|
render_graph(RenderedGraphFormat::kHtml), opts));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case for rendering graphs as URLs. We'll dump them to a file
|
// Special case for rendering graphs as URLs. We'll dump them to a file
|
||||||
@ -259,9 +274,35 @@ void DumpHloModuleImpl(const HloModule& module,
|
|||||||
string url = render_graph(RenderedGraphFormat::kUrl);
|
string url = render_graph(RenderedGraphFormat::kUrl);
|
||||||
std::cout << filename << " --> " << url << std::endl;
|
std::cout << filename << " --> " << url << std::endl;
|
||||||
if (!opts.dumping_to_stdout()) {
|
if (!opts.dumping_to_stdout()) {
|
||||||
DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts);
|
file_paths.push_back(
|
||||||
|
DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> dumped_file_paths;
|
||||||
|
for (const absl::optional<std::string>& path : file_paths) {
|
||||||
|
if (path.has_value()) {
|
||||||
|
dumped_file_paths.push_back(*path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dumped_file_paths;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DumpHloModuleMetadata(const HloModuleMetadataProto& metadata,
|
||||||
|
const CanonicalDebugOptions& opts,
|
||||||
|
absl::flat_hash_set<int64>* dumped_module_ids) {
|
||||||
|
// Return if metadata for this module has already been dumped.
|
||||||
|
if (!dumped_module_ids->insert(metadata.canonical_module_id()).second) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::string filename = absl::StrFormat("module_%04d.metadata.textproto",
|
||||||
|
metadata.canonical_module_id());
|
||||||
|
std::string content;
|
||||||
|
if (tensorflow::protobuf::TextFormat::PrintToString(metadata, &content)) {
|
||||||
|
DumpToFileInDirImpl(filename, content, opts);
|
||||||
|
} else {
|
||||||
|
LOG(ERROR) << "Failed to convert HloModuleMetadataProto to text.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||||
@ -381,18 +422,17 @@ bool DumpingToStdout(const DebugOptions& opts) {
|
|||||||
return CanonicalDebugOptions(opts).dumping_to_stdout();
|
return CanonicalDebugOptions(opts).dumping_to_stdout();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name,
|
std::vector<std::string> DumpHloModuleBetweenPassesIfEnabled(
|
||||||
string_view before_pass_name,
|
string_view pipeline_name, string_view before_pass_name,
|
||||||
string_view after_pass_name,
|
string_view after_pass_name, const HloModule& module) {
|
||||||
const HloModule& module) {
|
|
||||||
CanonicalDebugOptions opts(module.config().debug_options());
|
CanonicalDebugOptions opts(module.config().debug_options());
|
||||||
if (!opts.should_dump_module(module.name())) {
|
if (!opts.should_dump_module(module.name())) {
|
||||||
return;
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!opts.should_dump_pass(before_pass_name) &&
|
if (!opts.should_dump_pass(before_pass_name) &&
|
||||||
!opts.should_dump_pass(after_pass_name)) {
|
!opts.should_dump_pass(after_pass_name)) {
|
||||||
return;
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 step_number = StepNumberForModule(module);
|
int64 step_number = StepNumberForModule(module);
|
||||||
@ -401,7 +441,7 @@ void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name,
|
|||||||
string filename_suffix =
|
string filename_suffix =
|
||||||
StrFormat("%04d.%s.after_%s.before_%s", step_number, pipeline_name,
|
StrFormat("%04d.%s.after_%s.before_%s", step_number, pipeline_name,
|
||||||
after_pass_name, before_pass_name);
|
after_pass_name, before_pass_name);
|
||||||
DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr,
|
return DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr,
|
||||||
timestamp, filename_suffix, opts);
|
timestamp, filename_suffix, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -488,4 +528,21 @@ void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot,
|
|||||||
DumpToFileInDirImpl(filename, pb, canonical_opts);
|
DumpToFileInDirImpl(filename, pb, canonical_opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DumpHloModuleMetadataIfEnabled(const std::vector<HloModule*>& modules) {
|
||||||
|
absl::flat_hash_set<int64> dumped_module_ids;
|
||||||
|
for (const HloModule* module : modules) {
|
||||||
|
CanonicalDebugOptions opts(module->config().debug_options());
|
||||||
|
if (!opts.dump_module_metadata) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
DumpHloModuleMetadata(module->metadata().proto(), opts, &dumped_module_ids);
|
||||||
|
const absl::optional<HloModuleMetadataProto>& prepartitioning_metadata =
|
||||||
|
module->metadata().prepartitioning_metadata();
|
||||||
|
if (prepartitioning_metadata.has_value()) {
|
||||||
|
DumpHloModuleMetadata(*prepartitioning_metadata, opts,
|
||||||
|
&dumped_module_ids);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -75,11 +75,11 @@ void DumpHloModuleIfEnabled(const HloModule& module,
|
|||||||
absl::string_view name);
|
absl::string_view name);
|
||||||
|
|
||||||
// Dumps the given HLO module after running one HLO pass and before running
|
// Dumps the given HLO module after running one HLO pass and before running
|
||||||
// another, if that's enabled.
|
// another, if that's enabled. Returns the full file paths of all dumps of the
|
||||||
void DumpHloModuleBetweenPassesIfEnabled(absl::string_view pipeline_name,
|
// module, or an empty vector if nothing was dumped.
|
||||||
absl::string_view before_pass_name,
|
std::vector<std::string> DumpHloModuleBetweenPassesIfEnabled(
|
||||||
absl::string_view after_pass_name,
|
absl::string_view pipeline_name, absl::string_view before_pass_name,
|
||||||
const HloModule& module);
|
absl::string_view after_pass_name, const HloModule& module);
|
||||||
|
|
||||||
// Dumps the given HLO module during the given HLO pass, if that's enabled.
|
// Dumps the given HLO module during the given HLO pass, if that's enabled.
|
||||||
//
|
//
|
||||||
@ -100,6 +100,8 @@ void DumpHloSnapshotIfEnabled(const HloModule& module,
|
|||||||
void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot,
|
void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot,
|
||||||
const DebugOptions& opts);
|
const DebugOptions& opts);
|
||||||
|
|
||||||
|
void DumpHloModuleMetadataIfEnabled(const std::vector<HloModule*>& modules);
|
||||||
|
|
||||||
// Returns true if we should dump data for an HloModule. This is useful if you
|
// Returns true if we should dump data for an HloModule. This is useful if you
|
||||||
// want to check if DumpToFileInDir{,OrStdout} will do anything before
|
// want to check if DumpToFileInDir{,OrStdout} will do anything before
|
||||||
// generating an expensive string.
|
// generating an expensive string.
|
||||||
|
|||||||
@ -528,3 +528,66 @@ message HloSnapshot {
|
|||||||
// The name of the platform used to run the graph.
|
// The name of the platform used to run the graph.
|
||||||
string execution_platform = 4;
|
string execution_platform = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering
|
||||||
|
// with filename module_####.metadata.textproto, where #### is
|
||||||
|
// canonical_module_id.
|
||||||
|
message HloModuleMetadataProto {
|
||||||
|
// Uniquely identifies an HloModuleMetadata. Equal to the first unique_id
|
||||||
|
// of the module (a module may go through multiple unique_ids). If a module
|
||||||
|
// is partitioned into multiple modules, those modules will each have a new
|
||||||
|
// HloModuleMetadata with a different canonical_module_id.
|
||||||
|
int64 canonical_module_id = 1;
|
||||||
|
|
||||||
|
// Name of the module group that the module is part of.
|
||||||
|
string module_group_name = 2;
|
||||||
|
|
||||||
|
// The canonical module id of the module that this one is partitioned from,
|
||||||
|
// if applicable.
|
||||||
|
int64 original_module_id = 3;
|
||||||
|
|
||||||
|
// The canonical module ids of the modules that this one is partitioned into,
|
||||||
|
// if applicable.
|
||||||
|
repeated int64 partitioned_module_ids = 4;
|
||||||
|
|
||||||
|
// Metadata for the HLO passes that are run on the module.
|
||||||
|
repeated HloPassMetadata pass_metadata = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metadata for one run of an HLO pass on a module. Provides more information
|
||||||
|
// when processing debug dumps of HloProtos about the order of HLO passes and
|
||||||
|
// various other stats like duration. `pass_id` may also be used to identify a
|
||||||
|
// particular run of a pass in debug info that propagates through stages of
|
||||||
|
// compilation.
|
||||||
|
message HloPassMetadata {
|
||||||
|
// For a given module, pass_id uniquely identifies a run of an HLO pass on
|
||||||
|
// that module. Note that a pass_id may not always refer to the same pass
|
||||||
|
// because the order of passes during compilation may change. For finding
|
||||||
|
// metadata for a particular pass, pass_name and pipeline_name would be more
|
||||||
|
// reliable, although note that they may not be unique.
|
||||||
|
int64 pass_id = 1;
|
||||||
|
string pass_name = 2;
|
||||||
|
string pipeline_name = 3;
|
||||||
|
|
||||||
|
// Filenames of the dumps of the module after this pass ran. Module may be
|
||||||
|
// dumped in multiple formats, and the order of formats in this field will
|
||||||
|
// stay consistent across passes.
|
||||||
|
repeated string dump_filenames = 4;
|
||||||
|
|
||||||
|
// Return value of pass.Run(). True if this pass changed the module, or, in
|
||||||
|
// the case where the module was run through this pass as part of a module
|
||||||
|
// group, true if this pass changed any module in the same module group.
|
||||||
|
bool module_changed = 5;
|
||||||
|
|
||||||
|
// The unique_id of the module that this pass is run on. May be different from
|
||||||
|
// the canonical_module_id of the HloModuleMetadata that this HloPassMetadata
|
||||||
|
// is inside.
|
||||||
|
int64 module_id = 6;
|
||||||
|
|
||||||
|
// If the module went through this pass as part of a module group, this is
|
||||||
|
// set as the ids of all the modules in the module group. Empty otherwise.
|
||||||
|
repeated int64 module_group_module_ids = 7;
|
||||||
|
|
||||||
|
int64 start_timestamp_usec = 8;
|
||||||
|
int64 end_timestamp_usec = 9;
|
||||||
|
}
|
||||||
|
|||||||
@ -44,7 +44,10 @@ namespace xla {
|
|||||||
HloModule::HloModule(const string& name, HloModuleConfig config)
|
HloModule::HloModule(const string& name, HloModuleConfig config)
|
||||||
: name_(NameUniquer::GetSanitizedName(name)),
|
: name_(NameUniquer::GetSanitizedName(name)),
|
||||||
config_(std::move(config)),
|
config_(std::move(config)),
|
||||||
unique_id_(next_unique_module_id_++) {}
|
unique_id_(next_unique_module_id_++),
|
||||||
|
metadata_(tensorflow::Env::Default()) {
|
||||||
|
metadata_.set_canonical_module_id(unique_id_);
|
||||||
|
}
|
||||||
|
|
||||||
Status HloModule::set_schedule(HloSchedule schedule) {
|
Status HloModule::set_schedule(HloSchedule schedule) {
|
||||||
TF_RET_CHECK(schedule.module() == this);
|
TF_RET_CHECK(schedule.module() == this);
|
||||||
|
|||||||
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
@ -363,6 +364,16 @@ class HloModule {
|
|||||||
return cross_program_prefetches_;
|
return cross_program_prefetches_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const HloModuleMetadata& metadata() const { return metadata_; }
|
||||||
|
HloModuleMetadata* metadata() { return &metadata_; }
|
||||||
|
|
||||||
|
// Moves (not copies) metadata from this HloModule to `module`. To be used
|
||||||
|
// in cases like HloModuleGroup::ReplaceModule when metadata should be
|
||||||
|
// transferred out of a module before it's destroyed.
|
||||||
|
void MoveMetadataToModule(HloModule* module) {
|
||||||
|
module->metadata_ = std::move(metadata_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
HloComputation* AddComputationInternal(
|
HloComputation* AddComputationInternal(
|
||||||
std::unique_ptr<HloComputation> computation, bool is_entry,
|
std::unique_ptr<HloComputation> computation, bool is_entry,
|
||||||
@ -413,6 +424,9 @@ class HloModule {
|
|||||||
|
|
||||||
// Arguments to be prefetched across programs.
|
// Arguments to be prefetched across programs.
|
||||||
std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_;
|
std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_;
|
||||||
|
|
||||||
|
// Metadata for this module, such as its canonical id and the HLO passes run.
|
||||||
|
HloModuleMetadata metadata_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -96,12 +96,14 @@ uint64 HloModuleGroup::Hash() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) {
|
void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) {
|
||||||
|
module->metadata()->set_module_group_name(name());
|
||||||
modules_.push_back(std::move(module));
|
modules_.push_back(std::move(module));
|
||||||
module_ptrs_.push_back(modules_.back().get());
|
module_ptrs_.push_back(modules_.back().get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloModuleGroup::ReplaceModule(int index,
|
void HloModuleGroup::ReplaceModule(int index,
|
||||||
std::unique_ptr<HloModule> module) {
|
std::unique_ptr<HloModule> module) {
|
||||||
|
modules_.at(index)->MoveMetadataToModule(module.get());
|
||||||
modules_.at(index) = std::move(module);
|
modules_.at(index) = std::move(module);
|
||||||
module_ptrs_.at(index) = modules_.at(index).get();
|
module_ptrs_.at(index) = modules_.at(index).get();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,6 +27,8 @@ namespace xla {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
namespace op = ::xla::testing::opcode_matchers;
|
namespace op = ::xla::testing::opcode_matchers;
|
||||||
|
using ::testing::Property;
|
||||||
|
using ::testing::StrEq;
|
||||||
|
|
||||||
class HloModuleGroupTest : public HloTestBase {
|
class HloModuleGroupTest : public HloTestBase {
|
||||||
protected:
|
protected:
|
||||||
@ -202,6 +204,31 @@ ENTRY entry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that metadata is transferred when a module is replaced.
|
||||||
|
TEST_F(HloModuleGroupTest, ReplaceModuleMetadata) {
|
||||||
|
auto old_module = CreateNewVerifiedModule();
|
||||||
|
int old_module_id = old_module->unique_id();
|
||||||
|
old_module->metadata()->RecordPassStart();
|
||||||
|
TF_EXPECT_OK(old_module->metadata()->set_current_pass_name("fake pass"));
|
||||||
|
|
||||||
|
HloModuleGroup group(std::move(old_module));
|
||||||
|
EXPECT_EQ(group.module(0).metadata()->proto().module_group_name(),
|
||||||
|
group.name());
|
||||||
|
|
||||||
|
auto new_module = CreateNewVerifiedModule();
|
||||||
|
group.ReplaceModule(0, std::move(new_module));
|
||||||
|
|
||||||
|
EXPECT_NE(group.module(0).unique_id(), old_module_id);
|
||||||
|
const HloModuleMetadataProto& module_metadata =
|
||||||
|
group.module(0).metadata()->proto();
|
||||||
|
EXPECT_EQ(module_metadata.canonical_module_id(), old_module_id);
|
||||||
|
|
||||||
|
const HloPassMetadata& pass_metadata =
|
||||||
|
*module_metadata.pass_metadata().rbegin();
|
||||||
|
EXPECT_THAT(pass_metadata,
|
||||||
|
Property(&HloPassMetadata::pass_name, StrEq("fake pass")));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
87
tensorflow/compiler/xla/service/hlo_module_metadata.cc
Normal file
87
tensorflow/compiler/xla/service/hlo_module_metadata.cc
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
StatusOr<HloPassMetadata*> HloModuleMetadata::GetCurrentHloPassMetadata() {
|
||||||
|
if (running_passes_.empty()) {
|
||||||
|
return NotFound(
|
||||||
|
"HloPassMetadata for currently running pass not found, either because "
|
||||||
|
"the pass did not call RecordPassStart or because a pass is "
|
||||||
|
"creating/switching modules without using "
|
||||||
|
"HloModuleGroup::ReplaceModule.");
|
||||||
|
}
|
||||||
|
return running_passes_.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HloModuleMetadata::MutateCurrentHloPassMetadata(
|
||||||
|
const std::function<void(HloPassMetadata*)>& mutator) {
|
||||||
|
TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
|
||||||
|
GetCurrentHloPassMetadata());
|
||||||
|
mutator(pass_metadata);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void HloModuleMetadata::RecordPassStart() {
|
||||||
|
HloPassMetadata* pass_metadata = module_metadata_.add_pass_metadata();
|
||||||
|
pass_metadata->set_pass_id(next_pass_id_++);
|
||||||
|
pass_metadata->set_start_timestamp_usec(env_->NowMicros());
|
||||||
|
running_passes_.push_back(pass_metadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HloModuleMetadata::RecordPassEnd() {
|
||||||
|
TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
|
||||||
|
GetCurrentHloPassMetadata());
|
||||||
|
pass_metadata->set_end_timestamp_usec(env_->NowMicros());
|
||||||
|
running_passes_.pop_back();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void HloModuleMetadata::set_prepartitioning_metadata(
|
||||||
|
const HloModuleMetadata& prepartitioning_metadata) {
|
||||||
|
module_metadata_.set_original_module_id(
|
||||||
|
prepartitioning_metadata.proto().canonical_module_id());
|
||||||
|
prepartitioning_metadata_ = prepartitioning_metadata.proto();
|
||||||
|
prepartitioning_metadata_->clear_pass_metadata();
|
||||||
|
|
||||||
|
// Because HloPassMetadata represents the completion of a pass, metadata for
|
||||||
|
// all currently running passes need to be moved over to the new module.
|
||||||
|
absl::flat_hash_set<HloPassMetadata*> running_passes(
|
||||||
|
prepartitioning_metadata.running_passes_.begin(),
|
||||||
|
prepartitioning_metadata.running_passes_.end());
|
||||||
|
for (const HloPassMetadata& pass_metadata :
|
||||||
|
prepartitioning_metadata.proto().pass_metadata()) {
|
||||||
|
if (running_passes.contains(&pass_metadata)) {
|
||||||
|
HloPassMetadata* added_pass_metadata =
|
||||||
|
module_metadata_.add_pass_metadata();
|
||||||
|
*added_pass_metadata = pass_metadata;
|
||||||
|
running_passes_.push_back(added_pass_metadata);
|
||||||
|
next_pass_id_ =
|
||||||
|
std::max(next_pass_id_,
|
||||||
|
static_cast<int64>(added_pass_metadata->pass_id()) + 1);
|
||||||
|
} else {
|
||||||
|
*prepartitioning_metadata_->add_pass_metadata() = pass_metadata;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
126
tensorflow/compiler/xla/service/hlo_module_metadata.h
Normal file
126
tensorflow/compiler/xla/service/hlo_module_metadata.h
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Wrapper class for HloModuleMetadataProto to avoid allowing callers to mutate
|
||||||
|
// arbitrary fields. Specifically, callers cannot set timestamps or ids or
|
||||||
|
// set the fields of any pass not currently running.
|
||||||
|
class HloModuleMetadata {
|
||||||
|
public:
|
||||||
|
explicit HloModuleMetadata(tensorflow::Env* env) : env_(env) {}
|
||||||
|
|
||||||
|
const HloModuleMetadataProto& proto() const { return module_metadata_; }
|
||||||
|
|
||||||
|
// Creates a new HloPassMetadata. All calls to RecordPassStart should be
|
||||||
|
// matched by a later call to RecordPassEnd.
|
||||||
|
void RecordPassStart();
|
||||||
|
|
||||||
|
// Marks the currently running pass as finished. Returns NotFound if metadata
|
||||||
|
// for the currently running pass cannot be found.
|
||||||
|
Status RecordPassEnd();
|
||||||
|
|
||||||
|
const absl::optional<HloModuleMetadataProto>& prepartitioning_metadata()
|
||||||
|
const {
|
||||||
|
return prepartitioning_metadata_;
|
||||||
|
}
|
||||||
|
void set_prepartitioning_metadata(
|
||||||
|
const HloModuleMetadata& prepartitioning_metadata);
|
||||||
|
|
||||||
|
// Setters for HloModuleMetadataProto.
|
||||||
|
void set_module_group_name(const std::string& name) {
|
||||||
|
module_metadata_.set_module_group_name(name);
|
||||||
|
}
|
||||||
|
void set_canonical_module_id(int64 id) {
|
||||||
|
module_metadata_.set_canonical_module_id(id);
|
||||||
|
}
|
||||||
|
void add_partitioned_module_id(int64 id) {
|
||||||
|
module_metadata_.add_partitioned_module_ids(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setters for the current HloPassMetadata.
|
||||||
|
Status set_current_pass_name(const std::string& pass_name) {
|
||||||
|
return MutateCurrentHloPassMetadata(
|
||||||
|
[&pass_name](HloPassMetadata* pass_metadata) {
|
||||||
|
pass_metadata->set_pass_name(pass_name);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Status set_current_pass_pipeline_name(const std::string& pipeline_name) {
|
||||||
|
return MutateCurrentHloPassMetadata(
|
||||||
|
[&pipeline_name](HloPassMetadata* pass_metadata) {
|
||||||
|
pass_metadata->set_pipeline_name(pipeline_name);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Status add_current_pass_dump_filename(const std::string& dump_filename) {
|
||||||
|
return MutateCurrentHloPassMetadata(
|
||||||
|
[&dump_filename](HloPassMetadata* pass_metadata) {
|
||||||
|
pass_metadata->add_dump_filenames(dump_filename);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Status set_current_pass_module_changed(bool module_changed) {
|
||||||
|
return MutateCurrentHloPassMetadata(
|
||||||
|
[&module_changed](HloPassMetadata* pass_metadata) {
|
||||||
|
pass_metadata->set_module_changed(module_changed);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Status set_current_pass_module_id(int64 module_id) {
|
||||||
|
return MutateCurrentHloPassMetadata(
|
||||||
|
[&module_id](HloPassMetadata* pass_metadata) {
|
||||||
|
pass_metadata->set_module_id(module_id);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Status add_current_pass_module_group_module_id(int64 module_id) {
|
||||||
|
return MutateCurrentHloPassMetadata(
|
||||||
|
[&module_id](HloPassMetadata* pass_metadata) {
|
||||||
|
pass_metadata->add_module_group_module_ids(module_id);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Gets mutable metadata for the currently running pass. If passes are nested,
|
||||||
|
// finds the deepest one still running. Returns NotFound if metadata for the
|
||||||
|
// currently running pass cannot be found.
|
||||||
|
StatusOr<HloPassMetadata*> GetCurrentHloPassMetadata();
|
||||||
|
|
||||||
|
Status MutateCurrentHloPassMetadata(
|
||||||
|
const std::function<void(HloPassMetadata*)>& mutator);
|
||||||
|
|
||||||
|
HloModuleMetadataProto module_metadata_;
|
||||||
|
tensorflow::Env* env_;
|
||||||
|
int64 next_pass_id_ = 1;
|
||||||
|
|
||||||
|
// Stack of metadata for passes that are currently running. Size > 1 iff
|
||||||
|
// passes are nested.
|
||||||
|
std::vector<HloPassMetadata*> running_passes_;
|
||||||
|
|
||||||
|
// Metadata from before the module was partitioned, if applicable.
|
||||||
|
absl::optional<HloModuleMetadataProto> prepartitioning_metadata_ =
|
||||||
|
absl::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_
|
||||||
145
tensorflow/compiler/xla/service/hlo_module_metadata_test.cc
Normal file
145
tensorflow/compiler/xla/service/hlo_module_metadata_test.cc
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module_metadata.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::Property;
|
||||||
|
using ::testing::StrEq;
|
||||||
|
|
||||||
|
class TestEnv : public tensorflow::EnvWrapper {
|
||||||
|
public:
|
||||||
|
TestEnv() : EnvWrapper(Env::Default()) {}
|
||||||
|
|
||||||
|
uint64 NowMicros() const override { return current_micros_; }
|
||||||
|
|
||||||
|
void SetCurrentMicros(uint64 micros) { current_micros_ = micros; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint64 current_micros_ = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(HloModuleMetadata, RecordsPassStart) {
|
||||||
|
TestEnv env;
|
||||||
|
HloModuleMetadata module_metadata(&env);
|
||||||
|
env.SetCurrentMicros(1234);
|
||||||
|
module_metadata.RecordPassStart();
|
||||||
|
EXPECT_THAT(
|
||||||
|
module_metadata.proto().pass_metadata(),
|
||||||
|
ElementsAre(Property(&HloPassMetadata::start_timestamp_usec, 1234)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HloModuleMetadata, RecordsPassEnd) {
|
||||||
|
TestEnv env;
|
||||||
|
HloModuleMetadata module_metadata(&env);
|
||||||
|
module_metadata.RecordPassStart();
|
||||||
|
env.SetCurrentMicros(4321);
|
||||||
|
EXPECT_IS_OK(module_metadata.RecordPassEnd());
|
||||||
|
EXPECT_THAT(
|
||||||
|
module_metadata.proto().pass_metadata(),
|
||||||
|
ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 4321)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HloModuleMetadata, RecordsPassEndInNestedMetadata) {
|
||||||
|
TestEnv env;
|
||||||
|
HloModuleMetadata module_metadata(&env);
|
||||||
|
module_metadata.RecordPassStart();
|
||||||
|
module_metadata.RecordPassStart();
|
||||||
|
env.SetCurrentMicros(111);
|
||||||
|
EXPECT_IS_OK(module_metadata.RecordPassEnd());
|
||||||
|
EXPECT_THAT(module_metadata.proto().pass_metadata(),
|
||||||
|
ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 0),
|
||||||
|
Property(&HloPassMetadata::end_timestamp_usec, 111)));
|
||||||
|
|
||||||
|
env.SetCurrentMicros(222);
|
||||||
|
EXPECT_IS_OK(module_metadata.RecordPassEnd());
|
||||||
|
EXPECT_THAT(module_metadata.proto().pass_metadata(),
|
||||||
|
ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 222),
|
||||||
|
Property(&HloPassMetadata::end_timestamp_usec, 111)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HloModuleMetadata, RecordPassEndReturnsNotFound) {
|
||||||
|
HloModuleMetadata module_metadata(tensorflow::Env::Default());
|
||||||
|
EXPECT_EQ(module_metadata.RecordPassEnd().code(),
|
||||||
|
tensorflow::error::NOT_FOUND);
|
||||||
|
|
||||||
|
module_metadata.RecordPassStart();
|
||||||
|
EXPECT_IS_OK(module_metadata.RecordPassEnd());
|
||||||
|
EXPECT_EQ(module_metadata.RecordPassEnd().code(),
|
||||||
|
tensorflow::error::NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HloModuleMetadata, SetsHloPassMetadataFields) {
|
||||||
|
HloModuleMetadata module_metadata(tensorflow::Env::Default());
|
||||||
|
module_metadata.RecordPassStart();
|
||||||
|
EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
|
||||||
|
EXPECT_THAT(
|
||||||
|
module_metadata.proto().pass_metadata(),
|
||||||
|
ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HloModuleMetadata, SetsHloPassMetadataFieldsInNestedMetadata) {
|
||||||
|
HloModuleMetadata module_metadata(tensorflow::Env::Default());
|
||||||
|
module_metadata.RecordPassStart();
|
||||||
|
module_metadata.RecordPassStart();
|
||||||
|
EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name"));
|
||||||
|
EXPECT_THAT(
|
||||||
|
module_metadata.proto().pass_metadata(),
|
||||||
|
ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("")),
|
||||||
|
Property(&HloPassMetadata::pass_name, StrEq("fake name"))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HloModuleMetadata, SetterReturnsNotFound) {
|
||||||
|
HloModuleMetadata module_metadata(tensorflow::Env::Default());
|
||||||
|
EXPECT_EQ(module_metadata.set_current_pass_name("fake name").code(),
|
||||||
|
tensorflow::error::NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HloModuleMetadata, CopiesRunningPrepartitioningPasses) {
|
||||||
|
HloModuleMetadata old_module_metadata(tensorflow::Env::Default());
|
||||||
|
old_module_metadata.RecordPassStart();
|
||||||
|
EXPECT_IS_OK(old_module_metadata.set_current_pass_name("outer pass"));
|
||||||
|
|
||||||
|
old_module_metadata.RecordPassStart();
|
||||||
|
EXPECT_IS_OK(old_module_metadata.set_current_pass_name("finished pass"));
|
||||||
|
EXPECT_IS_OK(old_module_metadata.RecordPassEnd());
|
||||||
|
|
||||||
|
old_module_metadata.RecordPassStart();
|
||||||
|
EXPECT_IS_OK(old_module_metadata.set_current_pass_name("inner pass"));
|
||||||
|
|
||||||
|
HloModuleMetadata new_module_metadata(tensorflow::Env::Default());
|
||||||
|
new_module_metadata.set_prepartitioning_metadata(old_module_metadata);
|
||||||
|
|
||||||
|
// Passes that are still running go in the new module.
|
||||||
|
EXPECT_THAT(
|
||||||
|
new_module_metadata.proto().pass_metadata(),
|
||||||
|
ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("outer pass")),
|
||||||
|
Property(&HloPassMetadata::pass_name, StrEq("inner pass"))));
|
||||||
|
|
||||||
|
// Passes that finished go in the prepartitioning metadata.
|
||||||
|
EXPECT_THAT(new_module_metadata.prepartitioning_metadata()->pass_metadata(),
|
||||||
|
ElementsAre(Property(&HloPassMetadata::pass_name,
|
||||||
|
StrEq("finished pass"))));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
||||||
@ -31,6 +31,72 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void RecordPassStartMetadata(HloModule& module, const std::string& pass_name,
|
||||||
|
const std::string& pipeline_name) {
|
||||||
|
module.metadata()->RecordPassStart();
|
||||||
|
// An HloPassMetadata was just created so Status should always be OK.
|
||||||
|
TF_CHECK_OK(module.metadata()->set_current_pass_name(pass_name));
|
||||||
|
TF_CHECK_OK(module.metadata()->set_current_pass_pipeline_name(pipeline_name));
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecordPassStartMetadata(HloModuleGroup& module_group,
|
||||||
|
const std::string& pass_name,
|
||||||
|
const std::string& pipeline_name) {
|
||||||
|
for (HloModule* module : module_group.modules()) {
|
||||||
|
RecordPassStartMetadata(*module, pass_name, pipeline_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AttemptRecordPassEndMetadata(HloModule& module,
|
||||||
|
const std::string& pass_name,
|
||||||
|
bool module_changed) {
|
||||||
|
// Module id is set here instead of RecordPassStartMetadata because it may
|
||||||
|
// change in the middle of the pass, and we want the final id.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
module.metadata()->set_current_pass_module_id(module.unique_id()));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
module.metadata()->set_current_pass_module_changed(module_changed));
|
||||||
|
TF_RETURN_IF_ERROR(module.metadata()->RecordPassEnd());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecordPassEndMetadata(HloModule& module, const std::string& pass_name,
|
||||||
|
bool module_changed) {
|
||||||
|
Status status =
|
||||||
|
AttemptRecordPassEndMetadata(module, pass_name, module_changed);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL) << status.error_message();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AttemptRecordPassEndMetadata(HloModuleGroup& module_group,
|
||||||
|
const std::string& pass_name,
|
||||||
|
bool module_changed) {
|
||||||
|
for (HloModule* module : module_group.modules()) {
|
||||||
|
for (HloModule* other_module : module_group.modules()) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
module->metadata()->add_current_pass_module_group_module_id(
|
||||||
|
other_module->unique_id()));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
AttemptRecordPassEndMetadata(*module, pass_name, module_changed));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecordPassEndMetadata(HloModuleGroup& module_group,
|
||||||
|
const std::string& pass_name, bool module_changed) {
|
||||||
|
Status status =
|
||||||
|
AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL) << status.error_message();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
template <typename HloT>
|
template <typename HloT>
|
||||||
Status HloPassPipeline::RunInvariantCheckers(
|
Status HloPassPipeline::RunInvariantCheckers(
|
||||||
HloT* hlo, absl::string_view after_pass_name) {
|
HloT* hlo, absl::string_view after_pass_name) {
|
||||||
@ -54,34 +120,48 @@ Status HloPassPipeline::RunInvariantCheckers(
|
|||||||
template <typename HloT>
|
template <typename HloT>
|
||||||
StatusOr<bool> HloPassPipeline::RunPassesInternal(
|
StatusOr<bool> HloPassPipeline::RunPassesInternal(
|
||||||
HloT* hlo, absl::Span<HloPassInterface* const> passes) {
|
HloT* hlo, absl::Span<HloPassInterface* const> passes) {
|
||||||
string last_pass_name = "pipeline-start";
|
static constexpr absl::string_view kPipelineStart = "pipeline-start";
|
||||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
|
static constexpr absl::string_view kPipelineEnd = "pipeline-end";
|
||||||
|
std::string pipeline_name = std::string(name());
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
|
||||||
|
|
||||||
|
RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
|
||||||
|
MaybeDumpHloAndSaveFilenames(*hlo,
|
||||||
|
/*after_pass_name=*/kPipelineStart,
|
||||||
|
/*before_pass_name=*/passes.empty()
|
||||||
|
? kPipelineEnd
|
||||||
|
: passes.front()->name());
|
||||||
|
RecordPassEndMetadata(*hlo, std::string(kPipelineStart),
|
||||||
|
/*module_changed=*/false);
|
||||||
|
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
for (HloPassInterface* pass : passes) {
|
for (int i = 0; i < passes.size(); i++) {
|
||||||
|
HloPassInterface* pass = passes[i];
|
||||||
XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name()));
|
XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name()));
|
||||||
absl::string_view pass_name = pass->name();
|
std::string pass_name = std::string(pass->name());
|
||||||
VLOG(1) << " HLO pass " << pass_name;
|
VLOG(1) << " HLO pass " << pass_name;
|
||||||
VLOG(2) << " Module hash " << hlo->Hash();
|
VLOG(2) << " Module hash " << hlo->Hash();
|
||||||
MaybeDumpHlo(*hlo,
|
|
||||||
/*after_pass_name=*/last_pass_name,
|
|
||||||
/*before_pass_name=*/pass_name);
|
|
||||||
if (!pass->IsPassPipeline()) {
|
if (!pass->IsPassPipeline()) {
|
||||||
compilation_stats_->StartPass(pass_name);
|
compilation_stats_->StartPass(pass_name);
|
||||||
}
|
}
|
||||||
|
RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
|
||||||
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
|
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
|
||||||
|
MaybeDumpHloAndSaveFilenames(*hlo,
|
||||||
|
/*after_pass_name=*/pass_name,
|
||||||
|
/*before_pass_name=*/i + 1 >= passes.size()
|
||||||
|
? kPipelineEnd
|
||||||
|
: passes[i + 1]->name());
|
||||||
|
RecordPassEndMetadata(*hlo, pass_name, pass_changed);
|
||||||
changed |= pass_changed;
|
changed |= pass_changed;
|
||||||
if (pass_changed) {
|
if (pass_changed) {
|
||||||
VLOG(3) << " Pass caused changes" << pass->name();
|
VLOG(3) << " Pass caused changes" << pass->name();
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
|
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name));
|
||||||
last_pass_name = string(pass_name);
|
|
||||||
if (!pass->IsPassPipeline()) {
|
if (!pass->IsPassPipeline()) {
|
||||||
compilation_stats_->EndPass(pass_name);
|
compilation_stats_->EndPass(pass_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MaybeDumpHlo(*hlo,
|
|
||||||
/*after_pass_name=*/last_pass_name,
|
|
||||||
/*before_pass_name=*/"pipeline-end");
|
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,18 +209,23 @@ std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
|
|||||||
return enabled_passes;
|
return enabled_passes;
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
|
void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
|
||||||
absl::string_view after_pass_name,
|
HloModule& module, absl::string_view after_pass_name,
|
||||||
absl::string_view before_pass_name) {
|
absl::string_view before_pass_name) {
|
||||||
DumpHloModuleBetweenPassesIfEnabled(name(), before_pass_name, after_pass_name,
|
for (const std::string& filename : DumpHloModuleBetweenPassesIfEnabled(
|
||||||
module);
|
name(), before_pass_name, after_pass_name, module)) {
|
||||||
|
Status status = module.metadata()->add_current_pass_dump_filename(filename);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL) << status.error_message();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
|
void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
|
||||||
absl::string_view after_pass_name,
|
HloModuleGroup& module_group, absl::string_view after_pass_name,
|
||||||
absl::string_view before_pass_name) {
|
absl::string_view before_pass_name) {
|
||||||
for (const HloModule* module : module_group.modules()) {
|
for (HloModule* module : module_group.modules()) {
|
||||||
MaybeDumpHlo(*module, after_pass_name, before_pass_name);
|
MaybeDumpHloAndSaveFilenames(*module, after_pass_name, before_pass_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -90,11 +90,13 @@ class HloPassPipeline : public HloPassInterface {
|
|||||||
const DebugOptions& debug_options);
|
const DebugOptions& debug_options);
|
||||||
|
|
||||||
// Maybe dumps the given module or module group depending on flag values
|
// Maybe dumps the given module or module group depending on flag values
|
||||||
// contained in DebugOptions of module config.
|
// contained in DebugOptions of module config. If it is dumped, saves the
|
||||||
void MaybeDumpHlo(const HloModuleGroup& module_group,
|
// filenames of the dumps into module metadata.
|
||||||
|
void MaybeDumpHloAndSaveFilenames(HloModuleGroup& module_group,
|
||||||
absl::string_view after_pass_name,
|
absl::string_view after_pass_name,
|
||||||
absl::string_view before_pass_name);
|
absl::string_view before_pass_name);
|
||||||
void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
|
void MaybeDumpHloAndSaveFilenames(HloModule& module,
|
||||||
|
absl::string_view after_pass_name,
|
||||||
absl::string_view before_pass_name);
|
absl::string_view before_pass_name);
|
||||||
|
|
||||||
// Runs the invariant checker on the given HLO. HloT can be either HloModule
|
// Runs the invariant checker on the given HLO. HloT can be either HloModule
|
||||||
|
|||||||
@ -26,6 +26,10 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::SizeIs;
|
||||||
|
using ::testing::StrEq;
|
||||||
|
|
||||||
class HloPassPipelineTest : public HloTestBase {
|
class HloPassPipelineTest : public HloTestBase {
|
||||||
protected:
|
protected:
|
||||||
StatusOr<HloModuleGroup> ParseModuleGroup(
|
StatusOr<HloModuleGroup> ParseModuleGroup(
|
||||||
@ -255,5 +259,43 @@ ENTRY main {
|
|||||||
::testing::HasSubstr("Module group pass cannot be run on a module"));
|
::testing::HasSubstr("Module group pass cannot be run on a module"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that metadata is set when a module group goes through a pass pipeline.
|
||||||
|
TEST_F(HloPassPipelineTest, SetHloModuleMetadata) {
|
||||||
|
HloModuleGroup module_group(TestName());
|
||||||
|
module_group.push_back(CreateNewVerifiedModule());
|
||||||
|
module_group.push_back(CreateNewVerifiedModule());
|
||||||
|
|
||||||
|
HloPassPipeline pipeline(TestName());
|
||||||
|
pipeline.AddPass<BazToQuxModuleGroupPass>();
|
||||||
|
pipeline.AddPass<FooToBarModulePass>();
|
||||||
|
TF_ASSERT_OK(pipeline.RunOnModuleGroup(&module_group).status());
|
||||||
|
ASSERT_THAT(module_group.modules(), SizeIs(2));
|
||||||
|
|
||||||
|
std::vector<std::string> pass_names = {"pipeline-start", "baz2qux",
|
||||||
|
"foo2bar"};
|
||||||
|
std::string pipeline_name = std::string(pipeline.name());
|
||||||
|
for (const HloModule* module : module_group.modules()) {
|
||||||
|
const HloModuleMetadataProto& metadata = module->metadata().proto();
|
||||||
|
EXPECT_EQ(metadata.canonical_module_id(), module->unique_id());
|
||||||
|
EXPECT_EQ(metadata.module_group_name(), module_group.name());
|
||||||
|
|
||||||
|
ASSERT_THAT(metadata.pass_metadata(), SizeIs(3));
|
||||||
|
for (int pass = 0; pass < metadata.pass_metadata().size(); pass++) {
|
||||||
|
const HloPassMetadata& pass_metadata = metadata.pass_metadata(pass);
|
||||||
|
EXPECT_NE(pass_metadata.pass_id(), 0);
|
||||||
|
EXPECT_THAT(pass_metadata.pass_name(), StrEq(pass_names[pass]));
|
||||||
|
EXPECT_THAT(pass_metadata.pipeline_name(), StrEq(pipeline_name));
|
||||||
|
EXPECT_FALSE(pass_metadata.module_changed());
|
||||||
|
EXPECT_EQ(pass_metadata.module_id(), module->unique_id());
|
||||||
|
EXPECT_THAT(pass_metadata.module_group_module_ids(),
|
||||||
|
ElementsAre(module_group.module(0).unique_id(),
|
||||||
|
module_group.module(1).unique_id()));
|
||||||
|
EXPECT_GT(pass_metadata.start_timestamp_usec(), 0);
|
||||||
|
EXPECT_LT(pass_metadata.start_timestamp_usec(),
|
||||||
|
pass_metadata.end_timestamp_usec());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -261,6 +261,9 @@ message DebugOptions {
|
|||||||
// Max number of hlo module dumps in a directory. Set to < 0 for unbounded.
|
// Max number of hlo module dumps in a directory. Set to < 0 for unbounded.
|
||||||
int32 xla_dump_max_hlo_modules = 132;
|
int32 xla_dump_max_hlo_modules = 132;
|
||||||
|
|
||||||
|
// Dump HloModuleMetadata as a text proto for each HLO module.
|
||||||
|
bool xla_dump_module_metadata = 144;
|
||||||
|
|
||||||
//
|
//
|
||||||
// END flags controlling dumping HLO modules.
|
// END flags controlling dumping HLO modules.
|
||||||
//
|
//
|
||||||
@ -297,7 +300,7 @@ message DebugOptions {
|
|||||||
// Enable detailed logging into vlog.
|
// Enable detailed logging into vlog.
|
||||||
bool xla_detailed_logging = 143;
|
bool xla_detailed_logging = 143;
|
||||||
|
|
||||||
// Next id: 144
|
// Next id: 145
|
||||||
|
|
||||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||||
// interpretation of these values is left to the backend.
|
// interpretation of these values is left to the backend.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user