Avoid overflow in reduction GPU kernel for large tensors, see issue #22123.
Improve launch code. PiperOrigin-RevId: 221097986
This commit is contained in:
parent
df41194786
commit
fc44600e5c
@ -218,7 +218,11 @@ __global__ void RowReduceKernel(
|
|||||||
T in, outT out, int num_rows, int num_cols, Op op,
|
T in, outT out, int num_rows, int num_cols, Op op,
|
||||||
typename std::iterator_traits<T>::value_type initVal) {
|
typename std::iterator_traits<T>::value_type initVal) {
|
||||||
typedef typename std::iterator_traits<T>::value_type value_type;
|
typedef typename std::iterator_traits<T>::value_type value_type;
|
||||||
const int row = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
|
// Defensive index computation to avoid integer overflow.
|
||||||
|
assert(blockDim.x % 32 == 0);
|
||||||
|
int warps_per_block = blockDim.x / 32;
|
||||||
|
int warp_index = threadIdx.x / 32;
|
||||||
|
const int row = blockIdx.x * warps_per_block + warp_index;
|
||||||
const int lane = threadIdx.x % 32;
|
const int lane = threadIdx.x % 32;
|
||||||
|
|
||||||
if (num_cols == 1) {
|
if (num_cols == 1) {
|
||||||
@ -526,27 +530,27 @@ void LaunchScalarReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
|
|||||||
init);
|
init);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::size_t temp_storage_bytes = 0;
|
|
||||||
|
|
||||||
Tensor temp_storage;
|
size_t temp_storage_bytes = 0;
|
||||||
// written as a loop because it reduces clutter
|
auto reduce = [&](void* temp_storage_ptr) {
|
||||||
// first pass allocates memory, second launches kernel(s)
|
auto success =
|
||||||
for (int i = 0; i < 2; ++i) {
|
cub::DeviceReduce::Reduce(temp_storage_ptr, temp_storage_bytes, in, out,
|
||||||
auto success = cub::DeviceReduce::Reduce(
|
in_size, op, init, cu_stream);
|
||||||
i == 0 ? nullptr : temp_storage.flat<int8_t>().data(),
|
|
||||||
temp_storage_bytes, in, out, in_size, op, init, cu_stream);
|
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, success == 0,
|
ctx, success == 0,
|
||||||
errors::Internal("CUB reduce error", cudaGetErrorString(success)));
|
errors::Internal("CUB reduce error", cudaGetErrorString(success)));
|
||||||
|
};
|
||||||
|
|
||||||
if (i == 0)
|
reduce(nullptr); // Get required amount of temp storage.
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx,
|
Tensor temp_storage;
|
||||||
ctx->allocate_temp(
|
OP_REQUIRES_OK(
|
||||||
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
ctx, ctx->allocate_temp(
|
||||||
&temp_storage));
|
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
||||||
}
|
&temp_storage));
|
||||||
|
|
||||||
|
reduce(temp_storage.flat<int8_t>().data()); // Do reduction.
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op, typename OUT_T, typename IN_T>
|
template <typename T, typename Op, typename OUT_T, typename IN_T>
|
||||||
@ -569,25 +573,26 @@ void LaunchRowReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int num_rows,
|
|||||||
cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>>
|
cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>>
|
||||||
transform_iter(counting_iter, row_offset_op);
|
transform_iter(counting_iter, row_offset_op);
|
||||||
|
|
||||||
std::size_t temp_storage_bytes = 0;
|
size_t temp_storage_bytes = 0;
|
||||||
Tensor temp_storage;
|
auto reduce = [&](void* temp_storage_ptr) {
|
||||||
for (int i = 0; i < 2; ++i) {
|
|
||||||
auto success = cub::DeviceSegmentedReduce::Reduce(
|
auto success = cub::DeviceSegmentedReduce::Reduce(
|
||||||
i == 0 ? nullptr : temp_storage.flat<int8_t>().data(),
|
temp_storage_ptr, temp_storage_bytes, in, out, num_rows, transform_iter,
|
||||||
temp_storage_bytes, in, out, num_rows, transform_iter,
|
|
||||||
transform_iter + 1, op, init, cu_stream);
|
transform_iter + 1, op, init, cu_stream);
|
||||||
|
|
||||||
OP_REQUIRES(ctx, success == 0,
|
OP_REQUIRES(ctx, success == 0,
|
||||||
errors::Internal("CUB segmented reduce error",
|
errors::Internal("CUB segmented reduce error",
|
||||||
cudaGetErrorString(success)));
|
cudaGetErrorString(success)));
|
||||||
|
};
|
||||||
|
|
||||||
if (i == 0)
|
reduce(nullptr); // Get required amount of temp storage.
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx,
|
Tensor temp_storage;
|
||||||
ctx->allocate_temp(
|
OP_REQUIRES_OK(
|
||||||
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
ctx, ctx->allocate_temp(
|
||||||
&temp_storage));
|
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
||||||
}
|
&temp_storage));
|
||||||
|
|
||||||
|
reduce(temp_storage.flat<int8_t>().data()); // Do reduction.
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op, typename OUT_T, typename IN_T>
|
template <typename T, typename Op, typename OUT_T, typename IN_T>
|
||||||
@ -720,25 +725,25 @@ void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
|
|||||||
gather_iter);
|
gather_iter);
|
||||||
|
|
||||||
std::size_t temp_storage_bytes = 0;
|
std::size_t temp_storage_bytes = 0;
|
||||||
Tensor temp_storage;
|
auto reduce = [&](void* temp_storage_ptr) {
|
||||||
|
|
||||||
for (int i = 0; i < 2; ++i) {
|
|
||||||
auto success = cub::DeviceSegmentedReduce::Reduce(
|
auto success = cub::DeviceSegmentedReduce::Reduce(
|
||||||
i == 0 ? nullptr : temp_storage.flat<int8_t>().data(),
|
temp_storage_ptr, temp_storage_bytes, permute_iter, out, extent_y,
|
||||||
temp_storage_bytes, permute_iter, out, extent_y, transform_iter,
|
transform_iter, transform_iter + 1, op, init, cu_stream);
|
||||||
transform_iter + 1, op, init, cu_stream);
|
|
||||||
|
|
||||||
OP_REQUIRES(ctx, success == 0,
|
OP_REQUIRES(ctx, success == 0,
|
||||||
errors::Internal("CUB segmented reduce error",
|
errors::Internal("CUB segmented reduce error",
|
||||||
cudaGetErrorString(success)));
|
cudaGetErrorString(success)));
|
||||||
|
};
|
||||||
|
|
||||||
if (i == 0)
|
reduce(nullptr); // Get required amount of temp storage.
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx,
|
Tensor temp_storage;
|
||||||
ctx->allocate_temp(
|
OP_REQUIRES_OK(
|
||||||
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
ctx, ctx->allocate_temp(
|
||||||
&temp_storage));
|
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
||||||
}
|
&temp_storage));
|
||||||
|
|
||||||
|
reduce(temp_storage.flat<int8_t>().data()); // Do reduction.
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace reduction_op_helper {
|
namespace reduction_op_helper {
|
||||||
|
Loading…
Reference in New Issue
Block a user