[XLA] Fix a bug in cloning outfeeds, carried the wrong shape.

PiperOrigin-RevId: 163265592
This commit is contained in:
Chris Leary 2017-07-26 15:28:52 -07:00 committed by TensorFlower Gardener
parent 1bad826d6f
commit 08790e73d1
2 changed files with 22 additions and 1 deletions

View File

@ -912,7 +912,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateInfeed(shape, infeed_config());
case HloOpcode::kOutfeed:
CHECK_EQ(new_operands.size(), 1);
return CreateOutfeed(shape, new_operands[0], outfeed_config());
return CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
case HloOpcode::kBatchNormGrad:
CHECK_EQ(new_operands.size(), 5);
return CreateBatchNormGrad(shape, new_operands[0], new_operands[1],

View File

@ -638,6 +638,27 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
metadata, fusion->fused_expression_root()->operand(0)->metadata()));
}
TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
HloComputation::Builder builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1});
auto outfeed10 = builder.AddInstruction(
HloInstruction::CreateOutfeed(shape10, constant, ""));
auto outfeed01 = builder.AddInstruction(
HloInstruction::CreateOutfeed(shape01, constant, ""));
auto clone01 = builder.AddInstruction(outfeed01->Clone());
auto clone10 = builder.AddInstruction(outfeed10->Clone());
EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
}
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single unary operation.