[XLA] Fix a bug in cloning outfeeds, carried the wrong shape.
PiperOrigin-RevId: 163265592
This commit is contained in:
parent
1bad826d6f
commit
08790e73d1
@ -912,7 +912,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
return CreateInfeed(shape, infeed_config());
|
||||
case HloOpcode::kOutfeed:
|
||||
CHECK_EQ(new_operands.size(), 1);
|
||||
return CreateOutfeed(shape, new_operands[0], outfeed_config());
|
||||
return CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
CHECK_EQ(new_operands.size(), 5);
|
||||
return CreateBatchNormGrad(shape, new_operands[0], new_operands[1],
|
||||
|
@ -638,6 +638,27 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
|
||||
metadata, fusion->fused_expression_root()->operand(0)->metadata()));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR2<float>({
|
||||
{1, 2},
|
||||
{3, 4},
|
||||
})));
|
||||
auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
|
||||
auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1});
|
||||
auto outfeed10 = builder.AddInstruction(
|
||||
HloInstruction::CreateOutfeed(shape10, constant, ""));
|
||||
auto outfeed01 = builder.AddInstruction(
|
||||
HloInstruction::CreateOutfeed(shape01, constant, ""));
|
||||
|
||||
auto clone01 = builder.AddInstruction(outfeed01->Clone());
|
||||
auto clone10 = builder.AddInstruction(outfeed10->Clone());
|
||||
|
||||
EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
// Create a fusion instruction containing a single unary operation.
|
||||
|
Loading…
Reference in New Issue
Block a user