Add GPU implementation for tf.segment_sum. (#11630)
* Add GPU implementation for tf.segment_sum. * Refactor segment sum to compute asynchronously. * Add GPU tests and change test datatype to accommodate GPU kernel input requirement. * Add benchmarks. * Benchmark results against baseline unsorted impl. The columns are: datatypes, outer dimension, output outer dimension, inner dimension, execution time difference against baseline impl as a percentage of the time taken by the baseline impl (positive values mean faster execution than baseline). fp32 512 256 512 -11.8632 fp64 512 256 512 -10.1854 fp32 512 256 2048 -6.2147 fp64 512 256 2048 -0.106 fp32 512 256 8192 -0.0867 fp64 512 256 8192 3.6285 fp32 512 256 1120 -10.9163 fp64 512 256 1120 -1.3509 fp32 512 256 1215 -2.0428 fp64 512 256 1215 -7.884 fp32 512 256 1856 -7.3159 fp64 512 256 1856 -0.1011 fp32 512 256 1302 -3.7802 fp64 512 256 1302 4.6675 fp32 512 256 1329 -13.2275 fp64 512 256 1329 4.3505 fp32 512 256 1531 -7.5993 fp64 512 256 1531 -3.5371 fp32 512 256 1313 -5.0677 fp64 512 256 1313 -1.368 fp32 512 256 1672 -0.0907 fp64 512 256 1672 4.5809 fp32 512 256 1851 -10.2862 fp64 512 256 1851 2.3119 fp32 512 256 1584 -10.3406 fp64 512 256 1584 1.4481 fp32 512 64 512 -6.3544 fp64 512 64 512 -18.4343 fp32 512 64 2048 -9.6639 fp64 512 64 2048 0.0714 fp32 512 64 8192 1.2097 fp64 512 64 8192 10.4839 fp32 512 64 1120 -20.1102 fp64 512 64 1120 -5.0784 fp32 512 64 1215 -6.4061 fp64 512 64 1215 -14.1781 fp32 512 64 1856 2.4221 fp64 512 64 1856 -8.4205 fp32 512 64 1302 -10.2403 fp64 512 64 1302 -7.0577 fp32 512 64 1329 -10.515 fp64 512 64 1329 -4.4899 fp32 512 64 1531 -18.4045 fp64 512 64 1531 9.5982 fp32 512 64 1313 -17.0858 fp64 512 64 1313 -2.02 fp32 512 64 1672 -6.1997 fp64 512 64 1672 1.1616 fp32 512 64 1851 -2.3241 fp64 512 64 1851 -1.0585 fp32 512 64 1584 -7.2199 fp64 512 64 1584 -1.0865 fp32 512 16 512 -14.8754 fp64 512 16 512 -10.987 fp32 512 16 2048 -18.3725 fp64 512 16 2048 1.5949 fp32 512 16 8192 2.163 fp64 512 16 8192 12.5431 fp32 512 16 1120 2.2558 fp64 512 16 1120 -2.7135 fp32 512 16 1215 -12.7228 fp64 512 16 1215 11.0343 fp32 512 16 1856 -6.986 fp64 512 16 1856 7.0687 fp32 512 16 1302 -9.3881 fp64 512 16 1302 -8.2974 fp32 512 16 1329 -11.5103 fp64 512 16 1329 18.9707 fp32 512 16 1531 -14.3721 fp64 512 16 1531 9.6774 fp32 512 16 1313 -16.5546 fp64 512 16 1313 11.7528 fp32 512 16 1672 -10.3689 fp64 512 16 1672 15.1197 fp32 512 16 1851 -11.6021 fp64 512 16 1851 8.2983 fp32 512 16 1584 -13.6702 fp64 512 16 1584 9.4635 fp32 2048 1024 512 -5.6482 fp64 2048 1024 512 1.45 fp32 2048 1024 2048 -0.066 fp64 2048 1024 2048 3.6549 fp32 2048 1024 8192 4.3953 fp64 2048 1024 8192 5.0636 fp32 2048 1024 1120 2.3119 fp64 2048 1024 1120 3.4102 fp32 2048 1024 1215 1.6251 fp64 2048 1024 1215 2.4538 fp32 2048 1024 1856 1.4219 fp64 2048 1024 1856 5.2966 fp32 2048 1024 1302 0.4938 fp64 2048 1024 1302 3.6871 fp32 2048 1024 1329 2.7753 fp64 2048 1024 1329 4.2955 fp32 2048 1024 1531 1.8766 fp64 2048 1024 1531 4.4579 fp32 2048 1024 1313 0.6639 fp64 2048 1024 1313 4.5556 fp32 2048 1024 1672 1.1072 fp64 2048 1024 1672 3.8653 fp32 2048 1024 1851 1.1566 fp64 2048 1024 1851 3.6434 fp32 2048 1024 1584 0.7806 fp64 2048 1024 1584 4.3265 fp32 2048 256 512 -10.7236 fp64 2048 256 512 1.011 fp32 2048 256 2048 2.2321 fp64 2048 256 2048 12.2771 fp32 2048 256 8192 8.0287 fp64 2048 256 8192 15.4497 fp32 2048 256 1120 -8.1388 fp64 2048 256 1120 5.8003 fp32 2048 256 1215 1.709 fp64 2048 256 1215 12.4369 fp32 2048 256 1856 5.1844 fp64 2048 256 1856 14.2236 fp32 2048 256 1302 2.8457 fp64 2048 256 1302 10.5728 fp32 2048 256 1329 -2.547 fp64 2048 256 1329 12.1123 fp32 2048 256 1531 2.4946 fp64 2048 256 1531 12.2398 fp32 2048 256 1313 6.1621 fp64 2048 256 1313 9.857 fp32 2048 256 1672 2.176 fp64 2048 256 1672 9.8899 fp32 2048 256 1851 4.6307 fp64 2048 256 1851 15.0223 fp32 2048 256 1584 3.5238 fp64 2048 256 1584 10.3181 fp32 2048 64 512 -11.5325 fp64 2048 64 512 8.5141 fp32 2048 64 2048 0.6066 fp64 2048 64 2048 25.8166 fp32 2048 64 8192 15.5994 fp64 2048 64 8192 29.453 fp32 2048 64 1120 1.5933 fp64 2048 64 1120 17.1686 fp32 2048 64 1215 -11.8064 fp64 2048 64 1215 21.7897 fp32 2048 64 1856 3.3061 fp64 2048 64 1856 17.6379 fp32 2048 64 1302 -1.201 fp64 2048 64 1302 26.775 fp32 2048 64 1329 -1.377 fp64 2048 64 1329 23.6142 fp32 2048 64 1531 0.9212 fp64 2048 64 1531 16.7177 fp32 2048 64 1313 2.8448 fp64 2048 64 1313 26.824 fp32 2048 64 1672 1.5334 fp64 2048 64 1672 23.7874 fp32 2048 64 1851 0.1934 fp64 2048 64 1851 25.1446 fp32 2048 64 1584 -2.8748 fp64 2048 64 1584 22.3902 fp32 8192 4096 512 0.0512 fp64 8192 4096 512 2.8049 fp32 8192 4096 2048 3.6683 fp64 8192 4096 2048 5.7372 fp32 8192 4096 8192 6.2501 fp64 8192 4096 8192 5.6644 fp32 8192 4096 1120 3.4347 fp64 8192 4096 1120 5.9099 fp32 8192 4096 1215 4.0591 fp64 8192 4096 1215 6.2049 fp32 8192 4096 1856 4.5046 fp64 8192 4096 1856 5.9 fp32 8192 4096 1302 3.8744 fp64 8192 4096 1302 5.74 fp32 8192 4096 1329 3.9169 fp64 8192 4096 1329 6.302 fp32 8192 4096 1531 5.0479 fp64 8192 4096 1531 6.048 fp32 8192 4096 1313 3.5261 fp64 8192 4096 1313 6.0544 fp32 8192 4096 1672 4.6081 fp64 8192 4096 1672 5.2568 fp32 8192 4096 1851 4.2022 fp64 8192 4096 1851 6.0934 fp32 8192 4096 1584 3.3852 fp64 8192 4096 1584 5.6772 fp32 8192 1024 512 3.7405 fp64 8192 1024 512 16.4627 fp32 8192 1024 2048 8.3918 fp64 8192 1024 2048 18.5254 fp32 8192 1024 8192 13.7773 fp64 8192 1024 8192 17.4314 fp32 8192 1024 1120 6.2023 fp64 8192 1024 1120 16.689 fp32 8192 1024 1215 9.5441 fp64 8192 1024 1215 19.7246 fp32 8192 1024 1856 9.864 fp64 8192 1024 1856 18.2895 fp32 8192 1024 1302 7.3145 fp64 8192 1024 1302 19.8528 fp32 8192 1024 1329 9.6131 fp64 8192 1024 1329 19.5526 fp32 8192 1024 1531 8.9847 fp64 8192 1024 1531 20.3696 fp32 8192 1024 1313 7.2819 fp64 8192 1024 1313 20.5361 fp32 8192 1024 1672 11.8095 fp64 8192 1024 1672 18.3047 fp32 8192 1024 1851 12.1042 fp64 8192 1024 1851 21.8124 fp32 8192 1024 1584 9.6549 fp64 8192 1024 1584 18.1818 fp32 8192 256 512 8.2649 fp64 8192 256 512 20.9372 fp32 8192 256 2048 15.6297 fp64 8192 256 2048 35.6407 fp32 8192 256 8192 21.7055 fp64 8192 256 8192 37.225 fp32 8192 256 1120 8.322 fp64 8192 256 1120 33.6497 fp32 8192 256 1215 12.9148 fp64 8192 256 1215 40.0554 fp32 8192 256 1856 12.2226 fp64 8192 256 1856 36.2642 fp32 8192 256 1302 12.2956 fp64 8192 256 1302 40.4711 fp32 8192 256 1329 10.2045 fp64 8192 256 1329 38.4891 fp32 8192 256 1531 14.9187 fp64 8192 256 1531 40.7874 fp32 8192 256 1313 9.5106 fp64 8192 256 1313 42.1367 fp32 8192 256 1672 15.2577 fp64 8192 256 1672 36.7527 fp32 8192 256 1851 15.668 fp64 8192 256 1851 40.2035 fp32 8192 256 1584 14.126 fp64 8192 256 1584 32.7602
This commit is contained in:
parent
674db81731
commit
b5214cab61
@ -2589,7 +2589,9 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "segment_reduction_ops",
|
||||
prefix = "segment_reduction_ops",
|
||||
deps = MATH_DEPS,
|
||||
deps = MATH_DEPS + if_cuda([
|
||||
":cuda_solvers",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
// See docs in ../ops/math_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/kernels/segment_reduction_ops.h"
|
||||
#include <vector>
|
||||
@ -32,6 +35,15 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||
#include "tensorflow/core/platform/cuda.h"
|
||||
|
||||
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
@ -183,6 +195,105 @@ class SegmentReductionOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
// SegmentSumGPUOp is a segment sum operator implemented for GPU only.
|
||||
// TODO: This implementation of SegmentSumGPUOp is sometimes slower than
|
||||
// its unsorted counterpart (mostly when problem size is small).
|
||||
// This is due to the following two main reasons and a cost-effective way
|
||||
// to resolve these problems is desirable.
|
||||
// 1. Sorted segment sum requires a memory transfer from device to host in
|
||||
// order to know the size of the output dimension whereas unsorted segment
|
||||
// sum receives the size of the output dimension as an input parameter.
|
||||
// 2. Sorted segment sum is essentially a tiled version of unsorted segment
|
||||
// sum and therefore such optimization comes at an inherent cost. However
|
||||
// such cost may not be justified when the problem size is small. When to
|
||||
// use the tiled version or the untiled version depends on many factors
|
||||
// including data alignments, ratio of calculation to memory traffic and
|
||||
// obviously, the problem sizes.
|
||||
template <class T, class Index>
|
||||
class SegmentSumGPUOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit SegmentSumGPUOp(OpKernelConstruction* context)
|
||||
: AsyncOpKernel(context) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& segment_ids = context->input(1);
|
||||
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, TensorShapeUtils::IsVector(segment_ids.shape()),
|
||||
errors::InvalidArgument("segment_ids should be a vector."), done);
|
||||
|
||||
const int64 num_indices = segment_ids.NumElements();
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, num_indices == input.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"segment_ids should be the same size as dimension 0 of"
|
||||
" input."),
|
||||
done);
|
||||
|
||||
if (num_indices == 0) {
|
||||
TensorShape output_shape = input.shape();
|
||||
output_shape.set_dim(0, 0);
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context, context->allocate_output(0, output_shape, &output), done);
|
||||
done();
|
||||
return;
|
||||
}
|
||||
|
||||
perftools::gputools::DeviceMemoryBase output_rows_device(
|
||||
(void*)(segment_ids.template flat<Index>().data() + (num_indices - 1)));
|
||||
ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
|
||||
|
||||
auto stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, stream
|
||||
->ThenMemcpy(output_rows_host.mutable_data(),
|
||||
output_rows_device, sizeof(Index))
|
||||
.ok(),
|
||||
errors::Internal(
|
||||
"SegmentSumGPUOp: failed to copy output_rows from device"),
|
||||
done);
|
||||
|
||||
functor::SegmentSumFunctor<T, Index> functor_;
|
||||
auto create_and_check_output = [context, output_rows_host, &input,
|
||||
&segment_ids, &functor_, done]() {
|
||||
// Ensure that within the callback, the proper GPU settings are
|
||||
// configured.
|
||||
auto stream = context->op_device_context()->stream();
|
||||
ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
||||
|
||||
Index output_rows = *output_rows_host.data();
|
||||
output_rows++;
|
||||
OP_REQUIRES_ASYNC(context, output_rows > 0,
|
||||
errors::InvalidArgument("segment ids must be >= 0"),
|
||||
done);
|
||||
|
||||
TensorShape output_shape = input.shape();
|
||||
output_shape.set_dim(0, output_rows);
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context, context->allocate_output(0, output_shape, &output), done);
|
||||
|
||||
auto output_flat = output->flat_outer_dims<T>();
|
||||
auto data_ptr = input.template flat<T>().data();
|
||||
auto segment_flat = segment_ids.flat<Index>();
|
||||
functor_(context, context->eigen_device<GPUDevice>(), output_rows,
|
||||
segment_ids.shape(), segment_flat, input.NumElements(), data_ptr,
|
||||
output_flat);
|
||||
|
||||
done();
|
||||
};
|
||||
|
||||
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
|
||||
stream, create_and_check_output);
|
||||
}
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \
|
||||
default_value) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
@ -227,6 +338,23 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
|
||||
#undef REGISTER_REAL_CPU_KERNELS_ALL
|
||||
#undef REGISTER_COMPLEX_CPU_KERNELS_ALL
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SegmentSum") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
SegmentSumGPUOp<type, index_type>)
|
||||
|
||||
#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \
|
||||
REGISTER_GPU_SORTED_KERNELS(type, int32); \
|
||||
REGISTER_GPU_SORTED_KERNELS(type, int64);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL);
|
||||
#undef REGISTER_GPU_SORTED_KERNELS
|
||||
#undef REGISTER_GPU_SORTED_KERNELS_ALL
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace functor {
|
||||
|
||||
// UnsortedSegmentSumFunctor implementation for CPUDevice.
|
||||
|
@ -26,6 +26,28 @@ namespace tensorflow {
|
||||
class OpKernelContext;
|
||||
|
||||
namespace functor {
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
// Functor for SegmentSumGPUOp.
|
||||
// 'output_rows': the number of output segments (unique segment ids in
|
||||
// 'segment_ids').
|
||||
// 'segment_ids_shape': shape of 'segment_ids' tensor.
|
||||
// 'segment_ids': unsorted map from input to output segment ids at which to
|
||||
// perform segment sum operation.
|
||||
// 'data_size': size of input data tensor.
|
||||
// 'data': input data tensor.
|
||||
// 'output': output reshaped to {output_rows, output.size/output_rows}
|
||||
template <typename T, typename Index>
|
||||
struct SegmentSumFunctor {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
||||
const Index output_rows, const TensorShape& segment_ids_shape,
|
||||
typename TTypes<Index>::ConstFlat segment_ids,
|
||||
const Index data_size, const T* data,
|
||||
typename TTypes<T, 2>::Tensor output);
|
||||
};
|
||||
#endif
|
||||
|
||||
// BaseFunctor for definition of UnsorteSegmentReductionOp
|
||||
// for usage without templates.
|
||||
template <typename Device, typename T, typename Index>
|
||||
|
@ -54,6 +54,77 @@ __device__ __forceinline__ void AccumulateInto(
|
||||
CudaAtomicAdd(dest_scalar + 1, value.imag());
|
||||
}
|
||||
|
||||
// SortedSegmentSumFunctor kernel reduces input data just as
|
||||
// UnsortedSegmentSumCustomKernel does except that input data
|
||||
// is partitioned along the outer reduction dimension. This is
|
||||
// because consecutive rows (elements in a row share the same
|
||||
// outer dimension index) in the flattened 2D input data likely
|
||||
// belong to the same segment in sorted segment sum operation.
|
||||
// Therefore such partitioning strategy has two advantages over
|
||||
// the UnsortedSegmentSumFunctor kernel:
|
||||
// 1. Each thread reduces across multiple rows before writing
|
||||
// answers to the global memory, we can therefore
|
||||
// write reduction results to global memory less often.
|
||||
// 2. We may know that the current thread is the only contributor
|
||||
// to an output element because of the increasing nature of segment
|
||||
// ids. In such cases, we do not need to use atomic operations
|
||||
// to write results to global memory.
|
||||
// In the flattened view of input data (with only outer and inner
|
||||
// dimension), every thread processes a strip of input data of
|
||||
// size OuterDimTileSize x 1. This strip runs across multiple
|
||||
// rows of input data and all reduction elements share one inner
|
||||
// dimension index.
|
||||
template <typename T, typename Index, int OuterDimTileSize>
|
||||
__global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size,
|
||||
const Index inner_dim_size,
|
||||
const Index output_outer_dim_size,
|
||||
const Index* segment_ids,
|
||||
const T* input, T* output,
|
||||
const Index total_stripe_count) {
|
||||
CUDA_1D_KERNEL_LOOP(stripe_index, total_stripe_count) {
|
||||
const Index segment_offset = stripe_index % inner_dim_size;
|
||||
const Index input_outer_dim_index_base =
|
||||
stripe_index / inner_dim_size * Index(OuterDimTileSize);
|
||||
|
||||
T sum = T(0);
|
||||
Index first_segment_id = segment_ids[input_outer_dim_index_base];
|
||||
Index last_output_segment_id = output_outer_dim_size;
|
||||
|
||||
const Index actual_stripe_height =
|
||||
min(Index(OuterDimTileSize),
|
||||
input_outer_dim_size - input_outer_dim_index_base);
|
||||
for (Index j = 0; j < actual_stripe_height; j++) {
|
||||
Index current_output_segment_id =
|
||||
segment_ids[input_outer_dim_index_base + j];
|
||||
// Decide whether to write result to global memory.
|
||||
// Result is only written to global memory if we move
|
||||
// to another segment. Otherwise we can keep accumulating
|
||||
// locally.
|
||||
if (current_output_segment_id > last_output_segment_id) {
|
||||
const Index output_index =
|
||||
last_output_segment_id * inner_dim_size + segment_offset;
|
||||
// decide whether to write result to global memory using atomic
|
||||
// operations
|
||||
if (last_output_segment_id == first_segment_id) {
|
||||
AccumulateInto<T>(output + output_index, sum);
|
||||
} else {
|
||||
*(output + output_index) = sum;
|
||||
}
|
||||
sum = T(0);
|
||||
}
|
||||
sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size +
|
||||
segment_offset);
|
||||
last_output_segment_id = current_output_segment_id;
|
||||
}
|
||||
// For the last result in a strip, always write using atomic operations
|
||||
// due to possible race conditions with threads computing
|
||||
// the following strip.
|
||||
const Index output_index =
|
||||
last_output_segment_id * inner_dim_size + segment_offset;
|
||||
AccumulateInto<T>(output + output_index, sum);
|
||||
}
|
||||
}
|
||||
|
||||
// UnsortedSegmentSumFunctor kernel processes 'input_total_size' elements.
|
||||
// Each element is mapped from input to output by a combination of its
|
||||
// 'segment_ids' mapping and 'inner_dim_size'.
|
||||
@ -80,6 +151,47 @@ __global__ void UnsortedSegmentSumCustomKernel(
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T, typename Index>
|
||||
void SegmentSumFunctor<T, Index>::operator()(
|
||||
OpKernelContext* ctx, const GPUDevice& d, const Index output_rows,
|
||||
const TensorShape& segment_ids_shape,
|
||||
typename TTypes<Index>::ConstFlat segment_ids, const Index data_size,
|
||||
const T* data, typename TTypes<T, 2>::Tensor output) {
|
||||
if (output.size() == 0) {
|
||||
return;
|
||||
}
|
||||
// Set 'output' to zeros.
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d);
|
||||
SetZero<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
output.size(), output.data());
|
||||
if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Launch kernel to compute sorted segment sum.
|
||||
// Notes:
|
||||
// *) 'input_total_size' is the total number of elements to process.
|
||||
// *) 'segment_ids.shape' is a prefix of data's shape.
|
||||
// *) 'input_outer_dim_size' is the total number of segments to process.
|
||||
const Index input_total_size = data_size;
|
||||
const Index input_outer_dim_size = segment_ids.dimension(0);
|
||||
const Index input_inner_dim_size = input_total_size / input_outer_dim_size;
|
||||
|
||||
const int OuterDimTileSize = 8;
|
||||
|
||||
const Index input_outer_dim_num_stripe =
|
||||
Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize));
|
||||
|
||||
const Index total_stripe_count =
|
||||
input_inner_dim_size * input_outer_dim_num_stripe;
|
||||
|
||||
config = GetCudaLaunchConfig(total_stripe_count, d);
|
||||
SortedSegmentSumCustomKernel<T, Index, OuterDimTileSize><<<
|
||||
config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
input_outer_dim_size, input_inner_dim_size, output_rows,
|
||||
segment_ids.data(), data, output.data(), total_stripe_count);
|
||||
};
|
||||
|
||||
// UnsortedSegmentSumFunctor implementation for GPUDevice.
|
||||
template <typename T, typename Index>
|
||||
struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>: UnsortedSegmentBaseFunctor<GPUDevice, T, Index> {
|
||||
@ -117,6 +229,15 @@ struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>: UnsortedSegmentBaseFuncto
|
||||
}
|
||||
};
|
||||
|
||||
#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \
|
||||
template struct SegmentSumFunctor<T, Index>
|
||||
|
||||
#define DEFINE_SORTED_GPU_SPECS(T) \
|
||||
DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \
|
||||
DEFINE_SORTED_GPU_SPECS_INDEX(T, int64);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS);
|
||||
|
||||
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
|
||||
template struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>
|
||||
|
||||
|
@ -683,13 +683,15 @@ cuda_py_test(
|
||||
|
||||
tf_py_test(
|
||||
name = "segment_reduction_ops_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["segment_reduction_ops_test.py"],
|
||||
additional_deps = [
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python:nn_grad",
|
||||
],
|
||||
)
|
||||
|
@ -18,12 +18,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -107,8 +112,8 @@ class SegmentReductionOpTest(SegmentReductionHelper):
|
||||
curr_ops_list = complex_ops_list
|
||||
else:
|
||||
curr_ops_list = ops_list
|
||||
|
||||
with self.test_session(use_gpu=False):
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_x, np_x = self._input(shape, dtype=dtype)
|
||||
for np_op1, np_op2, tf_op in curr_ops_list:
|
||||
np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2)
|
||||
@ -130,7 +135,8 @@ class SegmentReductionOpTest(SegmentReductionHelper):
|
||||
|
||||
def testSegmentIdsSize(self):
|
||||
shape = [4, 4]
|
||||
with self.test_session():
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_x, _ = self._input(shape)
|
||||
indices = [0, 1]
|
||||
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
|
||||
@ -140,16 +146,18 @@ class SegmentReductionOpTest(SegmentReductionHelper):
|
||||
def testSegmentIdsValid(self):
|
||||
# This is a baseline for the following SegmentIdsInvalid* tests.
|
||||
shape = [4, 4]
|
||||
with self.test_session():
|
||||
tf_x, _ = self._input(shape)
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
|
||||
indices = [0, 0, 0, 1]
|
||||
result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval()
|
||||
self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result)
|
||||
|
||||
def testSegmentIdsGreaterThanZero(self):
|
||||
shape = [4, 4]
|
||||
with self.test_session():
|
||||
tf_x, np_x = self._input(shape)
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
|
||||
indices = [1, 1, 2, 2]
|
||||
np_ans = self._segmentReduce(indices, np_x, np.add)
|
||||
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
|
||||
@ -158,8 +166,9 @@ class SegmentReductionOpTest(SegmentReductionHelper):
|
||||
|
||||
def testSegmentIdsHole(self):
|
||||
shape = [4, 4]
|
||||
with self.test_session():
|
||||
tf_x, np_x = self._input(shape)
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
|
||||
indices = [0, 0, 3, 3]
|
||||
np_ans = self._segmentReduce(indices, np_x, np.add)
|
||||
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
|
||||
@ -199,8 +208,9 @@ class SegmentReductionOpTest(SegmentReductionHelper):
|
||||
|
||||
def testSegmentIdsInvalid4(self):
|
||||
shape = [4, 4]
|
||||
with self.test_session():
|
||||
tf_x, _ = self._input(shape)
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
|
||||
indices = [0, 0, 0, -1]
|
||||
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
|
||||
with self.assertRaisesOpError("segment ids must be >= 0"):
|
||||
@ -208,8 +218,9 @@ class SegmentReductionOpTest(SegmentReductionHelper):
|
||||
|
||||
def testSegmentIdsInvalid5(self):
|
||||
shape = [4, 4]
|
||||
with self.test_session():
|
||||
tf_x, _ = self._input(shape)
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
|
||||
indices = [0, 0, 0, -2]
|
||||
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
|
||||
with self.assertRaisesOpError("segment ids must be >= 0"):
|
||||
@ -635,6 +646,64 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
|
||||
with self.assertRaisesOpError(r"Segment id 0 out of range \[0, 0\)"):
|
||||
s.eval()
|
||||
|
||||
class SegmentReductionOpBenchmark(test.Benchmark):
|
||||
outer_dim_options = [2**x for x in range(9, 14, 2)]
|
||||
ratio_options = [2**x for x in range(1, 6, 2)]
|
||||
inner_dim_options = [2**x for x in range(9, 14, 2)]
|
||||
#randomly generated sizes with less alignments
|
||||
inner_dim_options += [1120, 1215, 1856, 1302, 1329, 1531, 1313, 1672, 1851, 1584]
|
||||
dtype_options = [np.float32, np.float64]
|
||||
options = (outer_dim_options,
|
||||
ratio_options, inner_dim_options, dtype_options)
|
||||
op_functors = [lambda vc, vs, seg_ids:
|
||||
("sorted", math_ops.segment_sum(vc, vs)),
|
||||
lambda vc, vs, seg_ids:
|
||||
("unsorted", math_ops.unsorted_segment_sum(vc, vs, seg_ids[-1]+1))]
|
||||
repeat = 10
|
||||
|
||||
def _npTypeToStr(self, t):
|
||||
if t == np.float32:
|
||||
return "fp32"
|
||||
if t == np.float64:
|
||||
return "fp64"
|
||||
|
||||
def _runGraph(self, op_functor, outer_dim, ratio, inner_dim, dtype):
|
||||
output_outer_dim = int(outer_dim/ratio)
|
||||
const = np.random.randint(5, size=(outer_dim, inner_dim))
|
||||
seg_ids = np.sort(np.random.randint(
|
||||
output_outer_dim, size=outer_dim))
|
||||
vs = variables.Variable(seg_ids.astype(np.int32))
|
||||
with ops.device("/gpu:0"):
|
||||
vc = variables.Variable(const.astype(dtype))
|
||||
name, op = op_functor(vc, vs, seg_ids)
|
||||
with session.Session() as sess:
|
||||
variables.global_variables_initializer().run()
|
||||
r = self.run_op_benchmark(sess, op, min_iters=self.repeat,
|
||||
name="_".join(map(str,
|
||||
[name,
|
||||
outer_dim,
|
||||
ratio,
|
||||
inner_dim,
|
||||
self._npTypeToStr(dtype)])))
|
||||
return name, r["wall_time"]
|
||||
|
||||
def benchmarkSegmentSumGPU(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
|
||||
output_outer_dim = int(outer_dim/ratio)
|
||||
op_functor = self.op_functors[0]
|
||||
with ops.Graph().as_default():
|
||||
self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
|
||||
|
||||
def benchmarkUnsortedSegmentSumGPU(self):
|
||||
if not test.is_gpu_available(cuda_only=True):
|
||||
return
|
||||
for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
|
||||
output_outer_dim = int(outer_dim/ratio)
|
||||
op_functor = self.op_functors[1]
|
||||
with ops.Graph().as_default():
|
||||
self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user