diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 04f7bd0d472..9924cbb49de 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2805,7 +2805,7 @@ tf_kernel_library( tf_kernel_library( name = "topk_op", prefix = "topk_op", - deps = NN_DEPS, + deps = NN_DEPS + if_cuda(["@cub_archive//:cub"]), ) tf_kernel_library( @@ -4164,6 +4164,7 @@ filegroup( "tile_functor.h", "tile_ops_cpu_impl.h", "tile_ops_impl.h", + "topk_op.h", "training_op_helpers.h", "training_ops.h", "transpose_functor.h", diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc index 5c89eaef5fe..05fee56335c 100644 --- a/tensorflow/core/kernels/topk_op.cc +++ b/tensorflow/core/kernels/topk_op.cc @@ -17,6 +17,8 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "tensorflow/core/kernels/topk_op.h" + #include #include #include @@ -25,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/top_n.h" #include "tensorflow/core/util/work_sharder.h" @@ -33,7 +36,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template +template class TopK : public OpKernel { public: explicit TopK(OpKernelConstruction* context) : OpKernel(context) { @@ -82,7 +85,25 @@ class TopK : public OpKernel { auto values = values_out->flat_inner_dims(); auto indices = indices_out->flat_inner_dims(); + Status s = functor::TopKFunctor::Compute( + context, sorted_, k, input, num_rows, num_cols, values, indices); + OP_REQUIRES_OK(context, s); + } + private: + int k_; + bool sorted_; +}; + +namespace functor { + +template +struct TopKFunctor { + static EIGEN_ALWAYS_INLINE Status + Compute(OpKernelContext* context, bool sorted, int k, + const typename TTypes::ConstTensor& input, const int64 num_rows, + const int64 num_cols, typename TTypes::Tensor values, + typename TTypes::Tensor indices) { const CPUDevice& d = context->eigen_device(); // Special case for k == 1. @@ -108,7 +129,7 @@ class TopK : public OpKernel { } } - return; + return Status::OK(); } auto SortIndices = [&, context](int start_batch, int limit_batch) { @@ -117,7 +138,6 @@ class TopK : public OpKernel { const auto comp = [input_data](const int32 a, const int32 b) { return input_data[a] > input_data[b]; }; - gtl::TopN filter(k, comp); // TODO(ebrevdo): For large k < num_cols, instead of using // TopN, it may be faster to create a temporary vector of // values 0..num_cols - 1 and then use std::partial_sort_copy @@ -130,13 +150,14 @@ class TopK : public OpKernel { std::sort(&indices(b, 0), &indices(b, k), comp); } else { // Use the TopN heap object to sort. + gtl::TopN filter(k, comp); filter.reserve(num_cols); for (int32 c = 0; c < num_cols; ++c) { filter.push(c); } int32 i = 0; - if (sorted_) { + if (sorted) { std::unique_ptr> top_k(filter.Extract()); for (auto top_k_it = top_k->begin(); top_k_it != top_k->end(); ++top_k_it, ++i) { @@ -158,35 +179,75 @@ class TopK : public OpKernel { // Guesstimate of cost; 4*N*log(K) where N == num_cols. // If K == N, assume the cost is N*log(K + 1). - const int64 cmp_cost = 3 * Eigen::TensorOpCost::AddCost() + - Eigen::TensorOpCost::AddCost(); - const int64 base_cost = + const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost() + + Eigen::TensorOpCost::AddCost(); + const double base_cost = cmp_cost * - static_cast(num_cols * - Eigen::numext::log2(static_cast(k + 1))); - const int64 sort_cost = (k == num_cols) ? base_cost : 4 * base_cost; - const int64 copy_cost = 2 * k * Eigen::TensorOpCost::AddCost(); - const int64 total_cost = sort_cost + copy_cost; + static_cast(num_cols * + Eigen::numext::log2(static_cast(k + 1))); + const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost; + const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost(); + const double total_cost = sort_cost + copy_cost; + const int64 final_cost = (total_cost >= static_cast(kint64max)) + ? kint64max + : static_cast(total_cost); auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); Shard(worker_threads.num_threads, worker_threads.workers, num_rows, - total_cost, SortIndices); - } + final_cost, SortIndices); - private: - int k_; - bool sorted_; + return Status::OK(); + } }; -#define REGISTER_KERNELS_NAME(name, type) \ - REGISTER_KERNEL_BUILDER( \ - Name(#name).Device(DEVICE_CPU).TypeConstraint("T"), TopK) +} // namespace functor + +#define REGISTER_KERNELS_NAME(name, type) \ + REGISTER_KERNEL_BUILDER( \ + Name(#name).Device(DEVICE_CPU).TypeConstraint("T"), \ + TopK) #define REGISTER_KERNELS(type) \ REGISTER_KERNELS_NAME(TopK, type); \ REGISTER_KERNELS_NAME(TopKV2, type) TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); -#undef REGISTER_KERNELS_TO_NAME +#undef REGISTER_KERNELS_NAME #undef REGISTER_KERNELS -} // namespace tensorflow +#ifdef GOOGLE_CUDA + +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + Status TopKFunctor::Compute( \ + OpKernelContext* context, bool sorted, int k, \ + const typename TTypes::ConstTensor& input, const int64 num_rows, \ + const int64 num_cols, typename TTypes::Tensor values, \ + typename TTypes::Tensor indices); \ + extern template struct functor::TopKFunctor; + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC); + +#undef DECLARE_GPU_SPEC + +} // namespace functor + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("TopK").Device(DEVICE_GPU).TypeConstraint("T"), \ + TopK) \ + REGISTER_KERNEL_BUILDER(Name("TopKV2") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("k"), \ + TopK) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS); +TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS); + +#undef REGISTER_KERNELS + +#endif // end GOOGLE_CUDA + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/topk_op.h b/tensorflow/core/kernels/topk_op.h new file mode 100644 index 00000000000..a53e3ec8d4f --- /dev/null +++ b/tensorflow/core/kernels/topk_op.h @@ -0,0 +1,42 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TOPK_OP_H_ +#define TENSORFLOW_TOPK_OP_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace functor { + +template +struct TopKFunctor { + static Status Compute(OpKernelContext* context, bool sorted, int k, + const typename TTypes::ConstTensor& input, + const int64 num_rows, const int64 num_cols, + typename TTypes::Tensor values, + typename TTypes::Tensor indices); +}; + +} // end namespace functor + +} // end namespace tensorflow + +#endif // TENSORFLOW_TOPK_OP_H_ diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc new file mode 100644 index 00000000000..e4b4a3cb493 --- /dev/null +++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc @@ -0,0 +1,573 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh" +#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/topk_op.h" +#include "tensorflow/core/lib/gtl/top_n.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +// Required for sorting Eigen::half +namespace cub { +template <> +struct NumericTraits + : BaseTraits { +}; +} // namespace cub + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace impl { + +enum class HeapType { kMinHeap, kMaxHeap }; +enum class PreferIndices { kLower, kHigher }; + +template +struct Entry { + int index; + T value; + + // Test-only. + static bool greater(const Entry& a, const Entry& b) { + if (a.value == b.value) { + return a.index < b.index; + } + return a.value > b.value; + } +}; + +template +struct LinearData { + typedef impl::Entry Entry; + + __device__ Entry& operator[](std::size_t index) const { return data[index]; } + + __device__ int get_index(int i) const { return data[i].index; } + __device__ T get_value(int i) const { return data[i].value; } + + Entry* const data; +}; + +template +struct IndirectLinearData { + typedef impl::Entry Entry; + + __device__ Entry& operator[](std::size_t index) const { return data[index]; } + + __device__ int get_index(int i) const { + return backing_data[data[i].index].index; + } + __device__ T get_value(int i) const { return data[i].value; } + + Entry* const data; + Entry* const backing_data; +}; + +#if GOOGLE_CUDA +template +struct StridedData { + typedef impl::Entry Entry; + + __device__ Entry& operator[](std::size_t index) const { + return data[index * blockDim.x + threadIdx.x]; + } + + __device__ int get_index(int i) const { return (*this)[i].index; } + __device__ T get_value(int i) const { return (*this)[i].value; } + + Entry* const data; +}; +#endif + +// A heap of Entry that can either work as a min-heap or as a max-heap. +template class Data, typename T> +struct IndexedHeap { + typedef typename Data::Entry Entry; + const Data data; + + __device__ bool is_above(int left, int right) { + T left_value = data.get_value(left); + T right_value = data.get_value(right); + if (left_value == right_value) { + if (preferIndices == PreferIndices::kLower) { + return data.get_index(left) < data.get_index(right); + } else { + return data.get_index(left) > data.get_index(right); + } + } + if (heapType == HeapType::kMinHeap) { + return left_value < right_value; + } else { + return left_value > right_value; + } + } + + __device__ void assign(int i, const Entry& entry) { data[i] = entry; } + + __device__ void push_up(int i) { + int child = i; + int parent; + for (; child > 0; child = parent) { + parent = (child - 1) / 2; + if (!is_above(child, parent)) { + // Heap property satisfied. + break; + } + swap(child, parent); + } + } + + __device__ void swap(int a, int b) { + auto tmp = data[b]; + data[b] = data[a]; + data[a] = tmp; + } + + __device__ void push_root_down(int k) { push_down(0, k); } + + // MAX-HEAPIFY in Cormen + __device__ void push_down(int node, int k) { + while (true) { + const int left = 2 * node + 1; + const int right = left + 1; + int smallest = node; + if (left < k && is_above(left, smallest)) { + smallest = left; + } + if (right < k && is_above(right, smallest)) { + smallest = right; + } + if (smallest == node) { + break; + } + swap(smallest, node); + node = smallest; + } + } + + // BUILD-MAX-HEAPIFY in Cormen + __device__ void build(int k) { + for (int node = (k - 1) / 2; node >= 0; node--) { + push_down(node, k); + } + } + + // HEAP-EXTRACT-MAX in Cormen + __device__ void remove_root(int k) { + data[0] = data[k - 1]; + push_root_down(k - 1); + } + + // in-place HEAPSORT in Cormen + // This method destroys the heap property. + __device__ void sort(int k) { + for (int slot = k - 1; slot > 0; slot--) { + // This is like remove_root but we insert the element at the end. + swap(slot, 0); + // Heap is now an element smaller. + push_root_down(/*k=*/slot); + } + } + + __device__ void replace_root(const Entry& entry, int k) { + data[0] = entry; + push_root_down(k); + } + + __device__ const Entry& root() { return data[0]; } +}; + +template class Data, typename T> +__device__ IndexedHeap make_indexed_heap( + typename Data::Entry* data) { + return IndexedHeap{Data{data}}; +} + +// heapTopK walks over [input, input+length) with `step_size` stride starting at +// `start_index`. +// It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to +// access elements in `heap_entries`. If sorted=true, the elements will be +// sorted at the end. +template class Data = LinearData> +__device__ void heapTopK(const T* __restrict__ input, int length, int k, + Entry* __restrict__ heap_entries, + bool sorted = false, int start_index = 0, + int step_size = 1) { + assert(k <= length); + + auto heap = + make_indexed_heap( + heap_entries); + + int heap_end_index = start_index + k * step_size; + if (heap_end_index > length) { + heap_end_index = length; + } + // Initialize the min-heap. + for (int index = start_index, slot = 0; index < heap_end_index; + index += step_size, slot++) { + heap.assign(slot, {index, input[index]}); + } + + heap.build(k); + + // Now iterate over the remaining items. + // If an item is smaller than the min element, it is not amongst the top k. + // Otherwise, replace the min element with it and push upwards. + for (int index = heap_end_index; index < length; index += step_size) { + // We prefer elements with lower indices. This is given here. + // Later elements automatically have higher indices, so can be discarded. + if (input[index] > heap.root().value) { + // This element should replace the min. + heap.replace_root({index, input[index]}, k); + } + } + + // Sort if wanted. + if (sorted) { + heap.sort(k); + } +} + +// mergeShards performs a top-k merge on `num_shards` many sorted streams that +// are sorted and stored in `entries` in a strided way: +// |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|... +// The overall top k elements are written to `top_k_values` and their indices +// to top_k_indices. +// `top_k_heap` is used as temporary storage for the merge heap. +template +__device__ void mergeShards(int num_shards, int k, + Entry* __restrict__ entries, + Entry* __restrict__ top_k_heap, T* top_k_values, + int* top_k_indices) { + // If k < num_shards, we can use a min-heap with k elements to get the top k + // of the sorted blocks. + // If k > num_shards, we can initialize a min-heap with the top element from + // each sorted block. + const int heap_size = k < num_shards ? k : num_shards; + + // Min-heap part. + { + auto min_heap = IndexedHeap{ + IndirectLinearData{top_k_heap, entries}}; + // Initialize the heap as a min-heap. + for (int slot = 0; slot < heap_size; slot++) { + min_heap.assign(slot, {slot, entries[slot].value}); + } + min_heap.build(heap_size); + + // Now perform top k with the remaining shards (if num_shards > heap_size). + for (int shard = heap_size; shard < num_shards; shard++) { + const auto entry = entries[shard]; + const auto root = min_heap.root(); + if (entry.value < root.value) { + continue; + } + if (entry.value == root.value && + entry.index > entries[root.index].index) { + continue; + } + // This element should replace the min. + min_heap.replace_root({shard, entry.value}, heap_size); + } + } + + // Max-part. + { + // Turn the min-heap into a max-heap in-place. + auto max_heap = IndexedHeap{ + IndirectLinearData{top_k_heap, entries}}; + // Heapify into a max heap. + max_heap.build(heap_size); + + // Now extract the minimum k-1 times. + // k is treated specially. + const int last_k = k - 1; + for (int rank = 0; rank < last_k; rank++) { + const Entry& max_element = max_heap.root(); + top_k_values[rank] = max_element.value; + int shard_index = max_element.index; + top_k_indices[rank] = entries[shard_index].index; + int next_shard_index = shard_index + num_shards; + // For rank < k-1, each top k heap still contains at least 1 element, + // so we can draw a replacement. + max_heap.replace_root({next_shard_index, entries[next_shard_index].value}, + heap_size); + } + + // rank == last_k. + const Entry& max_element = max_heap.root(); + top_k_values[last_k] = max_element.value; + int shard_index = max_element.index; + top_k_indices[last_k] = entries[shard_index].index; + } +} + +extern __shared__ char shared_memory[]; + +template +__global__ void TopKKernel(const T* input, int length, int k, bool sorted, + T* output, int* indices) { + const int batch_index = blockIdx.x; + const T* batch_input = input + batch_index * length; + + const int thread_index = threadIdx.x; + const int thread_count = blockDim.x; + + Entry* shared_entries = (Entry*)shared_memory; + + heapTopK(batch_input, length, k, shared_entries, true, + thread_index, thread_count); + + __syncthreads(); + if (thread_index == 0) { + const int offset = batch_index * k; + auto batch_output = output + offset; + auto batch_indices = indices + offset; + Entry* top_k_heap = shared_entries + thread_count * k; + + // TODO(blackhc): Erich says: Performance can likely be improved + // significantly by having the merge be done by multiple threads rather than + // just one. ModernGPU has some nice primitives that could help with this. + mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output, + batch_indices); + } +} + +template +cudaError LaunchTopKKernel(cudaStream_t stream, int num_shards, const T* input, + int batch_size, int length, int k, bool sorted, + T* output, int* indices) { + // This code assumes that k is small enough that the computation + // fits inside shared memory (hard coded to 48KB). In practice this + // means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64. + // The calculation is: + // shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k. + + // Use as many shards as possible. + if (num_shards <= 0) { + constexpr auto shared_memory_size = 48 << 10; // 48 KB + const auto heap_size = k * (sizeof(int) + sizeof(T)); + // shared_memory_size = (num_shards + 1) * heap_size <=> + num_shards = shared_memory_size / heap_size - 1; + if (num_shards <= 0) { + num_shards = 1; + } + auto shard_size = length / num_shards; + auto min_shard_size = 2 * k; + if (shard_size < min_shard_size) { + num_shards = length / min_shard_size; + } + if (num_shards <= 0) { + num_shards = 1; + } else if (num_shards > 1024) { + num_shards = 1024; + } + } + // We are limited by the amount of shared memory we have per block. + auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry); + + TopKKernel<<>>( + input, length, k, sorted, output, indices); + return cudaGetLastError(); +} + +struct SegmentOffsetCreator { + SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()( + const Eigen::array& ix) const { + return ix[0] * num_cols_; + }; + int num_cols_; +}; + +struct ColumnIndexCreator { + ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()( + const Eigen::array& ix) const { + return ix[0] % num_cols_; + } + + int num_cols_; +}; + +template +Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, + int num_cols, int k, + typename TTypes::Tensor values, + TTypes::Tensor indices) { + const GPUDevice& d = ctx->eigen_device(); + auto stream = ctx->eigen_gpu_device().stream(); + size_t temp_storage_bytes = -1; + + // TODO(ebrevdo): Once cub supports iterators for the ValueT and + // segment_offsets, replace these tensors with iterators that + // directly return the correct value. + Tensor input_indices; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT32, TensorShape({num_rows, num_cols}), &input_indices)); + auto input_indices_t = To32Bit(input_indices.flat()); + input_indices_t.device(d) = + input_indices_t.generate(ColumnIndexCreator(num_cols)); + + Tensor segment_offsets; + TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({num_rows + 1}), + &segment_offsets)); + auto segment_offsets_t = To32Bit(segment_offsets.flat()); + segment_offsets_t.device(d) = + segment_offsets_t.generate(SegmentOffsetCreator(num_cols)); + + Tensor temp_values; + Tensor temp_indices; + T* sorted_values_ptr; + int* sorted_indices_ptr; + if (k == num_cols) { + // Doing a full sort, no intermediate values needed. + sorted_values_ptr = values.data(); + sorted_indices_ptr = indices.data(); + } else { + // Need to create intermediate values for sorting. + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices)); + TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({num_rows, num_cols}), + &temp_values)); + sorted_indices_ptr = temp_indices.flat().data(); + sorted_values_ptr = temp_values.flat().data(); + } + + auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + /* d_temp_storage */ nullptr, + /* temp_storage_bytes */ temp_storage_bytes, + /* d_keys_in */ input, + /* d_keys_out */ sorted_values_ptr, + /* d_values_in */ input_indices_t.data(), + /* d_values_out */ sorted_indices_ptr, + /* num_items */ num_cols * num_rows, + /* num_segments */ num_rows, + /* d_begin_offsets */ segment_offsets_t.data(), + /* d_end_offsets */ segment_offsets_t.data() + 1, + /* begin_bit */ 0, + /* end_bit */ sizeof(T) * 8, + /* stream */ stream); + if (err != cudaSuccess) { + return errors::Internal( + "TopKOp: Could not launch " + "cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate " + "temp_storage_bytes, status: ", + cudaGetErrorString(err)); + } + Tensor temp_storage; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + err = cub::DeviceSegmentedRadixSort::SortPairsDescending( + /* d_temp_storage */ temp_storage.flat().data(), + /* temp_storage_bytes */ temp_storage_bytes, + /* d_keys_in */ input, + /* d_keys_out */ sorted_values_ptr, + /* d_values_in */ input_indices_t.data(), + /* d_values_out */ sorted_indices_ptr, + /* num_items */ num_cols * num_rows, + /* num_segments */ num_rows, + /* d_begin_offsets */ segment_offsets_t.data(), + /* d_end_offsets */ segment_offsets_t.data() + 1, + /* begin_bit */ 0, + /* end_bit */ sizeof(T) * 8, + /* stream */ stream); + if (err != cudaSuccess) { + return errors::Internal( + "TopKOp: Could not launch " + "cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, " + "temp_storage_bytes: ", + temp_storage_bytes, ", status: ", cudaGetErrorString(err)); + } + if (k < num_cols) { + // Need to copy subsets of sorted_indices and sorted_outputs to + // indices and outputs. + const Eigen::DSizes slice_indices{0, 0}; + const Eigen::DSizes slice_sizes{num_rows, k}; + To32Bit(indices).device(d) = + To32Bit(temp_indices.matrix()).slice(slice_indices, slice_sizes); + To32Bit(values).device(d) = + To32Bit(temp_values.matrix()).slice(slice_indices, slice_sizes); + } + return Status::OK(); +} + +} // end namespace impl + +namespace functor { + +template +struct TopKFunctor { + static EIGEN_ALWAYS_INLINE Status + Compute(OpKernelContext* context, bool sorted, int k, + const typename TTypes::ConstTensor& input, const int64 num_rows, + const int64 num_cols, typename TTypes::Tensor values, + typename TTypes::Tensor indices) { + // For small k, use the heap implementation. For larger k, use + // the in-place cub sort. For k == num_cols, always use the + // in-place cub sort. The thresholds for n and k were determined + // empirically. + if (num_cols <= 1000 || k == num_cols || k >= 100) { + return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols, + k, values, indices); + } else { + auto stream = context->eigen_gpu_device().stream(); + auto err = impl::LaunchTopKKernel(stream, /* num_shards */ 0, + input.data(), num_rows, num_cols, k, + sorted, values.data(), indices.data()); + if (err != cudaSuccess) { + return errors::Internal( + "Could not launch TopKKernel: ", cudaGetErrorString(err), "."); + } else { + return Status::OK(); + } + } + } +}; + +} // end namespace functor + +#define INSTANTIATE_TEMPLATE(type) \ + template struct functor::TopKFunctor; + +TF_CALL_GPU_NUMBER_TYPES(INSTANTIATE_TEMPLATE); +TF_CALL_INTEGRAL_TYPES(INSTANTIATE_TEMPLATE); +#undef INSTANTIATE_TEMPLATE + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 64d0b8fa52e..2c206f330bb 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -835,7 +835,7 @@ tf_py_test( ], ) -tf_py_test( +cuda_py_test( name = "topk_op_test", size = "small", srcs = ["topk_op_test.py"], diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py index b3f737c8841..034b8be4dd2 100644 --- a/tensorflow/python/kernel_tests/topk_op_test.py +++ b/tensorflow/python/kernel_tests/topk_op_test.py @@ -18,13 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools +import sys + import numpy as np +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -36,25 +43,103 @@ class TopKTest(test.TestCase): k, expected_values, expected_indices, - sorted=True): - np_values = np.array(expected_values) - np_indices = np.array(expected_indices) - with self.test_session(): + sorted=True): # pylint: disable=redefined-builtin + np_expected_values = np.array(expected_values) + np_expected_indices = np.array(expected_indices) + with self.test_session(use_gpu=True) as sess: values_op, indices_op = nn_ops.top_k(inputs, k, sorted=sorted) - values = values_op.eval() - indices = indices_op.eval() - self.assertShapeEqual(np_values, values_op) - self.assertShapeEqual(np_indices, indices_op) - self.assertAllEqual(np_indices, indices) - self.assertAllClose(np_values, values) + values, indices = sess.run([values_op, indices_op]) + + self.assertShapeEqual(np_expected_values, values_op) + self.assertShapeEqual(np_expected_indices, indices_op) + + if sorted: + self.assertAllClose(np_expected_values, values) + # Do some special casing of equality of indices: if indices + # are not the same, but values are floating type, ensure that + # the values are within epsilon of each other. + if not np.issubdtype(np_expected_values.dtype, np.float): + # Values are not floating point type; check indices exactly + self.assertAllEqual(np_expected_indices, indices) + else: + # Values are floating point; indices may be swapped for + # values near each other. + indices_not_equal = np_expected_indices != indices + if np.any(indices_not_equal): + values_unsure = values[indices_not_equal] + expected_values_unsure = expected_values[indices_not_equal] + self.assertAllClose(expected_values_unsure, values_unsure) + else: + np_inputs = np.array(inputs) + + # Check that the indices are valid. + for result_index, src_index in np.ndenumerate(indices): + value = values[result_index] + expected_value = np_inputs[result_index[0], src_index] + np.testing.utils.assert_almost_equal(value, expected_value) + + # Check that if two elements are equal, the lower-index element appears + # first. + shape = values.shape + for batch_index in range(shape[0]): + for index in range(shape[1] - 1): + if np.isclose(values[batch_index, index], + values[batch_index, index + 1]): + self.assertLess(indices[batch_index, index], + indices[batch_index, index + 1]) + + # Now check the results, ignoring order. + self.assertAllEqual(np.sort(np_expected_indices), np.sort(indices)) + self.assertAllClose(np.sort(np_expected_values), np.sort(values)) def testTop1(self): inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]] self._validateTopK(inputs, 1, [[0.4], [0.3]], [[3], [1]]) def testTop2(self): - inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]] - self._validateTopK(inputs, 2, [[0.4, 0.3], [0.3, 0.3]], [[3, 1], [2, 1]]) + inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.4, 0.2]] + self._validateTopK(inputs, 2, [[0.4, 0.3], [0.4, 0.3]], [[3, 1], [2, 1]]) + + def _testLargeSort(self, dtype): + b = 10 + n = 5000 + inputs = np.random.permutation( + np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n) + indices = np.argsort(-inputs, axis=1) + values = -np.sort(-inputs, axis=1) + self._validateTopK(inputs, n, values, indices) + + def testLargeSort(self): + self._testLargeSort(np.float32) + self._testLargeSort(np.float16) + + def _testLargeTopK(self, dtype): + b = 10 + n = 5000 + k = n - 1 + inputs = np.random.permutation( + np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n) + indices = np.argsort(-inputs, axis=1)[:, :k] + values = -np.sort(-inputs, axis=1)[:, :k] + self._validateTopK(inputs, k, values, indices) + + def testLargeTopK(self): + self._testLargeTopK(np.float32) + self._testLargeTopK(np.float16) + + def _testMediumTopK(self, dtype): + b = 5 + n = 500 + k = 50 + inputs = np.random.permutation( + np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n) + indices = np.argsort(-inputs, axis=1)[:, :k] + values = -np.sort(-inputs, axis=1)[:, :k] + self._validateTopK(inputs, k, values, indices) + + def testMediumTopK(self): + self._testMediumTopK(np.float32) + self._testMediumTopK(np.float16) def testTopAll(self): inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]] @@ -79,7 +164,7 @@ class TopKTest(test.TestCase): def testKNegative(self): inputs = [[0.1, 0.2], [0.3, 0.4]] - with self.test_session(): + with self.test_session(use_gpu=True): k = array_ops.placeholder(dtypes.int32) values, _ = nn_ops.top_k(inputs, k) with self.assertRaisesOpError("Need k >= 0, got -7"): @@ -92,7 +177,7 @@ class TopKTest(test.TestCase): nn_ops.top_k(inputs, 4) def testTopKGradients(self): - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: inputs = array_ops.placeholder(dtypes.int32, shape=[2, 5]) values, _ = nn_ops.top_k(inputs, 3) grad = sess.run( @@ -102,5 +187,33 @@ class TopKTest(test.TestCase): self.assertEqual(grad.tolist(), [[0, 0, 1, 3, 2], [0, 4, 0, 5, 6]]) +class TopKBenchmark(test.Benchmark): + + def benchmarkTopK(self): + for (m, n, p, use_gpu) in itertools.product( + [128], + [10, 100, 1000, 10000, 100000], + [0.001, 0.01, 0.5, 0.99, 1.0], + [False, True]): + k = int(p * n) + if k == 0: + continue + name = "m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu) + device = "/%s:0" % ("gpu" if use_gpu else "cpu") + with ops.Graph().as_default(): + with ops.device(device): + x = random_ops.random_uniform((m, n)) + v = resource_variable_ops.ResourceVariable(x) + op = nn_ops.top_k(v, k) + with session.Session() as sess: + v.initializer.run() + r = self.run_op_benchmark(sess, op, min_iters=100, name=name) + gb_processed_input = m * n / 1.0e9 + throughput = gb_processed_input / r["wall_time"] + print("Benchmark: %s \t wall_time: %0.03g s \t " + "Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput)) + sys.stdout.flush() + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py index a428d26996b..3e1fa0a287b 100644 --- a/tensorflow/python/kernel_tests/where_op_test.py +++ b/tensorflow/python/kernel_tests/where_op_test.py @@ -100,7 +100,7 @@ class WhereOpTest(test.TestCase): class WhereBenchmark(test.Benchmark): - def benchmarkWhereCPU(self): + def benchmarkWhere(self): for (m, n, p, use_gpu) in itertools.product( [10], [10, 100, 1000, 10000, 100000, 1000000],