diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 757fe9dbe7e..8d78c88a0f6 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -42,6 +42,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_dump_hlo_as_html(false); opts.set_xla_dump_include_timestamp(true); opts.set_xla_dump_max_hlo_modules(-1); + opts.set_xla_dump_module_metadata(false); #ifdef INTEL_MKL opts.set_xla_cpu_use_mkl_dnn(true); #endif // INTEL_MKL @@ -230,7 +231,6 @@ static void AllocateFlags() { }; flag_objects = new std::vector(); - flag_objects->reserve(55); // Don't use an initializer list for initializing the vector; this would // create a temporary copy, and exceeds the stack space when compiling with // certain configurations. @@ -513,6 +513,12 @@ static void AllocateFlags() { flag_values->xla_dump_max_hlo_modules(), "Max number of hlo module dumps in a directory. Set to < 0 for " "unbounded.")); + flag_objects->push_back(tensorflow::Flag( + "xla_dump_module_metadata", + bool_setter_for(&DebugOptions::set_xla_dump_module_metadata), + flag_values->xla_dump_module_metadata(), + "Dumps HloModuleMetadata as text protos to the directory specified " + "by --xla_dump_to.")); flag_objects->push_back(tensorflow::Flag( "xla_hlo_graph_addresses", bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index edbc0078869..df2dc34be10 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -412,6 +412,7 @@ cc_library( "hlo_instruction.cc", "hlo_instructions.cc", "hlo_module.cc", + "hlo_module_metadata.cc", "hlo_opcode.cc", "hlo_schedule.cc", "hlo_sharding.cc", @@ -428,6 +429,7 @@ cc_library( "hlo_instruction.h", "hlo_instructions.h", "hlo_module.h", + "hlo_module_metadata.h", "hlo_opcode.h", "hlo_schedule.h", "hlo_sharding.h", @@ -3047,6 +3049,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_module_metadata_test", + srcs = ["hlo_module_metadata_test.cc"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/stream_executor/lib", + ], +) + cc_library( name = "buffer_value", srcs = ["buffer_value.cc"], diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc index 0afcc4cd961..9c2aa7fe4d0 100644 --- a/tensorflow/compiler/xla/service/dump.cc +++ b/tensorflow/compiler/xla/service/dump.cc @@ -45,7 +45,8 @@ struct CanonicalDebugOptions { dump_as_url(opts.xla_dump_hlo_as_url()), dump_snapshots(opts.xla_dump_hlo_snapshots()), dump_include_timestamp(opts.xla_dump_include_timestamp()), - dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()) { + dump_max_hlo_modules(opts.xla_dump_max_hlo_modules()), + dump_module_metadata(opts.xla_dump_module_metadata()) { // This constructor examines the values in `opts` and turns on other flags // based on what we think is the user's intent. To reduce confusion about // what was a user-specified value versus an extrapolated value, within this @@ -137,18 +138,20 @@ struct CanonicalDebugOptions { bool dump_snapshots; bool dump_include_timestamp; int64 dump_max_hlo_modules; + bool dump_module_metadata; }; -void DumpToFileInDirImpl(string_view filename, string_view contents, - const CanonicalDebugOptions& opts) { +absl::optional DumpToFileInDirImpl( + string_view filename, string_view contents, + const CanonicalDebugOptions& opts) { if (opts.dumping_to_stdout()) { LOG(ERROR) << "Refusing to write " << filename << " to stdout. Pass --xla_dump_to= to write to a file."; - return; + return absl::nullopt; } if (opts.dump_to.empty()) { - return; + return absl::nullopt; } const string& dir = opts.dump_to; @@ -164,7 +167,7 @@ void DumpToFileInDirImpl(string_view filename, string_view contents, if (!status.ok() && !env->IsDirectory(dir).ok()) { LOG(ERROR) << "Could not create directory " << dir << " for dumping XLA debug data: " << status; - return; + return absl::nullopt; } } @@ -181,7 +184,7 @@ void DumpToFileInDirImpl(string_view filename, string_view contents, LOG(ERROR) << "Have already dumped " << matches.size() << " modules, more than the limit of " << opts.dump_max_hlo_modules; - return; + return absl::nullopt; } } @@ -192,33 +195,42 @@ void DumpToFileInDirImpl(string_view filename, string_view contents, LOG(ERROR) << "Could not write XLA debug data to " << file_path << ": " << status; } + + return file_path; } -void DumpToFileInDirOrStdoutImpl(string_view filename, string_view contents, - const CanonicalDebugOptions& opts) { +absl::optional DumpToFileInDirOrStdoutImpl( + string_view filename, string_view contents, + const CanonicalDebugOptions& opts) { // Dump to stdout if that's called for. if (opts.dumping_to_stdout()) { std::cout << "*** Begin " << filename << " ***\n" << contents << "\n*** End " << filename << " ***" << std::endl; - return; + return absl::nullopt; } // Otherwise, dump to a file. - DumpToFileInDirImpl(filename, contents, opts); + return DumpToFileInDirImpl(filename, contents, opts); } -void DumpHloModuleImpl(const HloModule& module, - const BufferAssignment* buffer_assn, - const HloExecutionProfile* profile, string_view prefix, - string_view suffix, const CanonicalDebugOptions& opts) { +// Returns full file paths of all dumps of the module. +std::vector DumpHloModuleImpl(const HloModule& module, + const BufferAssignment* buffer_assn, + const HloExecutionProfile* profile, + string_view prefix, + string_view suffix, + const CanonicalDebugOptions& opts) { string filename = FilenameFor(module, prefix, suffix); + std::vector> file_paths; + if (opts.dump_as_text) { - DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"), module.ToString(), - opts); + file_paths.push_back(DumpToFileInDirOrStdoutImpl(StrCat(filename, ".txt"), + module.ToString(), opts)); if (buffer_assn) { - DumpToFileInDirOrStdoutImpl(StrCat(filename, "-buffer-assignment.txt"), - buffer_assn->ToString(), opts); + file_paths.push_back(DumpToFileInDirOrStdoutImpl( + StrCat(filename, "-buffer-assignment.txt"), buffer_assn->ToString(), + opts)); } } @@ -229,7 +241,8 @@ void DumpHloModuleImpl(const HloModule& module, if (!tensorflow::SerializeToStringDeterministic(module_proto, &pb)) { pb = "Failed to serialize HLO module proto."; } - DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts); + file_paths.push_back( + DumpToFileInDirImpl(StrCat(filename, ".hlo.pb"), pb, opts)); } auto render_graph = [&](RenderedGraphFormat format) { @@ -244,13 +257,15 @@ void DumpHloModuleImpl(const HloModule& module, }; if (opts.dump_as_dot) { - DumpToFileInDirImpl(StrFormat("%s.dot", filename), - render_graph(RenderedGraphFormat::kDot), opts); + file_paths.push_back( + DumpToFileInDirImpl(StrFormat("%s.dot", filename), + render_graph(RenderedGraphFormat::kDot), opts)); } if (opts.dump_as_html) { - DumpToFileInDirImpl(StrFormat("%s.html", filename), - render_graph(RenderedGraphFormat::kHtml), opts); + file_paths.push_back( + DumpToFileInDirImpl(StrFormat("%s.html", filename), + render_graph(RenderedGraphFormat::kHtml), opts)); } // Special case for rendering graphs as URLs. We'll dump them to a file @@ -259,9 +274,35 @@ void DumpHloModuleImpl(const HloModule& module, string url = render_graph(RenderedGraphFormat::kUrl); std::cout << filename << " --> " << url << std::endl; if (!opts.dumping_to_stdout()) { - DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts); + file_paths.push_back( + DumpToFileInDirImpl(StrFormat("%s.url", filename), url, opts)); } } + + std::vector dumped_file_paths; + for (const absl::optional& path : file_paths) { + if (path.has_value()) { + dumped_file_paths.push_back(*path); + } + } + return dumped_file_paths; +} + +void DumpHloModuleMetadata(const HloModuleMetadataProto& metadata, + const CanonicalDebugOptions& opts, + absl::flat_hash_set* dumped_module_ids) { + // Return if metadata for this module has already been dumped. + if (!dumped_module_ids->insert(metadata.canonical_module_id()).second) { + return; + } + std::string filename = absl::StrFormat("module_%04d.metadata.textproto", + metadata.canonical_module_id()); + std::string content; + if (tensorflow::protobuf::TextFormat::PrintToString(metadata, &content)) { + DumpToFileInDirImpl(filename, content, opts); + } else { + LOG(ERROR) << "Failed to convert HloModuleMetadataProto to text."; + } } static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); @@ -381,18 +422,17 @@ bool DumpingToStdout(const DebugOptions& opts) { return CanonicalDebugOptions(opts).dumping_to_stdout(); } -void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name, - string_view before_pass_name, - string_view after_pass_name, - const HloModule& module) { +std::vector DumpHloModuleBetweenPassesIfEnabled( + string_view pipeline_name, string_view before_pass_name, + string_view after_pass_name, const HloModule& module) { CanonicalDebugOptions opts(module.config().debug_options()); if (!opts.should_dump_module(module.name())) { - return; + return {}; } if (!opts.should_dump_pass(before_pass_name) && !opts.should_dump_pass(after_pass_name)) { - return; + return {}; } int64 step_number = StepNumberForModule(module); @@ -401,8 +441,8 @@ void DumpHloModuleBetweenPassesIfEnabled(string_view pipeline_name, string filename_suffix = StrFormat("%04d.%s.after_%s.before_%s", step_number, pipeline_name, after_pass_name, before_pass_name); - DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr, - timestamp, filename_suffix, opts); + return DumpHloModuleImpl(module, /*buffer_assn=*/nullptr, /*profile=*/nullptr, + timestamp, filename_suffix, opts); } void DumpHloModuleDuringPassIfEnabled(string_view pass_name, @@ -488,4 +528,21 @@ void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot, DumpToFileInDirImpl(filename, pb, canonical_opts); } +void DumpHloModuleMetadataIfEnabled(const std::vector& modules) { + absl::flat_hash_set dumped_module_ids; + for (const HloModule* module : modules) { + CanonicalDebugOptions opts(module->config().debug_options()); + if (!opts.dump_module_metadata) { + continue; + } + DumpHloModuleMetadata(module->metadata().proto(), opts, &dumped_module_ids); + const absl::optional& prepartitioning_metadata = + module->metadata().prepartitioning_metadata(); + if (prepartitioning_metadata.has_value()) { + DumpHloModuleMetadata(*prepartitioning_metadata, opts, + &dumped_module_ids); + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/dump.h b/tensorflow/compiler/xla/service/dump.h index 045acca2632..c787882170d 100644 --- a/tensorflow/compiler/xla/service/dump.h +++ b/tensorflow/compiler/xla/service/dump.h @@ -75,11 +75,11 @@ void DumpHloModuleIfEnabled(const HloModule& module, absl::string_view name); // Dumps the given HLO module after running one HLO pass and before running -// another, if that's enabled. -void DumpHloModuleBetweenPassesIfEnabled(absl::string_view pipeline_name, - absl::string_view before_pass_name, - absl::string_view after_pass_name, - const HloModule& module); +// another, if that's enabled. Returns the full file paths of all dumps of the +// module, or an empty vector if nothing was dumped. +std::vector DumpHloModuleBetweenPassesIfEnabled( + absl::string_view pipeline_name, absl::string_view before_pass_name, + absl::string_view after_pass_name, const HloModule& module); // Dumps the given HLO module during the given HLO pass, if that's enabled. // @@ -100,6 +100,8 @@ void DumpHloSnapshotIfEnabled(const HloModule& module, void DumpHloSnapshotIfEnabled(const HloSnapshot& snapshot, const DebugOptions& opts); +void DumpHloModuleMetadataIfEnabled(const std::vector& modules); + // Returns true if we should dump data for an HloModule. This is useful if you // want to check if DumpToFileInDir{,OrStdout} will do anything before // generating an expensive string. diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 8c9d3a606c6..66194b1d98f 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -528,3 +528,66 @@ message HloSnapshot { // The name of the platform used to run the graph. string execution_platform = 4; } + +// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering +// with filename module_####.metadata.textproto, where #### is +// canonical_module_id. +message HloModuleMetadataProto { + // Uniquely identifies an HloModuleMetadata. Equal to the first unique_id + // of the module (a module may go through multiple unique_ids). If a module + // is partitioned into multiple modules, those modules will each have a new + // HloModuleMetadata with a different canonical_module_id. + int64 canonical_module_id = 1; + + // Name of the module group that the module is part of. + string module_group_name = 2; + + // The canonical module id of the module that this one is partitioned from, + // if applicable. + int64 original_module_id = 3; + + // The canonical module ids of the modules that this one is partitioned into, + // if applicable. + repeated int64 partitioned_module_ids = 4; + + // Metadata for the HLO passes that are run on the module. + repeated HloPassMetadata pass_metadata = 5; +} + +// Metadata for one run of an HLO pass on a module. Provides more information +// when processing debug dumps of HloProtos about the order of HLO passes and +// various other stats like duration. `pass_id` may also be used to identify a +// particular run of a pass in debug info that propagates through stages of +// compilation. +message HloPassMetadata { + // For a given module, pass_id uniquely identifies a run of an HLO pass on + // that module. Note that a pass_id may not always refer to the same pass + // because the order of passes during compilation may change. For finding + // metadata for a particular pass, pass_name and pipeline_name would be more + // reliable, although note that they may not be unique. + int64 pass_id = 1; + string pass_name = 2; + string pipeline_name = 3; + + // Filenames of the dumps of the module after this pass ran. Module may be + // dumped in multiple formats, and the order of formats in this field will + // stay consistent across passes. + repeated string dump_filenames = 4; + + // Return value of pass.Run(). True if this pass changed the module, or, in + // the case where the module was run through this pass as part of a module + // group, true if this pass changed any module in the same module group. + bool module_changed = 5; + + // The unique_id of the module that this pass is run on. May be different from + // the canonical_module_id of the HloModuleMetadata that this HloPassMetadata + // is inside. + int64 module_id = 6; + + // If the module went through this pass as part of a module group, this is + // set as the ids of all the modules in the module group. Empty otherwise. + repeated int64 module_group_module_ids = 7; + + int64 start_timestamp_usec = 8; + int64 end_timestamp_usec = 9; +} diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 4a67c1d2146..961e945fd15 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -44,7 +44,10 @@ namespace xla { HloModule::HloModule(const string& name, HloModuleConfig config) : name_(NameUniquer::GetSanitizedName(name)), config_(std::move(config)), - unique_id_(next_unique_module_id_++) {} + unique_id_(next_unique_module_id_++), + metadata_(tensorflow::Env::Default()) { + metadata_.set_canonical_module_id(unique_id_); +} Status HloModule::set_schedule(HloSchedule schedule) { TF_RET_CHECK(schedule.module() == this); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index bfc4ddc7e22..1e10e24843d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_module_metadata.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" @@ -363,6 +364,16 @@ class HloModule { return cross_program_prefetches_; } + const HloModuleMetadata& metadata() const { return metadata_; } + HloModuleMetadata* metadata() { return &metadata_; } + + // Moves (not copies) metadata from this HloModule to `module`. To be used + // in cases like HloModuleGroup::ReplaceModule when metadata should be + // transferred out of a module before it's destroyed. + void MoveMetadataToModule(HloModule* module) { + module->metadata_ = std::move(metadata_); + } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -413,6 +424,9 @@ class HloModule { // Arguments to be prefetched across programs. std::vector> cross_program_prefetches_; + + // Metadata for this module, such as its canonical id and the HLO passes run. + HloModuleMetadata metadata_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc index 8c49fd2d4f4..eb70fcdf79f 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -96,12 +96,14 @@ uint64 HloModuleGroup::Hash() const { } void HloModuleGroup::push_back(std::unique_ptr module) { + module->metadata()->set_module_group_name(name()); modules_.push_back(std::move(module)); module_ptrs_.push_back(modules_.back().get()); } void HloModuleGroup::ReplaceModule(int index, std::unique_ptr module) { + modules_.at(index)->MoveMetadataToModule(module.get()); modules_.at(index) = std::move(module); module_ptrs_.at(index) = modules_.at(index).get(); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc index 1b26451e6e4..9958b4283c5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -27,6 +27,8 @@ namespace xla { namespace { namespace op = ::xla::testing::opcode_matchers; +using ::testing::Property; +using ::testing::StrEq; class HloModuleGroupTest : public HloTestBase { protected: @@ -202,6 +204,31 @@ ENTRY entry { } } +// Test that metadata is transferred when a module is replaced. +TEST_F(HloModuleGroupTest, ReplaceModuleMetadata) { + auto old_module = CreateNewVerifiedModule(); + int old_module_id = old_module->unique_id(); + old_module->metadata()->RecordPassStart(); + TF_EXPECT_OK(old_module->metadata()->set_current_pass_name("fake pass")); + + HloModuleGroup group(std::move(old_module)); + EXPECT_EQ(group.module(0).metadata()->proto().module_group_name(), + group.name()); + + auto new_module = CreateNewVerifiedModule(); + group.ReplaceModule(0, std::move(new_module)); + + EXPECT_NE(group.module(0).unique_id(), old_module_id); + const HloModuleMetadataProto& module_metadata = + group.module(0).metadata()->proto(); + EXPECT_EQ(module_metadata.canonical_module_id(), old_module_id); + + const HloPassMetadata& pass_metadata = + *module_metadata.pass_metadata().rbegin(); + EXPECT_THAT(pass_metadata, + Property(&HloPassMetadata::pass_name, StrEq("fake pass"))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_metadata.cc new file mode 100644 index 00000000000..772326744a7 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_metadata.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_metadata.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/platform/env.h" + +namespace xla { + +StatusOr HloModuleMetadata::GetCurrentHloPassMetadata() { + if (running_passes_.empty()) { + return NotFound( + "HloPassMetadata for currently running pass not found, either because " + "the pass did not call RecordPassStart or because a pass is " + "creating/switching modules without using " + "HloModuleGroup::ReplaceModule."); + } + return running_passes_.back(); +} + +Status HloModuleMetadata::MutateCurrentHloPassMetadata( + const std::function& mutator) { + TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata, + GetCurrentHloPassMetadata()); + mutator(pass_metadata); + return Status::OK(); +} + +void HloModuleMetadata::RecordPassStart() { + HloPassMetadata* pass_metadata = module_metadata_.add_pass_metadata(); + pass_metadata->set_pass_id(next_pass_id_++); + pass_metadata->set_start_timestamp_usec(env_->NowMicros()); + running_passes_.push_back(pass_metadata); +} + +Status HloModuleMetadata::RecordPassEnd() { + TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata, + GetCurrentHloPassMetadata()); + pass_metadata->set_end_timestamp_usec(env_->NowMicros()); + running_passes_.pop_back(); + return Status::OK(); +} + +void HloModuleMetadata::set_prepartitioning_metadata( + const HloModuleMetadata& prepartitioning_metadata) { + module_metadata_.set_original_module_id( + prepartitioning_metadata.proto().canonical_module_id()); + prepartitioning_metadata_ = prepartitioning_metadata.proto(); + prepartitioning_metadata_->clear_pass_metadata(); + + // Because HloPassMetadata represents the completion of a pass, metadata for + // all currently running passes need to be moved over to the new module. + absl::flat_hash_set running_passes( + prepartitioning_metadata.running_passes_.begin(), + prepartitioning_metadata.running_passes_.end()); + for (const HloPassMetadata& pass_metadata : + prepartitioning_metadata.proto().pass_metadata()) { + if (running_passes.contains(&pass_metadata)) { + HloPassMetadata* added_pass_metadata = + module_metadata_.add_pass_metadata(); + *added_pass_metadata = pass_metadata; + running_passes_.push_back(added_pass_metadata); + next_pass_id_ = + std::max(next_pass_id_, + static_cast(added_pass_metadata->pass_id()) + 1); + } else { + *prepartitioning_metadata_->add_pass_metadata() = pass_metadata; + } + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_metadata.h b/tensorflow/compiler/xla/service/hlo_module_metadata.h new file mode 100644 index 00000000000..434e3bb0a26 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_metadata.h @@ -0,0 +1,126 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/env.h" + +namespace xla { + +// Wrapper class for HloModuleMetadataProto to avoid allowing callers to mutate +// arbitrary fields. Specifically, callers cannot set timestamps or ids or +// set the fields of any pass not currently running. +class HloModuleMetadata { + public: + explicit HloModuleMetadata(tensorflow::Env* env) : env_(env) {} + + const HloModuleMetadataProto& proto() const { return module_metadata_; } + + // Creates a new HloPassMetadata. All calls to RecordPassStart should be + // matched by a later call to RecordPassEnd. + void RecordPassStart(); + + // Marks the currently running pass as finished. Returns NotFound if metadata + // for the currently running pass cannot be found. + Status RecordPassEnd(); + + const absl::optional& prepartitioning_metadata() + const { + return prepartitioning_metadata_; + } + void set_prepartitioning_metadata( + const HloModuleMetadata& prepartitioning_metadata); + + // Setters for HloModuleMetadataProto. + void set_module_group_name(const std::string& name) { + module_metadata_.set_module_group_name(name); + } + void set_canonical_module_id(int64 id) { + module_metadata_.set_canonical_module_id(id); + } + void add_partitioned_module_id(int64 id) { + module_metadata_.add_partitioned_module_ids(id); + } + + // Setters for the current HloPassMetadata. + Status set_current_pass_name(const std::string& pass_name) { + return MutateCurrentHloPassMetadata( + [&pass_name](HloPassMetadata* pass_metadata) { + pass_metadata->set_pass_name(pass_name); + }); + } + Status set_current_pass_pipeline_name(const std::string& pipeline_name) { + return MutateCurrentHloPassMetadata( + [&pipeline_name](HloPassMetadata* pass_metadata) { + pass_metadata->set_pipeline_name(pipeline_name); + }); + } + Status add_current_pass_dump_filename(const std::string& dump_filename) { + return MutateCurrentHloPassMetadata( + [&dump_filename](HloPassMetadata* pass_metadata) { + pass_metadata->add_dump_filenames(dump_filename); + }); + } + Status set_current_pass_module_changed(bool module_changed) { + return MutateCurrentHloPassMetadata( + [&module_changed](HloPassMetadata* pass_metadata) { + pass_metadata->set_module_changed(module_changed); + }); + } + Status set_current_pass_module_id(int64 module_id) { + return MutateCurrentHloPassMetadata( + [&module_id](HloPassMetadata* pass_metadata) { + pass_metadata->set_module_id(module_id); + }); + } + Status add_current_pass_module_group_module_id(int64 module_id) { + return MutateCurrentHloPassMetadata( + [&module_id](HloPassMetadata* pass_metadata) { + pass_metadata->add_module_group_module_ids(module_id); + }); + } + + private: + // Gets mutable metadata for the currently running pass. If passes are nested, + // finds the deepest one still running. Returns NotFound if metadata for the + // currently running pass cannot be found. + StatusOr GetCurrentHloPassMetadata(); + + Status MutateCurrentHloPassMetadata( + const std::function& mutator); + + HloModuleMetadataProto module_metadata_; + tensorflow::Env* env_; + int64 next_pass_id_ = 1; + + // Stack of metadata for passes that are currently running. Size > 1 iff + // passes are nested. + std::vector running_passes_; + + // Metadata from before the module was partitioned, if applicable. + absl::optional prepartitioning_metadata_ = + absl::nullopt; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_metadata_test.cc b/tensorflow/compiler/xla/service/hlo_module_metadata_test.cc new file mode 100644 index 00000000000..a426ad83f9f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_metadata_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_metadata.h" + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace { + +using ::testing::ElementsAre; +using ::testing::Property; +using ::testing::StrEq; + +class TestEnv : public tensorflow::EnvWrapper { + public: + TestEnv() : EnvWrapper(Env::Default()) {} + + uint64 NowMicros() const override { return current_micros_; } + + void SetCurrentMicros(uint64 micros) { current_micros_ = micros; } + + private: + uint64 current_micros_ = 1; +}; + +TEST(HloModuleMetadata, RecordsPassStart) { + TestEnv env; + HloModuleMetadata module_metadata(&env); + env.SetCurrentMicros(1234); + module_metadata.RecordPassStart(); + EXPECT_THAT( + module_metadata.proto().pass_metadata(), + ElementsAre(Property(&HloPassMetadata::start_timestamp_usec, 1234))); +} + +TEST(HloModuleMetadata, RecordsPassEnd) { + TestEnv env; + HloModuleMetadata module_metadata(&env); + module_metadata.RecordPassStart(); + env.SetCurrentMicros(4321); + EXPECT_IS_OK(module_metadata.RecordPassEnd()); + EXPECT_THAT( + module_metadata.proto().pass_metadata(), + ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 4321))); +} + +TEST(HloModuleMetadata, RecordsPassEndInNestedMetadata) { + TestEnv env; + HloModuleMetadata module_metadata(&env); + module_metadata.RecordPassStart(); + module_metadata.RecordPassStart(); + env.SetCurrentMicros(111); + EXPECT_IS_OK(module_metadata.RecordPassEnd()); + EXPECT_THAT(module_metadata.proto().pass_metadata(), + ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 0), + Property(&HloPassMetadata::end_timestamp_usec, 111))); + + env.SetCurrentMicros(222); + EXPECT_IS_OK(module_metadata.RecordPassEnd()); + EXPECT_THAT(module_metadata.proto().pass_metadata(), + ElementsAre(Property(&HloPassMetadata::end_timestamp_usec, 222), + Property(&HloPassMetadata::end_timestamp_usec, 111))); +} + +TEST(HloModuleMetadata, RecordPassEndReturnsNotFound) { + HloModuleMetadata module_metadata(tensorflow::Env::Default()); + EXPECT_EQ(module_metadata.RecordPassEnd().code(), + tensorflow::error::NOT_FOUND); + + module_metadata.RecordPassStart(); + EXPECT_IS_OK(module_metadata.RecordPassEnd()); + EXPECT_EQ(module_metadata.RecordPassEnd().code(), + tensorflow::error::NOT_FOUND); +} + +TEST(HloModuleMetadata, SetsHloPassMetadataFields) { + HloModuleMetadata module_metadata(tensorflow::Env::Default()); + module_metadata.RecordPassStart(); + EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name")); + EXPECT_THAT( + module_metadata.proto().pass_metadata(), + ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("fake name")))); +} + +TEST(HloModuleMetadata, SetsHloPassMetadataFieldsInNestedMetadata) { + HloModuleMetadata module_metadata(tensorflow::Env::Default()); + module_metadata.RecordPassStart(); + module_metadata.RecordPassStart(); + EXPECT_IS_OK(module_metadata.set_current_pass_name("fake name")); + EXPECT_THAT( + module_metadata.proto().pass_metadata(), + ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("")), + Property(&HloPassMetadata::pass_name, StrEq("fake name")))); +} + +TEST(HloModuleMetadata, SetterReturnsNotFound) { + HloModuleMetadata module_metadata(tensorflow::Env::Default()); + EXPECT_EQ(module_metadata.set_current_pass_name("fake name").code(), + tensorflow::error::NOT_FOUND); +} + +TEST(HloModuleMetadata, CopiesRunningPrepartitioningPasses) { + HloModuleMetadata old_module_metadata(tensorflow::Env::Default()); + old_module_metadata.RecordPassStart(); + EXPECT_IS_OK(old_module_metadata.set_current_pass_name("outer pass")); + + old_module_metadata.RecordPassStart(); + EXPECT_IS_OK(old_module_metadata.set_current_pass_name("finished pass")); + EXPECT_IS_OK(old_module_metadata.RecordPassEnd()); + + old_module_metadata.RecordPassStart(); + EXPECT_IS_OK(old_module_metadata.set_current_pass_name("inner pass")); + + HloModuleMetadata new_module_metadata(tensorflow::Env::Default()); + new_module_metadata.set_prepartitioning_metadata(old_module_metadata); + + // Passes that are still running go in the new module. + EXPECT_THAT( + new_module_metadata.proto().pass_metadata(), + ElementsAre(Property(&HloPassMetadata::pass_name, StrEq("outer pass")), + Property(&HloPassMetadata::pass_name, StrEq("inner pass")))); + + // Passes that finished go in the prepartitioning metadata. + EXPECT_THAT(new_module_metadata.prepartitioning_metadata()->pass_metadata(), + ElementsAre(Property(&HloPassMetadata::pass_name, + StrEq("finished pass")))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 74c385f16bd..41f907fc85e 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -31,6 +31,72 @@ limitations under the License. namespace xla { +namespace { + +void RecordPassStartMetadata(HloModule& module, const std::string& pass_name, + const std::string& pipeline_name) { + module.metadata()->RecordPassStart(); + // An HloPassMetadata was just created so Status should always be OK. + TF_CHECK_OK(module.metadata()->set_current_pass_name(pass_name)); + TF_CHECK_OK(module.metadata()->set_current_pass_pipeline_name(pipeline_name)); +} + +void RecordPassStartMetadata(HloModuleGroup& module_group, + const std::string& pass_name, + const std::string& pipeline_name) { + for (HloModule* module : module_group.modules()) { + RecordPassStartMetadata(*module, pass_name, pipeline_name); + } +} + +Status AttemptRecordPassEndMetadata(HloModule& module, + const std::string& pass_name, + bool module_changed) { + // Module id is set here instead of RecordPassStartMetadata because it may + // change in the middle of the pass, and we want the final id. + TF_RETURN_IF_ERROR( + module.metadata()->set_current_pass_module_id(module.unique_id())); + TF_RETURN_IF_ERROR( + module.metadata()->set_current_pass_module_changed(module_changed)); + TF_RETURN_IF_ERROR(module.metadata()->RecordPassEnd()); + return Status::OK(); +} + +void RecordPassEndMetadata(HloModule& module, const std::string& pass_name, + bool module_changed) { + Status status = + AttemptRecordPassEndMetadata(module, pass_name, module_changed); + if (!status.ok()) { + LOG(FATAL) << status.error_message(); + } +} + +Status AttemptRecordPassEndMetadata(HloModuleGroup& module_group, + const std::string& pass_name, + bool module_changed) { + for (HloModule* module : module_group.modules()) { + for (HloModule* other_module : module_group.modules()) { + TF_RETURN_IF_ERROR( + module->metadata()->add_current_pass_module_group_module_id( + other_module->unique_id())); + } + TF_RETURN_IF_ERROR( + AttemptRecordPassEndMetadata(*module, pass_name, module_changed)); + } + return Status::OK(); +} + +void RecordPassEndMetadata(HloModuleGroup& module_group, + const std::string& pass_name, bool module_changed) { + Status status = + AttemptRecordPassEndMetadata(module_group, pass_name, module_changed); + if (!status.ok()) { + LOG(FATAL) << status.error_message(); + } +} + +} // namespace + template Status HloPassPipeline::RunInvariantCheckers( HloT* hlo, absl::string_view after_pass_name) { @@ -54,34 +120,48 @@ Status HloPassPipeline::RunInvariantCheckers( template StatusOr HloPassPipeline::RunPassesInternal( HloT* hlo, absl::Span passes) { - string last_pass_name = "pipeline-start"; - TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name)); + static constexpr absl::string_view kPipelineStart = "pipeline-start"; + static constexpr absl::string_view kPipelineEnd = "pipeline-end"; + std::string pipeline_name = std::string(name()); + + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart)); + + RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name); + MaybeDumpHloAndSaveFilenames(*hlo, + /*after_pass_name=*/kPipelineStart, + /*before_pass_name=*/passes.empty() + ? kPipelineEnd + : passes.front()->name()); + RecordPassEndMetadata(*hlo, std::string(kPipelineStart), + /*module_changed=*/false); + bool changed = false; - for (HloPassInterface* pass : passes) { + for (int i = 0; i < passes.size(); i++) { + HloPassInterface* pass = passes[i]; XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name())); - absl::string_view pass_name = pass->name(); + std::string pass_name = std::string(pass->name()); VLOG(1) << " HLO pass " << pass_name; VLOG(2) << " Module hash " << hlo->Hash(); - MaybeDumpHlo(*hlo, - /*after_pass_name=*/last_pass_name, - /*before_pass_name=*/pass_name); if (!pass->IsPassPipeline()) { compilation_stats_->StartPass(pass_name); } + RecordPassStartMetadata(*hlo, pass_name, pipeline_name); TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo)); + MaybeDumpHloAndSaveFilenames(*hlo, + /*after_pass_name=*/pass_name, + /*before_pass_name=*/i + 1 >= passes.size() + ? kPipelineEnd + : passes[i + 1]->name()); + RecordPassEndMetadata(*hlo, pass_name, pass_changed); changed |= pass_changed; if (pass_changed) { VLOG(3) << " Pass caused changes" << pass->name(); } TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass_name)); - last_pass_name = string(pass_name); if (!pass->IsPassPipeline()) { compilation_stats_->EndPass(pass_name); } } - MaybeDumpHlo(*hlo, - /*after_pass_name=*/last_pass_name, - /*before_pass_name=*/"pipeline-end"); return changed; } @@ -129,18 +209,23 @@ std::vector HloPassPipeline::GetEnabledPasses( return enabled_passes; } -void HloPassPipeline::MaybeDumpHlo(const HloModule& module, - absl::string_view after_pass_name, - absl::string_view before_pass_name) { - DumpHloModuleBetweenPassesIfEnabled(name(), before_pass_name, after_pass_name, - module); +void HloPassPipeline::MaybeDumpHloAndSaveFilenames( + HloModule& module, absl::string_view after_pass_name, + absl::string_view before_pass_name) { + for (const std::string& filename : DumpHloModuleBetweenPassesIfEnabled( + name(), before_pass_name, after_pass_name, module)) { + Status status = module.metadata()->add_current_pass_dump_filename(filename); + if (!status.ok()) { + LOG(FATAL) << status.error_message(); + } + } } -void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group, - absl::string_view after_pass_name, - absl::string_view before_pass_name) { - for (const HloModule* module : module_group.modules()) { - MaybeDumpHlo(*module, after_pass_name, before_pass_name); +void HloPassPipeline::MaybeDumpHloAndSaveFilenames( + HloModuleGroup& module_group, absl::string_view after_pass_name, + absl::string_view before_pass_name) { + for (HloModule* module : module_group.modules()) { + MaybeDumpHloAndSaveFilenames(*module, after_pass_name, before_pass_name); } } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 72549aaa681..13086880b21 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -90,12 +90,14 @@ class HloPassPipeline : public HloPassInterface { const DebugOptions& debug_options); // Maybe dumps the given module or module group depending on flag values - // contained in DebugOptions of module config. - void MaybeDumpHlo(const HloModuleGroup& module_group, - absl::string_view after_pass_name, - absl::string_view before_pass_name); - void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name, - absl::string_view before_pass_name); + // contained in DebugOptions of module config. If it is dumped, saves the + // filenames of the dumps into module metadata. + void MaybeDumpHloAndSaveFilenames(HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name); + void MaybeDumpHloAndSaveFilenames(HloModule& module, + absl::string_view after_pass_name, + absl::string_view before_pass_name); // Runs the invariant checker on the given HLO. HloT can be either HloModule // or HloModuleGroup. diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc index 20384b9da6b..c1e7d9fe1fb 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -26,6 +26,10 @@ limitations under the License. namespace xla { namespace { +using ::testing::ElementsAre; +using ::testing::SizeIs; +using ::testing::StrEq; + class HloPassPipelineTest : public HloTestBase { protected: StatusOr ParseModuleGroup( @@ -255,5 +259,43 @@ ENTRY main { ::testing::HasSubstr("Module group pass cannot be run on a module")); } +// Test that metadata is set when a module group goes through a pass pipeline. +TEST_F(HloPassPipelineTest, SetHloModuleMetadata) { + HloModuleGroup module_group(TestName()); + module_group.push_back(CreateNewVerifiedModule()); + module_group.push_back(CreateNewVerifiedModule()); + + HloPassPipeline pipeline(TestName()); + pipeline.AddPass(); + pipeline.AddPass(); + TF_ASSERT_OK(pipeline.RunOnModuleGroup(&module_group).status()); + ASSERT_THAT(module_group.modules(), SizeIs(2)); + + std::vector pass_names = {"pipeline-start", "baz2qux", + "foo2bar"}; + std::string pipeline_name = std::string(pipeline.name()); + for (const HloModule* module : module_group.modules()) { + const HloModuleMetadataProto& metadata = module->metadata().proto(); + EXPECT_EQ(metadata.canonical_module_id(), module->unique_id()); + EXPECT_EQ(metadata.module_group_name(), module_group.name()); + + ASSERT_THAT(metadata.pass_metadata(), SizeIs(3)); + for (int pass = 0; pass < metadata.pass_metadata().size(); pass++) { + const HloPassMetadata& pass_metadata = metadata.pass_metadata(pass); + EXPECT_NE(pass_metadata.pass_id(), 0); + EXPECT_THAT(pass_metadata.pass_name(), StrEq(pass_names[pass])); + EXPECT_THAT(pass_metadata.pipeline_name(), StrEq(pipeline_name)); + EXPECT_FALSE(pass_metadata.module_changed()); + EXPECT_EQ(pass_metadata.module_id(), module->unique_id()); + EXPECT_THAT(pass_metadata.module_group_module_ids(), + ElementsAre(module_group.module(0).unique_id(), + module_group.module(1).unique_id())); + EXPECT_GT(pass_metadata.start_timestamp_usec(), 0); + EXPECT_LT(pass_metadata.start_timestamp_usec(), + pass_metadata.end_timestamp_usec()); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f20a8be8de6..4642a76043f 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -261,6 +261,9 @@ message DebugOptions { // Max number of hlo module dumps in a directory. Set to < 0 for unbounded. int32 xla_dump_max_hlo_modules = 132; + // Dump HloModuleMetadata as a text proto for each HLO module. + bool xla_dump_module_metadata = 144; + // // END flags controlling dumping HLO modules. // @@ -297,7 +300,7 @@ message DebugOptions { // Enable detailed logging into vlog. bool xla_detailed_logging = 143; - // Next id: 144 + // Next id: 145 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.