[XLA] Move HLO instruction OpMetadata dummy names to a new field.
PiperOrigin-RevId: 356388571 Change-Id: I2282ef5a236ffd651c588f717cdab13f09c7e8ac
This commit is contained in:
parent
358a61ab90
commit
777b088450
tensorflow/compiler/xla
@ -919,11 +919,13 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
|
||||
// function, and that they would be correlated to the same TF op. This might
|
||||
// not always be correct since HLO optimizations can cross TF op boundaries.
|
||||
// But still this seems to be better than nothing.
|
||||
bool overwrite_dummy_name =
|
||||
absl::StartsWith(new_instruction->metadata().op_name(), "DUMMY") &&
|
||||
!old_instruction->metadata().op_name().empty() &&
|
||||
!absl::StartsWith(old_instruction->metadata().op_name(), "DUMMY");
|
||||
if (new_instruction->metadata().op_name().empty() || overwrite_dummy_name) {
|
||||
bool overwrite_op_name = new_instruction->metadata().op_name().empty() &&
|
||||
!old_instruction->metadata().op_name().empty();
|
||||
bool overwrite_pass_id =
|
||||
new_instruction->metadata().op_name().empty() &&
|
||||
new_instruction->metadata().logical_creation_pass_id() == 0 &&
|
||||
old_instruction->metadata().logical_creation_pass_id() != 0;
|
||||
if (overwrite_op_name || overwrite_pass_id) {
|
||||
new_instruction->set_metadata(old_instruction->metadata());
|
||||
}
|
||||
if (new_instruction->frontend_attributes().map().empty()) {
|
||||
|
@ -1614,6 +1614,9 @@ class HloInstruction {
|
||||
void set_metadata_op_name(const std::string& name) {
|
||||
metadata_.set_op_name(name);
|
||||
}
|
||||
void set_logical_creation_pass_id(int64 pass_id) {
|
||||
metadata_.set_logical_creation_pass_id(pass_id);
|
||||
}
|
||||
const OpMetadata& metadata() const { return metadata_; }
|
||||
|
||||
// Set/get the computation containing this instruction. set_parent should only
|
||||
|
@ -105,8 +105,8 @@ void SetInstructionMetadata(HloModule& module) {
|
||||
if (instruction->metadata().creation_pass_id() == 0) {
|
||||
instruction->set_creation_pass_id(*pass_id);
|
||||
}
|
||||
if (instruction->metadata().op_name().empty()) {
|
||||
instruction->set_metadata_op_name(absl::StrCat("DUMMY_", *pass_id));
|
||||
if (instruction->metadata().logical_creation_pass_id() == 0) {
|
||||
instruction->set_logical_creation_pass_id(*pass_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::StartsWith;
|
||||
using ::testing::StrEq;
|
||||
|
||||
class HloMetadataTest : public LocalClientTestBase {
|
||||
@ -59,8 +58,9 @@ TEST_F(HloMetadataTest, MetadataPropagation) {
|
||||
->module()
|
||||
.entry_computation()
|
||||
->root_instruction();
|
||||
EXPECT_EQ("add", instruction->metadata().op_type());
|
||||
EXPECT_EQ("my_sum_op", instruction->metadata().op_name());
|
||||
EXPECT_THAT(instruction->metadata().op_type(), StrEq("add"));
|
||||
EXPECT_THAT(instruction->metadata().op_name(), StrEq("my_sum_op"));
|
||||
EXPECT_NE(instruction->metadata().logical_creation_pass_id(), 0);
|
||||
}
|
||||
|
||||
TEST_F(HloMetadataTest, MetadataClearing) {
|
||||
@ -83,7 +83,7 @@ TEST_F(HloMetadataTest, MetadataClearing) {
|
||||
.entry_computation()
|
||||
->root_instruction();
|
||||
EXPECT_THAT(instruction->metadata().op_type(), StrEq(""));
|
||||
EXPECT_THAT(instruction->metadata().op_name(), StartsWith("DUMMY"));
|
||||
EXPECT_THAT(instruction->metadata().op_name(), StrEq(""));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -25,20 +25,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
void RemoveDummyMetadataNames(HloModule* module) {
|
||||
for (xla::HloComputation* computation : module->computations()) {
|
||||
for (xla::HloInstruction* instruction : computation->instructions()) {
|
||||
if (absl::StartsWith(instruction->metadata().op_name(), "DUMMY")) {
|
||||
instruction->set_metadata_op_name("");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LlvmIrGenTestBase::SetIrHook(bool match_optimized_ir) {
|
||||
auto llvm_compiler = GetLLVMCompiler();
|
||||
using std::placeholders::_1;
|
||||
@ -102,7 +88,6 @@ void LlvmIrGenTestBase::MatchOptimizedHlo(absl::string_view hlo,
|
||||
bool print_operand_shape) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
|
||||
GetOptimizedModule(hlo));
|
||||
RemoveDummyMetadataNames(optimized_module.get());
|
||||
HloPrintOptions print_opts;
|
||||
print_opts.set_print_operand_shape(print_operand_shape);
|
||||
StatusOr<bool> filecheck_result =
|
||||
|
@ -271,10 +271,6 @@ message OpMetadata {
|
||||
//
|
||||
// This name is often unique within a computation. Note: some frameworks
|
||||
// add auto-generated names if the user does not provide one.
|
||||
//
|
||||
// A dummy name may be assigned if op_name is empty in order to keep track of
|
||||
// where op_name first became empty. Dummy names begin with "DUMMY_" and may
|
||||
// include the current HloPassMetadata.pass_id.
|
||||
string op_name = 2;
|
||||
// Indicate a file and line that this op is associated to in a user's program.
|
||||
//
|
||||
@ -288,6 +284,12 @@ message OpMetadata {
|
||||
// object. Should never be copied between HLO instructions. Zero if unset and
|
||||
// -1 if the instruction was created before HLO passes began.
|
||||
int64 creation_pass_id = 6;
|
||||
|
||||
// HloPassMetadata.pass_id of the pass that created the logical functionality
|
||||
// that this HLO instruction represents. Should be copied between HLO
|
||||
// instructions that correspond across compilation passes. Zero if unset and
|
||||
// -1 if the instruction was created before HLO passes began.
|
||||
int64 logical_creation_pass_id = 7;
|
||||
}
|
||||
|
||||
// Profile data from the execution of a computation.
|
||||
|
Loading…
Reference in New Issue
Block a user