From a85d5a5b1b1fd0ccd4732184823b1215bb26aacc Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 10 Feb 2021 13:14:58 -0800 Subject: [PATCH] [XLA] Invert the permutation convention for xla::ShapeUtil::PermuteDimensions. Most users seem to have wanted the alternate convention in the first place. PiperOrigin-RevId: 356810248 Change-Id: Iadbfc87129597b916a502abedcc6efe6f9fd926e --- tensorflow/compiler/xla/literal.cc | 5 ++--- tensorflow/compiler/xla/service/dot_decomposer.cc | 6 ++---- tensorflow/compiler/xla/service/shape_inference.cc | 2 +- tensorflow/compiler/xla/shape_util.cc | 9 +++++---- tensorflow/compiler/xla/shape_util.h | 4 ++-- tensorflow/compiler/xla/shape_util_test.cc | 11 ++++------- 6 files changed, 16 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index c9cf07652d3..d15f78c41e0 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -837,9 +837,7 @@ Literal LiteralBase::Transpose(absl::Span permutation) const { // do a straight memory copy of the raw data set. // This is considerably faster than iterating over every array element using // the EachCell<>() and Set<>() APIs. - std::vector inverse_permutation = InversePermutation(permutation); - Shape permuted_shape = - ShapeUtil::PermuteDimensions(inverse_permutation, shape()); + Shape permuted_shape = ShapeUtil::PermuteDimensions(permutation, shape()); // Replace the layout with one affine to this shape, such that a // transpose operation can be performed by leaving the flat values // representation intact. @@ -853,6 +851,7 @@ Literal LiteralBase::Transpose(absl::Span permutation) const { // dimension has within the transposed array, a layout is affine if // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major // vector of the affine layout. + std::vector inverse_permutation = InversePermutation(permutation); CHECK(LayoutUtil::IsDenseArray(permuted_shape)); Layout* layout = permuted_shape.mutable_layout(); layout->clear_minor_to_major(); diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index e21ecedc951..71452fad46b 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -80,8 +80,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) { original_dnums.lhs_contracting_dimensions().end()); HloInstruction* transposed_lhs = computation->AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose), - lhs_shape), + ShapeUtil::PermuteDimensions(lhs_transpose, lhs_shape), original_dot->mutable_operand(0), lhs_transpose)); std::vector lhs_reshape_dims = batch_dim_sizes; if (lhs_non_contracting_size > 1) { @@ -127,8 +126,7 @@ Status CanonicalizeDot(HloInstruction* original_dot) { rhs_non_contracting_dims.end()); HloInstruction* transposed_rhs = computation->AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose), - rhs_shape), + ShapeUtil::PermuteDimensions(rhs_transpose, rhs_shape), original_dot->mutable_operand(1), rhs_transpose)); std::vector rhs_reshape_dims = batch_dim_sizes; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 8f978a1dd30..73da11b63c3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -3097,7 +3097,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However, // we need output[i]=input[dimensions[i]] which is // Permute(Inverse(dimensions),input). - return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand); + return ShapeUtil::PermuteDimensions(dimensions, operand); } /* static */ StatusOr ShapeInference::InferClampShape( diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index f376a88c0c3..830914c0d35 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1058,11 +1058,12 @@ Status ForEachMutableSubshapeHelper( absl::Span permutation, const Shape& shape) { Shape new_shape = shape; new_shape.clear_dimensions(); - for (auto dim : PermuteInverse(shape.dimensions(), permutation)) { + for (auto dim : Permute(shape.dimensions(), permutation)) { new_shape.add_dimensions(dim); } + auto inv_permutation = InversePermutation(permutation); for (int64 i = 0; i < shape.rank(); i++) { - new_shape.set_dynamic_dimension(permutation[i], + new_shape.set_dynamic_dimension(inv_permutation[i], shape.is_dynamic_dimension(i)); } @@ -1100,12 +1101,12 @@ Status ForEachMutableSubshapeHelper( new_layout->set_format(DENSE); new_layout->clear_minor_to_major(); for (auto index : ComposePermutations( - permutation, AsInt64Slice(shape.layout().minor_to_major()))) { + inv_permutation, AsInt64Slice(shape.layout().minor_to_major()))) { new_layout->add_minor_to_major(index); } // The permutation accepted by TransposeIsBitcast is the inverse of the // permutation here. - CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation))) + CHECK(TransposeIsBitcast(shape, new_shape, permutation)) << "shape=" << HumanStringWithLayout(shape) << ", new_shape=" << HumanStringWithLayout(new_shape) << ", permutation={" << absl::StrJoin(permutation, ",") << "}"; diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 57793f6e284..e0841a30808 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -597,13 +597,13 @@ class ShapeUtil { static Shape DropDegenerateDimensions(const Shape& shape); // Permutes the dimensions by the given permutation, so - // return_value.dimensions[permutation[i]] = argument.dimensions[i]. + // return_value.dimensions[i] = argument.dimensions[permutation[i]]. // // Postcondition: For any valid permutation, // // !HasLayout(shape) || // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), - // InversePermutation(permutation)). + // permutation). static Shape PermuteDimensions(absl::Span permutation, const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 3cb261034bf..ec4c9246a65 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -725,11 +725,8 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { SCOPED_TRACE( absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); - // TransposeIsBitcast takes the inverse of the permutation that - // PermuteDimensions takes. EXPECT_TRUE(ShapeUtil::TransposeIsBitcast( - s, ShapeUtil::PermuteDimensions(permutation, s), - InversePermutation(permutation))); + s, ShapeUtil::PermuteDimensions(permutation, s), permutation)); } while (std::next_permutation(permutation.begin(), permutation.end())); } while (std::next_permutation(layout.begin(), layout.end())); } @@ -756,9 +753,9 @@ TEST(ShapeUtilTest, PermuteDynamicDimensions) { auto permuted = ShapeUtil::PermuteDimensions(permutation, shape); for (int i = 0; i < shape.rank(); i++) { - EXPECT_EQ(permuted.dimensions(permutation[i]), shape.dimensions(i)); - EXPECT_EQ(permuted.is_dynamic_dimension(permutation[i]), - shape.is_dynamic_dimension(i)); + EXPECT_EQ(permuted.dimensions(i), shape.dimensions(permutation[i])); + EXPECT_EQ(permuted.is_dynamic_dimension(i), + shape.is_dynamic_dimension(permutation[i])); } } while (std::next_permutation(permutation.begin(), permutation.end())); }