Rename Segment*ReductionF
->*ReductionF
Rename `*OpGPU` to `NonAtomic*OpGpu` Rename `*OpGPU` to `NonAtomic*OpGpu` Add REGISTER_KERNEL_BUILDER for int32 Rename `*OpGPU` to `NonAtomic*OpGpu` Rename `*OpGPU` to `NonAtomic*OpGpu` Rename `*OpGPU` to `NonAtomic*OpGpu` Rename `Segment*ReductionF`->`*ReductionF` Fix problem of compiling
This commit is contained in:
parent
f85360f53b
commit
587385e0fd
@ -47,7 +47,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
// data: input data tensor.
|
||||
// output: output reshaped to {output_rows, output.size/output_rows}
|
||||
template <typename T, typename Index, typename InitialValueF,
|
||||
typename SegmentReductionF, typename SegmentAtomicReductionF>
|
||||
typename ReductionF, typename AtomicReductionF>
|
||||
struct SegmentReductionFunctor {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||
@ -71,7 +71,7 @@ struct UnsortedSegmentFunctor {
|
||||
|
||||
// Atomic reduction functors for the gpu.
|
||||
template <typename T>
|
||||
struct SumAtomicOpGpu {
|
||||
struct AtomicSumOpGpu {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
|
||||
const T& value) {
|
||||
GpuAtomicAdd(dest, value);
|
||||
@ -79,7 +79,7 @@ struct SumAtomicOpGpu {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ProdAtomicOpGpu {
|
||||
struct AtomicProdOpGpu {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
|
||||
const T& value) {
|
||||
GpuAtomicMul(dest, value);
|
||||
@ -87,7 +87,7 @@ struct ProdAtomicOpGpu {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxAtomicOpGpu {
|
||||
struct AtomicMaxOpGpu {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
|
||||
const T& value) {
|
||||
GpuAtomicMax(dest, value);
|
||||
@ -95,16 +95,16 @@ struct MaxAtomicOpGpu {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinAtomicOpGpu {
|
||||
struct AtomicMinOpGpu {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
|
||||
const T& value) {
|
||||
GpuAtomicMin(dest, value);
|
||||
}
|
||||
};
|
||||
|
||||
// Reduction functors for the gpu.
|
||||
// Non-atomic reduction functors for the gpu.
|
||||
template <typename T>
|
||||
struct SumOpGpu {
|
||||
struct NonAtomicSumOpGpu {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
|
||||
const T& value) {
|
||||
*dest += value;
|
||||
@ -112,7 +112,7 @@ struct SumOpGpu {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ProdOpGpu {
|
||||
struct NonAtomicProdOpGpu {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
|
||||
const T& value) {
|
||||
*dest *= value;
|
||||
@ -120,7 +120,7 @@ struct ProdOpGpu {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxOpGpu {
|
||||
struct NonAtomicMaxOpGpu {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
|
||||
const T& value) {
|
||||
*dest = max(*dest, value);
|
||||
@ -128,7 +128,7 @@ struct MaxOpGpu {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MinOpGpu {
|
||||
struct NonAtomicMinOpGpu {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
|
||||
const T& value) {
|
||||
*dest = min(*dest, value);
|
||||
|
@ -52,12 +52,12 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
// rows of input data and all reduction elements share one inner
|
||||
// dimension index.
|
||||
template <typename T, typename Index, int OuterDimTileSize,
|
||||
typename SegmentReductionF, typename SegmentAtomicReductionF>
|
||||
typename ReductionF, typename AtomicReductionF>
|
||||
__global__ void SortedSegmentReductionCustomKernel(
|
||||
const Index input_outer_dim_size, const Index inner_dim_size,
|
||||
const Index output_outer_dim_size, const Index* __restrict__ segment_ids,
|
||||
const T* __restrict__ input, T* __restrict__ output,
|
||||
const Index total_stripe_count,const T initial_value) {
|
||||
const Index total_stripe_count, const T initial_value) {
|
||||
for (int stripe_index : GpuGridRangeX(total_stripe_count)) {
|
||||
const Index segment_offset = stripe_index % inner_dim_size;
|
||||
const Index input_outer_dim_index_base =
|
||||
@ -82,13 +82,13 @@ __global__ void SortedSegmentReductionCustomKernel(
|
||||
// Decide whether to write result to global memory using atomic
|
||||
// operations.
|
||||
if (last_output_segment_id == first_segment_id) {
|
||||
SegmentAtomicReductionF()(output + output_index, reduce_res);
|
||||
AtomicReductionF()(output + output_index, reduce_res);
|
||||
} else {
|
||||
SegmentReductionF()(output + output_index, reduce_res);
|
||||
ReductionF()(output + output_index, reduce_res);
|
||||
}
|
||||
reduce_res = initial_value;
|
||||
}
|
||||
SegmentReductionF()(&reduce_res,
|
||||
ReductionF()(&reduce_res,
|
||||
ldg(input + (input_outer_dim_index_base + j)
|
||||
* inner_dim_size + segment_offset));
|
||||
last_output_segment_id = current_output_segment_id;
|
||||
@ -98,7 +98,7 @@ __global__ void SortedSegmentReductionCustomKernel(
|
||||
// the following strip.
|
||||
const Index output_index =
|
||||
last_output_segment_id * inner_dim_size + segment_offset;
|
||||
SegmentAtomicReductionF()(output + output_index, reduce_res);
|
||||
AtomicReductionF()(output + output_index, reduce_res);
|
||||
}
|
||||
}
|
||||
|
||||
@ -128,9 +128,9 @@ __global__ void UnsortedSegmentCustomKernel(
|
||||
namespace functor {
|
||||
|
||||
template <typename T, typename Index, typename InitialValueF,
|
||||
typename SegmentReductionF, typename SegmentAtomicReductionF>
|
||||
void SegmentReductionFunctor<T, Index, InitialValueF, SegmentReductionF,
|
||||
SegmentAtomicReductionF>::operator()(
|
||||
typename ReductionF, typename AtomicReductionF>
|
||||
void SegmentReductionFunctor<T, Index, InitialValueF, ReductionF,
|
||||
AtomicReductionF>::operator()(
|
||||
OpKernelContext* ctx, const GPUDevice& d, const Index output_rows,
|
||||
const TensorShape& segment_ids_shape,
|
||||
typename TTypes<Index>::ConstFlat segment_ids, const Index data_size,
|
||||
@ -168,8 +168,7 @@ void SegmentReductionFunctor<T, Index, InitialValueF, SegmentReductionF,
|
||||
config = GetGpuLaunchConfig(total_stripe_count, d);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
SortedSegmentReductionCustomKernel<T, Index, OuterDimTileSize,
|
||||
SegmentReductionF,
|
||||
SegmentAtomicReductionF>,
|
||||
ReductionF, AtomicReductionF>,
|
||||
config.block_count, config.thread_per_block, 0, d.stream(),
|
||||
input_outer_dim_size, input_inner_dim_size, output_rows,
|
||||
segment_ids.data(), data, output.data(), total_stripe_count,
|
||||
@ -216,13 +215,13 @@ struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
|
||||
|
||||
#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
|
||||
template struct SegmentReductionFunctor<T, Index, functor::Zero<T>, \
|
||||
functor::SumOpGpu<T>, functor::SumAtomicOpGpu<T>>; \
|
||||
functor::NonAtomicSumOpGpu<T>, functor::AtomicSumOpGpu<T>>; \
|
||||
template struct SegmentReductionFunctor<T, Index, functor::One<T>, \
|
||||
functor::ProdOpGpu<T>, functor::ProdAtomicOpGpu<T>>; \
|
||||
functor::NonAtomicProdOpGpu<T>, functor::AtomicProdOpGpu<T>>; \
|
||||
template struct SegmentReductionFunctor<T, Index, functor::Highest<T>, \
|
||||
functor::MinOpGpu<T>, functor::MinAtomicOpGpu<T>>; \
|
||||
functor::NonAtomicMinOpGpu<T>, functor::AtomicMinOpGpu<T>>; \
|
||||
template struct SegmentReductionFunctor<T, Index, functor::Lowest<T>, \
|
||||
functor::MaxOpGpu<T>, functor::MaxAtomicOpGpu<T>>;
|
||||
functor::NonAtomicMaxOpGpu<T>, functor::AtomicMaxOpGpu<T>>;
|
||||
|
||||
#define DEFINE_SORTED_GPU_SPECS(T) \
|
||||
DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \
|
||||
@ -232,16 +231,16 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
|
||||
|
||||
#define DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, Index) \
|
||||
template struct UnsortedSegmentFunctor< \
|
||||
GPUDevice, T, Index, functor::Lowest<T>, functor::MaxAtomicOpGpu<T>>; \
|
||||
GPUDevice, T, Index, functor::Lowest<T>, functor::AtomicMaxOpGpu<T>>; \
|
||||
template struct UnsortedSegmentFunctor< \
|
||||
GPUDevice, T, Index, functor::Highest<T>, functor::MinAtomicOpGpu<T>>; \
|
||||
GPUDevice, T, Index, functor::Highest<T>, functor::AtomicMinOpGpu<T>>; \
|
||||
template struct UnsortedSegmentFunctor<GPUDevice, T, Index, functor::One<T>, \
|
||||
functor::ProdAtomicOpGpu<T>>;
|
||||
functor::AtomicProdOpGpu<T>>;
|
||||
|
||||
// Sum is the only op that supports all input types currently.
|
||||
#define DEFINE_SUM_UNSORTED_GPU_SPECS_INDEX(T, Index) \
|
||||
template struct UnsortedSegmentFunctor< \
|
||||
GPUDevice, T, Index, functor::Zero<T>, functor::SumAtomicOpGpu<T>>;
|
||||
GPUDevice, T, Index, functor::Zero<T>, functor::AtomicSumOpGpu<T>>;
|
||||
|
||||
#define DEFINE_REAL_GPU_SPECS(T) \
|
||||
DEFINE_REAL_UNSORTED_GPU_SPECS_INDEX(T, int32); \
|
||||
|
@ -113,17 +113,44 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
|
||||
#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SegmentSum") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
SegmentSumGPUOp<type, index_type>)
|
||||
#define REGISTER_GPU_KERNEL_SORTEDSEGMENT( \
|
||||
name, type, index_type, initial_value_functor, \
|
||||
reduction_kernel_functor, atomic_reduction_kernel_functor) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name(name) \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
SegmentReductionGPUOp< \
|
||||
type, index_type, \
|
||||
functor::SegmentReductionFunctor<type, index_type, \
|
||||
initial_value_functor, \
|
||||
reduction_kernel_functor, \
|
||||
atomic_reduction_kernel_functor> >)
|
||||
|
||||
#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
|
||||
REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type, \
|
||||
functor::Zero<type>, \
|
||||
functor::NonAtomicSumOpGpu<type>, \
|
||||
functor::AtomicSumOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type, \
|
||||
functor::One<type>, \
|
||||
functor::NonAtomicProdOpGpu<type>, \
|
||||
functor::AtomicProdOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMin", type, index_type, \
|
||||
functor::Highest<type>, \
|
||||
functor::NonAtomicMinOpGpu<type>, \
|
||||
functor::AtomicMinOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMax", type, index_type, \
|
||||
functor::Lowest<type>, \
|
||||
functor::NonAtomicMaxOpGpu<type>, \
|
||||
functor::AtomicMaxOpGpu<type>);
|
||||
|
||||
#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
|
||||
REGISTER_GPU_SORTED_KERNELS(type, int32)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
|
||||
#undef REGISTER_GPU_KERNEL_SORTEDSEGMENT
|
||||
#undef REGISTER_GPU_SORTED_KERNELS
|
||||
#undef REGISTER_GPU_SORTED_KERNELS_ALL
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -81,20 +81,20 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
|
||||
#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
|
||||
REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentSum", type, index_type, \
|
||||
functor::Zero<type>, \
|
||||
functor::SumOpGpu<type>, \
|
||||
functor::SumAtomicOpGpu<type>); \
|
||||
functor::NonAtomicSumOpGpu<type>, \
|
||||
functor::AtomicSumOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentProd", type, index_type, \
|
||||
functor::One<type>, \
|
||||
functor::ProdOpGpu<type>, \
|
||||
functor::ProdAtomicOpGpu<type>); \
|
||||
functor::NonAtomicProdOpGpu<type>, \
|
||||
functor::AtomicProdOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMin", type, index_type, \
|
||||
functor::Highest<type>, \
|
||||
functor::MinOpGpu<type>, \
|
||||
functor::MinAtomicOpGpu<type>); \
|
||||
functor::NonAtomicMinOpGpu<type>, \
|
||||
functor::AtomicMinOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_SORTEDSEGMENT("SegmentMax", type, index_type, \
|
||||
functor::Lowest<type>, \
|
||||
functor::MaxOpGpu<type>, \
|
||||
functor::MaxAtomicOpGpu<type>);
|
||||
functor::NonAtomicMaxOpGpu<type>, \
|
||||
functor::AtomicMaxOpGpu<type>);
|
||||
|
||||
#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
|
||||
REGISTER_GPU_SORTED_KERNELS(type, int64);
|
||||
|
@ -88,18 +88,18 @@ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
|
||||
#define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \
|
||||
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
|
||||
functor::Lowest<type>, \
|
||||
functor::MaxAtomicOpGpu<type>); \
|
||||
functor::AtomicMaxOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
|
||||
functor::Highest<type>, \
|
||||
functor::MinAtomicOpGpu<type>); \
|
||||
functor::AtomicMinOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
|
||||
functor::One<type>, \
|
||||
functor::ProdAtomicOpGpu<type>);
|
||||
functor::AtomicProdOpGpu<type>);
|
||||
|
||||
#define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \
|
||||
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
|
||||
functor::Zero<type>, \
|
||||
functor::SumAtomicOpGpu<type>);
|
||||
functor::AtomicSumOpGpu<type>);
|
||||
|
||||
#define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
|
||||
REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int32)
|
||||
|
@ -88,18 +88,18 @@ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
|
||||
#define REGISTER_REAL_GPU_UNSORTED_KERNELS(type, index_type) \
|
||||
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMax", type, index_type, \
|
||||
functor::Lowest<type>, \
|
||||
functor::MaxAtomicOpGpu<type>); \
|
||||
functor::AtomicMaxOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentMin", type, index_type, \
|
||||
functor::Highest<type>, \
|
||||
functor::MinAtomicOpGpu<type>); \
|
||||
functor::AtomicMinOpGpu<type>); \
|
||||
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentProd", type, index_type, \
|
||||
functor::One<type>, \
|
||||
functor::ProdAtomicOpGpu<type>);
|
||||
functor::AtomicProdOpGpu<type>);
|
||||
|
||||
#define REGISTER_SUM_GPU_UNSORTED_KERNELS(type, index_type) \
|
||||
REGISTER_GPU_KERNEL_UNSORTEDSEGMENT("UnsortedSegmentSum", type, index_type, \
|
||||
functor::Zero<type>, \
|
||||
functor::SumAtomicOpGpu<type>);
|
||||
functor::AtomicSumOpGpu<type>);
|
||||
|
||||
#define REGISTER_REAL_GPU_UNSORTED_KERNELS_ALL(type) \
|
||||
REGISTER_REAL_GPU_UNSORTED_KERNELS(type, int64)
|
||||
|
Loading…
Reference in New Issue
Block a user