Get TopK op working on GPU again. Extend using cub's radix sort.
1. Undo rollback of Andreas Kirsch's initial implementation. 2. Use cub segmented radix sort if Andreas' heap-based impl for large k and small num_cols (thresholds of k=100, n=1000 determined empirically). 3. Use cub segmented radix sort if k == num_cols (this case is always faster). 4. Added benchmarks. Benchmarks show that the GPU implementation is up to 3x slower for small k but can be 10x faster for large num_cols and k. Benchmarks: Benchmark: m_128_n_10_k_5_use_gpu_False wall_time: 0.000166 s Throughput: 0.0077 GB/s Benchmark: m_128_n_10_k_5_use_gpu_True wall_time: 0.000796 s Throughput: 0.00161 GB/s Benchmark: m_128_n_10_k_9_use_gpu_False wall_time: 0.00017 s Throughput: 0.00751 GB/s Benchmark: m_128_n_10_k_9_use_gpu_True wall_time: 0.000796 s Throughput: 0.00161 GB/s Benchmark: m_128_n_10_k_10_use_gpu_False wall_time: 0.00017 s Throughput: 0.00753 GB/s Benchmark: m_128_n_10_k_10_use_gpu_True wall_time: 0.000775 s Throughput: 0.00165 GB/s Benchmark: m_128_n_100_k_1_use_gpu_False wall_time: 0.000155 s Throughput: 0.0826 GB/s Benchmark: m_128_n_100_k_1_use_gpu_True wall_time: 0.000796 s Throughput: 0.0161 GB/s Benchmark: m_128_n_100_k_50_use_gpu_False wall_time: 0.000247 s Throughput: 0.0519 GB/s Benchmark: m_128_n_100_k_50_use_gpu_True wall_time: 0.0008 s Throughput: 0.016 GB/s Benchmark: m_128_n_100_k_99_use_gpu_False wall_time: 0.000261 s Throughput: 0.049 GB/s Benchmark: m_128_n_100_k_99_use_gpu_True wall_time: 0.000794 s Throughput: 0.0161 GB/s Benchmark: m_128_n_100_k_100_use_gpu_False wall_time: 0.000239 s Throughput: 0.0536 GB/s Benchmark: m_128_n_100_k_100_use_gpu_True wall_time: 0.000777 s Throughput: 0.0165 GB/s Benchmark: m_128_n_1000_k_1_use_gpu_False wall_time: 0.000324 s Throughput: 0.395 GB/s Benchmark: m_128_n_1000_k_1_use_gpu_True wall_time: 0.000916 s Throughput: 0.14 GB/s Benchmark: m_128_n_1000_k_10_use_gpu_False wall_time: 0.00042 s Throughput: 0.305 GB/s Benchmark: m_128_n_1000_k_10_use_gpu_True wall_time: 0.000902 s Throughput: 0.142 GB/s Benchmark: m_128_n_1000_k_500_use_gpu_False wall_time: 0.0011 s Throughput: 0.116 GB/s Benchmark: m_128_n_1000_k_500_use_gpu_True wall_time: 0.00097 s Throughput: 0.132 GB/s Benchmark: m_128_n_1000_k_990_use_gpu_False wall_time: 0.00133 s Throughput: 0.0962 GB/s Benchmark: m_128_n_1000_k_990_use_gpu_True wall_time: 0.000993 s Throughput: 0.129 GB/s Benchmark: m_128_n_1000_k_1000_use_gpu_False wall_time: 0.00102 s Throughput: 0.126 GB/s Benchmark: m_128_n_1000_k_1000_use_gpu_True wall_time: 0.000964 s Throughput: 0.133 GB/s Benchmark: m_128_n_10000_k_10_use_gpu_False wall_time: 0.002 s Throughput: 0.64 GB/s Benchmark: m_128_n_10000_k_10_use_gpu_True wall_time: 0.00288 s Throughput: 0.445 GB/s Benchmark: m_128_n_10000_k_100_use_gpu_False wall_time: 0.00233 s Throughput: 0.549 GB/s Benchmark: m_128_n_10000_k_100_use_gpu_True wall_time: 0.00325 s Throughput: 0.394 GB/s Benchmark: m_128_n_10000_k_5000_use_gpu_False wall_time: 0.0127 s Throughput: 0.101 GB/s Benchmark: m_128_n_10000_k_5000_use_gpu_True wall_time: 0.00381 s Throughput: 0.336 GB/s Benchmark: m_128_n_10000_k_9900_use_gpu_False wall_time: 0.015 s Throughput: 0.0853 GB/s Benchmark: m_128_n_10000_k_9900_use_gpu_True wall_time: 0.00438 s Throughput: 0.292 GB/s Benchmark: m_128_n_10000_k_10000_use_gpu_False wall_time: 0.0104 s Throughput: 0.123 GB/s Benchmark: m_128_n_10000_k_10000_use_gpu_True wall_time: 0.00427 s Throughput: 0.3 GB/s Benchmark: m_128_n_100000_k_100_use_gpu_False wall_time: 0.0148 s Throughput: 0.865 GB/s Benchmark: m_128_n_100000_k_100_use_gpu_True wall_time: 0.0262 s Throughput: 0.488 GB/s Benchmark: m_128_n_100000_k_1000_use_gpu_False wall_time: 0.0201 s Throughput: 0.636 GB/s Benchmark: m_128_n_100000_k_1000_use_gpu_True wall_time: 0.0263 s Throughput: 0.486 GB/s Benchmark: m_128_n_100000_k_50000_use_gpu_False wall_time: 0.214 s Throughput: 0.0599 GB/s Benchmark: m_128_n_100000_k_50000_use_gpu_True wall_time: 0.0322 s Throughput: 0.398 GB/s Benchmark: m_128_n_100000_k_99000_use_gpu_False wall_time: 0.262 s Throughput: 0.0489 GB/s Benchmark: m_128_n_100000_k_99000_use_gpu_True wall_time: 0.0377 s Throughput: 0.34 GB/s Benchmark: m_128_n_100000_k_100000_use_gpu_False wall_time: 0.118 s Throughput: 0.108 GB/s Benchmark: m_128_n_100000_k_100000_use_gpu_True wall_time: 0.0365 s Throughput: 0.351 GB/s END_PUBLIC BEGIN_PUBLIC Automated g4 rollback of changelist 157169178 PiperOrigin-RevId: 161124193
This commit is contained in:
parent
0597c4189a
commit
0b5cce367c
@ -2805,7 +2805,7 @@ tf_kernel_library(
|
|||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "topk_op",
|
name = "topk_op",
|
||||||
prefix = "topk_op",
|
prefix = "topk_op",
|
||||||
deps = NN_DEPS,
|
deps = NN_DEPS + if_cuda(["@cub_archive//:cub"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
@ -4164,6 +4164,7 @@ filegroup(
|
|||||||
"tile_functor.h",
|
"tile_functor.h",
|
||||||
"tile_ops_cpu_impl.h",
|
"tile_ops_cpu_impl.h",
|
||||||
"tile_ops_impl.h",
|
"tile_ops_impl.h",
|
||||||
|
"topk_op.h",
|
||||||
"training_op_helpers.h",
|
"training_op_helpers.h",
|
||||||
"training_ops.h",
|
"training_ops.h",
|
||||||
"transpose_functor.h",
|
"transpose_functor.h",
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/topk_op.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -25,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/top_n.h"
|
#include "tensorflow/core/lib/gtl/top_n.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
@ -33,7 +36,7 @@ namespace tensorflow {
|
|||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
template <typename T>
|
template <typename Device, typename T>
|
||||||
class TopK : public OpKernel {
|
class TopK : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
|
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
@ -82,7 +85,25 @@ class TopK : public OpKernel {
|
|||||||
|
|
||||||
auto values = values_out->flat_inner_dims<T>();
|
auto values = values_out->flat_inner_dims<T>();
|
||||||
auto indices = indices_out->flat_inner_dims<int32>();
|
auto indices = indices_out->flat_inner_dims<int32>();
|
||||||
|
Status s = functor::TopKFunctor<Device, T>::Compute(
|
||||||
|
context, sorted_, k, input, num_rows, num_cols, values, indices);
|
||||||
|
OP_REQUIRES_OK(context, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int k_;
|
||||||
|
bool sorted_;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct TopKFunctor<CPUDevice, T> {
|
||||||
|
static EIGEN_ALWAYS_INLINE Status
|
||||||
|
Compute(OpKernelContext* context, bool sorted, int k,
|
||||||
|
const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
|
||||||
|
const int64 num_cols, typename TTypes<T, 2>::Tensor values,
|
||||||
|
typename TTypes<int, 2>::Tensor indices) {
|
||||||
const CPUDevice& d = context->eigen_device<CPUDevice>();
|
const CPUDevice& d = context->eigen_device<CPUDevice>();
|
||||||
|
|
||||||
// Special case for k == 1.
|
// Special case for k == 1.
|
||||||
@ -108,7 +129,7 @@ class TopK : public OpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto SortIndices = [&, context](int start_batch, int limit_batch) {
|
auto SortIndices = [&, context](int start_batch, int limit_batch) {
|
||||||
@ -117,7 +138,6 @@ class TopK : public OpKernel {
|
|||||||
const auto comp = [input_data](const int32 a, const int32 b) {
|
const auto comp = [input_data](const int32 a, const int32 b) {
|
||||||
return input_data[a] > input_data[b];
|
return input_data[a] > input_data[b];
|
||||||
};
|
};
|
||||||
gtl::TopN<int32, decltype(comp)> filter(k, comp);
|
|
||||||
// TODO(ebrevdo): For large k < num_cols, instead of using
|
// TODO(ebrevdo): For large k < num_cols, instead of using
|
||||||
// TopN, it may be faster to create a temporary vector of
|
// TopN, it may be faster to create a temporary vector of
|
||||||
// values 0..num_cols - 1 and then use std::partial_sort_copy
|
// values 0..num_cols - 1 and then use std::partial_sort_copy
|
||||||
@ -130,13 +150,14 @@ class TopK : public OpKernel {
|
|||||||
std::sort(&indices(b, 0), &indices(b, k), comp);
|
std::sort(&indices(b, 0), &indices(b, k), comp);
|
||||||
} else {
|
} else {
|
||||||
// Use the TopN heap object to sort.
|
// Use the TopN heap object to sort.
|
||||||
|
gtl::TopN<int32, decltype(comp)> filter(k, comp);
|
||||||
filter.reserve(num_cols);
|
filter.reserve(num_cols);
|
||||||
for (int32 c = 0; c < num_cols; ++c) {
|
for (int32 c = 0; c < num_cols; ++c) {
|
||||||
filter.push(c);
|
filter.push(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32 i = 0;
|
int32 i = 0;
|
||||||
if (sorted_) {
|
if (sorted) {
|
||||||
std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
|
std::unique_ptr<std::vector<int32>> top_k(filter.Extract());
|
||||||
for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
|
for (auto top_k_it = top_k->begin(); top_k_it != top_k->end();
|
||||||
++top_k_it, ++i) {
|
++top_k_it, ++i) {
|
||||||
@ -158,35 +179,75 @@ class TopK : public OpKernel {
|
|||||||
|
|
||||||
// Guesstimate of cost; 4*N*log(K) where N == num_cols.
|
// Guesstimate of cost; 4*N*log(K) where N == num_cols.
|
||||||
// If K == N, assume the cost is N*log(K + 1).
|
// If K == N, assume the cost is N*log(K + 1).
|
||||||
const int64 cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
|
const double cmp_cost = 3 * Eigen::TensorOpCost::AddCost<int32>() +
|
||||||
Eigen::TensorOpCost::AddCost<T>();
|
Eigen::TensorOpCost::AddCost<T>();
|
||||||
const int64 base_cost =
|
const double base_cost =
|
||||||
cmp_cost *
|
cmp_cost *
|
||||||
static_cast<int64>(num_cols *
|
static_cast<double>(num_cols *
|
||||||
Eigen::numext::log2(static_cast<float>(k + 1)));
|
Eigen::numext::log2(static_cast<float>(k + 1)));
|
||||||
const int64 sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
|
const double sort_cost = (k == num_cols) ? base_cost : 4 * base_cost;
|
||||||
const int64 copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
|
const double copy_cost = 2 * k * Eigen::TensorOpCost::AddCost<T>();
|
||||||
const int64 total_cost = sort_cost + copy_cost;
|
const double total_cost = sort_cost + copy_cost;
|
||||||
|
const int64 final_cost = (total_cost >= static_cast<double>(kint64max))
|
||||||
|
? kint64max
|
||||||
|
: static_cast<int64>(total_cost);
|
||||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||||
Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
|
Shard(worker_threads.num_threads, worker_threads.workers, num_rows,
|
||||||
total_cost, SortIndices);
|
final_cost, SortIndices);
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
return Status::OK();
|
||||||
int k_;
|
}
|
||||||
bool sorted_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_KERNELS_NAME(name, type) \
|
} // namespace functor
|
||||||
REGISTER_KERNEL_BUILDER( \
|
|
||||||
Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), TopK<type>)
|
#define REGISTER_KERNELS_NAME(name, type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name(#name).Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||||
|
TopK<CPUDevice, type>)
|
||||||
|
|
||||||
#define REGISTER_KERNELS(type) \
|
#define REGISTER_KERNELS(type) \
|
||||||
REGISTER_KERNELS_NAME(TopK, type); \
|
REGISTER_KERNELS_NAME(TopK, type); \
|
||||||
REGISTER_KERNELS_NAME(TopKV2, type)
|
REGISTER_KERNELS_NAME(TopKV2, type)
|
||||||
|
|
||||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
|
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
|
||||||
#undef REGISTER_KERNELS_TO_NAME
|
#undef REGISTER_KERNELS_NAME
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
} // namespace tensorflow
|
#ifdef GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
#define DECLARE_GPU_SPEC(T) \
|
||||||
|
template <> \
|
||||||
|
Status TopKFunctor<GPUDevice, T>::Compute( \
|
||||||
|
OpKernelContext* context, bool sorted, int k, \
|
||||||
|
const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows, \
|
||||||
|
const int64 num_cols, typename TTypes<T, 2>::Tensor values, \
|
||||||
|
typename TTypes<int, 2>::Tensor indices); \
|
||||||
|
extern template struct functor::TopKFunctor<GPUDevice, T>;
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||||
|
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPEC
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#define REGISTER_KERNELS(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("TopK").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
|
||||||
|
TopK<GPUDevice, type>) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TopKV2") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.TypeConstraint<type>("T") \
|
||||||
|
.HostMemory("k"), \
|
||||||
|
TopK<GPUDevice, type>)
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
||||||
|
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
|
||||||
|
|
||||||
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
|
#endif // end GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
42
tensorflow/core/kernels/topk_op.h
Normal file
42
tensorflow/core/kernels/topk_op.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_TOPK_OP_H_
|
||||||
|
#define TENSORFLOW_TOPK_OP_H_
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct TopKFunctor {
|
||||||
|
static Status Compute(OpKernelContext* context, bool sorted, int k,
|
||||||
|
const typename TTypes<T, 2>::ConstTensor& input,
|
||||||
|
const int64 num_rows, const int64 num_cols,
|
||||||
|
typename TTypes<T, 2>::Tensor values,
|
||||||
|
typename TTypes<int, 2>::Tensor indices);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace functor
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_TOPK_OP_H_
|
573
tensorflow/core/kernels/topk_op_gpu.cu.cc
Normal file
573
tensorflow/core/kernels/topk_op_gpu.cu.cc
Normal file
@ -0,0 +1,573 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
|
||||||
|
#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/topk_op.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/top_n.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
// Required for sorting Eigen::half
|
||||||
|
namespace cub {
|
||||||
|
template <>
|
||||||
|
struct NumericTraits<Eigen::half>
|
||||||
|
: BaseTraits<FLOATING_POINT, true, false, unsigned short int, Eigen::half> {
|
||||||
|
};
|
||||||
|
} // namespace cub
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
enum class HeapType { kMinHeap, kMaxHeap };
|
||||||
|
enum class PreferIndices { kLower, kHigher };
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Entry {
|
||||||
|
int index;
|
||||||
|
T value;
|
||||||
|
|
||||||
|
// Test-only.
|
||||||
|
static bool greater(const Entry<T>& a, const Entry<T>& b) {
|
||||||
|
if (a.value == b.value) {
|
||||||
|
return a.index < b.index;
|
||||||
|
}
|
||||||
|
return a.value > b.value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct LinearData {
|
||||||
|
typedef impl::Entry<T> Entry;
|
||||||
|
|
||||||
|
__device__ Entry& operator[](std::size_t index) const { return data[index]; }
|
||||||
|
|
||||||
|
__device__ int get_index(int i) const { return data[i].index; }
|
||||||
|
__device__ T get_value(int i) const { return data[i].value; }
|
||||||
|
|
||||||
|
Entry* const data;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct IndirectLinearData {
|
||||||
|
typedef impl::Entry<T> Entry;
|
||||||
|
|
||||||
|
__device__ Entry& operator[](std::size_t index) const { return data[index]; }
|
||||||
|
|
||||||
|
__device__ int get_index(int i) const {
|
||||||
|
return backing_data[data[i].index].index;
|
||||||
|
}
|
||||||
|
__device__ T get_value(int i) const { return data[i].value; }
|
||||||
|
|
||||||
|
Entry* const data;
|
||||||
|
Entry* const backing_data;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
template <typename T>
|
||||||
|
struct StridedData {
|
||||||
|
typedef impl::Entry<T> Entry;
|
||||||
|
|
||||||
|
__device__ Entry& operator[](std::size_t index) const {
|
||||||
|
return data[index * blockDim.x + threadIdx.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ int get_index(int i) const { return (*this)[i].index; }
|
||||||
|
__device__ T get_value(int i) const { return (*this)[i].value; }
|
||||||
|
|
||||||
|
Entry* const data;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// A heap of Entry<T> that can either work as a min-heap or as a max-heap.
|
||||||
|
template <HeapType heapType, PreferIndices preferIndices,
|
||||||
|
template <typename> class Data, typename T>
|
||||||
|
struct IndexedHeap {
|
||||||
|
typedef typename Data<T>::Entry Entry;
|
||||||
|
const Data<T> data;
|
||||||
|
|
||||||
|
__device__ bool is_above(int left, int right) {
|
||||||
|
T left_value = data.get_value(left);
|
||||||
|
T right_value = data.get_value(right);
|
||||||
|
if (left_value == right_value) {
|
||||||
|
if (preferIndices == PreferIndices::kLower) {
|
||||||
|
return data.get_index(left) < data.get_index(right);
|
||||||
|
} else {
|
||||||
|
return data.get_index(left) > data.get_index(right);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (heapType == HeapType::kMinHeap) {
|
||||||
|
return left_value < right_value;
|
||||||
|
} else {
|
||||||
|
return left_value > right_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void assign(int i, const Entry& entry) { data[i] = entry; }
|
||||||
|
|
||||||
|
__device__ void push_up(int i) {
|
||||||
|
int child = i;
|
||||||
|
int parent;
|
||||||
|
for (; child > 0; child = parent) {
|
||||||
|
parent = (child - 1) / 2;
|
||||||
|
if (!is_above(child, parent)) {
|
||||||
|
// Heap property satisfied.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
swap(child, parent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void swap(int a, int b) {
|
||||||
|
auto tmp = data[b];
|
||||||
|
data[b] = data[a];
|
||||||
|
data[a] = tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void push_root_down(int k) { push_down(0, k); }
|
||||||
|
|
||||||
|
// MAX-HEAPIFY in Cormen
|
||||||
|
__device__ void push_down(int node, int k) {
|
||||||
|
while (true) {
|
||||||
|
const int left = 2 * node + 1;
|
||||||
|
const int right = left + 1;
|
||||||
|
int smallest = node;
|
||||||
|
if (left < k && is_above(left, smallest)) {
|
||||||
|
smallest = left;
|
||||||
|
}
|
||||||
|
if (right < k && is_above(right, smallest)) {
|
||||||
|
smallest = right;
|
||||||
|
}
|
||||||
|
if (smallest == node) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
swap(smallest, node);
|
||||||
|
node = smallest;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BUILD-MAX-HEAPIFY in Cormen
|
||||||
|
__device__ void build(int k) {
|
||||||
|
for (int node = (k - 1) / 2; node >= 0; node--) {
|
||||||
|
push_down(node, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HEAP-EXTRACT-MAX in Cormen
|
||||||
|
__device__ void remove_root(int k) {
|
||||||
|
data[0] = data[k - 1];
|
||||||
|
push_root_down(k - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// in-place HEAPSORT in Cormen
|
||||||
|
// This method destroys the heap property.
|
||||||
|
__device__ void sort(int k) {
|
||||||
|
for (int slot = k - 1; slot > 0; slot--) {
|
||||||
|
// This is like remove_root but we insert the element at the end.
|
||||||
|
swap(slot, 0);
|
||||||
|
// Heap is now an element smaller.
|
||||||
|
push_root_down(/*k=*/slot);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void replace_root(const Entry& entry, int k) {
|
||||||
|
data[0] = entry;
|
||||||
|
push_root_down(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ const Entry& root() { return data[0]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <HeapType heapType, PreferIndices preferIndices,
|
||||||
|
template <typename> class Data, typename T>
|
||||||
|
__device__ IndexedHeap<heapType, preferIndices, Data, T> make_indexed_heap(
|
||||||
|
typename Data<T>::Entry* data) {
|
||||||
|
return IndexedHeap<heapType, preferIndices, Data, T>{Data<T>{data}};
|
||||||
|
}
|
||||||
|
|
||||||
|
// heapTopK walks over [input, input+length) with `step_size` stride starting at
|
||||||
|
// `start_index`.
|
||||||
|
// It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to
|
||||||
|
// access elements in `heap_entries`. If sorted=true, the elements will be
|
||||||
|
// sorted at the end.
|
||||||
|
template <typename T, template <typename> class Data = LinearData>
|
||||||
|
__device__ void heapTopK(const T* __restrict__ input, int length, int k,
|
||||||
|
Entry<T>* __restrict__ heap_entries,
|
||||||
|
bool sorted = false, int start_index = 0,
|
||||||
|
int step_size = 1) {
|
||||||
|
assert(k <= length);
|
||||||
|
|
||||||
|
auto heap =
|
||||||
|
make_indexed_heap<HeapType::kMinHeap, PreferIndices::kHigher, Data, T>(
|
||||||
|
heap_entries);
|
||||||
|
|
||||||
|
int heap_end_index = start_index + k * step_size;
|
||||||
|
if (heap_end_index > length) {
|
||||||
|
heap_end_index = length;
|
||||||
|
}
|
||||||
|
// Initialize the min-heap.
|
||||||
|
for (int index = start_index, slot = 0; index < heap_end_index;
|
||||||
|
index += step_size, slot++) {
|
||||||
|
heap.assign(slot, {index, input[index]});
|
||||||
|
}
|
||||||
|
|
||||||
|
heap.build(k);
|
||||||
|
|
||||||
|
// Now iterate over the remaining items.
|
||||||
|
// If an item is smaller than the min element, it is not amongst the top k.
|
||||||
|
// Otherwise, replace the min element with it and push upwards.
|
||||||
|
for (int index = heap_end_index; index < length; index += step_size) {
|
||||||
|
// We prefer elements with lower indices. This is given here.
|
||||||
|
// Later elements automatically have higher indices, so can be discarded.
|
||||||
|
if (input[index] > heap.root().value) {
|
||||||
|
// This element should replace the min.
|
||||||
|
heap.replace_root({index, input[index]}, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort if wanted.
|
||||||
|
if (sorted) {
|
||||||
|
heap.sort(k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeShards performs a top-k merge on `num_shards` many sorted streams that
|
||||||
|
// are sorted and stored in `entries` in a strided way:
|
||||||
|
// |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|...
|
||||||
|
// The overall top k elements are written to `top_k_values` and their indices
|
||||||
|
// to top_k_indices.
|
||||||
|
// `top_k_heap` is used as temporary storage for the merge heap.
|
||||||
|
template <typename T>
|
||||||
|
__device__ void mergeShards(int num_shards, int k,
|
||||||
|
Entry<T>* __restrict__ entries,
|
||||||
|
Entry<T>* __restrict__ top_k_heap, T* top_k_values,
|
||||||
|
int* top_k_indices) {
|
||||||
|
// If k < num_shards, we can use a min-heap with k elements to get the top k
|
||||||
|
// of the sorted blocks.
|
||||||
|
// If k > num_shards, we can initialize a min-heap with the top element from
|
||||||
|
// each sorted block.
|
||||||
|
const int heap_size = k < num_shards ? k : num_shards;
|
||||||
|
|
||||||
|
// Min-heap part.
|
||||||
|
{
|
||||||
|
auto min_heap = IndexedHeap<HeapType::kMinHeap, PreferIndices::kHigher,
|
||||||
|
IndirectLinearData, T>{
|
||||||
|
IndirectLinearData<T>{top_k_heap, entries}};
|
||||||
|
// Initialize the heap as a min-heap.
|
||||||
|
for (int slot = 0; slot < heap_size; slot++) {
|
||||||
|
min_heap.assign(slot, {slot, entries[slot].value});
|
||||||
|
}
|
||||||
|
min_heap.build(heap_size);
|
||||||
|
|
||||||
|
// Now perform top k with the remaining shards (if num_shards > heap_size).
|
||||||
|
for (int shard = heap_size; shard < num_shards; shard++) {
|
||||||
|
const auto entry = entries[shard];
|
||||||
|
const auto root = min_heap.root();
|
||||||
|
if (entry.value < root.value) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (entry.value == root.value &&
|
||||||
|
entry.index > entries[root.index].index) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// This element should replace the min.
|
||||||
|
min_heap.replace_root({shard, entry.value}, heap_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max-part.
|
||||||
|
{
|
||||||
|
// Turn the min-heap into a max-heap in-place.
|
||||||
|
auto max_heap = IndexedHeap<HeapType::kMaxHeap, PreferIndices::kLower,
|
||||||
|
IndirectLinearData, T>{
|
||||||
|
IndirectLinearData<T>{top_k_heap, entries}};
|
||||||
|
// Heapify into a max heap.
|
||||||
|
max_heap.build(heap_size);
|
||||||
|
|
||||||
|
// Now extract the minimum k-1 times.
|
||||||
|
// k is treated specially.
|
||||||
|
const int last_k = k - 1;
|
||||||
|
for (int rank = 0; rank < last_k; rank++) {
|
||||||
|
const Entry<T>& max_element = max_heap.root();
|
||||||
|
top_k_values[rank] = max_element.value;
|
||||||
|
int shard_index = max_element.index;
|
||||||
|
top_k_indices[rank] = entries[shard_index].index;
|
||||||
|
int next_shard_index = shard_index + num_shards;
|
||||||
|
// For rank < k-1, each top k heap still contains at least 1 element,
|
||||||
|
// so we can draw a replacement.
|
||||||
|
max_heap.replace_root({next_shard_index, entries[next_shard_index].value},
|
||||||
|
heap_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// rank == last_k.
|
||||||
|
const Entry<T>& max_element = max_heap.root();
|
||||||
|
top_k_values[last_k] = max_element.value;
|
||||||
|
int shard_index = max_element.index;
|
||||||
|
top_k_indices[last_k] = entries[shard_index].index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern __shared__ char shared_memory[];
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void TopKKernel(const T* input, int length, int k, bool sorted,
|
||||||
|
T* output, int* indices) {
|
||||||
|
const int batch_index = blockIdx.x;
|
||||||
|
const T* batch_input = input + batch_index * length;
|
||||||
|
|
||||||
|
const int thread_index = threadIdx.x;
|
||||||
|
const int thread_count = blockDim.x;
|
||||||
|
|
||||||
|
Entry<T>* shared_entries = (Entry<T>*)shared_memory;
|
||||||
|
|
||||||
|
heapTopK<T, StridedData>(batch_input, length, k, shared_entries, true,
|
||||||
|
thread_index, thread_count);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
if (thread_index == 0) {
|
||||||
|
const int offset = batch_index * k;
|
||||||
|
auto batch_output = output + offset;
|
||||||
|
auto batch_indices = indices + offset;
|
||||||
|
Entry<T>* top_k_heap = shared_entries + thread_count * k;
|
||||||
|
|
||||||
|
// TODO(blackhc): Erich says: Performance can likely be improved
|
||||||
|
// significantly by having the merge be done by multiple threads rather than
|
||||||
|
// just one. ModernGPU has some nice primitives that could help with this.
|
||||||
|
mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output,
|
||||||
|
batch_indices);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
cudaError LaunchTopKKernel(cudaStream_t stream, int num_shards, const T* input,
|
||||||
|
int batch_size, int length, int k, bool sorted,
|
||||||
|
T* output, int* indices) {
|
||||||
|
// This code assumes that k is small enough that the computation
|
||||||
|
// fits inside shared memory (hard coded to 48KB). In practice this
|
||||||
|
// means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64.
|
||||||
|
// The calculation is:
|
||||||
|
// shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k.
|
||||||
|
|
||||||
|
// Use as many shards as possible.
|
||||||
|
if (num_shards <= 0) {
|
||||||
|
constexpr auto shared_memory_size = 48 << 10; // 48 KB
|
||||||
|
const auto heap_size = k * (sizeof(int) + sizeof(T));
|
||||||
|
// shared_memory_size = (num_shards + 1) * heap_size <=>
|
||||||
|
num_shards = shared_memory_size / heap_size - 1;
|
||||||
|
if (num_shards <= 0) {
|
||||||
|
num_shards = 1;
|
||||||
|
}
|
||||||
|
auto shard_size = length / num_shards;
|
||||||
|
auto min_shard_size = 2 * k;
|
||||||
|
if (shard_size < min_shard_size) {
|
||||||
|
num_shards = length / min_shard_size;
|
||||||
|
}
|
||||||
|
if (num_shards <= 0) {
|
||||||
|
num_shards = 1;
|
||||||
|
} else if (num_shards > 1024) {
|
||||||
|
num_shards = 1024;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We are limited by the amount of shared memory we have per block.
|
||||||
|
auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
|
||||||
|
|
||||||
|
TopKKernel<<<batch_size, num_shards, shared_memory_size, stream>>>(
|
||||||
|
input, length, k, sorted, output, indices);
|
||||||
|
return cudaGetLastError();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SegmentOffsetCreator {
|
||||||
|
SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
|
||||||
|
const Eigen::array<int, 1>& ix) const {
|
||||||
|
return ix[0] * num_cols_;
|
||||||
|
};
|
||||||
|
int num_cols_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ColumnIndexCreator {
|
||||||
|
ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
|
||||||
|
const Eigen::array<int, 1>& ix) const {
|
||||||
|
return ix[0] % num_cols_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_cols_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
|
||||||
|
int num_cols, int k,
|
||||||
|
typename TTypes<T, 2>::Tensor values,
|
||||||
|
TTypes<int, 2>::Tensor indices) {
|
||||||
|
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||||
|
auto stream = ctx->eigen_gpu_device().stream();
|
||||||
|
size_t temp_storage_bytes = -1;
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Once cub supports iterators for the ValueT and
|
||||||
|
// segment_offsets, replace these tensors with iterators that
|
||||||
|
// directly return the correct value.
|
||||||
|
Tensor input_indices;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||||
|
DT_INT32, TensorShape({num_rows, num_cols}), &input_indices));
|
||||||
|
auto input_indices_t = To32Bit(input_indices.flat<int32>());
|
||||||
|
input_indices_t.device(d) =
|
||||||
|
input_indices_t.generate(ColumnIndexCreator(num_cols));
|
||||||
|
|
||||||
|
Tensor segment_offsets;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({num_rows + 1}),
|
||||||
|
&segment_offsets));
|
||||||
|
auto segment_offsets_t = To32Bit(segment_offsets.flat<int32>());
|
||||||
|
segment_offsets_t.device(d) =
|
||||||
|
segment_offsets_t.generate(SegmentOffsetCreator(num_cols));
|
||||||
|
|
||||||
|
Tensor temp_values;
|
||||||
|
Tensor temp_indices;
|
||||||
|
T* sorted_values_ptr;
|
||||||
|
int* sorted_indices_ptr;
|
||||||
|
if (k == num_cols) {
|
||||||
|
// Doing a full sort, no intermediate values needed.
|
||||||
|
sorted_values_ptr = values.data();
|
||||||
|
sorted_indices_ptr = indices.data();
|
||||||
|
} else {
|
||||||
|
// Need to create intermediate values for sorting.
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||||
|
DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices));
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||||
|
TensorShape({num_rows, num_cols}),
|
||||||
|
&temp_values));
|
||||||
|
sorted_indices_ptr = temp_indices.flat<int32>().data();
|
||||||
|
sorted_values_ptr = temp_values.flat<T>().data();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
||||||
|
/* d_temp_storage */ nullptr,
|
||||||
|
/* temp_storage_bytes */ temp_storage_bytes,
|
||||||
|
/* d_keys_in */ input,
|
||||||
|
/* d_keys_out */ sorted_values_ptr,
|
||||||
|
/* d_values_in */ input_indices_t.data(),
|
||||||
|
/* d_values_out */ sorted_indices_ptr,
|
||||||
|
/* num_items */ num_cols * num_rows,
|
||||||
|
/* num_segments */ num_rows,
|
||||||
|
/* d_begin_offsets */ segment_offsets_t.data(),
|
||||||
|
/* d_end_offsets */ segment_offsets_t.data() + 1,
|
||||||
|
/* begin_bit */ 0,
|
||||||
|
/* end_bit */ sizeof(T) * 8,
|
||||||
|
/* stream */ stream);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
return errors::Internal(
|
||||||
|
"TopKOp: Could not launch "
|
||||||
|
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
|
||||||
|
"temp_storage_bytes, status: ",
|
||||||
|
cudaGetErrorString(err));
|
||||||
|
}
|
||||||
|
Tensor temp_storage;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||||
|
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
||||||
|
&temp_storage));
|
||||||
|
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
||||||
|
/* d_temp_storage */ temp_storage.flat<int8>().data(),
|
||||||
|
/* temp_storage_bytes */ temp_storage_bytes,
|
||||||
|
/* d_keys_in */ input,
|
||||||
|
/* d_keys_out */ sorted_values_ptr,
|
||||||
|
/* d_values_in */ input_indices_t.data(),
|
||||||
|
/* d_values_out */ sorted_indices_ptr,
|
||||||
|
/* num_items */ num_cols * num_rows,
|
||||||
|
/* num_segments */ num_rows,
|
||||||
|
/* d_begin_offsets */ segment_offsets_t.data(),
|
||||||
|
/* d_end_offsets */ segment_offsets_t.data() + 1,
|
||||||
|
/* begin_bit */ 0,
|
||||||
|
/* end_bit */ sizeof(T) * 8,
|
||||||
|
/* stream */ stream);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
return errors::Internal(
|
||||||
|
"TopKOp: Could not launch "
|
||||||
|
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
|
||||||
|
"temp_storage_bytes: ",
|
||||||
|
temp_storage_bytes, ", status: ", cudaGetErrorString(err));
|
||||||
|
}
|
||||||
|
if (k < num_cols) {
|
||||||
|
// Need to copy subsets of sorted_indices and sorted_outputs to
|
||||||
|
// indices and outputs.
|
||||||
|
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
|
||||||
|
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
|
||||||
|
To32Bit(indices).device(d) =
|
||||||
|
To32Bit(temp_indices.matrix<int32>()).slice(slice_indices, slice_sizes);
|
||||||
|
To32Bit(values).device(d) =
|
||||||
|
To32Bit(temp_values.matrix<T>()).slice(slice_indices, slice_sizes);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace impl
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct TopKFunctor<GPUDevice, T> {
|
||||||
|
static EIGEN_ALWAYS_INLINE Status
|
||||||
|
Compute(OpKernelContext* context, bool sorted, int k,
|
||||||
|
const typename TTypes<T, 2>::ConstTensor& input, const int64 num_rows,
|
||||||
|
const int64 num_cols, typename TTypes<T, 2>::Tensor values,
|
||||||
|
typename TTypes<int, 2>::Tensor indices) {
|
||||||
|
// For small k, use the heap implementation. For larger k, use
|
||||||
|
// the in-place cub sort. For k == num_cols, always use the
|
||||||
|
// in-place cub sort. The thresholds for n and k were determined
|
||||||
|
// empirically.
|
||||||
|
if (num_cols <= 1000 || k == num_cols || k >= 100) {
|
||||||
|
return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols,
|
||||||
|
k, values, indices);
|
||||||
|
} else {
|
||||||
|
auto stream = context->eigen_gpu_device().stream();
|
||||||
|
auto err = impl::LaunchTopKKernel(stream, /* num_shards */ 0,
|
||||||
|
input.data(), num_rows, num_cols, k,
|
||||||
|
sorted, values.data(), indices.data());
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Could not launch TopKKernel: ", cudaGetErrorString(err), ".");
|
||||||
|
} else {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace functor
|
||||||
|
|
||||||
|
#define INSTANTIATE_TEMPLATE(type) \
|
||||||
|
template struct functor::TopKFunctor<GPUDevice, type>;
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(INSTANTIATE_TEMPLATE);
|
||||||
|
TF_CALL_INTEGRAL_TYPES(INSTANTIATE_TEMPLATE);
|
||||||
|
#undef INSTANTIATE_TEMPLATE
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
@ -835,7 +835,7 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
cuda_py_test(
|
||||||
name = "topk_op_test",
|
name = "topk_op_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["topk_op_test.py"],
|
srcs = ["topk_op_test.py"],
|
||||||
|
@ -18,13 +18,20 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -36,25 +43,103 @@ class TopKTest(test.TestCase):
|
|||||||
k,
|
k,
|
||||||
expected_values,
|
expected_values,
|
||||||
expected_indices,
|
expected_indices,
|
||||||
sorted=True):
|
sorted=True): # pylint: disable=redefined-builtin
|
||||||
np_values = np.array(expected_values)
|
np_expected_values = np.array(expected_values)
|
||||||
np_indices = np.array(expected_indices)
|
np_expected_indices = np.array(expected_indices)
|
||||||
with self.test_session():
|
with self.test_session(use_gpu=True) as sess:
|
||||||
values_op, indices_op = nn_ops.top_k(inputs, k, sorted=sorted)
|
values_op, indices_op = nn_ops.top_k(inputs, k, sorted=sorted)
|
||||||
values = values_op.eval()
|
values, indices = sess.run([values_op, indices_op])
|
||||||
indices = indices_op.eval()
|
|
||||||
self.assertShapeEqual(np_values, values_op)
|
self.assertShapeEqual(np_expected_values, values_op)
|
||||||
self.assertShapeEqual(np_indices, indices_op)
|
self.assertShapeEqual(np_expected_indices, indices_op)
|
||||||
self.assertAllEqual(np_indices, indices)
|
|
||||||
self.assertAllClose(np_values, values)
|
if sorted:
|
||||||
|
self.assertAllClose(np_expected_values, values)
|
||||||
|
# Do some special casing of equality of indices: if indices
|
||||||
|
# are not the same, but values are floating type, ensure that
|
||||||
|
# the values are within epsilon of each other.
|
||||||
|
if not np.issubdtype(np_expected_values.dtype, np.float):
|
||||||
|
# Values are not floating point type; check indices exactly
|
||||||
|
self.assertAllEqual(np_expected_indices, indices)
|
||||||
|
else:
|
||||||
|
# Values are floating point; indices may be swapped for
|
||||||
|
# values near each other.
|
||||||
|
indices_not_equal = np_expected_indices != indices
|
||||||
|
if np.any(indices_not_equal):
|
||||||
|
values_unsure = values[indices_not_equal]
|
||||||
|
expected_values_unsure = expected_values[indices_not_equal]
|
||||||
|
self.assertAllClose(expected_values_unsure, values_unsure)
|
||||||
|
else:
|
||||||
|
np_inputs = np.array(inputs)
|
||||||
|
|
||||||
|
# Check that the indices are valid.
|
||||||
|
for result_index, src_index in np.ndenumerate(indices):
|
||||||
|
value = values[result_index]
|
||||||
|
expected_value = np_inputs[result_index[0], src_index]
|
||||||
|
np.testing.utils.assert_almost_equal(value, expected_value)
|
||||||
|
|
||||||
|
# Check that if two elements are equal, the lower-index element appears
|
||||||
|
# first.
|
||||||
|
shape = values.shape
|
||||||
|
for batch_index in range(shape[0]):
|
||||||
|
for index in range(shape[1] - 1):
|
||||||
|
if np.isclose(values[batch_index, index],
|
||||||
|
values[batch_index, index + 1]):
|
||||||
|
self.assertLess(indices[batch_index, index],
|
||||||
|
indices[batch_index, index + 1])
|
||||||
|
|
||||||
|
# Now check the results, ignoring order.
|
||||||
|
self.assertAllEqual(np.sort(np_expected_indices), np.sort(indices))
|
||||||
|
self.assertAllClose(np.sort(np_expected_values), np.sort(values))
|
||||||
|
|
||||||
def testTop1(self):
|
def testTop1(self):
|
||||||
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
|
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
|
||||||
self._validateTopK(inputs, 1, [[0.4], [0.3]], [[3], [1]])
|
self._validateTopK(inputs, 1, [[0.4], [0.3]], [[3], [1]])
|
||||||
|
|
||||||
def testTop2(self):
|
def testTop2(self):
|
||||||
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
|
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.4, 0.2]]
|
||||||
self._validateTopK(inputs, 2, [[0.4, 0.3], [0.3, 0.3]], [[3, 1], [2, 1]])
|
self._validateTopK(inputs, 2, [[0.4, 0.3], [0.4, 0.3]], [[3, 1], [2, 1]])
|
||||||
|
|
||||||
|
def _testLargeSort(self, dtype):
|
||||||
|
b = 10
|
||||||
|
n = 5000
|
||||||
|
inputs = np.random.permutation(
|
||||||
|
np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n)
|
||||||
|
indices = np.argsort(-inputs, axis=1)
|
||||||
|
values = -np.sort(-inputs, axis=1)
|
||||||
|
self._validateTopK(inputs, n, values, indices)
|
||||||
|
|
||||||
|
def testLargeSort(self):
|
||||||
|
self._testLargeSort(np.float32)
|
||||||
|
self._testLargeSort(np.float16)
|
||||||
|
|
||||||
|
def _testLargeTopK(self, dtype):
|
||||||
|
b = 10
|
||||||
|
n = 5000
|
||||||
|
k = n - 1
|
||||||
|
inputs = np.random.permutation(
|
||||||
|
np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n)
|
||||||
|
indices = np.argsort(-inputs, axis=1)[:, :k]
|
||||||
|
values = -np.sort(-inputs, axis=1)[:, :k]
|
||||||
|
self._validateTopK(inputs, k, values, indices)
|
||||||
|
|
||||||
|
def testLargeTopK(self):
|
||||||
|
self._testLargeTopK(np.float32)
|
||||||
|
self._testLargeTopK(np.float16)
|
||||||
|
|
||||||
|
def _testMediumTopK(self, dtype):
|
||||||
|
b = 5
|
||||||
|
n = 500
|
||||||
|
k = 50
|
||||||
|
inputs = np.random.permutation(
|
||||||
|
np.linspace(0, 100, b * n, dtype=dtype)).reshape(b, n)
|
||||||
|
indices = np.argsort(-inputs, axis=1)[:, :k]
|
||||||
|
values = -np.sort(-inputs, axis=1)[:, :k]
|
||||||
|
self._validateTopK(inputs, k, values, indices)
|
||||||
|
|
||||||
|
def testMediumTopK(self):
|
||||||
|
self._testMediumTopK(np.float32)
|
||||||
|
self._testMediumTopK(np.float16)
|
||||||
|
|
||||||
def testTopAll(self):
|
def testTopAll(self):
|
||||||
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
|
inputs = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.3, 0.2]]
|
||||||
@ -79,7 +164,7 @@ class TopKTest(test.TestCase):
|
|||||||
|
|
||||||
def testKNegative(self):
|
def testKNegative(self):
|
||||||
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
||||||
with self.test_session():
|
with self.test_session(use_gpu=True):
|
||||||
k = array_ops.placeholder(dtypes.int32)
|
k = array_ops.placeholder(dtypes.int32)
|
||||||
values, _ = nn_ops.top_k(inputs, k)
|
values, _ = nn_ops.top_k(inputs, k)
|
||||||
with self.assertRaisesOpError("Need k >= 0, got -7"):
|
with self.assertRaisesOpError("Need k >= 0, got -7"):
|
||||||
@ -92,7 +177,7 @@ class TopKTest(test.TestCase):
|
|||||||
nn_ops.top_k(inputs, 4)
|
nn_ops.top_k(inputs, 4)
|
||||||
|
|
||||||
def testTopKGradients(self):
|
def testTopKGradients(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session(use_gpu=True) as sess:
|
||||||
inputs = array_ops.placeholder(dtypes.int32, shape=[2, 5])
|
inputs = array_ops.placeholder(dtypes.int32, shape=[2, 5])
|
||||||
values, _ = nn_ops.top_k(inputs, 3)
|
values, _ = nn_ops.top_k(inputs, 3)
|
||||||
grad = sess.run(
|
grad = sess.run(
|
||||||
@ -102,5 +187,33 @@ class TopKTest(test.TestCase):
|
|||||||
self.assertEqual(grad.tolist(), [[0, 0, 1, 3, 2], [0, 4, 0, 5, 6]])
|
self.assertEqual(grad.tolist(), [[0, 0, 1, 3, 2], [0, 4, 0, 5, 6]])
|
||||||
|
|
||||||
|
|
||||||
|
class TopKBenchmark(test.Benchmark):
|
||||||
|
|
||||||
|
def benchmarkTopK(self):
|
||||||
|
for (m, n, p, use_gpu) in itertools.product(
|
||||||
|
[128],
|
||||||
|
[10, 100, 1000, 10000, 100000],
|
||||||
|
[0.001, 0.01, 0.5, 0.99, 1.0],
|
||||||
|
[False, True]):
|
||||||
|
k = int(p * n)
|
||||||
|
if k == 0:
|
||||||
|
continue
|
||||||
|
name = "m_%d_n_%d_k_%g_use_gpu_%s" % (m, n, k, use_gpu)
|
||||||
|
device = "/%s:0" % ("gpu" if use_gpu else "cpu")
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
with ops.device(device):
|
||||||
|
x = random_ops.random_uniform((m, n))
|
||||||
|
v = resource_variable_ops.ResourceVariable(x)
|
||||||
|
op = nn_ops.top_k(v, k)
|
||||||
|
with session.Session() as sess:
|
||||||
|
v.initializer.run()
|
||||||
|
r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
|
||||||
|
gb_processed_input = m * n / 1.0e9
|
||||||
|
throughput = gb_processed_input / r["wall_time"]
|
||||||
|
print("Benchmark: %s \t wall_time: %0.03g s \t "
|
||||||
|
"Throughput: %0.03g GB/s" % (name, r["wall_time"], throughput))
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -100,7 +100,7 @@ class WhereOpTest(test.TestCase):
|
|||||||
|
|
||||||
class WhereBenchmark(test.Benchmark):
|
class WhereBenchmark(test.Benchmark):
|
||||||
|
|
||||||
def benchmarkWhereCPU(self):
|
def benchmarkWhere(self):
|
||||||
for (m, n, p, use_gpu) in itertools.product(
|
for (m, n, p, use_gpu) in itertools.product(
|
||||||
[10],
|
[10],
|
||||||
[10, 100, 1000, 10000, 100000, 1000000],
|
[10, 100, 1000, 10000, 100000, 1000000],
|
||||||
|
Loading…
Reference in New Issue
Block a user