Merge pull request #39429 from ekuznetsov139:google-upstream-topk

PiperOrigin-RevId: 337792038
Change-Id: I76a454570e8614558d0723d69e25a222cb4ca84e
This commit is contained in:
TensorFlower Gardener 2020-10-18 22:12:36 -07:00
commit f99034972d
14 changed files with 60 additions and 37 deletions

View File

@ -35,13 +35,15 @@ namespace gpuprim = ::cub;
#include "rocm/include/hipcub/hipcub.hpp"
namespace gpuprim = ::hipcub;
// Required for sorting Eigen::half
namespace rocprim {
namespace detail {
template <>
struct radix_key_codec_base<Eigen::half>
: radix_key_codec_floating<Eigen::half, unsigned short> {};
: radix_key_codec_floating<Eigen::half, uint16_t> {};
}; // namespace detail
}; // namespace rocprim
#endif // GOOGLE_CUDA
#endif // TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_

View File

@ -76,9 +76,9 @@ static Graph* InTopK(int num_targets, int num_classes, T top_k) {
BM_InTopK(int64, 64, 1000, 10, cpu);
BM_InTopK(int64, 64, 10000, 10, cpu);
#ifdef GOOGLE_CUDA
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
BM_InTopK(int64, 64, 1000, 10, gpu);
BM_InTopK(int64, 64, 10000, 10, gpu);
#endif // GOOGLE_CUDA
#endif // defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
} // namespace tensorflow

View File

@ -244,7 +244,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS_NAME
#undef REGISTER_KERNELS
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor {
#define DECLARE_GPU_SPEC(T) \
@ -277,6 +277,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#endif // end GOOGLE_CUDA
#endif // end GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // end namespace tensorflow

View File

@ -15,11 +15,12 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
#define TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include <cmath>
#include <string>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@ -39,7 +40,7 @@ limitations under the License.
namespace cub {
template <>
struct NumericTraits<Eigen::half>
: BaseTraits<FLOATING_POINT, true, false, unsigned short, Eigen::half> {};
: BaseTraits<FLOATING_POINT, true, false, uint16_t, Eigen::half> {};
} // namespace cub
#endif // GOOGLE_CUDA
@ -107,7 +108,7 @@ struct StridedData {
Entry* const data;
};
#endif
#endif // GOOGLE_CUDA
// A heap of Entry<T> that can either work as a min-heap or as a max-heap.
template <HeapType heapType, PreferIndices preferIndices,
@ -115,6 +116,7 @@ template <HeapType heapType, PreferIndices preferIndices,
struct IndexedHeap {
typedef typename Data<T>::Entry Entry;
const Data<T> data;
__device__ IndexedHeap(const Data<T>& d) : data(d) {}
__device__ bool is_above(int left, int right) {
T left_value = data.get_value(left);
@ -337,12 +339,21 @@ __device__ void mergeShards(int num_shards, int k,
}
}
#if GOOGLE_CUDA
extern __shared__ char shared_memory[];
#endif // GOOGLE_CUDA
template <typename T>
__global__ void TopKKernel(const T* __restrict__ input, int length, int k,
bool sorted, T* __restrict__ output,
int* __restrict__ indices) {
#if TENSORFLOW_USE_ROCM
__attribute__((amdgpu_flat_work_group_size(1, 256)))
#endif // TENSORFLOW_USE_ROCM
__global__ void
TopKKernel(const T* __restrict__ input, int length, int k, bool sorted,
T* __restrict__ output, int* __restrict__ indices) {
#if TENSORFLOW_USE_ROCM
HIP_DYNAMIC_SHARED(char, shared_memory);
#endif // TENSORFLOW_USE_ROCM
const int batch_index = blockIdx.x;
const T* batch_input = input + batch_index * length;
@ -370,7 +381,7 @@ __global__ void TopKKernel(const T* __restrict__ input, int length, int k,
}
template <typename T>
cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards,
cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards,
const T* input, int batch_size, int length, int k,
bool sorted, T* output, int* indices) {
// This code assumes that k is small enough that the computation
@ -395,9 +406,17 @@ cudaError LaunchTopKKernel(const cudaStream_t& stream, int num_shards,
}
if (num_shards <= 0) {
num_shards = 1;
#if GOOGLE_CUDA
} else if (num_shards > 1024) {
num_shards = 1024;
}
#elif TENSORFLOW_USE_ROCM
// ROCm can't execute with 1024 and requires an explicit
// amdgpu_flat_work_group_size attribute with >256
} else if (num_shards > 256) {
num_shards = 256;
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
}
// We are limited by the amount of shared memory we have per block.
auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry<T>);
@ -567,6 +586,6 @@ struct TopKFunctor<GPUDevice, T> {
} // end namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, double>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, float>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, Eigen::half>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, int16>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, int32>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, int64>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, int8>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -27,4 +27,4 @@ template struct functor::TopKFunctor<GPUDevice, uint32>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/topk_op.h"
@ -25,4 +25,4 @@ using Eigen::GpuDevice;
template struct functor::TopKFunctor<GPUDevice, uint8>;
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -102,11 +102,13 @@ class TopKTest(test.TestCase):
self._validateTopK(inputs, 2, [[0.4, 0.3], [0.4, 0.3]], [[3, 1], [2, 1]])
def testTop3(self):
k = 5
inputs = np.random.permutation(np.linspace(0, 100, 6140, dtype=np.float64))
indices = np.argsort(-inputs)[:k]
values = -np.sort(-inputs)[:k]
self._validateTopK(inputs, k, values, indices)
for k in range(3, 11, 2):
for dim in range(512, 12288, 512):
inputs = np.random.permutation(
np.linspace(0, 100, dim, dtype=np.float64))
indices = np.argsort(-inputs)[:k]
values = -np.sort(-inputs)[:k]
self._validateTopK(inputs, k, values, indices)
def testTop1AllNan(self):
inputs = [[np.NaN, np.NaN], [np.NaN, np.NaN]]