Get TopK op working on GPU again. Extend using cub's radix sort.

1. Undo rollback of Andreas Kirsch's initial implementation.
2. Use cub segmented radix sort if Andreas' heap-based impl
   for large k and small num_cols (thresholds of k=100, n=1000
   determined empirically).
3. Use cub segmented radix sort if k == num_cols (this case is always faster).
4. Added benchmarks.

Benchmarks show that the GPU implementation is up to 3x slower for small k but
can be 10x faster for large num_cols and k.

Benchmarks:

Benchmark: m_128_n_10_k_5_use_gpu_False          wall_time: 0.000166 s   Throughput: 0.0077 GB/s
Benchmark: m_128_n_10_k_5_use_gpu_True   wall_time: 0.000796 s   Throughput: 0.00161 GB/s
Benchmark: m_128_n_10_k_9_use_gpu_False          wall_time: 0.00017 s    Throughput: 0.00751 GB/s
Benchmark: m_128_n_10_k_9_use_gpu_True   wall_time: 0.000796 s   Throughput: 0.00161 GB/s
Benchmark: m_128_n_10_k_10_use_gpu_False         wall_time: 0.00017 s    Throughput: 0.00753 GB/s
Benchmark: m_128_n_10_k_10_use_gpu_True          wall_time: 0.000775 s   Throughput: 0.00165 GB/s
Benchmark: m_128_n_100_k_1_use_gpu_False         wall_time: 0.000155 s   Throughput: 0.0826 GB/s
Benchmark: m_128_n_100_k_1_use_gpu_True          wall_time: 0.000796 s   Throughput: 0.0161 GB/s
Benchmark: m_128_n_100_k_50_use_gpu_False        wall_time: 0.000247 s   Throughput: 0.0519 GB/s
Benchmark: m_128_n_100_k_50_use_gpu_True         wall_time: 0.0008 s     Throughput: 0.016 GB/s
Benchmark: m_128_n_100_k_99_use_gpu_False        wall_time: 0.000261 s   Throughput: 0.049 GB/s
Benchmark: m_128_n_100_k_99_use_gpu_True         wall_time: 0.000794 s   Throughput: 0.0161 GB/s
Benchmark: m_128_n_100_k_100_use_gpu_False       wall_time: 0.000239 s   Throughput: 0.0536 GB/s
Benchmark: m_128_n_100_k_100_use_gpu_True        wall_time: 0.000777 s   Throughput: 0.0165 GB/s
Benchmark: m_128_n_1000_k_1_use_gpu_False        wall_time: 0.000324 s   Throughput: 0.395 GB/s
Benchmark: m_128_n_1000_k_1_use_gpu_True         wall_time: 0.000916 s   Throughput: 0.14 GB/s
Benchmark: m_128_n_1000_k_10_use_gpu_False       wall_time: 0.00042 s    Throughput: 0.305 GB/s
Benchmark: m_128_n_1000_k_10_use_gpu_True        wall_time: 0.000902 s   Throughput: 0.142 GB/s
Benchmark: m_128_n_1000_k_500_use_gpu_False      wall_time: 0.0011 s     Throughput: 0.116 GB/s
Benchmark: m_128_n_1000_k_500_use_gpu_True       wall_time: 0.00097 s    Throughput: 0.132 GB/s
Benchmark: m_128_n_1000_k_990_use_gpu_False      wall_time: 0.00133 s    Throughput: 0.0962 GB/s
Benchmark: m_128_n_1000_k_990_use_gpu_True       wall_time: 0.000993 s   Throughput: 0.129 GB/s
Benchmark: m_128_n_1000_k_1000_use_gpu_False     wall_time: 0.00102 s    Throughput: 0.126 GB/s
Benchmark: m_128_n_1000_k_1000_use_gpu_True      wall_time: 0.000964 s   Throughput: 0.133 GB/s
Benchmark: m_128_n_10000_k_10_use_gpu_False      wall_time: 0.002 s      Throughput: 0.64 GB/s
Benchmark: m_128_n_10000_k_10_use_gpu_True       wall_time: 0.00288 s    Throughput: 0.445 GB/s
Benchmark: m_128_n_10000_k_100_use_gpu_False     wall_time: 0.00233 s    Throughput: 0.549 GB/s
Benchmark: m_128_n_10000_k_100_use_gpu_True      wall_time: 0.00325 s    Throughput: 0.394 GB/s
Benchmark: m_128_n_10000_k_5000_use_gpu_False    wall_time: 0.0127 s     Throughput: 0.101 GB/s
Benchmark: m_128_n_10000_k_5000_use_gpu_True     wall_time: 0.00381 s    Throughput: 0.336 GB/s
Benchmark: m_128_n_10000_k_9900_use_gpu_False    wall_time: 0.015 s      Throughput: 0.0853 GB/s
Benchmark: m_128_n_10000_k_9900_use_gpu_True     wall_time: 0.00438 s    Throughput: 0.292 GB/s
Benchmark: m_128_n_10000_k_10000_use_gpu_False   wall_time: 0.0104 s     Throughput: 0.123 GB/s
Benchmark: m_128_n_10000_k_10000_use_gpu_True    wall_time: 0.00427 s    Throughput: 0.3 GB/s
Benchmark: m_128_n_100000_k_100_use_gpu_False    wall_time: 0.0148 s     Throughput: 0.865 GB/s
Benchmark: m_128_n_100000_k_100_use_gpu_True     wall_time: 0.0262 s     Throughput: 0.488 GB/s
Benchmark: m_128_n_100000_k_1000_use_gpu_False   wall_time: 0.0201 s     Throughput: 0.636 GB/s
Benchmark: m_128_n_100000_k_1000_use_gpu_True    wall_time: 0.0263 s     Throughput: 0.486 GB/s
Benchmark: m_128_n_100000_k_50000_use_gpu_False          wall_time: 0.214 s      Throughput: 0.0599 GB/s
Benchmark: m_128_n_100000_k_50000_use_gpu_True   wall_time: 0.0322 s     Throughput: 0.398 GB/s
Benchmark: m_128_n_100000_k_99000_use_gpu_False          wall_time: 0.262 s      Throughput: 0.0489 GB/s
Benchmark: m_128_n_100000_k_99000_use_gpu_True   wall_time: 0.0377 s     Throughput: 0.34 GB/s
Benchmark: m_128_n_100000_k_100000_use_gpu_False         wall_time: 0.118 s      Throughput: 0.108 GB/s
Benchmark: m_128_n_100000_k_100000_use_gpu_True          wall_time: 0.0365 s     Throughput: 0.351 GB/s

END_PUBLIC

BEGIN_PUBLIC
Automated g4 rollback of changelist 157169178

PiperOrigin-RevId: 161124193
This commit is contained in:
Eugene Brevdo 2017-07-06 13:41:47 -07:00 committed by TensorFlower Gardener
parent 0597c4189a
commit 0b5cce367c
7 changed files with 829 additions and 39 deletions

View File

@ -2805,7 +2805,7 @@ tf_kernel_library(
tf_kernel_library(
name = "topk_op",
prefix = "topk_op",
deps = NN_DEPS,
deps = NN_DEPS + if_cuda(["@cub_archive//:cub"]),
)
tf_kernel_library(
@ -4164,6 +4164,7 @@ filegroup(
"tile_functor.h",
"tile_ops_cpu_impl.h",
"tile_ops_impl.h",
"topk_op.h",
"training_op_helpers.h",
"training_ops.h",
"transpose_functor.h",

View File

@ -17,6 +17,8 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/topk_op.h"
#include <algorithm>
#include <numeric>
#include <vector>
@ -25,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/util/work_sharder.h"
@ -33,7 +36,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename T>
template <typename Device, typename T>
class TopK : public OpKernel {
public:
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
@ -82,7 +85,25 @@ class TopK : public OpKernel {
auto values = values_out->flat_inner_dims<T>();
auto indices = indices_out->flat_inner_dims<int32>();
Status s = functor::TopKFunctor<Device, T>::Compute(
context, sorted_, k, input, num_rows, num_cols, values, indices);
OP_REQUIRES_OK(context, s);
}
private:
int k_;
bool sorted_;
};
namespace functor {
template <typename T>
struct TopKFunctor<CPUDevice, T> {
static EIGEN_ALWAYS_INLINE Status
Compute(OpKernelContext* context, bool sorted, int k,
const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
const int64 num_cols, typename TTypes<T, 2>::Tensor values,
typename TTypes<int, 2>::Tensor indices) {
const CPUDevice& d = context->eigen_device<CPUDevice>();
// Special case for k == 1.
@ -108,7 +129,7 @@ class TopK : public OpKernel {
}
}
return;
return Status::OK();
}
auto SortIndices = [&, context](int start_batch, int limit_batch) {
@ -117,7 +138,6 @@ class TopK : public OpKernel {
const auto comp = [input_data](const int32 a, const int32 b) {
return input_data[a] > input_data[b];
};
gtl::TopN<int32, decltype(comp)> filter(k, comp);
// TODO(ebrevdo): For large k < num_cols, instead of using
// TopN, it may be faster to create a temporary vector of
// values 0..num_cols - 1 and then use std::partial_sort_copy
@ -130,13 +150,14 @@ class TopK : public OpKernel {
std::sort(&indices(b, 0), &indices(b, k), comp);
} else {
// Use the TopN heap object to sort.
gtl::TopN<int32, decltype(comp)> filter(k, comp);
filter.reserve(num_cols);
for (int32 c = 0; c < num_cols; ++c) {
filter.push(c);
}
int32 i = 0;
if (sorted_) {
if (sorted) {
std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
++top_k_it, ++i) {
@ -158,35 +179,75 @@ class TopK : public OpKernel {
// Guesstimate of cost; 4*N*log(K) where N == num_cols.
// If K == N, assume the cost is N*log(K + 1).
const int64 cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
Eigen::TensorOpCost::AddCost<T>();
const int64 base_cost =
const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
Eigen::TensorOpCost::AddCost<T>();
const double base_cost =
cmp_cost *
static_cast<int64>(num_cols *
Eigen::numext::log2(static_cast<float>(k + 1)));
const int64 sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
const int64 copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
const int64 total_cost = sort_cost + copy_cost;
static_cast<double>(num_cols *
Eigen::numext::log2(static_cast<float>(k + 1)));
const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
const double total_cost = sort_cost + copy_cost;
const int64 final_cost = (total_cost >= static_cast<double>(kint64max))
? kint64max
: static_cast<int64>(total_cost);
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
total_cost, SortIndices);
}
final_cost, SortIndices);
private:
int k_;
bool sorted_;
return Status::OK();
}
};
#define REGISTER_KERNELS_NAME(name, type) \
REGISTER_KERNEL_BUILDER( \
Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), TopK<type>)
} // namespace functor
#define REGISTER_KERNELS_NAME(name, type) \
REGISTER_KERNEL_BUILDER( \
Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \
TopK<CPUDevice, type>)
#define REGISTER_KERNELS(type) \
REGISTER_KERNELS_NAME(TopK, type); \
REGISTER_KERNELS_NAME(TopKV2, type)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS_TO_NAME
#undef REGISTER_KERNELS_NAME
#undef REGISTER_KERNELS
} // namespace tensorflow
#ifdef GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
Status TopKFunctor<GPUDevice, T>::Compute( \
OpKernelContext* context, bool sorted, int k, \
const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \
const int64 num_cols, typename TTypes<T, 2>::Tensor values, \
typename TTypes<int, 2>::Tensor indices); \
extern template struct functor::TopKFunctor<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC
} // namespace functor
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
TopK<GPUDevice, type>) \
REGISTER_KERNEL_BUILDER(Name("TopKV2") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("k"), \
TopK<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#endif // end GOOGLE_CUDA
} // end namespace tensorflow

View File

@ -0,0 +1,42 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_TOPK_OP_H_
#define TENSORFLOW_TOPK_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace functor {
template <typename Device, typename T>
struct TopKFunctor {
static Status Compute(OpKernelContext* context, bool sorted, int k,
const typename TTypes<T, 2>::ConstTensor& input,
const int64 num_rows, const int64 num_cols,
typename TTypes<T, 2>::Tensor values,
typename TTypes<int, 2>::Tensor indices);
};
} // end namespace functor
} // end namespace tensorflow
#endif // TENSORFLOW_TOPK_OP_H_

View File

@ -0,0 +1,573 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include <cmath>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/topk_op.h"
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
// Required for sorting Eigen::half
namespace cub {
template <>
struct NumericTraits<Eigen::half>
: BaseTraits<FLOATING_POINT, true, false, unsigned short int, Eigen::half> {
};
} // namespace cub
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
namespace impl {
enum class HeapType { kMinHeap, kMaxHeap };
enum class PreferIndices { kLower, kHigher };
template <typename T>
struct Entry {
int index;
T value;
// Test-only.
static bool greater(const Entry<T>& a, const Entry<T>& b) {
if (a.value == b.value) {
return a.index < b.index;
}
return a.value > b.value;
}
};
template <typename T>
struct LinearData {
typedef impl::Entry<T> Entry;
__device__ Entry& operator[](std::size_t index) const { return data[index]; }
__device__ int get_index(int i) const { return data[i].index; }
__device__ T get_value(int i) const { return data[i].value; }
Entry* const data;
};
template <typename T>
struct IndirectLinearData {
typedef impl::Entry<T> Entry;
__device__ Entry& operator[](std::size_t index) const { return data[index]; }
__device__ int get_index(int i) const {
return backing_data[data[i].index].index;
}
__device__ T get_value(int i) const { return data[i].value; }
Entry* const data;
Entry* const backing_data;
};
#if GOOGLE_CUDA
template <typename T>
struct StridedData {
typedef impl::Entry<T> Entry;
__device__ Entry& operator[](std::size_t index) const {
return data[index * blockDim.x + threadIdx.x];
}
__device__ int get_index(int i) const { return (*this)[i].index; }
__device__ T get_value(int i) const { return (*this)[i].value; }
Entry* const data;
};
#endif
// A heap of Entry<T> that can either work as a min-heap or as a max-heap.
template <HeapType heapType, PreferIndices preferIndices,
template <typename> class Data, typename T>
struct IndexedHeap {
typedef typename Data<T>::Entry Entry;
const Data<T> data;
__device__ bool is_above(int left, int right) {
T left_value = data.get_value(left);
T right_value = data.get_value(right);
if (left_value == right_value) {
if (preferIndices == PreferIndices::kLower) {
return data.get_index(left) < data.get_index(right);
} else {
return data.get_index(left) > data.get_index(right);
}
}
if (heapType == HeapType::kMinHeap) {
return left_value < right_value;
} else {
return left_value > right_value;
}
}
__device__ void assign(int i, const Entry& entry) { data[i] = entry; }
__device__ void push_up(int i) {
int child = i;
int parent;
for (; child > 0; child = parent) {
parent = (child - 1) / 2;
if (!is_above(child, parent)) {
// Heap property satisfied.
break;
}
swap(child, parent);
}
}
__device__ void swap(int a, int b) {
auto tmp = data[b];
data[b] = data[a];
data[a] = tmp;
}
__device__ void push_root_down(int k) { push_down(0, k); }
// MAX-HEAPIFY in Cormen
__device__ void push_down(int node, int k) {
while (true) {
const int left = 2 * node + 1;
const int right = left + 1;
int smallest = node;
if (left < k && is_above(left, smallest)) {
smallest = left;
}
if (right < k && is_above(right, smallest)) {
smallest = right;
}
if (smallest == node) {
break;
}
swap(smallest, node);
node = smallest;
}
}
// BUILD-MAX-HEAPIFY in Cormen
__device__ void build(int k) {
for (int node = (k - 1) / 2; node >= 0; node--) {
push_down(node, k);
}
}
// HEAP-EXTRACT-MAX in Cormen
__device__ void remove_root(int k) {
data[0] = data[k - 1];
push_root_down(k - 1);
}
// in-place HEAPSORT in Cormen
// This method destroys the heap property.
__device__ void sort(int k) {
for (int slot = k - 1; slot > 0; slot--) {
// This is like remove_root but we insert the element at the end.
swap(slot, 0);
// Heap is now an element smaller.
push_root_down(/*k=*/slot);
}
}
__device__ void replace_root(const Entry& entry, int k) {
data[0] = entry;
push_root_down(k);
}
__device__ const Entry& root() { return data[0]; }
};
template <HeapType heapType, PreferIndices preferIndices,
template <typename> class Data, typename T>
__device__ IndexedHeap<heapType, preferIndices, Data, T> make_indexed_heap(
typename Data<T>::Entry* data) {
return IndexedHeap<heapType, preferIndices, Data, T>{Data<T>{data}};
}
// heapTopK walks over [input, input+length) with `step_size` stride starting at
// `start_index`.
// It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to
// access elements in `heap_entries`. If sorted=true, the elements will be
// sorted at the end.
template <typename T, template <typename> class Data = LinearData>
__device__ void heapTopK(const T* __restrict__ input, int length, int k,
Entry<T>* __restrict__ heap_entries,
bool sorted = false, int start_index = 0,
int step_size = 1) {
assert(k <= length);
auto heap =
make_indexed_heap<HeapType::kMinHeap, PreferIndices::kHigher, Data, T>(
heap_entries);
int heap_end_index = start_index + k * step_size;
if (heap_end_index > length) {
heap_end_index = length;
}
// Initialize the min-heap.
for (int index = start_index, slot = 0; index < heap_end_index;
index += step_size, slot++) {
heap.assign(slot, {index, input[index]});
}
heap.build(k);
// Now iterate over the remaining items.
// If an item is smaller than the min element, it is not amongst the top k.
// Otherwise, replace the min element with it and push upwards.
for (int index = heap_end_index; index < length; index += step_size) {
// We prefer elements with lower indices. This is given here.
// Later elements automatically have higher indices, so can be discarded.
if (input[index] > heap.root().value) {
// This element should replace the min.
heap.replace_root({index, input[index]}, k);
}
}
// Sort if wanted.
if (sorted) {
heap.sort(k);
}
}
// mergeShards performs a top-k merge on `num_shards` many sorted streams that
// are sorted and stored in `entries` in a strided way:
// |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|...
// The overall top k elements are written to `top_k_values` and their indices
// to top_k_indices.
// `top_k_heap` is used as temporary storage for the merge heap.
template <typename T>
__device__ void mergeShards(int num_shards, int k,
Entry<T>* __restrict__ entries,
Entry<T>* __restrict__ top_k_heap, T* top_k_values,
int* top_k_indices) {
// If k < num_shards, we can use a min-heap with k elements to get the top k
// of the sorted blocks.
// If k > num_shards, we can initialize a min-heap with the top element from
// each sorted block.
const int heap_size = k < num_shards ? k : num_shards;
// Min-heap part.
{
auto min_heap = IndexedHeap<HeapType::kMinHeap, PreferIndices::kHigher,
IndirectLinearData, T>{
IndirectLinearData<T>{top_k_heap, entries}};
// Initialize the heap as a min-heap.
for (int slot = 0; slot < heap_size; slot++) {
min_heap.assign(slot, {slot, entries[slot].value});
}
min_heap.build(heap_size);
// Now perform top k with the remaining shards (if num_shards > heap_size).
for (int shard = heap_size; shard < num_shards; shard++) {
const auto entry = entries[shard];
const auto root = min_heap.root();
if (entry.value < root.value) {
continue;
}
if (entry.value == root.value &&
entry.index > entries[root.index].index) {
continue;
}
// This element should replace the min.
min_heap.replace_root({shard, entry.value}, heap_size);
}
}
// Max-part.
{
// Turn the min-heap into a max-heap in-place.
auto max_heap = IndexedHeap<HeapType::kMaxHeap, PreferIndices::kLower,
IndirectLinearData, T>{
IndirectLinearData<T>{top_k_heap, entries}};
// Heapify into a max heap.
max_heap.build(heap_size);
// Now extract the minimum k-1 times.
// k is treated specially.
const int last_k = k - 1;
for (int rank = 0; rank < last_k; rank++) {
const Entry<T>& max_element = max_heap.root();
top_k_values[rank] = max_element.value;
int shard_index = max_element.index;
top_k_indices[rank] = entries[shard_index].index;
int next_shard_index = shard_index + num_shards;
// For rank < k-1, each top k heap still contains at least 1 element,
// so we can draw a replacement.
max_heap.replace_root({next_shard_index, entries[next_shard_index].value},
heap_size);
}
// rank == last_k.
const Entry<T>& max_element = max_heap.root();
top_k_values[last_k] = max_element.value;
int shard_index = max_element.index;
top_k_indices[last_k] = entries[shard_index].index;
}
}
extern __shared__ char shared_memory[];
template <typename T>
__global__ void TopKKernel(const T* input, int length, int k, bool sorted,
T* output, int* indices) {
const int batch_index = blockIdx.x;
const T* batch_input = input + batch_index * length;
const int thread_index = threadIdx.x;
const int thread_count = blockDim.x;
Entry<T>* shared_entries = (Entry<T>*)shared_memory;
heapTopK<T, StridedData>(batch_input, length, k, shared_entries, true,
thread_index, thread_count);
__syncthreads();
if (thread_index == 0) {
const int offset = batch_index * k;
auto batch_output = output + offset;
auto batch_indices = indices + offset;
Entry<T>* top_k_heap = shared_entries + thread_count * k;
// TODO(blackhc): Erich says: Performance can likely be improved
// significantly by having the merge be done by multiple threads rather than
// just one. ModernGPU has some nice primitives that could help with this.
mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output,
batch_indices);
}
}
template <typename T>
cudaError LaunchTopKKernel(cudaStream_t stream, int num_shards, const T* input,
int batch_size, int length, int k, bool sorted,
T* output, int* indices) {
// This code assumes that k is small enough that the computation
// fits inside shared memory (hard coded to 48KB). In practice this
// means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64.
// The calculation is:
// shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k.
// Use as many shards as possible.
if (num_shards <= 0) {
constexpr auto shared_memory_size = 48 << 10; // 48 KB
const auto heap_size = k * (sizeof(int) + sizeof(T));
// shared_memory_size = (num_shards + 1) * heap_size <=>
num_shards = shared_memory_size / heap_size - 1;
if (num_shards <= 0) {
num_shards = 1;
}
auto shard_size = length / num_shards;
auto min_shard_size = 2 * k;
if (shard_size < min_shard_size) {
num_shards = length / min_shard_size;
}
if (num_shards <= 0) {
num_shards = 1;
} else if (num_shards > 1024) {
num_shards = 1024;
}
}
// We are limited by the amount of shared memory we have per block.
auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
TopKKernel<<<batch_size, num_shards, shared_memory_size, stream>>>(
input, length, k, sorted, output, indices);
return cudaGetLastError();
}
struct SegmentOffsetCreator {
SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
const Eigen::array<int, 1>& ix) const {
return ix[0] * num_cols_;
};
int num_cols_;
};
struct ColumnIndexCreator {
ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
const Eigen::array<int, 1>& ix) const {
return ix[0] % num_cols_;
}
int num_cols_;
};
template <typename T>
Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
int num_cols, int k,
typename TTypes<T, 2>::Tensor values,
TTypes<int, 2>::Tensor indices) {
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
auto stream = ctx->eigen_gpu_device().stream();
size_t temp_storage_bytes = -1;
// TODO(ebrevdo): Once cub supports iterators for the ValueT and
// segment_offsets, replace these tensors with iterators that
// directly return the correct value.
Tensor input_indices;
TF_RETURN_IF_ERROR(ctx->allocate_temp(
DT_INT32, TensorShape({num_rows, num_cols}), &input_indices));
auto input_indices_t = To32Bit(input_indices.flat<int32>());
input_indices_t.device(d) =
input_indices_t.generate(ColumnIndexCreator(num_cols));
Tensor segment_offsets;
TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({num_rows + 1}),
&segment_offsets));
auto segment_offsets_t = To32Bit(segment_offsets.flat<int32>());
segment_offsets_t.device(d) =
segment_offsets_t.generate(SegmentOffsetCreator(num_cols));
Tensor temp_values;
Tensor temp_indices;
T* sorted_values_ptr;
int* sorted_indices_ptr;
if (k == num_cols) {
// Doing a full sort, no intermediate values needed.
sorted_values_ptr = values.data();
sorted_indices_ptr = indices.data();
} else {
// Need to create intermediate values for sorting.
TF_RETURN_IF_ERROR(ctx->allocate_temp(
DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices));
TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
TensorShape({num_rows, num_cols}),
&temp_values));
sorted_indices_ptr = temp_indices.flat<int32>().data();
sorted_values_ptr = temp_values.flat<T>().data();
}
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
/* d_temp_storage */ nullptr,
/* temp_storage_bytes */ temp_storage_bytes,
/* d_keys_in */ input,
/* d_keys_out */ sorted_values_ptr,
/* d_values_in */ input_indices_t.data(),
/* d_values_out */ sorted_indices_ptr,
/* num_items */ num_cols * num_rows,
/* num_segments */ num_rows,
/* d_begin_offsets */ segment_offsets_t.data(),
/* d_end_offsets */ segment_offsets_t.data() + 1,
/* begin_bit */ 0,
/* end_bit */ sizeof(T) * 8,
/* stream */ stream);
if (err != cudaSuccess) {
return errors::Internal(
"TopKOp: Could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
"temp_storage_bytes, status: ",
cudaGetErrorString(err));
}
Tensor temp_storage;
TF_RETURN_IF_ERROR(ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
&temp_storage));
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
/* d_temp_storage */ temp_storage.flat<int8>().data(),
/* temp_storage_bytes */ temp_storage_bytes,
/* d_keys_in */ input,
/* d_keys_out */ sorted_values_ptr,
/* d_values_in */ input_indices_t.data(),
/* d_values_out */ sorted_indices_ptr,
/* num_items */ num_cols * num_rows,
/* num_segments */ num_rows,
/* d_begin_offsets */ segment_offsets_t.data(),
/* d_end_offsets */ segment_offsets_t.data() + 1,
/* begin_bit */ 0,
/* end_bit */ sizeof(T) * 8,
/* stream */ stream);
if (err != cudaSuccess) {
return errors::Internal(
"TopKOp: Could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
"temp_storage_bytes: ",
temp_storage_bytes, ", status: ", cudaGetErrorString(err));
}
if (k < num_cols) {
// Need to copy subsets of sorted_indices and sorted_outputs to
// indices and outputs.
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
To32Bit(indices).device(d) =
To32Bit(temp_indices.matrix<int32>()).slice(slice_indices, slice_sizes);
To32Bit(values).device(d) =
To32Bit(temp_values.matrix<T>()).slice(slice_indices, slice_sizes);
}
return Status::OK();
}
} // end namespace impl
namespace functor {
template <typename T>
struct TopKFunctor<GPUDevice, T> {
static EIGEN_ALWAYS_INLINE Status
Compute(OpKernelContext* context, bool sorted, int k,
const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
const int64 num_cols, typename TTypes<T, 2>::Tensor values,
typename TTypes<int, 2>::Tensor indices) {
// For small k, use the heap implementation. For larger k, use
// the in-place cub sort. For k == num_cols, always use the
// in-place cub sort. The thresholds for n and k were determined
// empirically.
if (num_cols <= 1000 || k == num_cols || k >= 100) {
return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols,
k, values, indices);
} else {
auto stream = context->eigen_gpu_device().stream();
auto err = impl::LaunchTopKKernel(stream, /* num_shards */ 0,
input.data(), num_rows, num_cols, k,
sorted, values.data(), indices.data());
if (err != cudaSuccess) {
return errors::Internal(
"Could not launch TopKKernel: ", cudaGetErrorString(err), ".");
} else {
return Status::OK();
}
}
}
};
} // end namespace functor
#define INSTANTIATE_TEMPLATE(type) \
template struct functor::TopKFunctor<GPUDevice, type>;
TF_CALL_GPU_NUMBER_TYPES(INSTANTIATE_TEMPLATE);
TF_CALL_INTEGRAL_TYPES(INSTANTIATE_TEMPLATE);
#undef INSTANTIATE_TEMPLATE
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -835,7 +835,7 @@ tf_py_test(
],
)
tf_py_test(
cuda_py_test(
name = "topk_op_test",
size = "small",
srcs = ["topk_op_test.py"],

View File

@ -18,13 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import sys
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
@ -36,25 +43,103 @@ class TopKTest(test.TestCase):
k,
expected_values,
expected_indices,
sorted=True):
np_values = np.array(expected_values)
np_indices = np.array(expected_indices)
with self.test_session():
sorted=True): # pylint: disable=redefined-builtin
np_expected_values = np.array(expected_values)
np_expected_indices = np.array(expected_indices)
with self.test_session(use_gpu=True) as sess:
values_op, indices_op = nn_ops.top_k(inputs, k, sorted=sorted)
values = values_op.eval()
indices = indices_op.eval()
self.assertShapeEqual(np_values, values_op)
self.assertShapeEqual(np_indices, indices_op)
self.assertAllEqual(np_indices, indices)
self.assertAllClose(np_values, values)
values, indices = sess.run([values_op, indices_op])
self.assertShapeEqual(np_expected_values, values_op)
self.assertShapeEqual(np_expected_indices, indices_op)
if sorted:
self.assertAllClose(np_expected_values, values)
# Do some special casing of equality of indices: if indices
# are not the same, but values are floating type, ensure that
# the values are within epsilon of each other.
if not np.issubdtype(np_expected_values.dtype, np.float):
# Values are not floating point type; check indices exactly
self.assertAllEqual(np_expected_indices, indices)
else:
# Values are floating point; indices may be swapped for
# values near each other.
indices_not_equal = np_expected_indices != indices
if np.any(indices_not_equal):
values_unsure = values[indices_not_equal]
expected_values_unsure = expected_values[indices_not_equal]
self.assertAllClose(expected_values_unsure, values_unsure)
else:
np_inputs = np.array(inputs)
# Check that the indices are valid.
for result_index, src_index in np.ndenumerate(indices):
value = values[result_index]
expected_value = np_inputs[result_index[0], src_index]
np.testing.utils.assert_almost_equal(value, expected_value)
# Check that if two elements are equal, the lower-index element appears
# first.
shape = values.shape
for batch_index in range(shape[0]):
for index in range(shape[1] - 1):
if np.isclose(values[batch_index, index],
values[batch_index, index + 1]):
self.assertLess(indices[batch_index, index],
indices[batch_index, index + 1])
# Now check the results, ignoring order.
self.assertAllEqual(np.sort(np_expected_indices), np.sort(indices))
self.assertAllClose(np.sort(np_expected_values), np.sort(values))
def testTop1(self):
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
self._validateTopK(inputs, 1, [[0.4], [0.3]], [[3], [1]])
def testTop2(self):
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
self._validateTopK(inputs, 2, [[0.4, 0.3], [0.3, 0.3]], [[3, 1], [2, 1]])
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.4, 0.2]]
self._validateTopK(inputs, 2, [[0.4, 0.3], [0.4, 0.3]], [[3, 1], [2, 1]])
def _testLargeSort(self, dtype):
b = 10
n = 5000
inputs = np.random.permutation(
np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n)
indices = np.argsort(-inputs, axis=1)
values = -np.sort(-inputs, axis=1)
self._validateTopK(inputs, n, values, indices)
def testLargeSort(self):
self._testLargeSort(np.float32)
self._testLargeSort(np.float16)
def _testLargeTopK(self, dtype):
b = 10
n = 5000
k = n - 1
inputs = np.random.permutation(
np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n)
indices = np.argsort(-inputs, axis=1)[:, :k]
values = -np.sort(-inputs, axis=1)[:, :k]
self._validateTopK(inputs, k, values, indices)
def testLargeTopK(self):
self._testLargeTopK(np.float32)
self._testLargeTopK(np.float16)
def _testMediumTopK(self, dtype):
b = 5
n = 500
k = 50
inputs = np.random.permutation(
np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n)
indices = np.argsort(-inputs, axis=1)[:, :k]
values = -np.sort(-inputs, axis=1)[:, :k]
self._validateTopK(inputs, k, values, indices)
def testMediumTopK(self):
self._testMediumTopK(np.float32)
self._testMediumTopK(np.float16)
def testTopAll(self):
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
@ -79,7 +164,7 @@ class TopKTest(test.TestCase):
def testKNegative(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
with self.test_session():
with self.test_session(use_gpu=True):
k = array_ops.placeholder(dtypes.int32)
values, _ = nn_ops.top_k(inputs, k)
with self.assertRaisesOpError("Need k >= 0, got -7"):
@ -92,7 +177,7 @@ class TopKTest(test.TestCase):
nn_ops.top_k(inputs, 4)
def testTopKGradients(self):
with self.test_session() as sess:
with self.test_session(use_gpu=True) as sess:
inputs = array_ops.placeholder(dtypes.int32, shape=[2, 5])
values, _ = nn_ops.top_k(inputs, 3)
grad = sess.run(
@ -102,5 +187,33 @@ class TopKTest(test.TestCase):
self.assertEqual(grad.tolist(), [[0, 0, 1, 3, 2], [0, 4, 0, 5, 6]])
class TopKBenchmark(test.Benchmark):
def benchmarkTopK(self):
for (m, n, p, use_gpu) in itertools.product(
[128],
[10, 100, 1000, 10000, 100000],
[0.001, 0.01, 0.5, 0.99, 1.0],
[False, True]):
k = int(p * n)
if k == 0:
continue
name = "m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu)
device = "/%s:0" % ("gpu" if use_gpu else "cpu")
with ops.Graph().as_default():
with ops.device(device):
x = random_ops.random_uniform((m, n))
v = resource_variable_ops.ResourceVariable(x)
op = nn_ops.top_k(v, k)
with session.Session() as sess:
v.initializer.run()
r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
gb_processed_input = m * n / 1.0e9
throughput = gb_processed_input / r["wall_time"]
print("Benchmark: %s \t wall_time: %0.03g s \t "
"Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
sys.stdout.flush()
if __name__ == "__main__":
test.main()

View File

@ -100,7 +100,7 @@ class WhereOpTest(test.TestCase):
class WhereBenchmark(test.Benchmark):
def benchmarkWhereCPU(self):
def benchmarkWhere(self):
for (m, n, p, use_gpu) in itertools.product(
[10],
[10, 100, 1000, 10000, 100000, 1000000],