[XLA] Use ShapeUtil::PermuteDimensions
in ShapeInference::InferTransposeShape
Change: 144282152
This commit is contained in:
parent
873473ef0f
commit
401387665d
@ -1319,9 +1319,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
|||||||
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
|
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
|
||||||
// we need output[i]=input[dimensions[i]] which is
|
// we need output[i]=input[dimensions[i]] which is
|
||||||
// Permute(Inverse(dimensions),input).
|
// Permute(Inverse(dimensions),input).
|
||||||
return ShapeUtil::MakeShape(operand.element_type(),
|
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
|
||||||
Permute(InversePermutation(dimensions),
|
|
||||||
AsInt64Slice(operand.dimensions())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
|
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
|
||||||
|
@ -1125,8 +1125,8 @@ TEST_F(ShapeInferenceTest, Transpose) {
|
|||||||
ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
|
ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
|
||||||
EXPECT_IS_OK(inferred_shape_and_status);
|
EXPECT_IS_OK(inferred_shape_and_status);
|
||||||
Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
|
Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
|
||||||
EXPECT_TRUE(ShapeUtil::Equal(inferred_shape,
|
EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
|
||||||
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
|
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
x
Reference in New Issue
Block a user