[XLA] Use ShapeUtil::PermuteDimensions
in ShapeInference::InferTransposeShape
Change: 144282152
This commit is contained in:
parent
873473ef0f
commit
401387665d
@ -1319,9 +1319,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
|
||||
// we need output[i]=input[dimensions[i]] which is
|
||||
// Permute(Inverse(dimensions),input).
|
||||
return ShapeUtil::MakeShape(operand.element_type(),
|
||||
Permute(InversePermutation(dimensions),
|
||||
AsInt64Slice(operand.dimensions())));
|
||||
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
|
||||
|
@ -1125,8 +1125,8 @@ TEST_F(ShapeInferenceTest, Transpose) {
|
||||
ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
|
||||
EXPECT_IS_OK(inferred_shape_and_status);
|
||||
Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
|
||||
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape,
|
||||
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
|
||||
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user