[XLA] Use ShapeUtil::PermuteDimensions in ShapeInference::InferTransposeShape

Change: 144282152
This commit is contained in:
A. Unique TensorFlower 2017-01-11 20:09:44 -08:00 committed by TensorFlower Gardener
parent 873473ef0f
commit 401387665d
2 changed files with 3 additions and 5 deletions

View File

@ -1319,9 +1319,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
// we need output[i]=input[dimensions[i]] which is
// Permute(Inverse(dimensions),input).
return ShapeUtil::MakeShape(operand.element_type(),
Permute(InversePermutation(dimensions),
AsInt64Slice(operand.dimensions())));
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
}
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(

View File

@ -1125,8 +1125,8 @@ TEST_F(ShapeInferenceTest, Transpose) {
ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
EXPECT_IS_OK(inferred_shape_and_status);
Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape,
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
}
} // namespace