[XLA] Split permutation utilities from xla/util.* into a new xla/permutation_util.*
Refactoring only, NFC intended. PiperOrigin-RevId: 356624906 Change-Id: I46c25640331f9c4b27c6c82d60c9f2edda3b2833
This commit is contained in:
parent
fc108e0865
commit
231d3ab44d
@ -30,6 +30,7 @@ cc_library(
|
||||
hdrs = ["conv_emitter.h"],
|
||||
deps = [
|
||||
":conv_emitter_transforms",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
|
||||
|
@ -251,6 +251,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "permutation_util",
|
||||
srcs = ["permutation_util.cc"],
|
||||
hdrs = ["permutation_util.h"],
|
||||
deps = [
|
||||
":types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "protobuf_util",
|
||||
srcs = ["protobuf_util.cc"],
|
||||
@ -312,6 +325,7 @@ cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":permutation_util",
|
||||
":protobuf_util",
|
||||
":status",
|
||||
":status_macros",
|
||||
@ -354,6 +368,7 @@ tf_cc_test(
|
||||
name = "shape_util_test",
|
||||
srcs = ["shape_util_test.cc"],
|
||||
deps = [
|
||||
":permutation_util",
|
||||
":shape_util",
|
||||
":status_macros",
|
||||
":test",
|
||||
@ -431,6 +446,7 @@ cc_library(
|
||||
":array2d",
|
||||
":array3d",
|
||||
":array4d",
|
||||
":permutation_util",
|
||||
":shape_util",
|
||||
":status_macros",
|
||||
":types",
|
||||
|
@ -224,6 +224,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:execution_options_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
61
tensorflow/compiler/xla/permutation_util.cc
Normal file
61
tensorflow/compiler/xla/permutation_util.cc
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
|
||||
if (rank != permutation.size()) {
|
||||
return false;
|
||||
}
|
||||
absl::InlinedVector<int64, 8> trivial_permutation(rank);
|
||||
absl::c_iota(trivial_permutation, 0);
|
||||
return absl::c_is_permutation(permutation, trivial_permutation);
|
||||
}
|
||||
|
||||
std::vector<int64> InversePermutation(
|
||||
absl::Span<const int64> input_permutation) {
|
||||
DCHECK(IsPermutation(input_permutation, input_permutation.size()));
|
||||
std::vector<int64> output_permutation(input_permutation.size(), -1);
|
||||
for (size_t i = 0; i < input_permutation.size(); ++i) {
|
||||
output_permutation.at(input_permutation.at(i)) = i;
|
||||
}
|
||||
return output_permutation;
|
||||
}
|
||||
|
||||
std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
|
||||
absl::Span<const int64> p2) {
|
||||
CHECK_EQ(p1.size(), p2.size());
|
||||
std::vector<int64> output;
|
||||
for (size_t i = 0; i < p1.size(); ++i) {
|
||||
output.push_back(p1.at(p2.at(i)));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
bool IsIdentityPermutation(absl::Span<const int64> permutation) {
|
||||
for (int64 i = 0; i < permutation.size(); ++i) {
|
||||
if (permutation[i] != i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace xla
|
64
tensorflow/compiler/xla/permutation_util.h
Normal file
64
tensorflow/compiler/xla/permutation_util.h
Normal file
@ -0,0 +1,64 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Utilities for working with permutations.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PERMUTATION_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PERMUTATION_UTIL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Checks whether permutation is a permutation of the [0, rank) integer range.
|
||||
bool IsPermutation(absl::Span<const int64> permutation, int64 rank);
|
||||
|
||||
// Applies `permutation` on `input` and returns the permuted array.
|
||||
// For each i, output[permutation[i]] = input[i].
|
||||
//
|
||||
// Precondition:
|
||||
// 1. `permutation` is a permutation of 0..permutation.size()-1.
|
||||
// 2. permutation.size() == input.size().
|
||||
template <typename Container>
|
||||
std::vector<typename Container::value_type> Permute(
|
||||
absl::Span<const int64> permutation, const Container& input) {
|
||||
using T = typename Container::value_type;
|
||||
absl::Span<const T> data(input);
|
||||
CHECK(IsPermutation(permutation, data.size()));
|
||||
std::vector<T> output(data.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||
output[permutation[i]] = data[i];
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
|
||||
std::vector<int64> InversePermutation(
|
||||
absl::Span<const int64> input_permutation);
|
||||
|
||||
// Composes two permutations: output[i] = p1[p2[i]].
|
||||
std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
|
||||
absl::Span<const int64> p2);
|
||||
|
||||
// Returns true iff permutation == {0, 1, 2, ...}.
|
||||
bool IsIdentityPermutation(absl::Span<const int64> permutation);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PERMUTATION_UTIL_H_
|
@ -215,6 +215,7 @@ cc_library(
|
||||
hdrs = ["shape_inference.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -381,6 +382,7 @@ tf_cc_test(
|
||||
":hlo_element_type_converter",
|
||||
":hlo_evaluator",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -2036,6 +2038,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -2627,6 +2630,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
@ -2744,6 +2748,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -3533,6 +3538,7 @@ cc_library(
|
||||
":logical_buffer",
|
||||
":tuple_points_to_analysis",
|
||||
":tuple_simplifier",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -3785,6 +3791,7 @@ cc_library(
|
||||
":hlo_pass",
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -4110,6 +4117,7 @@ tf_cc_test(
|
||||
":pattern_matcher",
|
||||
":pattern_matcher_gmock",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/overflow_util.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
|
@ -987,6 +987,7 @@ cc_library(
|
||||
":cpu_runtime",
|
||||
":ir_emission_utils",
|
||||
":target_machine_features",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
|
@ -275,6 +275,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -979,6 +980,7 @@ cc_library(
|
||||
":backend_configs_cc",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
|
@ -58,6 +58,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/reference_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
|
||||
|
@ -94,6 +94,7 @@ cc_library(
|
||||
hdrs = ["ir_array.h"],
|
||||
deps = [
|
||||
":llvm_util",
|
||||
"//tensorflow/compiler/xla:permutation_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/index_util.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/overflow_util.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/permutation_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
|
@ -110,44 +110,6 @@ string Reindent(absl::string_view original,
|
||||
});
|
||||
}
|
||||
|
||||
bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
|
||||
if (rank != permutation.size()) {
|
||||
return false;
|
||||
}
|
||||
absl::InlinedVector<int64, 8> trivial_permutation(rank);
|
||||
absl::c_iota(trivial_permutation, 0);
|
||||
return absl::c_is_permutation(permutation, trivial_permutation);
|
||||
}
|
||||
|
||||
std::vector<int64> InversePermutation(
|
||||
absl::Span<const int64> input_permutation) {
|
||||
DCHECK(IsPermutation(input_permutation, input_permutation.size()));
|
||||
std::vector<int64> output_permutation(input_permutation.size(), -1);
|
||||
for (size_t i = 0; i < input_permutation.size(); ++i) {
|
||||
output_permutation.at(input_permutation.at(i)) = i;
|
||||
}
|
||||
return output_permutation;
|
||||
}
|
||||
|
||||
std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
|
||||
absl::Span<const int64> p2) {
|
||||
CHECK_EQ(p1.size(), p2.size());
|
||||
std::vector<int64> output;
|
||||
for (size_t i = 0; i < p1.size(); ++i) {
|
||||
output.push_back(p1.at(p2.at(i)));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
bool IsIdentityPermutation(absl::Span<const int64> permutation) {
|
||||
for (int64 i = 0; i < permutation.size(); ++i) {
|
||||
if (permutation[i] != i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
string RoundTripFpToString(tensorflow::bfloat16 value) {
|
||||
return absl::StrFormat("%.4g", static_cast<float>(value));
|
||||
}
|
||||
|
@ -328,39 +328,6 @@ Status ResourceExhaustedStrCat(Args&&... concat) {
|
||||
// uniformly replaced with "indentation".
|
||||
string Reindent(absl::string_view original, absl::string_view indentation);
|
||||
|
||||
// Checks whether permutation is a permutation of the [0, rank) integer range.
|
||||
bool IsPermutation(absl::Span<const int64> permutation, int64 rank);
|
||||
|
||||
// Applies `permutation` on `input` and returns the permuted array.
|
||||
// For each i, output[permutation[i]] = input[i].
|
||||
//
|
||||
// Precondition:
|
||||
// 1. `permutation` is a permutation of 0..permutation.size()-1.
|
||||
// 2. permutation.size() == input.size().
|
||||
template <typename Container>
|
||||
std::vector<typename Container::value_type> Permute(
|
||||
absl::Span<const int64> permutation, const Container& input) {
|
||||
using T = typename Container::value_type;
|
||||
absl::Span<const T> data(input);
|
||||
CHECK(IsPermutation(permutation, data.size()));
|
||||
std::vector<T> output(data.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||
output[permutation[i]] = data[i];
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
|
||||
std::vector<int64> InversePermutation(
|
||||
absl::Span<const int64> input_permutation);
|
||||
|
||||
// Composes two permutations: output[i] = p1[p2[i]].
|
||||
std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
|
||||
absl::Span<const int64> p2);
|
||||
|
||||
// Returns true iff permutation == {0, 1, 2, ...}.
|
||||
bool IsIdentityPermutation(absl::Span<const int64> permutation);
|
||||
|
||||
template <typename Container>
|
||||
int64 PositionInContainer(const Container& container, int64 value) {
|
||||
return std::distance(container.begin(), absl::c_find(container, value));
|
||||
|
Loading…
Reference in New Issue
Block a user