From 777b088450af840b914b909d950be29a75365e90 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Feb 2021 17:15:23 -0800 Subject: [PATCH] [XLA] Move HLO instruction OpMetadata dummy names to a new field. PiperOrigin-RevId: 356388571 Change-Id: I2282ef5a236ffd651c588f717cdab13f09c7e8ac --- .../compiler/xla/service/hlo_computation.cc | 12 +++++++----- tensorflow/compiler/xla/service/hlo_instruction.h | 3 +++ .../compiler/xla/service/hlo_pass_pipeline.cc | 4 ++-- .../compiler/xla/tests/hlo_metadata_test.cc | 8 ++++---- .../compiler/xla/tests/llvm_irgen_test_base.cc | 15 --------------- tensorflow/compiler/xla/xla_data.proto | 10 ++++++---- 6 files changed, 22 insertions(+), 30 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index cbf2c49fa7c..d061ee69864 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -919,11 +919,13 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, // function, and that they would be correlated to the same TF op. This might // not always be correct since HLO optimizations can cross TF op boundaries. // But still this seems to be better than nothing. - bool overwrite_dummy_name = - absl::StartsWith(new_instruction->metadata().op_name(), "DUMMY") && - !old_instruction->metadata().op_name().empty() && - !absl::StartsWith(old_instruction->metadata().op_name(), "DUMMY"); - if (new_instruction->metadata().op_name().empty() || overwrite_dummy_name) { + bool overwrite_op_name = new_instruction->metadata().op_name().empty() && + !old_instruction->metadata().op_name().empty(); + bool overwrite_pass_id = + new_instruction->metadata().op_name().empty() && + new_instruction->metadata().logical_creation_pass_id() == 0 && + old_instruction->metadata().logical_creation_pass_id() != 0; + if (overwrite_op_name || overwrite_pass_id) { new_instruction->set_metadata(old_instruction->metadata()); } if (new_instruction->frontend_attributes().map().empty()) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index f27c7de7090..da11d3e3367 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1614,6 +1614,9 @@ class HloInstruction { void set_metadata_op_name(const std::string& name) { metadata_.set_op_name(name); } + void set_logical_creation_pass_id(int64 pass_id) { + metadata_.set_logical_creation_pass_id(pass_id); + } const OpMetadata& metadata() const { return metadata_; } // Set/get the computation containing this instruction. set_parent should only diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 6f25cb2e2f9..25b2df02d88 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -105,8 +105,8 @@ void SetInstructionMetadata(HloModule& module) { if (instruction->metadata().creation_pass_id() == 0) { instruction->set_creation_pass_id(*pass_id); } - if (instruction->metadata().op_name().empty()) { - instruction->set_metadata_op_name(absl::StrCat("DUMMY_", *pass_id)); + if (instruction->metadata().logical_creation_pass_id() == 0) { + instruction->set_logical_creation_pass_id(*pass_id); } } } diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index 9b397dc7299..4188f43c9e8 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -22,7 +22,6 @@ limitations under the License. namespace xla { namespace { -using ::testing::StartsWith; using ::testing::StrEq; class HloMetadataTest : public LocalClientTestBase { @@ -59,8 +58,9 @@ TEST_F(HloMetadataTest, MetadataPropagation) { ->module() .entry_computation() ->root_instruction(); - EXPECT_EQ("add", instruction->metadata().op_type()); - EXPECT_EQ("my_sum_op", instruction->metadata().op_name()); + EXPECT_THAT(instruction->metadata().op_type(), StrEq("add")); + EXPECT_THAT(instruction->metadata().op_name(), StrEq("my_sum_op")); + EXPECT_NE(instruction->metadata().logical_creation_pass_id(), 0); } TEST_F(HloMetadataTest, MetadataClearing) { @@ -83,7 +83,7 @@ TEST_F(HloMetadataTest, MetadataClearing) { .entry_computation() ->root_instruction(); EXPECT_THAT(instruction->metadata().op_type(), StrEq("")); - EXPECT_THAT(instruction->metadata().op_name(), StartsWith("DUMMY")); + EXPECT_THAT(instruction->metadata().op_name(), StrEq("")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index b4d8d3c8716..d10d54dab1c 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -25,20 +25,6 @@ limitations under the License. namespace xla { -namespace { - -void RemoveDummyMetadataNames(HloModule* module) { - for (xla::HloComputation* computation : module->computations()) { - for (xla::HloInstruction* instruction : computation->instructions()) { - if (absl::StartsWith(instruction->metadata().op_name(), "DUMMY")) { - instruction->set_metadata_op_name(""); - } - } - } -} - -} // namespace - void LlvmIrGenTestBase::SetIrHook(bool match_optimized_ir) { auto llvm_compiler = GetLLVMCompiler(); using std::placeholders::_1; @@ -102,7 +88,6 @@ void LlvmIrGenTestBase::MatchOptimizedHlo(absl::string_view hlo, bool print_operand_shape) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(hlo)); - RemoveDummyMetadataNames(optimized_module.get()); HloPrintOptions print_opts; print_opts.set_print_operand_shape(print_operand_shape); StatusOr filecheck_result = diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 7e636afa387..5f7950bb45c 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -271,10 +271,6 @@ message OpMetadata { // // This name is often unique within a computation. Note: some frameworks // add auto-generated names if the user does not provide one. - // - // A dummy name may be assigned if op_name is empty in order to keep track of - // where op_name first became empty. Dummy names begin with "DUMMY_" and may - // include the current HloPassMetadata.pass_id. string op_name = 2; // Indicate a file and line that this op is associated to in a user's program. // @@ -288,6 +284,12 @@ message OpMetadata { // object. Should never be copied between HLO instructions. Zero if unset and // -1 if the instruction was created before HLO passes began. int64 creation_pass_id = 6; + + // HloPassMetadata.pass_id of the pass that created the logical functionality + // that this HLO instruction represents. Should be copied between HLO + // instructions that correspond across compilation passes. Zero if unset and + // -1 if the instruction was created before HLO passes began. + int64 logical_creation_pass_id = 7; } // Profile data from the execution of a computation.