From abc55107eb7a03fe3d83f95fd5e1b8e4def90826 Mon Sep 17 00:00:00 2001 From: Mingxing Tan Date: Tue, 12 Jun 2018 17:02:51 -0700 Subject: [PATCH] Autotune bias_add_grad GPU kernel using two candidates: customized bias_add_grad and reduce_sum. PiperOrigin-RevId: 200306546 --- tensorflow/core/kernels/BUILD | 5 +- tensorflow/core/kernels/bias_op.cc | 142 +++++++++++++++++++++- tensorflow/core/kernels/bias_op_gpu.cu.cc | 28 +++++ tensorflow/core/kernels/bias_op_gpu.h | 87 +++++++++++++ 4 files changed, 256 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5e4c8a78b06..6487cd39712 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3381,7 +3381,10 @@ tf_kernel_library( tf_kernel_library( name = "bias_op", prefix = "bias_op", - deps = NN_DEPS, + deps = NN_DEPS + if_cuda([ + ":reduction_ops", + "@cub_archive//:cub", + ]), ) tf_kernel_library( diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index 9fda7169a8b..127c3a1ca16 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -29,6 +29,7 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/kernels/bias_op_gpu.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/stream_executor/cuda/cuda_stream.h" #endif // GOOGLE_CUDA namespace tensorflow { @@ -363,6 +364,40 @@ class BiasOp : public BinaryOp { TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL +struct BiasGradAutotuneGroup { + static string name() { return "BiasGrad"; } +}; + +class BiasAddGradGPUConfig { + public: + BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {} + string ToString() const { + if (mode_ == BiasAddGradGPUMode::kNative) { + return "native CUDA kernel."; + } + if (mode_ == BiasAddGradGPUMode::kReduction) { + return "cub reduction kernel."; + } + return "unknown kernel."; + } + BiasAddGradGPUMode get_mode() const { return mode_; } + void set_mode(BiasAddGradGPUMode val) { mode_ = val; } + + bool operator==(const BiasAddGradGPUConfig& other) const { + return this->mode_ == other.get_mode(); + } + + bool operator!=(const BiasAddGradGPUConfig& other) const { + return !(*this == other); + } + + private: + BiasAddGradGPUMode mode_; +}; +typedef AutoTuneSingleton + AutotuneBiasGrad; + template class BiasGradOp : public OpKernel { public: @@ -377,6 +412,49 @@ class BiasGradOp : public OpKernel { } } + void ComputeWithCustomKernel(OpKernelContext* context, + const Tensor& output_backprop, int32 batch, + int32 width, int32 height, int32 channel, + Tensor* output) { + BiasGradGPU::compute(context->template eigen_device(), + output_backprop.template flat().data(), + output->flat().data(), batch, width, height, + channel, data_format_); + } + + void ComputeWithReduceSum(OpKernelContext* context, + const Tensor& output_backprop, int32 batch, + int32 width, int32 height, int32 channel, + Tensor* output) { + if (data_format_ == FORMAT_NCHW) { + int32 row_count = batch * channel; + int32 col_count = height * width; + Tensor temp_grad_outputs; + // For 'NCHW' format, we perform reduction twice: first HW, then N. + TensorShape temp_grad_output_shape{row_count, col_count}; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + temp_grad_output_shape, + &temp_grad_outputs)); + BiasGradGPU::DoRowReduction( + context, temp_grad_outputs.flat().data(), + output_backprop.template flat().data(), row_count, col_count); + + row_count = batch; + col_count = channel; + BiasGradGPU::DoColReduction(context, output->flat().data(), + temp_grad_outputs.flat().data(), + row_count, col_count); + } else { + // For 'NHWC', we simply apply reduction once on NHW. + int32 row_count = batch * height * width; + int32 col_count = channel; + BiasGradGPU::DoColReduction( + context, const_cast(output->flat().data()), + reinterpret_cast(output_backprop.template flat().data()), + row_count, col_count); + } + } + void Compute(OpKernelContext* context) override { const Tensor& output_backprop = context->input(0); @@ -396,11 +474,65 @@ class BiasGradOp : public OpKernel { se::DeviceMemoryBase output_ptr(output->flat().data(), output->NumElements() * sizeof(T)); stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T)); - if (output_backprop.NumElements() > 0) { - BiasGradGPU::compute(context->template eigen_device(), - output_backprop.template flat().data(), - output->flat().data(), batch, width, height, - channel, data_format_); + if (output_backprop.NumElements() <= 0) return; + + int device_id = stream->parent()->device_ordinal(); + DataType dtype = output_backprop.dtype(); + BiasAddParams bias_parameters = { + {batch, height * width, channel}, + data_format_, + dtype, + device_id, + }; + + // Autotune two algorithm: customized + BiasAddGradGPUConfig algo_config; + if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) { + BiasGradGPUProfileResult best_result; + // Initialize the timer. + perftools::gputools::Timer timer(stream->parent()); + stream->InitTimer(&timer); + stream->ThenStartTimer(&timer); + ComputeWithCustomKernel(context, output_backprop, batch, width, height, + channel, output); + stream->ThenStopTimer(&timer); + uint64 elapsed_microseconds = timer.Microseconds(); + VLOG(1) << "BiasAddGrad " << bias_parameters.ToString() + << " Native algo latency: " << elapsed_microseconds; + if (elapsed_microseconds < best_result.elapsed_time()) { + best_result.set_algorithm(BiasAddGradGPUMode::kNative); + best_result.set_elapsed_time(elapsed_microseconds); + } + + // Try reduction and profile. + stream->ThenStartTimer(&timer); + ComputeWithReduceSum(context, output_backprop, batch, width, height, + channel, output); + stream->ThenStopTimer(&timer); + + elapsed_microseconds = timer.Microseconds(); + VLOG(1) << "BiasAddGrad " << bias_parameters.ToString() + << " Reduction algo latency: " << elapsed_microseconds; + if (elapsed_microseconds < best_result.elapsed_time()) { + best_result.set_algorithm(BiasAddGradGPUMode::kReduction); + best_result.set_elapsed_time(elapsed_microseconds); + } + + algo_config.set_mode(best_result.algorithm()); + AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config); + + // Results are already available during autotune, so no need to continue. + return; + } + + // Choose the best algorithm based on autotune results. + if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) { + ComputeWithReduceSum(context, output_backprop, batch, width, height, + channel, output); + } else { + // Default to the customized kernel. + ComputeWithCustomKernel(context, output_backprop, batch, width, height, + channel, output); } } diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index 754b93b073a..1a7211a7cba 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -24,6 +24,14 @@ limitations under the License. #include "tensorflow/core/kernels/bias_op_gpu.h" #include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" +#include "tensorflow/core/kernels/reduction_ops_common.h" + namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -239,6 +247,26 @@ void BiasGradGPU::compute(const GPUDevice& d, const T* output_backprop, } } +template +void BiasGradGPU::DoRowReduction(OpKernelContext* context, T* output, + const T* input, int rows, int cols) { + typedef const Eigen::array::Tensor::Index, 1>& ReductionAxes; + Constants constants; + cub::Sum op; + functor::ReduceImpl( + context, output, input, 2, rows, cols, 1, 1, constants.kOne, op); +} + +template +void BiasGradGPU::DoColReduction(OpKernelContext* context, T* output, + const T* input, int rows, int cols) { + typedef const Eigen::array::Tensor::Index, 1>& ReductionAxes; + Constants constants; + cub::Sum op; + functor::ReduceImpl( + context, output, input, 2, rows, cols, 1, 1, constants.kZero, op); +} + #define DEFINE_GPU_SPECS(T) \ template struct BiasGPU; \ template struct BiasGradGPU; diff --git a/tensorflow/core/kernels/bias_op_gpu.h b/tensorflow/core/kernels/bias_op_gpu.h index 9f14cc296f6..60c274c8262 100644 --- a/tensorflow/core/kernels/bias_op_gpu.h +++ b/tensorflow/core/kernels/bias_op_gpu.h @@ -19,7 +19,9 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -38,6 +40,91 @@ struct BiasGradGPU { static void compute(const GPUDevice& device, const T* output_backprop, T* bias_backprop, int32 batch, int32 height, int32 width, int32 channel, TensorFormat data_format); + + static void DoRowReduction(OpKernelContext* context, T* output, + const T* input, int rows, int cols); + + static void DoColReduction(OpKernelContext* context, T* output, + const T* input, int rows, int cols); +}; + +enum class BiasAddGradGPUMode { + kInvalid = 0, + kNative = 1, + kReduction = 2, +}; + +// Describe the BiasGradGPU result from a perf experiment. +// +// Arguments: +// algorithm: returns the method to use for bias add grad. +// elapsed_time; returns the measured elapsed time in microseconds. +class BiasGradGPUProfileResult { + public: + bool is_valid() const { + return (algorithm_ != BiasAddGradGPUMode::kInvalid && + elapsed_time_ != std::numeric_limits::max()); + } + BiasAddGradGPUMode algorithm() const { return algorithm_; } + void set_algorithm(BiasAddGradGPUMode val) { algorithm_ = val; } + uint64 elapsed_time() const { return elapsed_time_; } + void set_elapsed_time(uint64 val) { elapsed_time_ = val; } + + private: + BiasAddGradGPUMode algorithm_ = BiasAddGradGPUMode::kInvalid; + uint64 elapsed_time_ = std::numeric_limits::max(); +}; + +// Encapsulate all the shape information that is used in bias add grad +// operations. +class BiasAddParams { + public: + // We use a list to maintain both the shape value and the order (data format). + using SpatialArray = gtl::InlinedVector; + BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format, + DataType dtype, int device_id) + : in_shape_(in_shape), + data_format_(data_format), + dtype_(dtype), + device_id_(device_id) { + for (int64 val : in_shape_) { + hash_code_ = Hash64Combine(hash_code_, val); + } + hash_code_ = Hash64Combine(hash_code_, data_format); + hash_code_ = Hash64Combine(hash_code_, dtype); + hash_code_ = Hash64Combine(hash_code_, device_id); + } + bool operator==(const BiasAddParams& other) const { + return this->get_data_as_tuple() == other.get_data_as_tuple(); + } + + bool operator!=(const BiasAddParams& other) const { + return !(*this == other); + } + uint64 hash() const { return hash_code_; } + + string ToString() const { + // clang-format off + return strings::StrCat( + "(", str_util::Join(in_shape_, ", "), "), ", + data_format_, ", ", dtype_, ", ", device_id_); + // clang-format on + } + + protected: + using ParamsDataType = std::tuple; + + ParamsDataType get_data_as_tuple() const { + return std::make_tuple(in_shape_, data_format_, dtype_, device_id_); + } + + uint64 hash_code_ = 0; + + private: + SpatialArray in_shape_; + TensorFormat data_format_; + DataType dtype_; + int device_id_; }; } // namespace tensorflow