Merge pull request from firejq:patch-5

PiperOrigin-RevId: 343650542
Change-Id: Ia41ab95e5b3dda20c552a75f7784eabade14567c
This commit is contained in:
TensorFlower Gardener 2020-11-21 07:27:48 -08:00
commit dad4331852
8 changed files with 189 additions and 85 deletions

View File

@ -36,7 +36,8 @@ namespace functor {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
typedef Eigen::GpuDevice GPUDevice;
// Functor for SegmentSumGPUOp.
// Functor for SegmentSumGPUOp & SegmentProdGPUOp & SegmentMaxGPUOp
// & SegmentMinGPUOp.
// output_rows: the number of output segments (unique segment ids in
// 'segment_ids').
// segment_ids_shape: shape of 'segment_ids' tensor.
@ -45,8 +46,9 @@ typedef Eigen::GpuDevice GPUDevice;
// data_size: size of input data tensor.
// data: input data tensor.
// output: output reshaped to {output_rows, output.size/output_rows}
template <typename T, typename Index>
struct SegmentSumFunctor {
template <typename T, typename Index, typename InitialValueF,
typename ReductionF, typename AtomicReductionF>
struct SegmentReductionFunctor {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
const Index output_rows, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
@ -66,9 +68,10 @@ struct UnsortedSegmentFunctor {
};
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// reduction functors for the gpu
// Atomic reduction functors for the gpu.
template <typename T>
struct SumOpGpu {
struct AtomicSumOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicAdd(dest, value);
@ -76,7 +79,7 @@ struct SumOpGpu {
};
template <typename T>
struct ProdOpGpu {
struct AtomicProdOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicMul(dest, value);
@ -84,7 +87,7 @@ struct ProdOpGpu {
};
template <typename T>
struct MaxOpGpu {
struct AtomicMaxOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicMax(dest, value);
@ -92,16 +95,49 @@ struct MaxOpGpu {
};
template <typename T>
struct MinOpGpu {
struct AtomicMinOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicMin(dest, value);
}
};
// Non-atomic reduction functors for the gpu.
template <typename T>
struct NonAtomicSumOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest += value;
}
};
template <typename T>
struct NonAtomicProdOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest *= value;
}
};
template <typename T>
struct NonAtomicMaxOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest = max(*dest, value);
}
};
template <typename T>
struct NonAtomicMinOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest = min(*dest, value);
}
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// initial value functors
// Initial value functors.
template <typename T>
struct Zero {
EIGEN_STRONG_INLINE T operator()() const { return T(0); }

View File

@ -31,14 +31,14 @@ namespace tensorflow {
using GPUDevice = Eigen::GpuDevice;
// SortedSegmentSumFunctor kernel reduces input data just as
// UnsortedSegmentSumCustomKernel does except that input data
// SortedSegmentReductionFunctor kernel reduces input data just as
// UnsortedSegmentReductionCustomKernel does except that input data
// is partitioned along the outer reduction dimension. This is
// because consecutive rows (elements in a row share the same
// outer dimension index) in the flattened 2D input data likely
// belong to the same segment in sorted segment sum operation.
// Therefore such partitioning strategy has two advantages over
// the UnsortedSegmentSumFunctor kernel:
// the UnsortedSegmentReductionFunctor kernel:
// 1. Each thread reduces across multiple rows before writing
// answers to the global memory, we can therefore
// write reduction results to global memory less often.
@ -51,18 +51,19 @@ using GPUDevice = Eigen::GpuDevice;
// size OuterDimTileSize x 1. This strip runs across multiple
// rows of input data and all reduction elements share one inner
// dimension index.
template <typename T, typename Index, int OuterDimTileSize>
__global__ void SortedSegmentSumCustomKernel(
template <typename T, typename Index, int OuterDimTileSize, typename ReductionF,
typename AtomicReductionF>
__global__ void SortedSegmentReductionCustomKernel(
const Index input_outer_dim_size, const Index inner_dim_size,
const Index output_outer_dim_size, const Index* __restrict__ segment_ids,
const T* __restrict__ input, T* __restrict__ output,
const Index total_stripe_count) {
const Index total_stripe_count, const T initial_value) {
for (int stripe_index : GpuGridRangeX(total_stripe_count)) {
const Index segment_offset = stripe_index % inner_dim_size;
const Index input_outer_dim_index_base =
stripe_index / inner_dim_size * Index(OuterDimTileSize);
T sum = T(0);
T reduce_res = initial_value;
Index first_segment_id = segment_ids[input_outer_dim_index_base];
Index last_output_segment_id = output_outer_dim_size;
@ -72,24 +73,25 @@ __global__ void SortedSegmentSumCustomKernel(
for (Index j = 0; j < actual_stripe_height; j++) {
Index current_output_segment_id =
segment_ids[input_outer_dim_index_base + j];
// Decide whether to write result to global memory.
// Result is only written to global memory if we move
// to another segment. Otherwise we can keep accumulating
// locally.
// Decide whether to write result to global memory. Result is only written
// to global memory if we move to another segment. Otherwise we can keep
// accumulating locally.
if (current_output_segment_id > last_output_segment_id) {
const Index output_index =
last_output_segment_id * inner_dim_size + segment_offset;
// decide whether to write result to global memory using atomic
// operations
// Decide whether to write result to global memory using atomic
// operations.
if (last_output_segment_id == first_segment_id) {
GpuAtomicAdd(output + output_index, sum);
AtomicReductionF()(output + output_index, reduce_res);
} else {
*(output + output_index) = sum;
ReductionF()(output + output_index, reduce_res);
}
sum = T(0);
reduce_res = initial_value;
}
sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size +
segment_offset);
ReductionF()(
&reduce_res,
ldg(input + (input_outer_dim_index_base + j) * inner_dim_size +
segment_offset));
last_output_segment_id = current_output_segment_id;
}
// For the last result in a strip, always write using atomic operations
@ -97,7 +99,7 @@ __global__ void SortedSegmentSumCustomKernel(
// the following strip.
const Index output_index =
last_output_segment_id * inner_dim_size + segment_offset;
GpuAtomicAdd(output + output_index, sum);
AtomicReductionF()(output + output_index, reduce_res);
}
}
@ -126,25 +128,30 @@ __global__ void UnsortedSegmentCustomKernel(
namespace functor {
template <typename T, typename Index>
void SegmentSumFunctor<T, Index>::operator()(
OpKernelContext* ctx, const GPUDevice& d, const Index output_rows,
const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids, const Index data_size,
const T* data, typename TTypes<T, 2>::Tensor output) {
template <typename T, typename Index, typename InitialValueF,
typename ReductionF, typename AtomicReductionF>
void SegmentReductionFunctor<
T, Index, InitialValueF, ReductionF,
AtomicReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d,
const Index output_rows,
const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
typename TTypes<T, 2>::Tensor output) {
if (output.size() == 0) {
return;
}
// Set 'output' to zeros.
// Set 'output' to initial value.
GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
TF_CHECK_OK(GpuLaunchKernel(SetZero<T>, config.block_count,
const T InitialValue = InitialValueF()();
TF_CHECK_OK(GpuLaunchKernel(SetToValue<T>, config.block_count,
config.thread_per_block, 0, d.stream(),
output.size(), output.data()));
output.size(), output.data(), InitialValue));
if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
return;
}
// Launch kernel to compute sorted segment sum.
// Launch kernel to compute sorted segment reduction.
// Notes:
// *) 'input_total_size' is the total number of elements to process.
// *) 'segment_ids.shape' is a prefix of data's shape.
@ -163,10 +170,12 @@ void SegmentSumFunctor<T, Index>::operator()(
config = GetGpuLaunchConfig(total_stripe_count, d);
TF_CHECK_OK(GpuLaunchKernel(
SortedSegmentSumCustomKernel<T, Index, OuterDimTileSize>,
SortedSegmentReductionCustomKernel<T, Index, OuterDimTileSize, ReductionF,
AtomicReductionF>,
config.block_count, config.thread_per_block, 0, d.stream(),
input_outer_dim_size, input_inner_dim_size, output_rows,
segment_ids.data(), data, output.data(), total_stripe_count));
segment_ids.data(), data, output.data(), total_stripe_count,
InitialValue));
}
template <typename T, typename Index, typename InitialValueF,
@ -207,8 +216,19 @@ struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
}
};
#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
template struct SegmentSumFunctor<T, Index>
#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
template struct SegmentReductionFunctor<T, Index, functor::Zero<T>, \
functor::NonAtomicSumOpGpu<T>, \
functor::AtomicSumOpGpu<T>>; \
template struct SegmentReductionFunctor<T, Index, functor::One<T>, \
functor::NonAtomicProdOpGpu<T>, \
functor::AtomicProdOpGpu<T>>; \
template struct SegmentReductionFunctor<T, Index, functor::Highest<T>, \
functor::NonAtomicMinOpGpu<T>, \
functor::AtomicMinOpGpu<T>>; \
template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>, \
functor::NonAtomicMaxOpGpu<T>, \
functor::AtomicMaxOpGpu<T>>;
#define DEFINE_SORTED_GPU_SPECS(T) \
DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \
@ -218,16 +238,16 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
#define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \
template struct UnsortedSegmentFunctor< \
GPUDevice, T, Index, functor::Lowest<T>, functor::MaxOpGpu<T>>; \
GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>; \
template struct UnsortedSegmentFunctor< \
GPUDevice, T, Index, functor::Highest<T>, functor::MinOpGpu<T>>; \
GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>; \
template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
functor::ProdOpGpu<T>>;
functor::AtomicProdOpGpu<T>>;
// sum is the only op that supports all input types currently
// Sum is the only op that supports all input types currently.
#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
template struct UnsortedSegmentFunctor< \
GPUDevice, T, Index, functor::Zero<T>, functor::SumOpGpu<T>>;
GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>;
#define DEFINE_REAL_GPU_SPECS(T) \
DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \

View File

@ -206,24 +206,26 @@ class SegmentReductionOp : public OpKernel {
};
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// SegmentSumGPUOp is a segment sum operator implemented for GPU only.
// TODO: This implementation of SegmentSumGPUOp is sometimes slower than
// SegmentReductionGPUOp is a segment reduction operator implemented for GPU
// only.
// TODO: This implementation of SegmentReductionGPUOp is sometimes slower than
// its unsorted counterpart (mostly when problem size is small).
// This is due to the following two main reasons and a cost-effective way
// to resolve these problems is desirable.
// 1. Sorted segment sum requires a memory transfer from device to host in
// order to know the size of the output dimension whereas unsorted segment
// sum receives the size of the output dimension as an input parameter.
// 2. Sorted segment sum is essentially a tiled version of unsorted segment
// sum and therefore such optimization comes at an inherent cost. However
// such cost may not be justified when the problem size is small. When to
// use the tiled version or the untiled version depends on many factors
// including data alignments, ratio of calculation to memory traffic and
// obviously, the problem sizes.
template <class T, class Index>
class SegmentSumGPUOp : public AsyncOpKernel {
// 1. Sorted segment reduction requires a memory transfer from device to host
// in order to know the size of the output dimension whereas unsorted
// segment reduction receives the size of the output dimension as an input
// parameter.
// 2. Sorted segment reduction is essentially a tiled version of unsorted
// segment reduction and therefore such optimization comes at an inherent
// cost. However such cost may not be justified when the problem size is
// small. When to use the tiled version or the untiled version depends on
// many factors including data alignments, ratio of calculation to memory
// traffic and obviously, the problem sizes.
template <class T, class Index, class SegmentReductionFunctor>
class SegmentReductionGPUOp : public AsyncOpKernel {
public:
explicit SegmentSumGPUOp(OpKernelConstruction* context)
explicit SegmentReductionGPUOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {}
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
@ -265,11 +267,11 @@ class SegmentSumGPUOp : public AsyncOpKernel {
->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device,
sizeof(Index))
.ok(),
errors::Internal(
"SegmentSumGPUOp: failed to copy output_rows from device"),
errors::Internal(type_string() +
": failed to copy output_rows from device"),
done);
functor::SegmentSumFunctor<T, Index> functor_;
SegmentReductionFunctor functor_;
auto create_and_check_output = [context, output_rows_host, &input,
&segment_ids, &functor_, done]() {
// Ensure that within the callback, the proper GPU settings are

View File

@ -113,17 +113,39 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
REGISTER_KERNEL_BUILDER(Name("SegmentSum") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
SegmentSumGPUOp<type, index_type>)
#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
name, type, index_type, initial_value_functor, reduction_kernel_functor, \
atomic_reduction_kernel_functor) \
REGISTER_KERNEL_BUILDER( \
Name(name) \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
SegmentReductionGPUOp< \
type, index_type, \
functor::SegmentReductionFunctor< \
type, index_type, initial_value_functor, \
reduction_kernel_functor, atomic_reduction_kernel_functor> >)
#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
"SegmentSum", type, index_type, functor::Zero<type>, \
functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>); \
REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
"SegmentProd", type, index_type, functor::One<type>, \
functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
"SegmentMin", type, index_type, functor::Highest<type>, \
functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>); \
REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
"SegmentMax", type, index_type, functor::Lowest<type>, \
functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
REGISTER_GPU_SORTED_KERNELS(type, int32)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
#undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
#undef REGISTER_GPU_SORTED_KERNELS
#undef REGISTER_GPU_SORTED_KERNELS_ALL
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -63,17 +63,39 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
REGISTER_KERNEL_BUILDER(Name("SegmentSum") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
SegmentSumGPUOp<type, index_type>)
#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
name, type, index_type, initial_value_functor, reduction_kernel_functor, \
atomic_reduction_kernel_functor) \
REGISTER_KERNEL_BUILDER( \
Name(name) \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
SegmentReductionGPUOp< \
type, index_type, \
functor::SegmentReductionFunctor< \
type, index_type, initial_value_functor, \
reduction_kernel_functor, atomic_reduction_kernel_functor> >)
#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
"SegmentSum", type, index_type, functor::Zero<type>, \
functor::NonAtomicSumOpGpu<type>, functor::AtomicSumOpGpu<type>); \
REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
"SegmentProd", type, index_type, functor::One<type>, \
functor::NonAtomicProdOpGpu<type>, functor::AtomicProdOpGpu<type>); \
REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
"SegmentMin", type, index_type, functor::Highest<type>, \
functor::NonAtomicMinOpGpu<type>, functor::AtomicMinOpGpu<type>); \
REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
"SegmentMax", type, index_type, functor::Lowest<type>, \
functor::NonAtomicMaxOpGpu<type>, functor::AtomicMaxOpGpu<type>);
#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
REGISTER_GPU_SORTED_KERNELS(type, int64);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
#undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
#undef REGISTER_GPU_SORTED_KERNELS
#undef REGISTER_GPU_SORTED_KERNELS_ALL
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -88,18 +88,18 @@ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
#define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
functor::Lowest<type>, \
functor::MaxOpGpu<type>); \
functor::AtomicMaxOpGpu<type>); \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
functor::Highest<type>, \
functor::MinOpGpu<type>); \
functor::AtomicMinOpGpu<type>); \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
functor::One<type>, \
functor::ProdOpGpu<type>);
functor::AtomicProdOpGpu<type>);
#define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
functor::Zero<type>, \
functor::SumOpGpu<type>);
functor::AtomicSumOpGpu<type>);
#define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32)

View File

@ -88,18 +88,18 @@ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
#define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
functor::Lowest<type>, \
functor::MaxOpGpu<type>); \
functor::AtomicMaxOpGpu<type>); \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
functor::Highest<type>, \
functor::MinOpGpu<type>); \
functor::AtomicMinOpGpu<type>); \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
functor::One<type>, \
functor::ProdOpGpu<type>);
functor::AtomicProdOpGpu<type>);
#define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
functor::Zero<type>, \
functor::SumOpGpu<type>);
functor::AtomicSumOpGpu<type>);
#define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64)

View File

@ -987,11 +987,13 @@ tf_py_test(
],
)
tf_py_test(
cuda_py_test(
name = "segment_reduction_ops_test",
size = "medium",
srcs = ["segment_reduction_ops_test.py"],
shard_count = 10,
# TODO (b/173835746): the test fails with XLA.
xla_enable_strict_auto_jit = False,
deps = [
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",