[XLA] Change xla::IsPermutation() not to check the size, and make its implementation O(n).

xla::IsPermutation() does two things: it checks a permutation is actually a permutation, and checks it has a particular size. There's no need to do the latter in xla::IsPermutation because it can be done just as easily by checking .size(), and a number of callers don't need the size check.

In addition, change the implementation to be O(n) time; there's no need to call std::is_permutation() which guarantees nothing better than O(n^2) behavior.

PiperOrigin-RevId: 356643750
Change-Id: Ib177287556b5c22266185fa813798a4ad5392054
This commit is contained in:
Peter Hawkins 2021-02-09 19:06:42 -08:00 committed by TensorFlower Gardener
parent 02e6913295
commit 9fcf66e16d
7 changed files with 75 additions and 25 deletions

View File

@ -251,6 +251,18 @@ cc_library(
],
)
tf_cc_test(
name = "util_test",
srcs = ["util_test.cc"],
deps = [
":test",
":types",
":util",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:bfloat16",
],
)
cc_library(
name = "permutation_util",
srcs = ["permutation_util.cc"],
@ -264,6 +276,16 @@ cc_library(
],
)
tf_cc_test(
name = "permutation_util_test",
srcs = ["permutation_util_test.cc"],
deps = [
":permutation_util",
":test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "protobuf_util",
srcs = ["protobuf_util.cc"],
@ -282,18 +304,6 @@ cc_library(
],
)
tf_cc_test(
name = "util_test",
srcs = ["util_test.cc"],
deps = [
":test",
":types",
":util",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:bfloat16",
],
)
tf_cc_test(
name = "iterator_util_test",
srcs = ["iterator_util_test.cc"],

View File

@ -831,7 +831,7 @@ StatusOr<Literal> LiteralBase::Reshape(
Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, shape().rank()))
CHECK(shape().rank() == permutation.size() && IsPermutation(permutation))
<< "Given permutation is not a permutation of dimension numbers";
// To transpose the array, we just permute the dimensions and layout, and
// do a straight memory copy of the raw data set.

View File

@ -20,18 +20,20 @@ limitations under the License.
namespace xla {
bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
if (rank != permutation.size()) {
return false;
bool IsPermutation(absl::Span<const int64> permutation) {
absl::InlinedVector<bool, 8> seen(permutation.size(), false);
for (int64 p : permutation) {
if (p < 0 || p >= permutation.size() || seen[p]) {
return false;
}
seen[p] = true;
}
absl::InlinedVector<int64, 8> trivial_permutation(rank);
absl::c_iota(trivial_permutation, 0);
return absl::c_is_permutation(permutation, trivial_permutation);
return true;
}
std::vector<int64> InversePermutation(
absl::Span<const int64> input_permutation) {
DCHECK(IsPermutation(input_permutation, input_permutation.size()));
DCHECK(IsPermutation(input_permutation));
std::vector<int64> output_permutation(input_permutation.size(), -1);
for (size_t i = 0; i < input_permutation.size(); ++i) {
output_permutation.at(input_permutation.at(i)) = i;
@ -43,6 +45,7 @@ std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
absl::Span<const int64> p2) {
CHECK_EQ(p1.size(), p2.size());
std::vector<int64> output;
output.reserve(p1.size());
for (size_t i = 0; i < p1.size(); ++i) {
output.push_back(p1.at(p2.at(i)));
}

View File

@ -26,8 +26,9 @@ limitations under the License.
namespace xla {
// Checks whether permutation is a permutation of the [0, rank) integer range.
bool IsPermutation(absl::Span<const int64> permutation, int64 rank);
// Returns true if permutation is a permutation of the integers
// [0, permutation.size()).
bool IsPermutation(absl::Span<const int64> permutation);
// Applies `permutation` on `input` and returns the permuted array.
// For each i, output[permutation[i]] = input[i].
@ -40,7 +41,8 @@ std::vector<typename Container::value_type> Permute(
absl::Span<const int64> permutation, const Container& input) {
using T = typename Container::value_type;
absl::Span<const T> data(input);
CHECK(IsPermutation(permutation, data.size()));
CHECK_EQ(permutation.size(), data.size());
CHECK(IsPermutation(permutation));
std::vector<T> output(data.size());
for (size_t i = 0; i < permutation.size(); ++i) {
output[permutation[i]] = data[i];

View File

@ -0,0 +1,35 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/test.h"
namespace xla {
namespace {
TEST(PermutationUtilTest, IsPermutation) {
EXPECT_TRUE(IsPermutation({}));
EXPECT_TRUE(IsPermutation({0}));
EXPECT_FALSE(IsPermutation({-3}));
EXPECT_TRUE(IsPermutation({0, 1}));
EXPECT_FALSE(IsPermutation({1, 1}));
EXPECT_TRUE(IsPermutation({1, 0}));
EXPECT_TRUE(IsPermutation({3, 1, 0, 2}));
EXPECT_FALSE(IsPermutation({3, 0, 2}));
}
} // namespace
} // namespace xla

View File

@ -2189,7 +2189,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
for (auto dim : lhs_contracting_dims) {
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
}
CHECK(IsPermutation(permutation, permutation.size()));
CHECK(IsPermutation(permutation));
auto new_lhs_contracting_dims =
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
lhs_contracting_dims.Clear();

View File

@ -3087,7 +3087,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand, absl::Span<const int64> dimensions) {
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
if (!IsPermutation(dimensions, operand.rank())) {
if (dimensions.size() != operand.rank() || !IsPermutation(dimensions)) {
return InvalidArgument(
"Transpose dimensions [%s] are not a permutation of the operand "
"dimensions (operand shape is %s).",