Autotune bias_add_grad GPU kernel using two candidates: customized bias_add_grad and reduce_sum.
PiperOrigin-RevId: 200306546
This commit is contained in:
parent
19011fa34b
commit
abc55107eb
@ -3381,7 +3381,10 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "bias_op",
|
||||
prefix = "bias_op",
|
||||
deps = NN_DEPS,
|
||||
deps = NN_DEPS + if_cuda([
|
||||
":reduction_ops",
|
||||
"@cub_archive//:cub",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/bias_op_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
@ -363,6 +364,40 @@ class BiasOp<GPUDevice, T> : public BinaryOp<T> {
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
struct BiasGradAutotuneGroup {
|
||||
static string name() { return "BiasGrad"; }
|
||||
};
|
||||
|
||||
class BiasAddGradGPUConfig {
|
||||
public:
|
||||
BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {}
|
||||
string ToString() const {
|
||||
if (mode_ == BiasAddGradGPUMode::kNative) {
|
||||
return "native CUDA kernel.";
|
||||
}
|
||||
if (mode_ == BiasAddGradGPUMode::kReduction) {
|
||||
return "cub reduction kernel.";
|
||||
}
|
||||
return "unknown kernel.";
|
||||
}
|
||||
BiasAddGradGPUMode get_mode() const { return mode_; }
|
||||
void set_mode(BiasAddGradGPUMode val) { mode_ = val; }
|
||||
|
||||
bool operator==(const BiasAddGradGPUConfig& other) const {
|
||||
return this->mode_ == other.get_mode();
|
||||
}
|
||||
|
||||
bool operator!=(const BiasAddGradGPUConfig& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
private:
|
||||
BiasAddGradGPUMode mode_;
|
||||
};
|
||||
typedef AutoTuneSingleton<BiasGradAutotuneGroup, BiasAddParams,
|
||||
BiasAddGradGPUConfig>
|
||||
AutotuneBiasGrad;
|
||||
|
||||
template <typename T>
|
||||
class BiasGradOp<GPUDevice, T> : public OpKernel {
|
||||
public:
|
||||
@ -377,6 +412,49 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeWithCustomKernel(OpKernelContext* context,
|
||||
const Tensor& output_backprop, int32 batch,
|
||||
int32 width, int32 height, int32 channel,
|
||||
Tensor* output) {
|
||||
BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
|
||||
output_backprop.template flat<T>().data(),
|
||||
output->flat<T>().data(), batch, width, height,
|
||||
channel, data_format_);
|
||||
}
|
||||
|
||||
void ComputeWithReduceSum(OpKernelContext* context,
|
||||
const Tensor& output_backprop, int32 batch,
|
||||
int32 width, int32 height, int32 channel,
|
||||
Tensor* output) {
|
||||
if (data_format_ == FORMAT_NCHW) {
|
||||
int32 row_count = batch * channel;
|
||||
int32 col_count = height * width;
|
||||
Tensor temp_grad_outputs;
|
||||
// For 'NCHW' format, we perform reduction twice: first HW, then N.
|
||||
TensorShape temp_grad_output_shape{row_count, col_count};
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
|
||||
temp_grad_output_shape,
|
||||
&temp_grad_outputs));
|
||||
BiasGradGPU<T>::DoRowReduction(
|
||||
context, temp_grad_outputs.flat<T>().data(),
|
||||
output_backprop.template flat<T>().data(), row_count, col_count);
|
||||
|
||||
row_count = batch;
|
||||
col_count = channel;
|
||||
BiasGradGPU<T>::DoColReduction(context, output->flat<T>().data(),
|
||||
temp_grad_outputs.flat<T>().data(),
|
||||
row_count, col_count);
|
||||
} else {
|
||||
// For 'NHWC', we simply apply reduction once on NHW.
|
||||
int32 row_count = batch * height * width;
|
||||
int32 col_count = channel;
|
||||
BiasGradGPU<T>::DoColReduction(
|
||||
context, const_cast<T*>(output->flat<T>().data()),
|
||||
reinterpret_cast<const T*>(output_backprop.template flat<T>().data()),
|
||||
row_count, col_count);
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& output_backprop = context->input(0);
|
||||
|
||||
@ -396,11 +474,65 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
|
||||
se::DeviceMemoryBase output_ptr(output->flat<T>().data(),
|
||||
output->NumElements() * sizeof(T));
|
||||
stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
|
||||
if (output_backprop.NumElements() > 0) {
|
||||
BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
|
||||
output_backprop.template flat<T>().data(),
|
||||
output->flat<T>().data(), batch, width, height,
|
||||
channel, data_format_);
|
||||
if (output_backprop.NumElements() <= 0) return;
|
||||
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = output_backprop.dtype();
|
||||
BiasAddParams bias_parameters = {
|
||||
{batch, height * width, channel},
|
||||
data_format_,
|
||||
dtype,
|
||||
device_id,
|
||||
};
|
||||
|
||||
// Autotune two algorithm: customized
|
||||
BiasAddGradGPUConfig algo_config;
|
||||
if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) {
|
||||
BiasGradGPUProfileResult best_result;
|
||||
// Initialize the timer.
|
||||
perftools::gputools::Timer timer(stream->parent());
|
||||
stream->InitTimer(&timer);
|
||||
stream->ThenStartTimer(&timer);
|
||||
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
|
||||
channel, output);
|
||||
stream->ThenStopTimer(&timer);
|
||||
uint64 elapsed_microseconds = timer.Microseconds();
|
||||
VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
|
||||
<< " Native algo latency: " << elapsed_microseconds;
|
||||
if (elapsed_microseconds < best_result.elapsed_time()) {
|
||||
best_result.set_algorithm(BiasAddGradGPUMode::kNative);
|
||||
best_result.set_elapsed_time(elapsed_microseconds);
|
||||
}
|
||||
|
||||
// Try reduction and profile.
|
||||
stream->ThenStartTimer(&timer);
|
||||
ComputeWithReduceSum(context, output_backprop, batch, width, height,
|
||||
channel, output);
|
||||
stream->ThenStopTimer(&timer);
|
||||
|
||||
elapsed_microseconds = timer.Microseconds();
|
||||
VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
|
||||
<< " Reduction algo latency: " << elapsed_microseconds;
|
||||
if (elapsed_microseconds < best_result.elapsed_time()) {
|
||||
best_result.set_algorithm(BiasAddGradGPUMode::kReduction);
|
||||
best_result.set_elapsed_time(elapsed_microseconds);
|
||||
}
|
||||
|
||||
algo_config.set_mode(best_result.algorithm());
|
||||
AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config);
|
||||
|
||||
// Results are already available during autotune, so no need to continue.
|
||||
return;
|
||||
}
|
||||
|
||||
// Choose the best algorithm based on autotune results.
|
||||
if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
|
||||
ComputeWithReduceSum(context, output_backprop, batch, width, height,
|
||||
channel, output);
|
||||
} else {
|
||||
// Default to the customized kernel.
|
||||
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
|
||||
channel, output);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/bias_op_gpu.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
|
||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
@ -239,6 +247,26 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BiasGradGPU<T>::DoRowReduction(OpKernelContext* context, T* output,
|
||||
const T* input, int rows, int cols) {
|
||||
typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
|
||||
Constants<GPUDevice> constants;
|
||||
cub::Sum op;
|
||||
functor::ReduceImpl<T, cub::Sum, T*, const T*, ReductionAxes>(
|
||||
context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BiasGradGPU<T>::DoColReduction(OpKernelContext* context, T* output,
|
||||
const T* input, int rows, int cols) {
|
||||
typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
|
||||
Constants<GPUDevice> constants;
|
||||
cub::Sum op;
|
||||
functor::ReduceImpl<T, cub::Sum, T*, const T*, ReductionAxes>(
|
||||
context, output, input, 2, rows, cols, 1, 1, constants.kZero, op);
|
||||
}
|
||||
|
||||
#define DEFINE_GPU_SPECS(T) \
|
||||
template struct BiasGPU<T>; \
|
||||
template struct BiasGradGPU<T>;
|
||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -38,6 +40,91 @@ struct BiasGradGPU {
|
||||
static void compute(const GPUDevice& device, const T* output_backprop,
|
||||
T* bias_backprop, int32 batch, int32 height, int32 width,
|
||||
int32 channel, TensorFormat data_format);
|
||||
|
||||
static void DoRowReduction(OpKernelContext* context, T* output,
|
||||
const T* input, int rows, int cols);
|
||||
|
||||
static void DoColReduction(OpKernelContext* context, T* output,
|
||||
const T* input, int rows, int cols);
|
||||
};
|
||||
|
||||
enum class BiasAddGradGPUMode {
|
||||
kInvalid = 0,
|
||||
kNative = 1,
|
||||
kReduction = 2,
|
||||
};
|
||||
|
||||
// Describe the BiasGradGPU result from a perf experiment.
|
||||
//
|
||||
// Arguments:
|
||||
// algorithm: returns the method to use for bias add grad.
|
||||
// elapsed_time; returns the measured elapsed time in microseconds.
|
||||
class BiasGradGPUProfileResult {
|
||||
public:
|
||||
bool is_valid() const {
|
||||
return (algorithm_ != BiasAddGradGPUMode::kInvalid &&
|
||||
elapsed_time_ != std::numeric_limits<float>::max());
|
||||
}
|
||||
BiasAddGradGPUMode algorithm() const { return algorithm_; }
|
||||
void set_algorithm(BiasAddGradGPUMode val) { algorithm_ = val; }
|
||||
uint64 elapsed_time() const { return elapsed_time_; }
|
||||
void set_elapsed_time(uint64 val) { elapsed_time_ = val; }
|
||||
|
||||
private:
|
||||
BiasAddGradGPUMode algorithm_ = BiasAddGradGPUMode::kInvalid;
|
||||
uint64 elapsed_time_ = std::numeric_limits<uint64>::max();
|
||||
};
|
||||
|
||||
// Encapsulate all the shape information that is used in bias add grad
|
||||
// operations.
|
||||
class BiasAddParams {
|
||||
public:
|
||||
// We use a list to maintain both the shape value and the order (data format).
|
||||
using SpatialArray = gtl::InlinedVector<int64, 4>;
|
||||
BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format,
|
||||
DataType dtype, int device_id)
|
||||
: in_shape_(in_shape),
|
||||
data_format_(data_format),
|
||||
dtype_(dtype),
|
||||
device_id_(device_id) {
|
||||
for (int64 val : in_shape_) {
|
||||
hash_code_ = Hash64Combine(hash_code_, val);
|
||||
}
|
||||
hash_code_ = Hash64Combine(hash_code_, data_format);
|
||||
hash_code_ = Hash64Combine(hash_code_, dtype);
|
||||
hash_code_ = Hash64Combine(hash_code_, device_id);
|
||||
}
|
||||
bool operator==(const BiasAddParams& other) const {
|
||||
return this->get_data_as_tuple() == other.get_data_as_tuple();
|
||||
}
|
||||
|
||||
bool operator!=(const BiasAddParams& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
uint64 hash() const { return hash_code_; }
|
||||
|
||||
string ToString() const {
|
||||
// clang-format off
|
||||
return strings::StrCat(
|
||||
"(", str_util::Join(in_shape_, ", "), "), ",
|
||||
data_format_, ", ", dtype_, ", ", device_id_);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
protected:
|
||||
using ParamsDataType = std::tuple<SpatialArray, TensorFormat, DataType, int>;
|
||||
|
||||
ParamsDataType get_data_as_tuple() const {
|
||||
return std::make_tuple(in_shape_, data_format_, dtype_, device_id_);
|
||||
}
|
||||
|
||||
uint64 hash_code_ = 0;
|
||||
|
||||
private:
|
||||
SpatialArray in_shape_;
|
||||
TensorFormat data_format_;
|
||||
DataType dtype_;
|
||||
int device_id_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user