Another try

This commit is contained in:
Todd Wang 2017-06-21 17:51:55 -07:00 committed by GitHub
parent 49292d15d2
commit 2df6cd3acd

View File

@ -195,16 +195,23 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
// 2. permutation.size() == input.size().
template <template <typename...> class C, typename T>
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
C<T> input_) {
tensorflow::gtl::ArraySlice<T> input(input_);
CHECK(IsPermutation(permutation, input.size()));
std::vector<T> output(input.size());
C<T> input) {
tensorflow::gtl::ArraySlice<T> data(input);
CHECK(IsPermutation(permutation, data.size()));
std::vector<T> output(data.size());
for (size_t i = 0; i < permutation.size(); ++i) {
output[permutation[i]] = input[i];
output[permutation[i]] = data[i];
}
return output;
}
// Override of the above that works around compile failures with vectors.
template <typename T>
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
const std::vector<T>& input) {
return Permute<std::vector, T>(permutation, input);
}
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
std::vector<int64> InversePermutation(
tensorflow::gtl::ArraySlice<int64> input_permutation);