From 9fcf66e16dfbd32ebf739bac3e2dd1db59c67360 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 9 Feb 2021 19:06:42 -0800 Subject: [PATCH] [XLA] Change xla::IsPermutation() not to check the size, and make its implementation O(n). xla::IsPermutation() does two things: it checks a permutation is actually a permutation, and checks it has a particular size. There's no need to do the latter in xla::IsPermutation because it can be done just as easily by checking .size(), and a number of callers don't need the size check. In addition, change the implementation to be O(n) time; there's no need to call std::is_permutation() which guarantees nothing better than O(n^2) behavior. PiperOrigin-RevId: 356643750 Change-Id: Ib177287556b5c22266185fa813798a4ad5392054 --- tensorflow/compiler/xla/BUILD | 34 +++++++++++------- tensorflow/compiler/xla/literal.cc | 2 +- tensorflow/compiler/xla/permutation_util.cc | 17 +++++---- tensorflow/compiler/xla/permutation_util.h | 8 +++-- .../compiler/xla/permutation_util_test.cc | 35 +++++++++++++++++++ .../xla/service/algebraic_simplifier.cc | 2 +- .../compiler/xla/service/shape_inference.cc | 2 +- 7 files changed, 75 insertions(+), 25 deletions(-) create mode 100644 tensorflow/compiler/xla/permutation_util_test.cc diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index a66f30a47c0..4b3b523b18d 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -251,6 +251,18 @@ cc_library( ], ) +tf_cc_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":test", + ":types", + ":util", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:bfloat16", + ], +) + cc_library( name = "permutation_util", srcs = ["permutation_util.cc"], @@ -264,6 +276,16 @@ cc_library( ], ) +tf_cc_test( + name = "permutation_util_test", + srcs = ["permutation_util_test.cc"], + deps = [ + ":permutation_util", + ":test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "protobuf_util", srcs = ["protobuf_util.cc"], @@ -282,18 +304,6 @@ cc_library( ], ) -tf_cc_test( - name = "util_test", - srcs = ["util_test.cc"], - deps = [ - ":test", - ":types", - ":util", - "//tensorflow/core:test_main", - "//tensorflow/core/platform:bfloat16", - ], -) - tf_cc_test( name = "iterator_util_test", srcs = ["iterator_util_test.cc"], diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 34e17f9917a..c9cf07652d3 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -831,7 +831,7 @@ StatusOr LiteralBase::Reshape( Literal LiteralBase::Transpose(absl::Span permutation) const { CHECK(shape().IsArray()) << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, shape().rank())) + CHECK(shape().rank() == permutation.size() && IsPermutation(permutation)) << "Given permutation is not a permutation of dimension numbers"; // To transpose the array, we just permute the dimensions and layout, and // do a straight memory copy of the raw data set. diff --git a/tensorflow/compiler/xla/permutation_util.cc b/tensorflow/compiler/xla/permutation_util.cc index bbba5dc9bde..a5db6019a5b 100644 --- a/tensorflow/compiler/xla/permutation_util.cc +++ b/tensorflow/compiler/xla/permutation_util.cc @@ -20,18 +20,20 @@ limitations under the License. namespace xla { -bool IsPermutation(absl::Span permutation, int64 rank) { - if (rank != permutation.size()) { - return false; +bool IsPermutation(absl::Span permutation) { + absl::InlinedVector seen(permutation.size(), false); + for (int64 p : permutation) { + if (p < 0 || p >= permutation.size() || seen[p]) { + return false; + } + seen[p] = true; } - absl::InlinedVector trivial_permutation(rank); - absl::c_iota(trivial_permutation, 0); - return absl::c_is_permutation(permutation, trivial_permutation); + return true; } std::vector InversePermutation( absl::Span input_permutation) { - DCHECK(IsPermutation(input_permutation, input_permutation.size())); + DCHECK(IsPermutation(input_permutation)); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { output_permutation.at(input_permutation.at(i)) = i; @@ -43,6 +45,7 @@ std::vector ComposePermutations(absl::Span p1, absl::Span p2) { CHECK_EQ(p1.size(), p2.size()); std::vector output; + output.reserve(p1.size()); for (size_t i = 0; i < p1.size(); ++i) { output.push_back(p1.at(p2.at(i))); } diff --git a/tensorflow/compiler/xla/permutation_util.h b/tensorflow/compiler/xla/permutation_util.h index 13ec3c0cd12..ee9e2ba4c62 100644 --- a/tensorflow/compiler/xla/permutation_util.h +++ b/tensorflow/compiler/xla/permutation_util.h @@ -26,8 +26,9 @@ limitations under the License. namespace xla { -// Checks whether permutation is a permutation of the [0, rank) integer range. -bool IsPermutation(absl::Span permutation, int64 rank); +// Returns true if permutation is a permutation of the integers +// [0, permutation.size()). +bool IsPermutation(absl::Span permutation); // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. @@ -40,7 +41,8 @@ std::vector Permute( absl::Span permutation, const Container& input) { using T = typename Container::value_type; absl::Span data(input); - CHECK(IsPermutation(permutation, data.size())); + CHECK_EQ(permutation.size(), data.size()); + CHECK(IsPermutation(permutation)); std::vector output(data.size()); for (size_t i = 0; i < permutation.size(); ++i) { output[permutation[i]] = data[i]; diff --git a/tensorflow/compiler/xla/permutation_util_test.cc b/tensorflow/compiler/xla/permutation_util_test.cc new file mode 100644 index 00000000000..b489d288903 --- /dev/null +++ b/tensorflow/compiler/xla/permutation_util_test.cc @@ -0,0 +1,35 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/permutation_util.h" + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(PermutationUtilTest, IsPermutation) { + EXPECT_TRUE(IsPermutation({})); + EXPECT_TRUE(IsPermutation({0})); + EXPECT_FALSE(IsPermutation({-3})); + EXPECT_TRUE(IsPermutation({0, 1})); + EXPECT_FALSE(IsPermutation({1, 1})); + EXPECT_TRUE(IsPermutation({1, 0})); + EXPECT_TRUE(IsPermutation({3, 1, 0, 2})); + EXPECT_FALSE(IsPermutation({3, 0, 2})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index ba7bf93bee2..6c845730237 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -2189,7 +2189,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( for (auto dim : lhs_contracting_dims) { permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]); } - CHECK(IsPermutation(permutation, permutation.size())); + CHECK(IsPermutation(permutation)); auto new_lhs_contracting_dims = ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation); lhs_contracting_dims.Clear(); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index df0c4ac5521..8f978a1dd30 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -3087,7 +3087,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); - if (!IsPermutation(dimensions, operand.rank())) { + if (dimensions.size() != operand.rank() || !IsPermutation(dimensions)) { return InvalidArgument( "Transpose dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).",