Autotune bias_add_grad GPU kernel using two candidates: customized bias_add_grad and reduce_sum.

PiperOrigin-RevId: 200306546
This commit is contained in:
Mingxing Tan 2018-06-12 17:02:51 -07:00 committed by TensorFlower Gardener
parent 19011fa34b
commit abc55107eb
4 changed files with 256 additions and 6 deletions

View File

@ -3381,7 +3381,10 @@ tf_kernel_library(
tf_kernel_library(
name = "bias_op",
prefix = "bias_op",
deps = NN_DEPS,
deps = NN_DEPS + if_cuda([
":reduction_ops",
"@cub_archive//:cub",
]),
)
tf_kernel_library(

View File

@ -29,6 +29,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/bias_op_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
@ -363,6 +364,40 @@ class BiasOp<GPUDevice, T> : public BinaryOp<T> {
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
struct BiasGradAutotuneGroup {
static string name() { return "BiasGrad"; }
};
class BiasAddGradGPUConfig {
public:
BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {}
string ToString() const {
if (mode_ == BiasAddGradGPUMode::kNative) {
return "native CUDA kernel.";
}
if (mode_ == BiasAddGradGPUMode::kReduction) {
return "cub reduction kernel.";
}
return "unknown kernel.";
}
BiasAddGradGPUMode get_mode() const { return mode_; }
void set_mode(BiasAddGradGPUMode val) { mode_ = val; }
bool operator==(const BiasAddGradGPUConfig& other) const {
return this->mode_ == other.get_mode();
}
bool operator!=(const BiasAddGradGPUConfig& other) const {
return !(*this == other);
}
private:
BiasAddGradGPUMode mode_;
};
typedef AutoTuneSingleton<BiasGradAutotuneGroup, BiasAddParams,
BiasAddGradGPUConfig>
AutotuneBiasGrad;
template <typename T>
class BiasGradOp<GPUDevice, T> : public OpKernel {
public:
@ -377,6 +412,49 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
}
}
void ComputeWithCustomKernel(OpKernelContext* context,
const Tensor& output_backprop, int32 batch,
int32 width, int32 height, int32 channel,
Tensor* output) {
BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
output_backprop.template flat<T>().data(),
output->flat<T>().data(), batch, width, height,
channel, data_format_);
}
void ComputeWithReduceSum(OpKernelContext* context,
const Tensor& output_backprop, int32 batch,
int32 width, int32 height, int32 channel,
Tensor* output) {
if (data_format_ == FORMAT_NCHW) {
int32 row_count = batch * channel;
int32 col_count = height * width;
Tensor temp_grad_outputs;
// For 'NCHW' format, we perform reduction twice: first HW, then N.
TensorShape temp_grad_output_shape{row_count, col_count};
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
temp_grad_output_shape,
&temp_grad_outputs));
BiasGradGPU<T>::DoRowReduction(
context, temp_grad_outputs.flat<T>().data(),
output_backprop.template flat<T>().data(), row_count, col_count);
row_count = batch;
col_count = channel;
BiasGradGPU<T>::DoColReduction(context, output->flat<T>().data(),
temp_grad_outputs.flat<T>().data(),
row_count, col_count);
} else {
// For 'NHWC', we simply apply reduction once on NHW.
int32 row_count = batch * height * width;
int32 col_count = channel;
BiasGradGPU<T>::DoColReduction(
context, const_cast<T*>(output->flat<T>().data()),
reinterpret_cast<const T*>(output_backprop.template flat<T>().data()),
row_count, col_count);
}
}
void Compute(OpKernelContext* context) override {
const Tensor& output_backprop = context->input(0);
@ -396,11 +474,65 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
se::DeviceMemoryBase output_ptr(output->flat<T>().data(),
output->NumElements() * sizeof(T));
stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
if (output_backprop.NumElements() > 0) {
BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
output_backprop.template flat<T>().data(),
output->flat<T>().data(), batch, width, height,
channel, data_format_);
if (output_backprop.NumElements() <= 0) return;
int device_id = stream->parent()->device_ordinal();
DataType dtype = output_backprop.dtype();
BiasAddParams bias_parameters = {
{batch, height * width, channel},
data_format_,
dtype,
device_id,
};
// Autotune two algorithm: customized
BiasAddGradGPUConfig algo_config;
if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) {
BiasGradGPUProfileResult best_result;
// Initialize the timer.
perftools::gputools::Timer timer(stream->parent());
stream->InitTimer(&timer);
stream->ThenStartTimer(&timer);
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
channel, output);
stream->ThenStopTimer(&timer);
uint64 elapsed_microseconds = timer.Microseconds();
VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
<< " Native algo latency: " << elapsed_microseconds;
if (elapsed_microseconds < best_result.elapsed_time()) {
best_result.set_algorithm(BiasAddGradGPUMode::kNative);
best_result.set_elapsed_time(elapsed_microseconds);
}
// Try reduction and profile.
stream->ThenStartTimer(&timer);
ComputeWithReduceSum(context, output_backprop, batch, width, height,
channel, output);
stream->ThenStopTimer(&timer);
elapsed_microseconds = timer.Microseconds();
VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
<< " Reduction algo latency: " << elapsed_microseconds;
if (elapsed_microseconds < best_result.elapsed_time()) {
best_result.set_algorithm(BiasAddGradGPUMode::kReduction);
best_result.set_elapsed_time(elapsed_microseconds);
}
algo_config.set_mode(best_result.algorithm());
AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config);
// Results are already available during autotune, so no need to continue.
return;
}
// Choose the best algorithm based on autotune results.
if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
ComputeWithReduceSum(context, output_backprop, batch, width, height,
channel, output);
} else {
// Default to the customized kernel.
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
channel, output);
}
}

View File

@ -24,6 +24,14 @@ limitations under the License.
#include "tensorflow/core/kernels/bias_op_gpu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
#include "tensorflow/core/kernels/reduction_ops_common.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
@ -239,6 +247,26 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
}
}
template <typename T>
void BiasGradGPU<T>::DoRowReduction(OpKernelContext* context, T* output,
const T* input, int rows, int cols) {
typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
Constants<GPUDevice> constants;
cub::Sum op;
functor::ReduceImpl<T, cub::Sum, T*, const T*, ReductionAxes>(
context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
}
template <typename T>
void BiasGradGPU<T>::DoColReduction(OpKernelContext* context, T* output,
const T* input, int rows, int cols) {
typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
Constants<GPUDevice> constants;
cub::Sum op;
functor::ReduceImpl<T, cub::Sum, T*, const T*, ReductionAxes>(
context, output, input, 2, rows, cols, 1, 1, constants.kZero, op);
}
#define DEFINE_GPU_SPECS(T) \
template struct BiasGPU<T>; \
template struct BiasGradGPU<T>;

View File

@ -19,7 +19,9 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@ -38,6 +40,91 @@ struct BiasGradGPU {
static void compute(const GPUDevice& device, const T* output_backprop,
T* bias_backprop, int32 batch, int32 height, int32 width,
int32 channel, TensorFormat data_format);
static void DoRowReduction(OpKernelContext* context, T* output,
const T* input, int rows, int cols);
static void DoColReduction(OpKernelContext* context, T* output,
const T* input, int rows, int cols);
};
enum class BiasAddGradGPUMode {
kInvalid = 0,
kNative = 1,
kReduction = 2,
};
// Describe the BiasGradGPU result from a perf experiment.
//
// Arguments:
// algorithm: returns the method to use for bias add grad.
// elapsed_time; returns the measured elapsed time in microseconds.
class BiasGradGPUProfileResult {
public:
bool is_valid() const {
return (algorithm_ != BiasAddGradGPUMode::kInvalid &&
elapsed_time_ != std::numeric_limits<float>::max());
}
BiasAddGradGPUMode algorithm() const { return algorithm_; }
void set_algorithm(BiasAddGradGPUMode val) { algorithm_ = val; }
uint64 elapsed_time() const { return elapsed_time_; }
void set_elapsed_time(uint64 val) { elapsed_time_ = val; }
private:
BiasAddGradGPUMode algorithm_ = BiasAddGradGPUMode::kInvalid;
uint64 elapsed_time_ = std::numeric_limits<uint64>::max();
};
// Encapsulate all the shape information that is used in bias add grad
// operations.
class BiasAddParams {
public:
// We use a list to maintain both the shape value and the order (data format).
using SpatialArray = gtl::InlinedVector<int64, 4>;
BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format,
DataType dtype, int device_id)
: in_shape_(in_shape),
data_format_(data_format),
dtype_(dtype),
device_id_(device_id) {
for (int64 val : in_shape_) {
hash_code_ = Hash64Combine(hash_code_, val);
}
hash_code_ = Hash64Combine(hash_code_, data_format);
hash_code_ = Hash64Combine(hash_code_, dtype);
hash_code_ = Hash64Combine(hash_code_, device_id);
}
bool operator==(const BiasAddParams& other) const {
return this->get_data_as_tuple() == other.get_data_as_tuple();
}
bool operator!=(const BiasAddParams& other) const {
return !(*this == other);
}
uint64 hash() const { return hash_code_; }
string ToString() const {
// clang-format off
return strings::StrCat(
"(", str_util::Join(in_shape_, ", "), "), ",
data_format_, ", ", dtype_, ", ", device_id_);
// clang-format on
}
protected:
using ParamsDataType = std::tuple<SpatialArray, TensorFormat, DataType, int>;
ParamsDataType get_data_as_tuple() const {
return std::make_tuple(in_shape_, data_format_, dtype_, device_id_);
}
uint64 hash_code_ = 0;
private:
SpatialArray in_shape_;
TensorFormat data_format_;
DataType dtype_;
int device_id_;
};
} // namespace tensorflow