diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 42d5c1d1550..32aa81d1f42 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -195,16 +195,23 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank); // 2. permutation.size() == input.size(). template <template <typename...> class C, typename T> std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation, - C<T> input_) { - tensorflow::gtl::ArraySlice<T> input(input_); - CHECK(IsPermutation(permutation, input.size())); - std::vector<T> output(input.size()); + C<T> input) { + tensorflow::gtl::ArraySlice<T> data(input); + CHECK(IsPermutation(permutation, data.size())); + std::vector<T> output(data.size()); for (size_t i = 0; i < permutation.size(); ++i) { - output[permutation[i]] = input[i]; + output[permutation[i]] = data[i]; } return output; } +// Override of the above that works around compile failures with vectors. +template <typename T> +std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation, + const std::vector<T>& input) { + return Permute<std::vector, T>(permutation, input); +} + // Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i. std::vector<int64> InversePermutation( tensorflow::gtl::ArraySlice<int64> input_permutation);