[XLA] Invert the permutation convention for xla::ShapeUtil::PermuteDimensions.
Most users seem to have wanted the alternate convention in the first place. PiperOrigin-RevId: 356810248 Change-Id: Iadbfc87129597b916a502abedcc6efe6f9fd926e
This commit is contained in:
parent
36f19f2bcf
commit
a85d5a5b1b
@ -837,9 +837,7 @@ Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
|
||||
// do a straight memory copy of the raw data set.
|
||||
// This is considerably faster than iterating over every array element using
|
||||
// the EachCell<>() and Set<>() APIs.
|
||||
std::vector<int64> inverse_permutation = InversePermutation(permutation);
|
||||
Shape permuted_shape =
|
||||
ShapeUtil::PermuteDimensions(inverse_permutation, shape());
|
||||
Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape());
|
||||
// Replace the layout with one affine to this shape, such that a
|
||||
// transpose operation can be performed by leaving the flat values
|
||||
// representation intact.
|
||||
@ -853,6 +851,7 @@ Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
|
||||
// dimension has within the transposed array, a layout is affine if
|
||||
// MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
|
||||
// vector of the affine layout.
|
||||
std::vector<int64> inverse_permutation = InversePermutation(permutation);
|
||||
CHECK(LayoutUtil::IsDenseArray(permuted_shape));
|
||||
Layout* layout = permuted_shape.mutable_layout();
|
||||
layout->clear_minor_to_major();
|
||||
|
@ -80,8 +80,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
original_dnums.lhs_contracting_dimensions().end());
|
||||
HloInstruction* transposed_lhs =
|
||||
computation->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose),
|
||||
lhs_shape),
|
||||
ShapeUtil::PermuteDimensions(lhs_transpose, lhs_shape),
|
||||
original_dot->mutable_operand(0), lhs_transpose));
|
||||
std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
|
||||
if (lhs_non_contracting_size > 1) {
|
||||
@ -127,8 +126,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
rhs_non_contracting_dims.end());
|
||||
HloInstruction* transposed_rhs =
|
||||
computation->AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose),
|
||||
rhs_shape),
|
||||
ShapeUtil::PermuteDimensions(rhs_transpose, rhs_shape),
|
||||
original_dot->mutable_operand(1), rhs_transpose));
|
||||
|
||||
std::vector<int64> rhs_reshape_dims = batch_dim_sizes;
|
||||
|
@ -3097,7 +3097,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
|
||||
// we need output[i]=input[dimensions[i]] which is
|
||||
// Permute(Inverse(dimensions),input).
|
||||
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
|
||||
return ShapeUtil::PermuteDimensions(dimensions, operand);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
|
||||
|
@ -1058,11 +1058,12 @@ Status ForEachMutableSubshapeHelper(
|
||||
absl::Span<const int64> permutation, const Shape& shape) {
|
||||
Shape new_shape = shape;
|
||||
new_shape.clear_dimensions();
|
||||
for (auto dim : PermuteInverse(shape.dimensions(), permutation)) {
|
||||
for (auto dim : Permute(shape.dimensions(), permutation)) {
|
||||
new_shape.add_dimensions(dim);
|
||||
}
|
||||
auto inv_permutation = InversePermutation(permutation);
|
||||
for (int64 i = 0; i < shape.rank(); i++) {
|
||||
new_shape.set_dynamic_dimension(permutation[i],
|
||||
new_shape.set_dynamic_dimension(inv_permutation[i],
|
||||
shape.is_dynamic_dimension(i));
|
||||
}
|
||||
|
||||
@ -1100,12 +1101,12 @@ Status ForEachMutableSubshapeHelper(
|
||||
new_layout->set_format(DENSE);
|
||||
new_layout->clear_minor_to_major();
|
||||
for (auto index : ComposePermutations(
|
||||
permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
|
||||
inv_permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
|
||||
new_layout->add_minor_to_major(index);
|
||||
}
|
||||
// The permutation accepted by TransposeIsBitcast is the inverse of the
|
||||
// permutation here.
|
||||
CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
|
||||
CHECK(TransposeIsBitcast(shape, new_shape, permutation))
|
||||
<< "shape=" << HumanStringWithLayout(shape)
|
||||
<< ", new_shape=" << HumanStringWithLayout(new_shape)
|
||||
<< ", permutation={" << absl::StrJoin(permutation, ",") << "}";
|
||||
|
@ -597,13 +597,13 @@ class ShapeUtil {
|
||||
static Shape DropDegenerateDimensions(const Shape& shape);
|
||||
|
||||
// Permutes the dimensions by the given permutation, so
|
||||
// return_value.dimensions[permutation[i]] = argument.dimensions[i].
|
||||
// return_value.dimensions[i] = argument.dimensions[permutation[i]].
|
||||
//
|
||||
// Postcondition: For any valid permutation,
|
||||
//
|
||||
// !HasLayout(shape) ||
|
||||
// TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
|
||||
// InversePermutation(permutation)).
|
||||
// permutation).
|
||||
static Shape PermuteDimensions(absl::Span<const int64> permutation,
|
||||
const Shape& shape);
|
||||
|
||||
|
@ -725,11 +725,8 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) {
|
||||
SCOPED_TRACE(
|
||||
absl::StrCat("permutation=", absl::StrJoin(permutation, ",")));
|
||||
|
||||
// TransposeIsBitcast takes the inverse of the permutation that
|
||||
// PermuteDimensions takes.
|
||||
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(
|
||||
s, ShapeUtil::PermuteDimensions(permutation, s),
|
||||
InversePermutation(permutation)));
|
||||
s, ShapeUtil::PermuteDimensions(permutation, s), permutation));
|
||||
} while (std::next_permutation(permutation.begin(), permutation.end()));
|
||||
} while (std::next_permutation(layout.begin(), layout.end()));
|
||||
}
|
||||
@ -756,9 +753,9 @@ TEST(ShapeUtilTest, PermuteDynamicDimensions) {
|
||||
|
||||
auto permuted = ShapeUtil::PermuteDimensions(permutation, shape);
|
||||
for (int i = 0; i < shape.rank(); i++) {
|
||||
EXPECT_EQ(permuted.dimensions(permutation[i]), shape.dimensions(i));
|
||||
EXPECT_EQ(permuted.is_dynamic_dimension(permutation[i]),
|
||||
shape.is_dynamic_dimension(i));
|
||||
EXPECT_EQ(permuted.dimensions(i), shape.dimensions(permutation[i]));
|
||||
EXPECT_EQ(permuted.is_dynamic_dimension(i),
|
||||
shape.is_dynamic_dimension(permutation[i]));
|
||||
}
|
||||
} while (std::next_permutation(permutation.begin(), permutation.end()));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user