[XLA] Add creation_pass_id and dummy op_names to OpMetadata.

PiperOrigin-RevId: 347108311
Change-Id: Iaeb7cea0c049e5f538557b1103972705e1a3f0de
This commit is contained in:
A. Unique TensorFlower 2020-12-11 17:42:58 -08:00 committed by TensorFlower Gardener
parent 0d882ea469
commit 9dc5df5f0f
10 changed files with 88 additions and 15 deletions

View File

@ -144,6 +144,7 @@ bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
}
return false;
}
} // namespace
namespace internal {

View File

@ -324,7 +324,7 @@ TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
input = f32[1,17,9,9] parameter(0)
filter = f32[3,3,17,32] parameter(1)
conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo"}
conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo" op_name="bar"}
ROOT relu = f32[1,32,9,9] maximum(zeros, conv)
})";
@ -337,9 +337,9 @@ TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
backend().default_stream_executor(), backend().memory_allocator())
.ConsumeValueOrDie()
->ToString();
EXPECT_THAT(
optimized_hlo_string,
::testing::ContainsRegex(R"(custom-call.*metadata=\{op_type="foo"\})"));
EXPECT_THAT(optimized_hlo_string,
::testing::ContainsRegex(
R"(custom-call.*metadata=\{op_type="foo" op_name="bar"\})"));
}
TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {

View File

@ -918,7 +918,11 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
// function, and that they would be correlated to the same TF op. This might
// not always be correct since HLO optimizations can cross TF op boundaries.
// But still this seems to be better than nothing.
if (new_instruction->metadata().op_name().empty()) {
bool overwrite_dummy_name =
absl::StartsWith(new_instruction->metadata().op_name(), "DUMMY") &&
!old_instruction->metadata().op_name().empty() &&
!absl::StartsWith(old_instruction->metadata().op_name(), "DUMMY");
if (new_instruction->metadata().op_name().empty() || overwrite_dummy_name) {
new_instruction->set_metadata(old_instruction->metadata());
}
if (new_instruction->frontend_attributes().map().empty()) {

View File

@ -1601,8 +1601,19 @@ class HloInstruction {
const PrecisionConfig& precision_config() const;
PrecisionConfig* mutable_precision_config();
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
// Sets the debug metadata for this instruction, excluding creation_pass_id,
// which should never be copied anywhere.
void set_metadata(const OpMetadata& metadata) {
int64 creation_pass_id = metadata_.creation_pass_id();
metadata_ = metadata;
metadata_.set_creation_pass_id(creation_pass_id);
}
void set_creation_pass_id(int64 pass_id) {
metadata_.set_creation_pass_id(pass_id);
}
void set_metadata_op_name(const std::string& name) {
metadata_.set_op_name(name);
}
const OpMetadata& metadata() const { return metadata_; }
// Set/get the computation containing this instruction. set_parent should only

View File

@ -61,6 +61,12 @@ class HloModuleMetadata {
module_metadata_.add_partitioned_module_ids(id);
}
StatusOr<int64> current_pass_id() {
TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
GetCurrentHloPassMetadata());
return pass_metadata->pass_id();
}
// Setters for the current HloPassMetadata.
Status set_current_pass_name(const std::string& pass_name) {
return MutateCurrentHloPassMetadata(

View File

@ -67,7 +67,7 @@ void RecordPassEndMetadata(HloModule& module, const std::string& pass_name,
Status status =
AttemptRecordPassEndMetadata(module, pass_name, module_changed);
if (!status.ok()) {
LOG(FATAL) << status.error_message();
LOG(FATAL) << status;
}
}
@ -91,7 +91,30 @@ void RecordPassEndMetadata(HloModuleGroup& module_group,
Status status =
AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
if (!status.ok()) {
LOG(FATAL) << status.error_message();
LOG(FATAL) << status;
}
}
void SetInstructionMetadata(HloModule& module) {
StatusOr<int64> pass_id = module.metadata()->current_pass_id();
if (!pass_id.ok()) {
LOG(FATAL) << pass_id.status();
}
for (xla::HloComputation* computation : module.computations()) {
for (xla::HloInstruction* instruction : computation->instructions()) {
if (instruction->metadata().creation_pass_id() == 0) {
instruction->set_creation_pass_id(*pass_id);
}
if (instruction->metadata().op_name().empty()) {
instruction->set_metadata_op_name(absl::StrCat("DUMMY_", *pass_id));
}
}
}
}
void SetInstructionMetadata(HloModuleGroup& module_group) {
for (HloModule* module : module_group.modules()) {
SetInstructionMetadata(*module);
}
}
@ -127,6 +150,7 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal(
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
SetInstructionMetadata(*hlo);
MaybeDumpHloAndSaveFilenames(*hlo,
/*after_pass_name=*/kPipelineStart,
/*before_pass_name=*/passes.empty()
@ -147,6 +171,7 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal(
}
RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
SetInstructionMetadata(*hlo);
MaybeDumpHloAndSaveFilenames(*hlo,
/*after_pass_name=*/pass_name,
/*before_pass_name=*/i + 1 >= passes.size()
@ -216,7 +241,7 @@ void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
name(), before_pass_name, after_pass_name, module)) {
Status status = module.metadata()->add_current_pass_dump_filename(filename);
if (!status.ok()) {
LOG(FATAL) << status.error_message();
LOG(FATAL) << status;
}
}
}

View File

@ -4,7 +4,7 @@ HloModule Add
ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] {
%x = f32[2,2,2]{2,1,0} parameter(0)
%y = f32[2,2,2]{2,1,0} parameter(1)
ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y)
ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y), metadata={op_name="original_tf_op"}
}
// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: lmhlo.add; failed for testing: std.return]
// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y), metadata={op_name="original_tf_op"}: failed for testing: lmhlo.add; failed for testing: std.return]

View File

@ -22,6 +22,9 @@ limitations under the License.
namespace xla {
namespace {
using ::testing::StartsWith;
using ::testing::StrEq;
class HloMetadataTest : public LocalClientTestBase {
protected:
HloMetadataTest() {
@ -79,9 +82,8 @@ TEST_F(HloMetadataTest, MetadataClearing) {
->module()
.entry_computation()
->root_instruction();
// We expect these to be empty (no metadata set).
EXPECT_EQ("", instruction->metadata().op_type());
EXPECT_EQ("", instruction->metadata().op_name());
EXPECT_THAT(instruction->metadata().op_type(), StrEq(""));
EXPECT_THAT(instruction->metadata().op_name(), StartsWith("DUMMY"));
}
} // namespace

View File

@ -25,6 +25,20 @@ limitations under the License.
namespace xla {
namespace {
void RemoveDummyMetadataNames(HloModule* module) {
for (xla::HloComputation* computation : module->computations()) {
for (xla::HloInstruction* instruction : computation->instructions()) {
if (absl::StartsWith(instruction->metadata().op_name(), "DUMMY")) {
instruction->set_metadata_op_name("");
}
}
}
}
} // namespace
void LlvmIrGenTestBase::SetIrHook(bool match_optimized_ir) {
auto llvm_compiler = GetLLVMCompiler();
using std::placeholders::_1;
@ -88,6 +102,7 @@ void LlvmIrGenTestBase::MatchOptimizedHlo(absl::string_view hlo,
bool print_operand_shape) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
GetOptimizedModule(hlo));
RemoveDummyMetadataNames(optimized_module.get());
HloPrintOptions print_opts;
print_opts.set_print_operand_shape(print_operand_shape);
StatusOr<bool> filecheck_result =

View File

@ -271,6 +271,10 @@ message OpMetadata {
//
// This name is often unique within a computation. Note: some frameworks
// add auto-generated names if the user does not provide one.
//
// A dummy name may be assigned if op_name is empty in order to keep track of
// where op_name first became empty. Dummy names begin with "DUMMY_" and may
// include the current HloPassMetadata.pass_id.
string op_name = 2;
// Indicate a file and line that this op is associated to in a user's program.
//
@ -279,6 +283,11 @@ message OpMetadata {
int32 source_line = 4;
repeated ProfileType profile_type = 5;
// HloPassMetadata.pass_id of the pass that created this HLO instruction
// object. Should never be copied between HLO instructions. Zero if unset and
// -1 if the instruction was created before HLO passes began.
int64 creation_pass_id = 6;
}
// Profile data from the execution of a computation.