[XLA] Add creation_pass_id and dummy op_names to OpMetadata.
PiperOrigin-RevId: 347108311 Change-Id: Iaeb7cea0c049e5f538557b1103972705e1a3f0de
This commit is contained in:
parent
0d882ea469
commit
9dc5df5f0f
@ -144,6 +144,7 @@ bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
|
@ -324,7 +324,7 @@ TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
|
||||
input = f32[1,17,9,9] parameter(0)
|
||||
filter = f32[3,3,17,32] parameter(1)
|
||||
|
||||
conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo"}
|
||||
conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo" op_name="bar"}
|
||||
ROOT relu = f32[1,32,9,9] maximum(zeros, conv)
|
||||
})";
|
||||
|
||||
@ -337,9 +337,9 @@ TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
|
||||
backend().default_stream_executor(), backend().memory_allocator())
|
||||
.ConsumeValueOrDie()
|
||||
->ToString();
|
||||
EXPECT_THAT(
|
||||
optimized_hlo_string,
|
||||
::testing::ContainsRegex(R"(custom-call.*metadata=\{op_type="foo"\})"));
|
||||
EXPECT_THAT(optimized_hlo_string,
|
||||
::testing::ContainsRegex(
|
||||
R"(custom-call.*metadata=\{op_type="foo" op_name="bar"\})"));
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {
|
||||
|
@ -918,7 +918,11 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
|
||||
// function, and that they would be correlated to the same TF op. This might
|
||||
// not always be correct since HLO optimizations can cross TF op boundaries.
|
||||
// But still this seems to be better than nothing.
|
||||
if (new_instruction->metadata().op_name().empty()) {
|
||||
bool overwrite_dummy_name =
|
||||
absl::StartsWith(new_instruction->metadata().op_name(), "DUMMY") &&
|
||||
!old_instruction->metadata().op_name().empty() &&
|
||||
!absl::StartsWith(old_instruction->metadata().op_name(), "DUMMY");
|
||||
if (new_instruction->metadata().op_name().empty() || overwrite_dummy_name) {
|
||||
new_instruction->set_metadata(old_instruction->metadata());
|
||||
}
|
||||
if (new_instruction->frontend_attributes().map().empty()) {
|
||||
|
@ -1601,8 +1601,19 @@ class HloInstruction {
|
||||
const PrecisionConfig& precision_config() const;
|
||||
PrecisionConfig* mutable_precision_config();
|
||||
|
||||
// Sets the debug metadata for this instruction.
|
||||
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
|
||||
// Sets the debug metadata for this instruction, excluding creation_pass_id,
|
||||
// which should never be copied anywhere.
|
||||
void set_metadata(const OpMetadata& metadata) {
|
||||
int64 creation_pass_id = metadata_.creation_pass_id();
|
||||
metadata_ = metadata;
|
||||
metadata_.set_creation_pass_id(creation_pass_id);
|
||||
}
|
||||
void set_creation_pass_id(int64 pass_id) {
|
||||
metadata_.set_creation_pass_id(pass_id);
|
||||
}
|
||||
void set_metadata_op_name(const std::string& name) {
|
||||
metadata_.set_op_name(name);
|
||||
}
|
||||
const OpMetadata& metadata() const { return metadata_; }
|
||||
|
||||
// Set/get the computation containing this instruction. set_parent should only
|
||||
|
@ -61,6 +61,12 @@ class HloModuleMetadata {
|
||||
module_metadata_.add_partitioned_module_ids(id);
|
||||
}
|
||||
|
||||
StatusOr<int64> current_pass_id() {
|
||||
TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata,
|
||||
GetCurrentHloPassMetadata());
|
||||
return pass_metadata->pass_id();
|
||||
}
|
||||
|
||||
// Setters for the current HloPassMetadata.
|
||||
Status set_current_pass_name(const std::string& pass_name) {
|
||||
return MutateCurrentHloPassMetadata(
|
||||
|
@ -67,7 +67,7 @@ void RecordPassEndMetadata(HloModule& module, const std::string& pass_name,
|
||||
Status status =
|
||||
AttemptRecordPassEndMetadata(module, pass_name, module_changed);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << status.error_message();
|
||||
LOG(FATAL) << status;
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,7 +91,30 @@ void RecordPassEndMetadata(HloModuleGroup& module_group,
|
||||
Status status =
|
||||
AttemptRecordPassEndMetadata(module_group, pass_name, module_changed);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << status.error_message();
|
||||
LOG(FATAL) << status;
|
||||
}
|
||||
}
|
||||
|
||||
void SetInstructionMetadata(HloModule& module) {
|
||||
StatusOr<int64> pass_id = module.metadata()->current_pass_id();
|
||||
if (!pass_id.ok()) {
|
||||
LOG(FATAL) << pass_id.status();
|
||||
}
|
||||
for (xla::HloComputation* computation : module.computations()) {
|
||||
for (xla::HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->metadata().creation_pass_id() == 0) {
|
||||
instruction->set_creation_pass_id(*pass_id);
|
||||
}
|
||||
if (instruction->metadata().op_name().empty()) {
|
||||
instruction->set_metadata_op_name(absl::StrCat("DUMMY_", *pass_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetInstructionMetadata(HloModuleGroup& module_group) {
|
||||
for (HloModule* module : module_group.modules()) {
|
||||
SetInstructionMetadata(*module);
|
||||
}
|
||||
}
|
||||
|
||||
@ -127,6 +150,7 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal(
|
||||
TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, kPipelineStart));
|
||||
|
||||
RecordPassStartMetadata(*hlo, std::string(kPipelineStart), pipeline_name);
|
||||
SetInstructionMetadata(*hlo);
|
||||
MaybeDumpHloAndSaveFilenames(*hlo,
|
||||
/*after_pass_name=*/kPipelineStart,
|
||||
/*before_pass_name=*/passes.empty()
|
||||
@ -147,6 +171,7 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal(
|
||||
}
|
||||
RecordPassStartMetadata(*hlo, pass_name, pipeline_name);
|
||||
TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
|
||||
SetInstructionMetadata(*hlo);
|
||||
MaybeDumpHloAndSaveFilenames(*hlo,
|
||||
/*after_pass_name=*/pass_name,
|
||||
/*before_pass_name=*/i + 1 >= passes.size()
|
||||
@ -216,7 +241,7 @@ void HloPassPipeline::MaybeDumpHloAndSaveFilenames(
|
||||
name(), before_pass_name, after_pass_name, module)) {
|
||||
Status status = module.metadata()->add_current_pass_dump_filename(filename);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << status.error_message();
|
||||
LOG(FATAL) << status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ HloModule Add
|
||||
ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] {
|
||||
%x = f32[2,2,2]{2,1,0} parameter(0)
|
||||
%y = f32[2,2,2]{2,1,0} parameter(1)
|
||||
ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y)
|
||||
ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y), metadata={op_name="original_tf_op"}
|
||||
}
|
||||
|
||||
// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: lmhlo.add; failed for testing: std.return]
|
||||
// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y), metadata={op_name="original_tf_op"}: failed for testing: lmhlo.add; failed for testing: std.return]
|
||||
|
@ -22,6 +22,9 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::StartsWith;
|
||||
using ::testing::StrEq;
|
||||
|
||||
class HloMetadataTest : public LocalClientTestBase {
|
||||
protected:
|
||||
HloMetadataTest() {
|
||||
@ -79,9 +82,8 @@ TEST_F(HloMetadataTest, MetadataClearing) {
|
||||
->module()
|
||||
.entry_computation()
|
||||
->root_instruction();
|
||||
// We expect these to be empty (no metadata set).
|
||||
EXPECT_EQ("", instruction->metadata().op_type());
|
||||
EXPECT_EQ("", instruction->metadata().op_name());
|
||||
EXPECT_THAT(instruction->metadata().op_type(), StrEq(""));
|
||||
EXPECT_THAT(instruction->metadata().op_name(), StartsWith("DUMMY"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -25,6 +25,20 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
void RemoveDummyMetadataNames(HloModule* module) {
|
||||
for (xla::HloComputation* computation : module->computations()) {
|
||||
for (xla::HloInstruction* instruction : computation->instructions()) {
|
||||
if (absl::StartsWith(instruction->metadata().op_name(), "DUMMY")) {
|
||||
instruction->set_metadata_op_name("");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LlvmIrGenTestBase::SetIrHook(bool match_optimized_ir) {
|
||||
auto llvm_compiler = GetLLVMCompiler();
|
||||
using std::placeholders::_1;
|
||||
@ -88,6 +102,7 @@ void LlvmIrGenTestBase::MatchOptimizedHlo(absl::string_view hlo,
|
||||
bool print_operand_shape) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
|
||||
GetOptimizedModule(hlo));
|
||||
RemoveDummyMetadataNames(optimized_module.get());
|
||||
HloPrintOptions print_opts;
|
||||
print_opts.set_print_operand_shape(print_operand_shape);
|
||||
StatusOr<bool> filecheck_result =
|
||||
|
@ -271,6 +271,10 @@ message OpMetadata {
|
||||
//
|
||||
// This name is often unique within a computation. Note: some frameworks
|
||||
// add auto-generated names if the user does not provide one.
|
||||
//
|
||||
// A dummy name may be assigned if op_name is empty in order to keep track of
|
||||
// where op_name first became empty. Dummy names begin with "DUMMY_" and may
|
||||
// include the current HloPassMetadata.pass_id.
|
||||
string op_name = 2;
|
||||
// Indicate a file and line that this op is associated to in a user's program.
|
||||
//
|
||||
@ -279,6 +283,11 @@ message OpMetadata {
|
||||
int32 source_line = 4;
|
||||
|
||||
repeated ProfileType profile_type = 5;
|
||||
|
||||
// HloPassMetadata.pass_id of the pass that created this HLO instruction
|
||||
// object. Should never be copied between HLO instructions. Zero if unset and
|
||||
// -1 if the instruction was created before HLO passes began.
|
||||
int64 creation_pass_id = 6;
|
||||
}
|
||||
|
||||
// Profile data from the execution of a computation.
|
||||
|
Loading…
x
Reference in New Issue
Block a user