[XLA] Split permutation utilities from xla/util.* into a new xla/permutation_util.*

Refactoring only, NFC intended.

PiperOrigin-RevId: 356624906
Change-Id: I46c25640331f9c4b27c6c82d60c9f2edda3b2833
This commit is contained in:
Peter Hawkins 2021-02-09 16:57:08 -08:00 committed by TensorFlower Gardener
parent fc108e0865
commit 231d3ab44d
28 changed files with 172 additions and 71 deletions

View File

@ -30,6 +30,7 @@ cc_library(
hdrs = ["conv_emitter.h"],
deps = [
":conv_emitter_transforms",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/window_util.h"

View File

@ -251,6 +251,19 @@ cc_library(
],
)
cc_library(
name = "permutation_util",
srcs = ["permutation_util.cc"],
hdrs = ["permutation_util.h"],
deps = [
":types",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "protobuf_util",
srcs = ["protobuf_util.cc"],
@ -312,6 +325,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
":permutation_util",
":protobuf_util",
":status",
":status_macros",
@ -354,6 +368,7 @@ tf_cc_test(
name = "shape_util_test",
srcs = ["shape_util_test.cc"],
deps = [
":permutation_util",
":shape_util",
":status_macros",
":test",
@ -431,6 +446,7 @@ cc_library(
":array2d",
":array3d",
":array4d",
":permutation_util",
":shape_util",
":status_macros",
":types",

View File

@ -224,6 +224,7 @@ cc_library(
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"

View File

@ -0,0 +1,61 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/permutation_util.h"
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
namespace xla {
bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
if (rank != permutation.size()) {
return false;
}
absl::InlinedVector<int64, 8> trivial_permutation(rank);
absl::c_iota(trivial_permutation, 0);
return absl::c_is_permutation(permutation, trivial_permutation);
}
std::vector<int64> InversePermutation(
absl::Span<const int64> input_permutation) {
DCHECK(IsPermutation(input_permutation, input_permutation.size()));
std::vector<int64> output_permutation(input_permutation.size(), -1);
for (size_t i = 0; i < input_permutation.size(); ++i) {
output_permutation.at(input_permutation.at(i)) = i;
}
return output_permutation;
}
std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
absl::Span<const int64> p2) {
CHECK_EQ(p1.size(), p2.size());
std::vector<int64> output;
for (size_t i = 0; i < p1.size(); ++i) {
output.push_back(p1.at(p2.at(i)));
}
return output;
}
bool IsIdentityPermutation(absl::Span<const int64> permutation) {
for (int64 i = 0; i < permutation.size(); ++i) {
if (permutation[i] != i) {
return false;
}
}
return true;
}
} // namespace xla

View File

@ -0,0 +1,64 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Utilities for working with permutations.
#ifndef TENSORFLOW_COMPILER_XLA_PERMUTATION_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_PERMUTATION_UTIL_H_
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
// Checks whether permutation is a permutation of the [0, rank) integer range.
bool IsPermutation(absl::Span<const int64> permutation, int64 rank);
// Applies `permutation` on `input` and returns the permuted array.
// For each i, output[permutation[i]] = input[i].
//
// Precondition:
// 1. `permutation` is a permutation of 0..permutation.size()-1.
// 2. permutation.size() == input.size().
template <typename Container>
std::vector<typename Container::value_type> Permute(
absl::Span<const int64> permutation, const Container& input) {
using T = typename Container::value_type;
absl::Span<const T> data(input);
CHECK(IsPermutation(permutation, data.size()));
std::vector<T> output(data.size());
for (size_t i = 0; i < permutation.size(); ++i) {
output[permutation[i]] = data[i];
}
return output;
}
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
std::vector<int64> InversePermutation(
absl::Span<const int64> input_permutation);
// Composes two permutations: output[i] = p1[p2[i]].
std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
absl::Span<const int64> p2);
// Returns true iff permutation == {0, 1, 2, ...}.
bool IsIdentityPermutation(absl::Span<const int64> permutation);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PERMUTATION_UTIL_H_

View File

@ -215,6 +215,7 @@ cc_library(
hdrs = ["shape_inference.h"],
deps = [
":hlo",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -381,6 +382,7 @@ tf_cc_test(
":hlo_element_type_converter",
":hlo_evaluator",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@ -2036,6 +2038,7 @@ cc_library(
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
@ -2627,6 +2630,7 @@ cc_library(
deps = [
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
@ -2744,6 +2748,7 @@ cc_library(
deps = [
":hlo_pass",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
@ -3533,6 +3538,7 @@ cc_library(
":logical_buffer",
":tuple_points_to_analysis",
":tuple_simplifier",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -3785,6 +3791,7 @@ cc_library(
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
@ -4110,6 +4117,7 @@ tf_cc_test(
":pattern_matcher",
":pattern_matcher_gmock",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"

View File

@ -987,6 +987,7 @@ cc_library(
":cpu_runtime",
":ir_emission_utils",
":target_machine_features",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"

View File

@ -275,6 +275,7 @@ cc_library(
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -979,6 +980,7 @@ cc_library(
":backend_configs_cc",
":ir_emission_utils",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"

View File

@ -58,6 +58,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"

View File

@ -94,6 +94,7 @@ cc_library(
hdrs = ["ir_array.h"],
deps = [
":llvm_util",
"//tensorflow/compiler/xla:permutation_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"

View File

@ -110,44 +110,6 @@ string Reindent(absl::string_view original,
});
}
bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
if (rank != permutation.size()) {
return false;
}
absl::InlinedVector<int64, 8> trivial_permutation(rank);
absl::c_iota(trivial_permutation, 0);
return absl::c_is_permutation(permutation, trivial_permutation);
}
std::vector<int64> InversePermutation(
absl::Span<const int64> input_permutation) {
DCHECK(IsPermutation(input_permutation, input_permutation.size()));
std::vector<int64> output_permutation(input_permutation.size(), -1);
for (size_t i = 0; i < input_permutation.size(); ++i) {
output_permutation.at(input_permutation.at(i)) = i;
}
return output_permutation;
}
std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
absl::Span<const int64> p2) {
CHECK_EQ(p1.size(), p2.size());
std::vector<int64> output;
for (size_t i = 0; i < p1.size(); ++i) {
output.push_back(p1.at(p2.at(i)));
}
return output;
}
bool IsIdentityPermutation(absl::Span<const int64> permutation) {
for (int64 i = 0; i < permutation.size(); ++i) {
if (permutation[i] != i) {
return false;
}
}
return true;
}
string RoundTripFpToString(tensorflow::bfloat16 value) {
return absl::StrFormat("%.4g", static_cast<float>(value));
}

View File

@ -328,39 +328,6 @@ Status ResourceExhaustedStrCat(Args&&... concat) {
// uniformly replaced with "indentation".
string Reindent(absl::string_view original, absl::string_view indentation);
// Checks whether permutation is a permutation of the [0, rank) integer range.
bool IsPermutation(absl::Span<const int64> permutation, int64 rank);
// Applies `permutation` on `input` and returns the permuted array.
// For each i, output[permutation[i]] = input[i].
//
// Precondition:
// 1. `permutation` is a permutation of 0..permutation.size()-1.
// 2. permutation.size() == input.size().
template <typename Container>
std::vector<typename Container::value_type> Permute(
absl::Span<const int64> permutation, const Container& input) {
using T = typename Container::value_type;
absl::Span<const T> data(input);
CHECK(IsPermutation(permutation, data.size()));
std::vector<T> output(data.size());
for (size_t i = 0; i < permutation.size(); ++i) {
output[permutation[i]] = data[i];
}
return output;
}
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
std::vector<int64> InversePermutation(
absl::Span<const int64> input_permutation);
// Composes two permutations: output[i] = p1[p2[i]].
std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
absl::Span<const int64> p2);
// Returns true iff permutation == {0, 1, 2, ...}.
bool IsIdentityPermutation(absl::Span<const int64> permutation);
template <typename Container>
int64 PositionInContainer(const Container& container, int64 value) {
return std::distance(container.begin(), absl::c_find(container, value));