Clean up merge mess from CL 255758354, which accidentally undid a lot of changes.

PiperOrigin-RevId: 256343541
This commit is contained in:
A. Unique TensorFlower 2019-07-03 04:46:10 -07:00 committed by TensorFlower Gardener
parent b2ee39a25c
commit 28d194b368
3 changed files with 69 additions and 57 deletions

View File

@ -231,7 +231,7 @@ void LSTMBlockCellFpropWithCUDA(
typename TTypes<T>::Matrix co, typename TTypes<T>::Matrix icfo,
typename TTypes<T>::Matrix h, int batch_size, int cell_size,
int input_size) {
const cudaStream_t& cu_stream = GetGpuStream(ctx);
const auto& cu_stream = GetGpuStream(ctx);
// Concatenate xh = [x, h].
//
@ -370,7 +370,7 @@ void LSTMBlockCellBpropWithCUDA(
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
typename TTypes<T>::Vec wco_grad, const int batch_size, const int cell_size,
const bool use_peephole) {
const cudaStream_t& cu_stream = GetGpuStream(ctx);
const auto& cu_stream = GetGpuStream(ctx);
dim3 block_dim_2d(std::min(batch_size, 8), 32);
dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast<int>(block_dim_2d.x)),

View File

@ -436,7 +436,7 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows,
typename TTypes<T, 2>::Tensor values,
TTypes<int, 2>::Tensor indices) {
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
const cudaStream_t& cu_stream = GetGpuStream(ctx);
const auto& cu_stream = GetGpuStream(ctx);
size_t temp_storage_bytes = -1;
// TODO(ebrevdo): Once cub supports iterators for ValueT replace that tensor
@ -550,7 +550,7 @@ struct TopKFunctor<GPUDevice, T> {
return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols,
k, values, indices);
} else {
const cudaStream_t& cu_stream = GetGpuStream(context);
const auto& cu_stream = GetGpuStream(context);
auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0,
input.data(), num_rows, num_cols, k,
sorted, values.data(), indices.data());

View File

@ -16,15 +16,19 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
#define TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "third_party/cub/device/device_reduce.cuh"
#include "third_party/cub/device/device_select.cuh"
#include "third_party/cub/iterator/counting_input_iterator.cuh"
#include "third_party/cub/iterator/transform_input_iterator.cuh"
#elif TENSORFLOW_USE_ROCM
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
@ -33,6 +37,12 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#if GOOGLE_CUDA
namespace gpuprim = ::cub;
#elif TENSORFLOW_USE_ROCM
namespace gpuprim = ::hipcub;
#endif
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
@ -46,7 +56,7 @@ __global__ void PropagateWhereIndicesKernel(
// TODO(ebrevdo): Use a multi-dimensional loop, increasing the
// dimensions of individual indices manually, instead of relying on
// a scalar loop variable and using integer division.
CUDA_1D_KERNEL_LOOP(i, output_rows) {
GPU_1D_KERNEL_LOOP(i, output_rows) {
TIndex index_value = ldg(output + NDIM * i);
#pragma unroll
for (int c = 0; c < NDIM; ++c) {
@ -69,27 +79,28 @@ struct IsNonzero {
template <typename T, typename TIndex>
struct CubDeviceReduceCount {
cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
const T* d_in, TIndex* d_out, int num_items,
cudaStream_t stream = 0,
bool debug_synchronous = false) {
gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
const T* d_in, TIndex* d_out, int num_items,
gpuStream_t stream = 0,
bool debug_synchronous = false) {
IsNonzero<T> is_nonzero;
cub::TransformInputIterator<bool, IsNonzero<T>, const T*> is_nonzero_iter(
d_in, is_nonzero);
return cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
is_nonzero_iter, d_out, num_items, stream,
debug_synchronous);
gpuprim::TransformInputIterator<bool, IsNonzero<T>, const T*>
is_nonzero_iter(d_in, is_nonzero);
return gpuprim::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes,
is_nonzero_iter, d_out, num_items, stream,
debug_synchronous);
}
};
template <typename TIndex>
struct CubDeviceReduceCount<bool, TIndex> {
cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
const bool* d_in, TIndex* d_out, int num_items,
cudaStream_t stream = 0,
bool debug_synchronous = false) {
return cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, d_in,
d_out, num_items, stream, debug_synchronous);
gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
const bool* d_in, TIndex* d_out, int num_items,
gpuStream_t stream = 0,
bool debug_synchronous = false) {
return gpuprim::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, d_in,
d_out, num_items, stream,
debug_synchronous);
}
};
@ -100,16 +111,16 @@ struct CubDeviceSelectFlaggedCounter;
template <typename T, typename TIndex, typename OutputIterator>
struct CubDeviceSelectFlaggedCounter<T, TIndex, OutputIterator,
false /*IsConvertibleToBool*/> {
cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
const T* d_flags, OutputIterator d_out,
TIndex* d_num_selected_out, int num_items,
cudaStream_t stream = 0,
bool debug_synchronous = false) {
cub::CountingInputIterator<TIndex> select_counter(0);
gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
const T* d_flags, OutputIterator d_out,
TIndex* d_num_selected_out, int num_items,
gpuStream_t stream = 0,
bool debug_synchronous = false) {
gpuprim::CountingInputIterator<TIndex> select_counter(0);
IsNonzero<T> is_nonzero;
cub::TransformInputIterator<bool, IsNonzero<T>, const T*> is_nonzero_iter(
d_flags, is_nonzero);
return cub::DeviceSelect::Flagged(
gpuprim::TransformInputIterator<bool, IsNonzero<T>, const T*>
is_nonzero_iter(d_flags, is_nonzero);
return gpuprim::DeviceSelect::Flagged(
d_temp_storage, temp_storage_bytes, select_counter /*d_in*/,
is_nonzero_iter /*d_flags*/, d_out, d_num_selected_out, num_items,
stream, debug_synchronous);
@ -119,13 +130,13 @@ struct CubDeviceSelectFlaggedCounter<T, TIndex, OutputIterator,
template <typename T, typename TIndex, typename OutputIterator>
struct CubDeviceSelectFlaggedCounter<T, TIndex, OutputIterator,
true /*IsConvertibleToBool*/> {
cudaError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
const T* d_flags, OutputIterator d_out,
TIndex* d_num_selected_out, int num_items,
cudaStream_t stream = 0,
bool debug_synchronous = false) {
cub::CountingInputIterator<TIndex> select_counter(0);
return cub::DeviceSelect::Flagged(
gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes,
const T* d_flags, OutputIterator d_out,
TIndex* d_num_selected_out, int num_items,
gpuStream_t stream = 0,
bool debug_synchronous = false) {
gpuprim::CountingInputIterator<TIndex> select_counter(0);
return gpuprim::DeviceSelect::Flagged(
d_temp_storage, temp_storage_bytes, select_counter /*d_in*/, d_flags,
d_out, d_num_selected_out, num_items, stream, debug_synchronous);
}
@ -139,7 +150,7 @@ struct NumTrue<GPUDevice, T, TIndex> {
OpKernelContext* ctx, const GPUDevice& d,
typename TTypes<T>::ConstFlat input,
typename TTypes<TIndex>::Scalar num_true) {
const cudaStream_t& cu_stream = GetGpuStream(ctx);
const auto& cu_stream = GetGpuStream(ctx);
std::size_t temp_storage_bytes = 0;
const T* input_data = input.data();
@ -154,11 +165,11 @@ struct NumTrue<GPUDevice, T, TIndex> {
/*num_items*/ input.size(),
/*stream*/ cu_stream);
if (first_success != cudaSuccess) {
if (first_success != gpuSuccess) {
return errors::Internal(
"WhereOp: Could not launch cub::DeviceReduce::Sum to calculate "
"WhereOp: Could not launch gpuprim::DeviceReduce::Sum to calculate "
"temp_storage_bytes, status: ",
cudaGetErrorString(first_success));
GpuGetErrorString(first_success));
}
Tensor temp_storage;
@ -173,11 +184,11 @@ struct NumTrue<GPUDevice, T, TIndex> {
/*num_items*/ input.size(),
/*stream*/ cu_stream);
if (second_success != cudaSuccess) {
if (second_success != gpuSuccess) {
return errors::Internal(
"WhereOp: Could not launch cub::DeviceReduce::Sum to count "
"WhereOp: Could not launch gpuprim::DeviceReduce::Sum to count "
"number of true / nonzero indices. temp_storage_bytes: ",
temp_storage_bytes, ", status: ", cudaGetErrorString(second_success));
temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
}
return Status::OK();
@ -266,7 +277,7 @@ struct Where<GPUDevice, NDIM, T, TIndex> {
return Status::OK();
}
const cudaStream_t& cu_stream = GetGpuStream(ctx);
const auto& cu_stream = GetGpuStream(ctx);
std::size_t temp_storage_bytes = 0;
@ -290,11 +301,12 @@ struct Where<GPUDevice, NDIM, T, TIndex> {
/*d_num_selected_out*/ found_true_device,
/*num_items*/ input.size(),
/*stream*/ cu_stream);
if (first_success != cudaSuccess) {
if (first_success != gpuSuccess) {
return errors::Internal(
"WhereOp: Could not launch cub::DeviceSelect::Flagged to calculate "
"WhereOp: Could not launch gpuprim::DeviceSelect::Flagged to "
"calculate "
"temp_storage_bytes, status: ",
cudaGetErrorString(first_success));
GpuGetErrorString(first_success));
}
Tensor temp_storage;
@ -310,11 +322,11 @@ struct Where<GPUDevice, NDIM, T, TIndex> {
/*num_items*/ input.size(),
/*stream*/ cu_stream);
if (second_success != cudaSuccess) {
if (second_success != gpuSuccess) {
return errors::Internal(
"WhereOp: Could not launch cub::DeviceSelect::Flagged to copy "
"WhereOp: Could not launch gpuprim::DeviceSelect::Flagged to copy "
"indices out, status: ",
cudaGetErrorString(second_success));
GpuGetErrorString(second_success));
}
// TODO(ebrevdo): Find a way to synchronously copy back data from
@ -323,11 +335,11 @@ struct Where<GPUDevice, NDIM, T, TIndex> {
const Eigen::array<TIndex, NDIM> strides =
CalculateStrides<TIndex, T, NDIM>(input);
const TIndex output_rows = output.dimension(0);
GpuLaunchConfig config = GetCudaLaunchConfig(output_rows, d);
TF_CHECK_OK(CudaLaunchKernel(PropagateWhereIndicesKernel<NDIM, TIndex>,
config.block_count, config.thread_per_block, 0,
d.stream(), output_rows, strides,
output.data()));
GpuLaunchConfig config = GetGpuLaunchConfig(output_rows, d);
TF_CHECK_OK(GpuLaunchKernel(PropagateWhereIndicesKernel<NDIM, TIndex>,
config.block_count, config.thread_per_block, 0,
d.stream(), output_rows, strides,
output.data()));
return Status::OK();
}
@ -349,6 +361,6 @@ TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_SPEC);
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_