Another try
This commit is contained in:
parent
49292d15d2
commit
2df6cd3acd
@ -195,16 +195,23 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
|
|||||||
// 2. permutation.size() == input.size().
|
// 2. permutation.size() == input.size().
|
||||||
template <template <typename...> class C, typename T>
|
template <template <typename...> class C, typename T>
|
||||||
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
|
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
|
||||||
C<T> input_) {
|
C<T> input) {
|
||||||
tensorflow::gtl::ArraySlice<T> input(input_);
|
tensorflow::gtl::ArraySlice<T> data(input);
|
||||||
CHECK(IsPermutation(permutation, input.size()));
|
CHECK(IsPermutation(permutation, data.size()));
|
||||||
std::vector<T> output(input.size());
|
std::vector<T> output(data.size());
|
||||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||||
output[permutation[i]] = input[i];
|
output[permutation[i]] = data[i];
|
||||||
}
|
}
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Override of the above that works around compile failures with vectors.
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
|
||||||
|
const std::vector<T>& input) {
|
||||||
|
return Permute<std::vector, T>(permutation, input);
|
||||||
|
}
|
||||||
|
|
||||||
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
|
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
|
||||||
std::vector<int64> InversePermutation(
|
std::vector<int64> InversePermutation(
|
||||||
tensorflow::gtl::ArraySlice<int64> input_permutation);
|
tensorflow::gtl::ArraySlice<int64> input_permutation);
|
||||||
|
Loading…
Reference in New Issue
Block a user