[XLA] Invert the permutation convention for xla::ShapeUtil::PermuteDimensions.
Most users seem to have wanted the alternate convention in the first place. PiperOrigin-RevId: 356810248 Change-Id: Iadbfc87129597b916a502abedcc6efe6f9fd926e
This commit is contained in:
parent
36f19f2bcf
commit
a85d5a5b1b
@ -837,9 +837,7 @@ Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
|
|||||||
// do a straight memory copy of the raw data set.
|
// do a straight memory copy of the raw data set.
|
||||||
// This is considerably faster than iterating over every array element using
|
// This is considerably faster than iterating over every array element using
|
||||||
// the EachCell<>() and Set<>() APIs.
|
// the EachCell<>() and Set<>() APIs.
|
||||||
std::vector<int64> inverse_permutation = InversePermutation(permutation);
|
Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape());
|
||||||
Shape permuted_shape =
|
|
||||||
ShapeUtil::PermuteDimensions(inverse_permutation, shape());
|
|
||||||
// Replace the layout with one affine to this shape, such that a
|
// Replace the layout with one affine to this shape, such that a
|
||||||
// transpose operation can be performed by leaving the flat values
|
// transpose operation can be performed by leaving the flat values
|
||||||
// representation intact.
|
// representation intact.
|
||||||
@ -853,6 +851,7 @@ Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
|
|||||||
// dimension has within the transposed array, a layout is affine if
|
// dimension has within the transposed array, a layout is affine if
|
||||||
// MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
|
// MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
|
||||||
// vector of the affine layout.
|
// vector of the affine layout.
|
||||||
|
std::vector<int64> inverse_permutation = InversePermutation(permutation);
|
||||||
CHECK(LayoutUtil::IsDenseArray(permuted_shape));
|
CHECK(LayoutUtil::IsDenseArray(permuted_shape));
|
||||||
Layout* layout = permuted_shape.mutable_layout();
|
Layout* layout = permuted_shape.mutable_layout();
|
||||||
layout->clear_minor_to_major();
|
layout->clear_minor_to_major();
|
||||||
|
@ -80,8 +80,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
|
|||||||
original_dnums.lhs_contracting_dimensions().end());
|
original_dnums.lhs_contracting_dimensions().end());
|
||||||
HloInstruction* transposed_lhs =
|
HloInstruction* transposed_lhs =
|
||||||
computation->AddInstruction(HloInstruction::CreateTranspose(
|
computation->AddInstruction(HloInstruction::CreateTranspose(
|
||||||
ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose),
|
ShapeUtil::PermuteDimensions(lhs_transpose, lhs_shape),
|
||||||
lhs_shape),
|
|
||||||
original_dot->mutable_operand(0), lhs_transpose));
|
original_dot->mutable_operand(0), lhs_transpose));
|
||||||
std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
|
std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
|
||||||
if (lhs_non_contracting_size > 1) {
|
if (lhs_non_contracting_size > 1) {
|
||||||
@ -127,8 +126,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
|
|||||||
rhs_non_contracting_dims.end());
|
rhs_non_contracting_dims.end());
|
||||||
HloInstruction* transposed_rhs =
|
HloInstruction* transposed_rhs =
|
||||||
computation->AddInstruction(HloInstruction::CreateTranspose(
|
computation->AddInstruction(HloInstruction::CreateTranspose(
|
||||||
ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose),
|
ShapeUtil::PermuteDimensions(rhs_transpose, rhs_shape),
|
||||||
rhs_shape),
|
|
||||||
original_dot->mutable_operand(1), rhs_transpose));
|
original_dot->mutable_operand(1), rhs_transpose));
|
||||||
|
|
||||||
std::vector<int64> rhs_reshape_dims = batch_dim_sizes;
|
std::vector<int64> rhs_reshape_dims = batch_dim_sizes;
|
||||||
|
@ -3097,7 +3097,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
|
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
|
||||||
// we need output[i]=input[dimensions[i]] which is
|
// we need output[i]=input[dimensions[i]] which is
|
||||||
// Permute(Inverse(dimensions),input).
|
// Permute(Inverse(dimensions),input).
|
||||||
return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand);
|
return ShapeUtil::PermuteDimensions(dimensions, operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
|
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
|
||||||
|
@ -1058,11 +1058,12 @@ Status ForEachMutableSubshapeHelper(
|
|||||||
absl::Span<const int64> permutation, const Shape& shape) {
|
absl::Span<const int64> permutation, const Shape& shape) {
|
||||||
Shape new_shape = shape;
|
Shape new_shape = shape;
|
||||||
new_shape.clear_dimensions();
|
new_shape.clear_dimensions();
|
||||||
for (auto dim : PermuteInverse(shape.dimensions(), permutation)) {
|
for (auto dim : Permute(shape.dimensions(), permutation)) {
|
||||||
new_shape.add_dimensions(dim);
|
new_shape.add_dimensions(dim);
|
||||||
}
|
}
|
||||||
|
auto inv_permutation = InversePermutation(permutation);
|
||||||
for (int64 i = 0; i < shape.rank(); i++) {
|
for (int64 i = 0; i < shape.rank(); i++) {
|
||||||
new_shape.set_dynamic_dimension(permutation[i],
|
new_shape.set_dynamic_dimension(inv_permutation[i],
|
||||||
shape.is_dynamic_dimension(i));
|
shape.is_dynamic_dimension(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1100,12 +1101,12 @@ Status ForEachMutableSubshapeHelper(
|
|||||||
new_layout->set_format(DENSE);
|
new_layout->set_format(DENSE);
|
||||||
new_layout->clear_minor_to_major();
|
new_layout->clear_minor_to_major();
|
||||||
for (auto index : ComposePermutations(
|
for (auto index : ComposePermutations(
|
||||||
permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
|
inv_permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
|
||||||
new_layout->add_minor_to_major(index);
|
new_layout->add_minor_to_major(index);
|
||||||
}
|
}
|
||||||
// The permutation accepted by TransposeIsBitcast is the inverse of the
|
// The permutation accepted by TransposeIsBitcast is the inverse of the
|
||||||
// permutation here.
|
// permutation here.
|
||||||
CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
|
CHECK(TransposeIsBitcast(shape, new_shape, permutation))
|
||||||
<< "shape=" << HumanStringWithLayout(shape)
|
<< "shape=" << HumanStringWithLayout(shape)
|
||||||
<< ", new_shape=" << HumanStringWithLayout(new_shape)
|
<< ", new_shape=" << HumanStringWithLayout(new_shape)
|
||||||
<< ", permutation={" << absl::StrJoin(permutation, ",") << "}";
|
<< ", permutation={" << absl::StrJoin(permutation, ",") << "}";
|
||||||
|
@ -597,13 +597,13 @@ class ShapeUtil {
|
|||||||
static Shape DropDegenerateDimensions(const Shape& shape);
|
static Shape DropDegenerateDimensions(const Shape& shape);
|
||||||
|
|
||||||
// Permutes the dimensions by the given permutation, so
|
// Permutes the dimensions by the given permutation, so
|
||||||
// return_value.dimensions[permutation[i]] = argument.dimensions[i].
|
// return_value.dimensions[i] = argument.dimensions[permutation[i]].
|
||||||
//
|
//
|
||||||
// Postcondition: For any valid permutation,
|
// Postcondition: For any valid permutation,
|
||||||
//
|
//
|
||||||
// !HasLayout(shape) ||
|
// !HasLayout(shape) ||
|
||||||
// TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
|
// TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
|
||||||
// InversePermutation(permutation)).
|
// permutation).
|
||||||
static Shape PermuteDimensions(absl::Span<const int64> permutation,
|
static Shape PermuteDimensions(absl::Span<const int64> permutation,
|
||||||
const Shape& shape);
|
const Shape& shape);
|
||||||
|
|
||||||
|
@ -725,11 +725,8 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) {
|
|||||||
SCOPED_TRACE(
|
SCOPED_TRACE(
|
||||||
absl::StrCat("permutation=", absl::StrJoin(permutation, ",")));
|
absl::StrCat("permutation=", absl::StrJoin(permutation, ",")));
|
||||||
|
|
||||||
// TransposeIsBitcast takes the inverse of the permutation that
|
|
||||||
// PermuteDimensions takes.
|
|
||||||
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(
|
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(
|
||||||
s, ShapeUtil::PermuteDimensions(permutation, s),
|
s, ShapeUtil::PermuteDimensions(permutation, s), permutation));
|
||||||
InversePermutation(permutation)));
|
|
||||||
} while (std::next_permutation(permutation.begin(), permutation.end()));
|
} while (std::next_permutation(permutation.begin(), permutation.end()));
|
||||||
} while (std::next_permutation(layout.begin(), layout.end()));
|
} while (std::next_permutation(layout.begin(), layout.end()));
|
||||||
}
|
}
|
||||||
@ -756,9 +753,9 @@ TEST(ShapeUtilTest, PermuteDynamicDimensions) {
|
|||||||
|
|
||||||
auto permuted = ShapeUtil::PermuteDimensions(permutation, shape);
|
auto permuted = ShapeUtil::PermuteDimensions(permutation, shape);
|
||||||
for (int i = 0; i < shape.rank(); i++) {
|
for (int i = 0; i < shape.rank(); i++) {
|
||||||
EXPECT_EQ(permuted.dimensions(permutation[i]), shape.dimensions(i));
|
EXPECT_EQ(permuted.dimensions(i), shape.dimensions(permutation[i]));
|
||||||
EXPECT_EQ(permuted.is_dynamic_dimension(permutation[i]),
|
EXPECT_EQ(permuted.is_dynamic_dimension(i),
|
||||||
shape.is_dynamic_dimension(i));
|
shape.is_dynamic_dimension(permutation[i]));
|
||||||
}
|
}
|
||||||
} while (std::next_permutation(permutation.begin(), permutation.end()));
|
} while (std::next_permutation(permutation.begin(), permutation.end()));
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user