From 08790e73d1d83bbb48d73d6f358466a9c75b06a9 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Wed, 26 Jul 2017 15:28:52 -0700 Subject: [PATCH] [XLA] Fix a bug in cloning outfeeds, carried the wrong shape. PiperOrigin-RevId: 163265592 --- .../compiler/xla/service/hlo_instruction.cc | 2 +- .../xla/service/hlo_instruction_test.cc | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index ed8a942d03a..c11fea09d14 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -912,7 +912,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return CreateInfeed(shape, infeed_config()); case HloOpcode::kOutfeed: CHECK_EQ(new_operands.size(), 1); - return CreateOutfeed(shape, new_operands[0], outfeed_config()); + return CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); case HloOpcode::kBatchNormGrad: CHECK_EQ(new_operands.size(), 5); return CreateBatchNormGrad(shape, new_operands[0], new_operands[1], diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 5951c833dba..ced8417fcef 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -638,6 +638,27 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { metadata, fusion->fused_expression_root()->operand(0)->metadata())); } +TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { + HloComputation::Builder builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({ + {1, 2}, + {3, 4}, + }))); + auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); + auto outfeed10 = builder.AddInstruction( + HloInstruction::CreateOutfeed(shape10, constant, "")); + auto outfeed01 = builder.AddInstruction( + HloInstruction::CreateOutfeed(shape01, constant, "")); + + auto clone01 = builder.AddInstruction(outfeed01->Clone()); + auto clone10 = builder.AddInstruction(outfeed10->Clone()); + + EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01)); + EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10)); +} + TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single unary operation.