[XLA] Invert the permutation convention for xla::ShapeUtil::PermuteDimensions.

Most users seem to have wanted the alternate convention in the first place.

PiperOrigin-RevId: 356810248
Change-Id: Iadbfc87129597b916a502abedcc6efe6f9fd926e
This commit is contained in:
Peter Hawkins 2021-02-10 13:14:58 -08:00 committed by TensorFlower Gardener
parent 36f19f2bcf
commit a85d5a5b1b
6 changed files with 16 additions and 21 deletions

View File

@ -837,9 +837,7 @@ Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
// do a straight memory copy of the raw data set.
// This is considerably faster than iterating over every array element using
// the EachCell<>() and Set<>() APIs.
std::vector<int64> inverse_permutation = InversePermutation(permutation);
Shape permuted_shape =
ShapeUtil::PermuteDimensions(inverse_permutation, shape());
Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape());
// Replace the layout with one affine to this shape, such that a
// transpose operation can be performed by leaving the flat values
// representation intact.
@ -853,6 +851,7 @@ Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
// dimension has within the transposed array, a layout is affine if
// MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
// vector of the affine layout.
std::vector<int64> inverse_permutation = InversePermutation(permutation);
CHECK(LayoutUtil::IsDenseArray(permuted_shape));
Layout* layout = permuted_shape.mutable_layout();
layout->clear_minor_to_major();

View File

@ -80,8 +80,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
original_dnums.lhs_contracting_dimensions().end());
HloInstruction* transposed_lhs =
computation->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose),
lhs_shape),
ShapeUtil::PermuteDimensions(lhs_transpose, lhs_shape),
original_dot->mutable_operand(0), lhs_transpose));
std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
if (lhs_non_contracting_size > 1) {
@ -127,8 +126,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
rhs_non_contracting_dims.end());
HloInstruction* transposed_rhs =
computation->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose),
rhs_shape),
ShapeUtil::PermuteDimensions(rhs_transpose, rhs_shape),
original_dot->mutable_operand(1), rhs_transpose));
std::vector<int64> rhs_reshape_dims = batch_dim_sizes;

View File

@ -3097,7 +3097,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
// we need output[i]=input[dimensions[i]] which is
// Permute(Inverse(dimensions),input).
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
return ShapeUtil::PermuteDimensions(dimensions, operand);
}
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(

View File

@ -1058,11 +1058,12 @@ Status ForEachMutableSubshapeHelper(
absl::Span<const int64> permutation, const Shape& shape) {
Shape new_shape = shape;
new_shape.clear_dimensions();
for (auto dim : PermuteInverse(shape.dimensions(), permutation)) {
for (auto dim : Permute(shape.dimensions(), permutation)) {
new_shape.add_dimensions(dim);
}
auto inv_permutation = InversePermutation(permutation);
for (int64 i = 0; i < shape.rank(); i++) {
new_shape.set_dynamic_dimension(permutation[i],
new_shape.set_dynamic_dimension(inv_permutation[i],
shape.is_dynamic_dimension(i));
}
@ -1100,12 +1101,12 @@ Status ForEachMutableSubshapeHelper(
new_layout->set_format(DENSE);
new_layout->clear_minor_to_major();
for (auto index : ComposePermutations(
permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
inv_permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
new_layout->add_minor_to_major(index);
}
// The permutation accepted by TransposeIsBitcast is the inverse of the
// permutation here.
CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
CHECK(TransposeIsBitcast(shape, new_shape, permutation))
<< "shape=" << HumanStringWithLayout(shape)
<< ", new_shape=" << HumanStringWithLayout(new_shape)
<< ", permutation={" << absl::StrJoin(permutation, ",") << "}";

View File

@ -597,13 +597,13 @@ class ShapeUtil {
static Shape DropDegenerateDimensions(const Shape& shape);
// Permutes the dimensions by the given permutation, so
// return_value.dimensions[permutation[i]] = argument.dimensions[i].
// return_value.dimensions[i] = argument.dimensions[permutation[i]].
//
// Postcondition: For any valid permutation,
//
// !HasLayout(shape) ||
// TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
// InversePermutation(permutation)).
// permutation).
static Shape PermuteDimensions(absl::Span<const int64> permutation,
const Shape& shape);

View File

@ -725,11 +725,8 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) {
SCOPED_TRACE(
absl::StrCat("permutation=", absl::StrJoin(permutation, ",")));
// TransposeIsBitcast takes the inverse of the permutation that
// PermuteDimensions takes.
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(
s, ShapeUtil::PermuteDimensions(permutation, s),
InversePermutation(permutation)));
s, ShapeUtil::PermuteDimensions(permutation, s), permutation));
} while (std::next_permutation(permutation.begin(), permutation.end()));
} while (std::next_permutation(layout.begin(), layout.end()));
}
@ -756,9 +753,9 @@ TEST(ShapeUtilTest, PermuteDynamicDimensions) {
auto permuted = ShapeUtil::PermuteDimensions(permutation, shape);
for (int i = 0; i < shape.rank(); i++) {
EXPECT_EQ(permuted.dimensions(permutation[i]), shape.dimensions(i));
EXPECT_EQ(permuted.is_dynamic_dimension(permutation[i]),
shape.is_dynamic_dimension(i));
EXPECT_EQ(permuted.dimensions(i), shape.dimensions(permutation[i]));
EXPECT_EQ(permuted.is_dynamic_dimension(i),
shape.is_dynamic_dimension(permutation[i]));
}
} while (std::next_permutation(permutation.begin(), permutation.end()));
}