diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7ca2832ec7d..1970b213c9b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -178,6 +178,7 @@ cc_test( deps = [ ":hlo", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d15b8236bba..1ede4e963f6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -435,6 +435,7 @@ HloInstruction::CreateSelectAndScatter( auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; instruction->set_parent(fused_root->parent()); + instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); instruction->CheckFusionInstruction(); return instruction; @@ -858,6 +859,7 @@ std::unique_ptr HloInstruction::Clone(const string& suffix) { CloneWithNewOperands(shape_, operands_); clone->name_ = name() + "." + suffix; clone->set_parent(parent()); + clone->set_metadata(metadata_); return clone; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 050fceca9c3..eeabc61ec82 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -609,6 +610,25 @@ TEST_F(HloInstructionTest, ChainFusionOp) { UnorderedElementsAre(fusion.get(), exp1.get())); } +TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { + // Create a chain of fused unary ops. + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto exp1 = + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); + auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); + OpMetadata metadata; + metadata.set_op_name("tf_op"); + exp1->set_metadata(metadata); + exp2->set_metadata(metadata); + + auto fusion = HloInstruction::CreateFusion( + r0f32_, HloInstruction::FusionKind::kLoop, exp2.get()); + auto* fused = fusion->FuseInstruction(exp1.get()); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata())); +} + TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});