[XLA] Change xla::IsPermutation() not to check the size, and make its implementation O(n).
xla::IsPermutation() does two things: it checks a permutation is actually a permutation, and checks it has a particular size. There's no need to do the latter in xla::IsPermutation because it can be done just as easily by checking .size(), and a number of callers don't need the size check. In addition, change the implementation to be O(n) time; there's no need to call std::is_permutation() which guarantees nothing better than O(n^2) behavior. PiperOrigin-RevId: 356643750 Change-Id: Ib177287556b5c22266185fa813798a4ad5392054
This commit is contained in:
parent
02e6913295
commit
9fcf66e16d
@ -251,6 +251,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "util_test",
|
||||
srcs = ["util_test.cc"],
|
||||
deps = [
|
||||
":test",
|
||||
":types",
|
||||
":util",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:bfloat16",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "permutation_util",
|
||||
srcs = ["permutation_util.cc"],
|
||||
@ -264,6 +276,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "permutation_util_test",
|
||||
srcs = ["permutation_util_test.cc"],
|
||||
deps = [
|
||||
":permutation_util",
|
||||
":test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "protobuf_util",
|
||||
srcs = ["protobuf_util.cc"],
|
||||
@ -282,18 +304,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "util_test",
|
||||
srcs = ["util_test.cc"],
|
||||
deps = [
|
||||
":test",
|
||||
":types",
|
||||
":util",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:bfloat16",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "iterator_util_test",
|
||||
srcs = ["iterator_util_test.cc"],
|
||||
|
@ -831,7 +831,7 @@ StatusOr<Literal> LiteralBase::Reshape(
|
||||
|
||||
Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
|
||||
CHECK(shape().IsArray()) << "Tuple is not supported for transpose";
|
||||
CHECK(IsPermutation(permutation, shape().rank()))
|
||||
CHECK(shape().rank() == permutation.size() && IsPermutation(permutation))
|
||||
<< "Given permutation is not a permutation of dimension numbers";
|
||||
// To transpose the array, we just permute the dimensions and layout, and
|
||||
// do a straight memory copy of the raw data set.
|
||||
|
@ -20,18 +20,20 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
|
||||
if (rank != permutation.size()) {
|
||||
return false;
|
||||
bool IsPermutation(absl::Span<const int64> permutation) {
|
||||
absl::InlinedVector<bool, 8> seen(permutation.size(), false);
|
||||
for (int64 p : permutation) {
|
||||
if (p < 0 || p >= permutation.size() || seen[p]) {
|
||||
return false;
|
||||
}
|
||||
seen[p] = true;
|
||||
}
|
||||
absl::InlinedVector<int64, 8> trivial_permutation(rank);
|
||||
absl::c_iota(trivial_permutation, 0);
|
||||
return absl::c_is_permutation(permutation, trivial_permutation);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64> InversePermutation(
|
||||
absl::Span<const int64> input_permutation) {
|
||||
DCHECK(IsPermutation(input_permutation, input_permutation.size()));
|
||||
DCHECK(IsPermutation(input_permutation));
|
||||
std::vector<int64> output_permutation(input_permutation.size(), -1);
|
||||
for (size_t i = 0; i < input_permutation.size(); ++i) {
|
||||
output_permutation.at(input_permutation.at(i)) = i;
|
||||
@ -43,6 +45,7 @@ std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
|
||||
absl::Span<const int64> p2) {
|
||||
CHECK_EQ(p1.size(), p2.size());
|
||||
std::vector<int64> output;
|
||||
output.reserve(p1.size());
|
||||
for (size_t i = 0; i < p1.size(); ++i) {
|
||||
output.push_back(p1.at(p2.at(i)));
|
||||
}
|
||||
|
@ -26,8 +26,9 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Checks whether permutation is a permutation of the [0, rank) integer range.
|
||||
bool IsPermutation(absl::Span<const int64> permutation, int64 rank);
|
||||
// Returns true if permutation is a permutation of the integers
|
||||
// [0, permutation.size()).
|
||||
bool IsPermutation(absl::Span<const int64> permutation);
|
||||
|
||||
// Applies `permutation` on `input` and returns the permuted array.
|
||||
// For each i, output[permutation[i]] = input[i].
|
||||
@ -40,7 +41,8 @@ std::vector<typename Container::value_type> Permute(
|
||||
absl::Span<const int64> permutation, const Container& input) {
|
||||
using T = typename Container::value_type;
|
||||
absl::Span<const T> data(input);
|
||||
CHECK(IsPermutation(permutation, data.size()));
|
||||
CHECK_EQ(permutation.size(), data.size());
|
||||
CHECK(IsPermutation(permutation));
|
||||
std::vector<T> output(data.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||
output[permutation[i]] = data[i];
|
||||
|
35
tensorflow/compiler/xla/permutation_util_test.cc
Normal file
35
tensorflow/compiler/xla/permutation_util_test.cc
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(PermutationUtilTest, IsPermutation) {
|
||||
EXPECT_TRUE(IsPermutation({}));
|
||||
EXPECT_TRUE(IsPermutation({0}));
|
||||
EXPECT_FALSE(IsPermutation({-3}));
|
||||
EXPECT_TRUE(IsPermutation({0, 1}));
|
||||
EXPECT_FALSE(IsPermutation({1, 1}));
|
||||
EXPECT_TRUE(IsPermutation({1, 0}));
|
||||
EXPECT_TRUE(IsPermutation({3, 1, 0, 2}));
|
||||
EXPECT_FALSE(IsPermutation({3, 0, 2}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -2189,7 +2189,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
||||
for (auto dim : lhs_contracting_dims) {
|
||||
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
|
||||
}
|
||||
CHECK(IsPermutation(permutation, permutation.size()));
|
||||
CHECK(IsPermutation(permutation));
|
||||
auto new_lhs_contracting_dims =
|
||||
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
|
||||
lhs_contracting_dims.Clear();
|
||||
|
@ -3087,7 +3087,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
const Shape& operand, absl::Span<const int64> dimensions) {
|
||||
TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
|
||||
|
||||
if (!IsPermutation(dimensions, operand.rank())) {
|
||||
if (dimensions.size() != operand.rank() || !IsPermutation(dimensions)) {
|
||||
return InvalidArgument(
|
||||
"Transpose dimensions [%s] are not a permutation of the operand "
|
||||
"dimensions (operand shape is %s).",
|
||||
|
Loading…
Reference in New Issue
Block a user