Another try
This commit is contained in:
parent
49292d15d2
commit
2df6cd3acd
@ -195,16 +195,23 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
|
||||
// 2. permutation.size() == input.size().
|
||||
template <template <typename...> class C, typename T>
|
||||
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
|
||||
C<T> input_) {
|
||||
tensorflow::gtl::ArraySlice<T> input(input_);
|
||||
CHECK(IsPermutation(permutation, input.size()));
|
||||
std::vector<T> output(input.size());
|
||||
C<T> input) {
|
||||
tensorflow::gtl::ArraySlice<T> data(input);
|
||||
CHECK(IsPermutation(permutation, data.size()));
|
||||
std::vector<T> output(data.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||
output[permutation[i]] = input[i];
|
||||
output[permutation[i]] = data[i];
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
// Override of the above that works around compile failures with vectors.
|
||||
template <typename T>
|
||||
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
|
||||
const std::vector<T>& input) {
|
||||
return Permute<std::vector, T>(permutation, input);
|
||||
}
|
||||
|
||||
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
|
||||
std::vector<int64> InversePermutation(
|
||||
tensorflow::gtl::ArraySlice<int64> input_permutation);
|
||||
|
Loading…
Reference in New Issue
Block a user