Exposing launchpad for conv2d backprop, and unify launchpads for conv2d and depthwise_conv to match example in documentation (see ./extend/adding_an_op.md)
PiperOrigin-RevId: 167480081
This commit is contained in:
parent
0302320e11
commit
07356b48e4
@ -91,6 +91,20 @@ namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename T>
|
||||
struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& out_backprop, const Tensor& input,
|
||||
int row_stride, int col_stride, const Padding& padding,
|
||||
Tensor* filter_backprop, TensorFormat data_format) {
|
||||
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
|
||||
functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
|
||||
d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), filter_backprop->dim_size(0),
|
||||
filter_backprop->dim_size(1), row_stride, col_stride);
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef TENSORFLOW_USE_LIBXSMM
|
||||
template <typename Device, class T>
|
||||
struct LaunchXsmmBackwardFilter {
|
||||
@ -237,11 +251,9 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
|
||||
}
|
||||
#endif
|
||||
|
||||
functor::SpatialConvolutionBackwardKernel<Device, T>()(
|
||||
context->eigen_device<Device>(), filter_backprop->tensor<T, 4>(),
|
||||
input.tensor<T, 4>(), out_backprop.tensor<T, 4>(),
|
||||
dims.spatial_dims[0].filter_size, dims.spatial_dims[1].filter_size,
|
||||
dims.spatial_dims[0].stride, dims.spatial_dims[1].stride);
|
||||
LaunchConv2DBackpropInputOp<Device, T>()(
|
||||
context, false, false, out_backprop, input, dims.spatial_dims[0].stride,
|
||||
dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -495,15 +507,10 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
|
||||
use_cudnn_ &= CanUseCudnn();
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
using perftools::gputools::dnn::AlgorithmConfig;
|
||||
using perftools::gputools::dnn::AlgorithmType;
|
||||
using perftools::gputools::dnn::ProfileResult;
|
||||
using perftools::gputools::dnn::kDefaultAlgorithm;
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& filter_sizes = context->input(1);
|
||||
const Tensor& out_backprop = context->input(2);
|
||||
@ -512,340 +519,22 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
errors::InvalidArgument(
|
||||
"Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
|
||||
filter_sizes.dims()));
|
||||
const TensorShape& input_shape = input.shape();
|
||||
TensorShape filter_shape;
|
||||
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
|
||||
filter_sizes.vec<int32>(), &filter_shape));
|
||||
|
||||
ConvBackpropDimensions dims;
|
||||
OP_REQUIRES_OK(context,
|
||||
ConvBackpropComputeDimensions(
|
||||
"Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2,
|
||||
input.shape(), filter_shape, out_backprop.shape(),
|
||||
strides_, padding_, data_format_, &dims));
|
||||
|
||||
Tensor* filter_backprop = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, filter_shape, &filter_backprop));
|
||||
|
||||
const int padding_rows =
|
||||
(padding_ == VALID)
|
||||
? 0
|
||||
: std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
|
||||
dims.spatial_dims[0].stride +
|
||||
dims.spatial_dims[0].filter_size -
|
||||
dims.spatial_dims[0].input_size);
|
||||
const int padding_cols =
|
||||
(padding_ == VALID)
|
||||
? 0
|
||||
: std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
|
||||
dims.spatial_dims[1].stride +
|
||||
dims.spatial_dims[1].filter_size -
|
||||
dims.spatial_dims[1].input_size);
|
||||
// For now we take the stride from the second and third dimensions only (we
|
||||
// do not support striding on the batch or depth dimension).
|
||||
const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
|
||||
const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
|
||||
|
||||
// TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
|
||||
// calling it when that is true. Remove this check when (if?) cuDNN starts
|
||||
// supporting different padding.
|
||||
bool rows_odd = (padding_rows % 2 != 0);
|
||||
bool cols_odd = (padding_cols % 2 != 0);
|
||||
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
if (!use_cudnn_) {
|
||||
context->SetStatus(errors::Unimplemented(
|
||||
"Conv2DBackprop for GPU is not currently supported "
|
||||
"without cudnn"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!cudnn_disable_conv_1x1_optimization_ &&
|
||||
dims.spatial_dims[0].filter_size == 1 &&
|
||||
dims.spatial_dims[1].filter_size == 1 &&
|
||||
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
|
||||
data_format_ == FORMAT_NHWC) {
|
||||
const uint64 m = dims.in_depth;
|
||||
const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size *
|
||||
dims.spatial_dims[1].input_size;
|
||||
const uint64 n = dims.out_depth;
|
||||
|
||||
// The shape of output backprop is
|
||||
// [batch, out_rows, out_cols, out_depth]
|
||||
// From cublas's perspective, it is: n x k
|
||||
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
|
||||
out_backprop.template flat<T>().size());
|
||||
|
||||
// The shape of input is
|
||||
// [batch, in_rows, in_cols, in_depth],
|
||||
// From cublas's perspective, it is: m x k
|
||||
auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
||||
input.template flat<T>().size());
|
||||
|
||||
// the shape of the filter backprop from the conv_2d should be
|
||||
// [1, 1, in_depth, out_depth]
|
||||
// From cublas's perspective, it is: n x m
|
||||
auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
|
||||
filter_backprop->template flat<T>().size());
|
||||
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
|
||||
perftools::gputools::blas::Transpose::kTranspose,
|
||||
n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
} else if (dims.spatial_dims[0].filter_size ==
|
||||
dims.spatial_dims[0].input_size &&
|
||||
dims.spatial_dims[1].filter_size ==
|
||||
dims.spatial_dims[1].input_size &&
|
||||
padding_ == VALID && data_format_ == FORMAT_NHWC) {
|
||||
// The input data and filter have the same height/width, so call cublas
|
||||
// directly.
|
||||
const uint64 m = dims.spatial_dims[0].input_size *
|
||||
dims.spatial_dims[1].input_size * dims.in_depth;
|
||||
const uint64 k = dims.batch_size;
|
||||
const uint64 n = dims.out_depth;
|
||||
|
||||
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
||||
input.template flat<T>().size());
|
||||
auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
|
||||
out_backprop.template flat<T>().size());
|
||||
auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
|
||||
filter_backprop->template flat<T>().size());
|
||||
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
|
||||
perftools::gputools::blas::Transpose::kTranspose,
|
||||
n, m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor compatible_input;
|
||||
if (rows_odd || cols_odd) {
|
||||
// If a padding dimension is odd, we have one more element on the right
|
||||
// side or the bottom side. This is unsupported in cudnn. Therefore,
|
||||
// we pad that extra element and make it compatible.
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(data_format_, dims.batch_size,
|
||||
dims.spatial_dims[0].input_size + rows_odd,
|
||||
dims.spatial_dims[1].input_size + cols_odd,
|
||||
dims.in_depth),
|
||||
&compatible_input));
|
||||
|
||||
functor::PadInput<GPUDevice, T, int, 4>()(
|
||||
context->template eigen_device<GPUDevice>(),
|
||||
To32Bit(input.tensor<T, 4>()), {{0, 0}}, {{rows_odd, cols_odd}},
|
||||
To32Bit(compatible_input.tensor<T, 4>()), data_format_);
|
||||
} else {
|
||||
compatible_input = input;
|
||||
}
|
||||
|
||||
CHECK(padding_rows >= 0 && padding_cols >= 0)
|
||||
<< "Negative row or col paddings: (" << padding_rows << ", "
|
||||
<< padding_cols << ")";
|
||||
perftools::gputools::dnn::BatchDescriptor input_desc;
|
||||
input_desc.set_count(dims.batch_size)
|
||||
.set_height(GetTensorDim(compatible_input, data_format_, 'H'))
|
||||
.set_width(GetTensorDim(compatible_input, data_format_, 'W'))
|
||||
.set_feature_map_count(dims.in_depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::BatchDescriptor output_desc;
|
||||
output_desc.set_count(dims.batch_size)
|
||||
.set_height(dims.spatial_dims[0].output_size)
|
||||
.set_width(dims.spatial_dims[1].output_size)
|
||||
.set_feature_map_count(dims.out_depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::FilterDescriptor filter_desc;
|
||||
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
|
||||
.set_input_filter_width(dims.spatial_dims[1].filter_size)
|
||||
.set_input_feature_map_count(dims.in_depth)
|
||||
.set_output_feature_map_count(dims.out_depth);
|
||||
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
|
||||
conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
|
||||
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
|
||||
.set_zero_padding_height(padding_rows / 2)
|
||||
.set_zero_padding_width(padding_cols / 2);
|
||||
|
||||
// NOTE(zhengxq):
|
||||
// cuDNN only supports the following layouts :
|
||||
// Input : B x D x R x C
|
||||
// Filter : OD x ID x R x C
|
||||
// Whereas, we have
|
||||
// Input : B x R x C x D
|
||||
// Filter : R x C x ID x OD
|
||||
// TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
|
||||
// The first TransformDepth performs
|
||||
// (B x R x C x D) => (B x D x R x C).
|
||||
// Since the tensor returned from cuDNN is B x D x R x C also,
|
||||
// the second TransformDepth performs
|
||||
// (B x D x R x C) => (B x R x C x D).
|
||||
|
||||
Tensor pre_transformed_filter_backprop;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({dims.out_depth, dims.in_depth,
|
||||
dims.spatial_dims[0].filter_size,
|
||||
dims.spatial_dims[1].filter_size}),
|
||||
&pre_transformed_filter_backprop));
|
||||
|
||||
Tensor transformed_out_backprop;
|
||||
if (data_format_ == FORMAT_NHWC) {
|
||||
TensorShape nchw_shape = ShapeFromFormat(
|
||||
FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
|
||||
dims.spatial_dims[1].output_size, dims.out_depth);
|
||||
if (dims.out_depth > 1) {
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value, nchw_shape,
|
||||
&transformed_out_backprop));
|
||||
functor::NHWCToNCHW<Device, T, 4>()(
|
||||
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
|
||||
transformed_out_backprop.tensor<T, 4>());
|
||||
} else {
|
||||
// If depth <= 1, just reshape.
|
||||
CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
|
||||
}
|
||||
} else {
|
||||
transformed_out_backprop = out_backprop;
|
||||
}
|
||||
|
||||
Tensor transformed_input;
|
||||
if (data_format_ == FORMAT_NHWC) {
|
||||
TensorShape nchw_shape = ShapeFromFormat(
|
||||
FORMAT_NCHW, GetTensorDim(compatible_input, data_format_, 'N'),
|
||||
GetTensorDim(compatible_input, data_format_, 'H'),
|
||||
GetTensorDim(compatible_input, data_format_, 'W'),
|
||||
GetTensorDim(compatible_input, data_format_, 'C'));
|
||||
if (nchw_shape.dim_size(1) > 1) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DataTypeToEnum<T>::value,
|
||||
nchw_shape, &transformed_input));
|
||||
functor::NHWCToNCHW<Device, T, 4>()(
|
||||
context->eigen_device<Device>(),
|
||||
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
|
||||
transformed_input.tensor<T, 4>());
|
||||
} else {
|
||||
// If depth <= 1, just reshape.
|
||||
CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
|
||||
}
|
||||
} else {
|
||||
transformed_input = compatible_input;
|
||||
}
|
||||
|
||||
auto out_backprop_ptr =
|
||||
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
|
||||
transformed_out_backprop.template flat<T>().size());
|
||||
auto filter_backprop_ptr = AsDeviceMemory(
|
||||
pre_transformed_filter_backprop.template flat<T>().data(),
|
||||
pre_transformed_filter_backprop.template flat<T>().size());
|
||||
auto input_ptr =
|
||||
AsDeviceMemory(transformed_input.template flat<T>().data(),
|
||||
transformed_input.template flat<T>().size());
|
||||
|
||||
static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
|
||||
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
|
||||
);
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = input.dtype();
|
||||
ConvParameters conv_parameters = {
|
||||
dims.batch_size, // batch
|
||||
dims.in_depth, // in_depths
|
||||
{{input_desc.height(), // in_rows
|
||||
input_desc.width()}}, // in_cols
|
||||
dims.out_depth, // out_depths
|
||||
{{dims.spatial_dims[0].filter_size, // filter_rows
|
||||
dims.spatial_dims[1].filter_size}}, // filter_cols
|
||||
{{dims.spatial_dims[0].stride, // stride_rows
|
||||
dims.spatial_dims[1].stride}}, // stride_cols
|
||||
{{padding_rows, // padding_rows
|
||||
padding_cols}}, // padding_cols
|
||||
dtype, // tensor datatype
|
||||
device_id, // device_id
|
||||
};
|
||||
AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune_ && !AutoTuneConvBwdFilter::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
CudnnScratchAllocator scratch_allocator(
|
||||
ConvolveBackwardFilterScratchSize, context);
|
||||
ProfileResult profile_result;
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardFilterWithAlgorithm(
|
||||
input_desc, input_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, filter_desc, &filter_backprop_ptr,
|
||||
&scratch_allocator, AlgorithmConfig(profile_algorithm),
|
||||
&profile_result)
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(context,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
|
||||
context);
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardFilterWithAlgorithm(
|
||||
input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
|
||||
filter_desc, &filter_backprop_ptr, &scratch_allocator,
|
||||
algorithm_config, nullptr)
|
||||
.ok();
|
||||
|
||||
if (!cudnn_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"cuDNN Backward Filter function launch failure : input shape(",
|
||||
input_shape.DebugString(), ") filter shape(",
|
||||
filter_shape.DebugString(), ")"));
|
||||
return;
|
||||
}
|
||||
|
||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||
functor::ReverseTransformFilter<Device, T, 4>()(
|
||||
context->eigen_device<Device>(),
|
||||
toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(),
|
||||
filter_backprop->tensor<T, 4>());
|
||||
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
|
||||
stride_rows, stride_cols, padding_, filter_backprop,
|
||||
data_format_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -853,12 +542,351 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
Padding padding_;
|
||||
bool use_cudnn_;
|
||||
TensorFormat data_format_;
|
||||
LaunchConv2DBackpropFilterOp<Device, T> launcher_;
|
||||
bool cudnn_use_autotune_;
|
||||
bool cudnn_disable_conv_1x1_optimization_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
|
||||
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& out_backprop, const Tensor& input, int row_stride,
|
||||
int col_stride, const Padding& padding, Tensor* filter_backprop,
|
||||
TensorFormat data_format) {
|
||||
using perftools::gputools::dnn::AlgorithmConfig;
|
||||
using perftools::gputools::dnn::AlgorithmType;
|
||||
using perftools::gputools::dnn::ProfileResult;
|
||||
|
||||
std::vector<int32> strides(4, 1);
|
||||
strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
|
||||
strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
|
||||
TensorShape filter_shape = filter_backprop->shape();
|
||||
|
||||
ConvBackpropDimensions dims;
|
||||
OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
|
||||
"Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2,
|
||||
input.shape(), filter_shape, out_backprop.shape(),
|
||||
strides, padding, data_format, &dims));
|
||||
|
||||
const int padding_rows =
|
||||
(padding == VALID)
|
||||
? 0
|
||||
: std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
|
||||
dims.spatial_dims[0].stride +
|
||||
dims.spatial_dims[0].filter_size -
|
||||
dims.spatial_dims[0].input_size);
|
||||
const int padding_cols =
|
||||
(padding == VALID)
|
||||
? 0
|
||||
: std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
|
||||
dims.spatial_dims[1].stride +
|
||||
dims.spatial_dims[1].filter_size -
|
||||
dims.spatial_dims[1].input_size);
|
||||
|
||||
// TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
|
||||
// calling it when that is true. Remove this check when (if?) cuDNN starts
|
||||
// supporting different padding.
|
||||
bool rows_odd = (padding_rows % 2 != 0);
|
||||
bool cols_odd = (padding_cols % 2 != 0);
|
||||
|
||||
auto* stream = ctx->op_device_context()->stream();
|
||||
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
if (!use_cudnn) {
|
||||
ctx->SetStatus(errors::Unimplemented(
|
||||
"Conv2DBackprop for GPU is not currently supported "
|
||||
"without cudnn"));
|
||||
return;
|
||||
}
|
||||
|
||||
bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization();
|
||||
if (!cudnn_disable_conv_1x1_optimization_ &&
|
||||
dims.spatial_dims[0].filter_size == 1 &&
|
||||
dims.spatial_dims[1].filter_size == 1 &&
|
||||
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
|
||||
data_format == FORMAT_NHWC) {
|
||||
const uint64 m = dims.in_depth;
|
||||
const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size *
|
||||
dims.spatial_dims[1].input_size;
|
||||
const uint64 n = dims.out_depth;
|
||||
|
||||
// The shape of output backprop is
|
||||
// [batch, out_rows, out_cols, out_depth]
|
||||
// From cublas's perspective, it is: n x k
|
||||
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
|
||||
out_backprop.template flat<T>().size());
|
||||
|
||||
// The shape of input is
|
||||
// [batch, in_rows, in_cols, in_depth],
|
||||
// From cublas's perspective, it is: m x k
|
||||
auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
||||
input.template flat<T>().size());
|
||||
|
||||
// the shape of the filter backprop from the conv_2d should be
|
||||
// [1, 1, in_depth, out_depth]
|
||||
// From cublas's perspective, it is: n x m
|
||||
auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
|
||||
filter_backprop->template flat<T>().size());
|
||||
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
|
||||
perftools::gputools::blas::Transpose::kTranspose, n,
|
||||
m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
} else if (dims.spatial_dims[0].filter_size ==
|
||||
dims.spatial_dims[0].input_size &&
|
||||
dims.spatial_dims[1].filter_size ==
|
||||
dims.spatial_dims[1].input_size &&
|
||||
padding == VALID && data_format == FORMAT_NHWC) {
|
||||
// The input data and filter have the same height/width, so call cublas
|
||||
// directly.
|
||||
const uint64 m = dims.spatial_dims[0].input_size *
|
||||
dims.spatial_dims[1].input_size * dims.in_depth;
|
||||
const uint64 k = dims.batch_size;
|
||||
const uint64 n = dims.out_depth;
|
||||
|
||||
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
|
||||
input.template flat<T>().size());
|
||||
auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
|
||||
out_backprop.template flat<T>().size());
|
||||
auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
|
||||
filter_backprop->template flat<T>().size());
|
||||
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
|
||||
perftools::gputools::blas::Transpose::kTranspose, n,
|
||||
m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor compatible_input;
|
||||
if (rows_odd || cols_odd) {
|
||||
// If a padding dimension is odd, we have one more element on the right
|
||||
// side or the bottom side. This is unsupported in cudnn. Therefore,
|
||||
// we pad that extra element and make it compatible.
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(data_format, dims.batch_size,
|
||||
dims.spatial_dims[0].input_size + rows_odd,
|
||||
dims.spatial_dims[1].input_size + cols_odd,
|
||||
dims.in_depth),
|
||||
&compatible_input));
|
||||
|
||||
functor::PadInput<GPUDevice, T, int, 4>()(
|
||||
ctx->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 4>()),
|
||||
{{0, 0}}, {{rows_odd, cols_odd}},
|
||||
To32Bit(compatible_input.tensor<T, 4>()), data_format);
|
||||
} else {
|
||||
compatible_input = input;
|
||||
}
|
||||
|
||||
CHECK(padding_rows >= 0 && padding_cols >= 0)
|
||||
<< "Negative row or col paddings: (" << padding_rows << ", "
|
||||
<< padding_cols << ")";
|
||||
perftools::gputools::dnn::BatchDescriptor input_desc;
|
||||
input_desc.set_count(dims.batch_size)
|
||||
.set_height(GetTensorDim(compatible_input, data_format, 'H'))
|
||||
.set_width(GetTensorDim(compatible_input, data_format, 'W'))
|
||||
.set_feature_map_count(dims.in_depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::BatchDescriptor output_desc;
|
||||
output_desc.set_count(dims.batch_size)
|
||||
.set_height(dims.spatial_dims[0].output_size)
|
||||
.set_width(dims.spatial_dims[1].output_size)
|
||||
.set_feature_map_count(dims.out_depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::FilterDescriptor filter_desc;
|
||||
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
|
||||
.set_input_filter_width(dims.spatial_dims[1].filter_size)
|
||||
.set_input_feature_map_count(dims.in_depth)
|
||||
.set_output_feature_map_count(dims.out_depth);
|
||||
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
|
||||
conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
|
||||
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
|
||||
.set_zero_padding_height(padding_rows / 2)
|
||||
.set_zero_padding_width(padding_cols / 2);
|
||||
|
||||
// NOTE(zhengxq):
|
||||
// cuDNN only supports the following layouts :
|
||||
// Input : B x D x R x C
|
||||
// Filter : OD x ID x R x C
|
||||
// Whereas, we have
|
||||
// Input : B x R x C x D
|
||||
// Filter : R x C x ID x OD
|
||||
// TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
|
||||
// The first TransformDepth performs
|
||||
// (B x R x C x D) => (B x D x R x C).
|
||||
// Since the tensor returned from cuDNN is B x D x R x C also,
|
||||
// the second TransformDepth performs
|
||||
// (B x D x R x C) => (B x R x C x D).
|
||||
|
||||
Tensor pre_transformed_filter_backprop;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||
TensorShape({dims.out_depth, dims.in_depth,
|
||||
dims.spatial_dims[0].filter_size,
|
||||
dims.spatial_dims[1].filter_size}),
|
||||
&pre_transformed_filter_backprop));
|
||||
|
||||
Tensor transformed_out_backprop;
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
TensorShape nchw_shape = ShapeFromFormat(
|
||||
FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
|
||||
dims.spatial_dims[1].output_size, dims.out_depth);
|
||||
if (dims.out_depth > 1) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
|
||||
&transformed_out_backprop));
|
||||
functor::NHWCToNCHW<GPUDevice, T, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
|
||||
transformed_out_backprop.tensor<T, 4>());
|
||||
} else {
|
||||
// If depth <= 1, just reshape.
|
||||
CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
|
||||
}
|
||||
} else {
|
||||
transformed_out_backprop = out_backprop;
|
||||
}
|
||||
|
||||
Tensor transformed_input;
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
TensorShape nchw_shape = ShapeFromFormat(
|
||||
FORMAT_NCHW, GetTensorDim(compatible_input, data_format, 'N'),
|
||||
GetTensorDim(compatible_input, data_format, 'H'),
|
||||
GetTensorDim(compatible_input, data_format, 'W'),
|
||||
GetTensorDim(compatible_input, data_format, 'C'));
|
||||
if (nchw_shape.dim_size(1) > 1) {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||
nchw_shape, &transformed_input));
|
||||
functor::NHWCToNCHW<GPUDevice, T, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(),
|
||||
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
|
||||
transformed_input.tensor<T, 4>());
|
||||
} else {
|
||||
// If depth <= 1, just reshape.
|
||||
CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape));
|
||||
}
|
||||
} else {
|
||||
transformed_input = compatible_input;
|
||||
}
|
||||
|
||||
auto out_backprop_ptr =
|
||||
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
|
||||
transformed_out_backprop.template flat<T>().size());
|
||||
auto filter_backprop_ptr =
|
||||
AsDeviceMemory(pre_transformed_filter_backprop.template flat<T>().data(),
|
||||
pre_transformed_filter_backprop.template flat<T>().size());
|
||||
auto input_ptr = AsDeviceMemory(transformed_input.template flat<T>().data(),
|
||||
transformed_input.template flat<T>().size());
|
||||
|
||||
static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
|
||||
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
|
||||
);
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = input.dtype();
|
||||
ConvParameters conv_parameters = {
|
||||
dims.batch_size, // batch
|
||||
dims.in_depth, // in_depths
|
||||
{{input_desc.height(), // in_rows
|
||||
input_desc.width()}}, // in_cols
|
||||
dims.out_depth, // out_depths
|
||||
{{dims.spatial_dims[0].filter_size, // filter_rows
|
||||
dims.spatial_dims[1].filter_size}}, // filter_cols
|
||||
{{dims.spatial_dims[0].stride, // stride_rows
|
||||
dims.spatial_dims[1].stride}}, // stride_cols
|
||||
{{padding_rows, // padding_rows
|
||||
padding_cols}}, // padding_cols
|
||||
dtype, // tensor datatype
|
||||
device_id, // device_id
|
||||
};
|
||||
AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
|
||||
ctx);
|
||||
ProfileResult profile_result;
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardFilterWithAlgorithm(
|
||||
input_desc, input_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, filter_desc, &filter_backprop_ptr,
|
||||
&scratch_allocator, AlgorithmConfig(profile_algorithm),
|
||||
&profile_result)
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
|
||||
ctx);
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardFilterWithAlgorithm(
|
||||
input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
|
||||
filter_desc, &filter_backprop_ptr, &scratch_allocator,
|
||||
algorithm_config, nullptr)
|
||||
.ok();
|
||||
|
||||
if (!cudnn_launch_status) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"cuDNN Backward Filter function launch failure : input shape(",
|
||||
input.shape().DebugString(), ") filter shape(",
|
||||
filter_shape.DebugString(), ")"));
|
||||
return;
|
||||
}
|
||||
|
||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||
functor::ReverseTransformFilter<GPUDevice, T, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(),
|
||||
toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(),
|
||||
filter_backprop->tensor<T, 4>());
|
||||
}
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
|
@ -97,29 +97,17 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
// for CPU for now since nvcc times out when trying to compile them.
|
||||
// TODO(yangke): enable them for GPUs when we have a faster compiler.
|
||||
|
||||
template <typename Device, class T>
|
||||
struct LaunchBackwardInputConvolution {
|
||||
bool operator()(OpKernelContext* context, const Device&,
|
||||
typename TTypes<T, 4>::Tensor,
|
||||
typename TTypes<T, 4>::ConstTensor,
|
||||
typename TTypes<T, 4>::ConstTensor, int, int, int, int,
|
||||
TensorFormat) const {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LaunchBackwardInputConvolution<CPUDevice, float> {
|
||||
bool operator()(OpKernelContext* context, const CPUDevice& d,
|
||||
typename TTypes<float, 4>::Tensor input_backward,
|
||||
typename TTypes<float, 4>::ConstTensor kernel,
|
||||
typename TTypes<float, 4>::ConstTensor output_backward,
|
||||
int input_rows, int input_cols, int row_stride,
|
||||
int col_stride, TensorFormat data_format) const {
|
||||
functor::SpatialConvolutionBackwardInput<CPUDevice, float>()(
|
||||
d, input_backward, kernel, output_backward, input_rows, input_cols,
|
||||
row_stride, col_stride);
|
||||
return true;
|
||||
template <typename T>
|
||||
struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& out_backprop, const Tensor& filter,
|
||||
int row_stride, int col_stride, const Padding& padding,
|
||||
Tensor* in_backprop, TensorFormat data_format) {
|
||||
const CPUDevice& d = ctx->eigen_device<CPUDevice>();
|
||||
functor::SpatialConvolutionBackwardInput<CPUDevice, T>()(
|
||||
d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), in_backprop->dim_size(1),
|
||||
in_backprop->dim_size(2), row_stride, col_stride);
|
||||
}
|
||||
};
|
||||
|
||||
@ -268,11 +256,10 @@ class Conv2DFastBackpropInputOp : public OpKernel {
|
||||
}
|
||||
#endif
|
||||
|
||||
LaunchBackwardInputConvolution<Device, T>()(
|
||||
context, context->eigen_device<Device>(), in_backprop->tensor<T, 4>(),
|
||||
filter.tensor<T, 4>(), out_backprop.tensor<T, 4>(),
|
||||
dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size,
|
||||
dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, data_format_);
|
||||
LaunchConv2DBackpropInputOp<Device, T>()(
|
||||
context, false, false, out_backprop, filter,
|
||||
dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_,
|
||||
in_backprop, data_format_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -600,10 +587,6 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
using perftools::gputools::dnn::AlgorithmConfig;
|
||||
using perftools::gputools::dnn::AlgorithmType;
|
||||
using perftools::gputools::dnn::ProfileResult;
|
||||
using perftools::gputools::dnn::kDefaultAlgorithm;
|
||||
const Tensor& input_sizes = context->input(0);
|
||||
const Tensor& filter = context->input(1);
|
||||
const Tensor& out_backprop = context->input(2);
|
||||
@ -615,340 +598,18 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
||||
TensorShape input_shape;
|
||||
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
|
||||
input_sizes.vec<int32>(), &input_shape));
|
||||
const TensorShape& filter_shape = filter.shape();
|
||||
|
||||
ConvBackpropDimensions dims;
|
||||
OP_REQUIRES_OK(
|
||||
context, ConvBackpropComputeDimensions(
|
||||
"Conv2DSlowBackpropInput", /*num_spatial_dims=*/2,
|
||||
input_shape, filter_shape, out_backprop.shape(), strides_,
|
||||
padding_, data_format_, &dims));
|
||||
|
||||
Tensor* in_backprop = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input_shape, &in_backprop));
|
||||
|
||||
const int padding_rows =
|
||||
(padding_ == VALID)
|
||||
? 0
|
||||
: std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
|
||||
dims.spatial_dims[0].stride +
|
||||
dims.spatial_dims[0].filter_size -
|
||||
dims.spatial_dims[0].input_size);
|
||||
const int padding_cols =
|
||||
(padding_ == VALID)
|
||||
? 0
|
||||
: std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
|
||||
dims.spatial_dims[1].stride +
|
||||
dims.spatial_dims[1].filter_size -
|
||||
dims.spatial_dims[1].input_size);
|
||||
// For now we take the stride from the second and third dimensions only (we
|
||||
// do not support striding on the batch or depth dimension).
|
||||
const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
|
||||
const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
|
||||
|
||||
// TODO(keveman): cuDNN only supports equal padding on both sides, so only
|
||||
// calling it when that is true. Remove this check when (if?) cuDNN starts
|
||||
// supporting different padding.
|
||||
bool rows_odd = (padding_rows % 2 != 0);
|
||||
bool cols_odd = (padding_cols % 2 != 0);
|
||||
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
if (!use_cudnn_) {
|
||||
context->SetStatus(errors::Unimplemented(
|
||||
"Conv2DBackpropInput for GPU is not currently supported "
|
||||
"without cudnn"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (dims.spatial_dims[0].filter_size == 1 &&
|
||||
dims.spatial_dims[1].filter_size == 1 &&
|
||||
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
|
||||
data_format_ == FORMAT_NHWC) {
|
||||
// 1x1 filter, so call cublas directly.
|
||||
const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
|
||||
dims.spatial_dims[1].input_size;
|
||||
const uint64 k = dims.out_depth;
|
||||
const uint64 n = dims.in_depth;
|
||||
|
||||
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
|
||||
out_backprop.template flat<T>().size());
|
||||
auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
|
||||
filter.template flat<T>().size());
|
||||
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
|
||||
in_backprop->template flat<T>().size());
|
||||
|
||||
auto transpose = perftools::gputools::blas::Transpose::kTranspose;
|
||||
auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
|
||||
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
|
||||
a_ptr, k, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
} else if (dims.spatial_dims[0].filter_size ==
|
||||
dims.spatial_dims[0].input_size &&
|
||||
dims.spatial_dims[1].filter_size ==
|
||||
dims.spatial_dims[1].input_size &&
|
||||
padding_ == VALID && data_format_ == FORMAT_NHWC) {
|
||||
// The input data and filter have the same height/width, so call cublas
|
||||
// directly.
|
||||
const uint64 m = dims.batch_size;
|
||||
const uint64 k = dims.out_depth;
|
||||
const uint64 n = dims.spatial_dims[0].input_size *
|
||||
dims.spatial_dims[1].input_size * dims.in_depth;
|
||||
|
||||
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
|
||||
out_backprop.template flat<T>().size());
|
||||
auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
|
||||
filter.template flat<T>().size());
|
||||
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
|
||||
in_backprop->template flat<T>().size());
|
||||
|
||||
auto transpose = perftools::gputools::blas::Transpose::kTranspose;
|
||||
auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
|
||||
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
|
||||
a_ptr, k, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
TensorShape compatible_input_shape;
|
||||
if (rows_odd || cols_odd) {
|
||||
// If a padding dimension is odd, we have one more element on the right
|
||||
// side or the bottom side. This is unsupported in cudnn. Therefore,
|
||||
// we pad that extra element and make it compatible.
|
||||
compatible_input_shape = ShapeFromFormat(
|
||||
data_format_, dims.batch_size,
|
||||
dims.spatial_dims[0].input_size + rows_odd,
|
||||
dims.spatial_dims[1].input_size + cols_odd, dims.in_depth);
|
||||
} else {
|
||||
compatible_input_shape = input_shape;
|
||||
}
|
||||
|
||||
CHECK(padding_rows >= 0 && padding_cols >= 0)
|
||||
<< "Negative row or col paddings: (" << padding_rows << ", "
|
||||
<< padding_cols << ")";
|
||||
perftools::gputools::dnn::BatchDescriptor input_desc;
|
||||
input_desc.set_count(dims.batch_size)
|
||||
.set_height(GetTensorDim(compatible_input_shape, data_format_, 'H'))
|
||||
.set_width(GetTensorDim(compatible_input_shape, data_format_, 'W'))
|
||||
.set_feature_map_count(dims.in_depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::BatchDescriptor output_desc;
|
||||
output_desc.set_count(dims.batch_size)
|
||||
.set_height(dims.spatial_dims[0].output_size)
|
||||
.set_width(dims.spatial_dims[1].output_size)
|
||||
.set_feature_map_count(dims.out_depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::FilterDescriptor filter_desc;
|
||||
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
|
||||
.set_input_filter_width(dims.spatial_dims[1].filter_size)
|
||||
.set_input_feature_map_count(dims.in_depth)
|
||||
.set_output_feature_map_count(dims.out_depth);
|
||||
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
|
||||
conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
|
||||
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
|
||||
.set_zero_padding_height(padding_rows / 2)
|
||||
.set_zero_padding_width(padding_cols / 2);
|
||||
|
||||
// NOTE(keveman):
|
||||
// cuDNN only supports the following layouts :
|
||||
// Input : B x D x R x C
|
||||
// Filter : OD x ID x R x C
|
||||
// Whereas, we have
|
||||
// Input : B x R x C x D
|
||||
// Filter : R x C x ID x OD
|
||||
// TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
|
||||
// The first TransformDepth performs
|
||||
// (B x R x C x D) => (B x D x R x C).
|
||||
// Since the tensor returned from cuDNN is B x D x R x C also,
|
||||
// the second TransformDepth performs
|
||||
// (B x D x R x C) => (B x R x C x D).
|
||||
Tensor transformed_filter;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({dims.out_depth, dims.in_depth,
|
||||
dims.spatial_dims[0].filter_size,
|
||||
dims.spatial_dims[1].filter_size}),
|
||||
&transformed_filter));
|
||||
|
||||
functor::TransformFilter<Device, T, int, 4>()(
|
||||
context->eigen_device<Device>(), To32Bit(filter.tensor<T, 4>()),
|
||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||
|
||||
Tensor transformed_out_backprop;
|
||||
if (data_format_ == FORMAT_NHWC) {
|
||||
TensorShape nchw_shape = ShapeFromFormat(
|
||||
FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
|
||||
dims.spatial_dims[1].output_size, dims.out_depth);
|
||||
if (dims.out_depth > 1) {
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DataTypeToEnum<T>::value, nchw_shape,
|
||||
&transformed_out_backprop));
|
||||
functor::NHWCToNCHW<Device, T, 4>()(
|
||||
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
|
||||
transformed_out_backprop.tensor<T, 4>());
|
||||
} else {
|
||||
// If depth <= 1, then just reshape.
|
||||
CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
|
||||
}
|
||||
} else {
|
||||
transformed_out_backprop = out_backprop;
|
||||
}
|
||||
|
||||
Tensor pre_transformed_in_backprop;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(
|
||||
FORMAT_NCHW,
|
||||
GetTensorDim(compatible_input_shape, data_format_, 'N'),
|
||||
GetTensorDim(compatible_input_shape, data_format_, 'H'),
|
||||
GetTensorDim(compatible_input_shape, data_format_, 'W'),
|
||||
GetTensorDim(compatible_input_shape, data_format_, 'C')),
|
||||
&pre_transformed_in_backprop));
|
||||
|
||||
auto out_backprop_ptr =
|
||||
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
|
||||
transformed_out_backprop.template flat<T>().size());
|
||||
auto filter_ptr =
|
||||
AsDeviceMemory(transformed_filter.template flat<T>().data(),
|
||||
transformed_filter.template flat<T>().size());
|
||||
auto in_backprop_ptr =
|
||||
AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
|
||||
pre_transformed_in_backprop.template flat<T>().size());
|
||||
|
||||
static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
|
||||
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
|
||||
);
|
||||
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
|
||||
context);
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = out_backprop.dtype();
|
||||
ConvParameters conv_parameters = {
|
||||
dims.batch_size, // batch
|
||||
dims.in_depth, // in_depths
|
||||
{{input_desc.height(), // in_rows
|
||||
input_desc.width()}}, // in_cols
|
||||
dims.out_depth, // out_depths
|
||||
{{dims.spatial_dims[0].filter_size, // filter_rows
|
||||
dims.spatial_dims[1].filter_size}}, // filter_cols
|
||||
{{dims.spatial_dims[0].stride, // stride_rows
|
||||
dims.spatial_dims[1].stride}}, // stride_cols
|
||||
{{padding_rows, // padding_rows
|
||||
padding_cols}}, // padding_cols
|
||||
dtype, // tensor data type
|
||||
device_id, // device_id
|
||||
};
|
||||
AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune_ && !AutoTuneConvBwdData::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
|
||||
context);
|
||||
ProfileResult profile_result;
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardDataWithAlgorithm(
|
||||
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(profile_algorithm), &profile_result)
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(context,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardDataWithAlgorithm(
|
||||
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
|
||||
algorithm_config, nullptr)
|
||||
.ok();
|
||||
|
||||
if (!cudnn_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
"cuDNN Backward Data function launch failure : input shape(",
|
||||
input_shape.DebugString(), ") filter shape(",
|
||||
filter_shape.DebugString(), ")"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (rows_odd || cols_odd) {
|
||||
Tensor in_backprop_remove_padding;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW,
|
||||
GetTensorDim(input_shape, data_format_, 'N'),
|
||||
GetTensorDim(input_shape, data_format_, 'H'),
|
||||
GetTensorDim(input_shape, data_format_, 'W'),
|
||||
GetTensorDim(input_shape, data_format_, 'C')),
|
||||
&in_backprop_remove_padding));
|
||||
|
||||
// Remove the padding for odd rows or cols.
|
||||
functor::PadInput<GPUDevice, T, int, 4>()(
|
||||
context->template eigen_device<GPUDevice>(),
|
||||
To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
|
||||
.tensor<T, 4>()),
|
||||
{{0, 0}}, {{-rows_odd, -cols_odd}},
|
||||
To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW);
|
||||
|
||||
pre_transformed_in_backprop = in_backprop_remove_padding;
|
||||
}
|
||||
|
||||
if (data_format_ == FORMAT_NHWC) {
|
||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||
functor::NCHWToNHWC<Device, T, 4>()(
|
||||
context->eigen_device<Device>(),
|
||||
toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
|
||||
in_backprop->tensor<T, 4>());
|
||||
} else {
|
||||
*in_backprop = pre_transformed_in_backprop;
|
||||
}
|
||||
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
|
||||
stride_rows, stride_cols, padding_, in_backprop, data_format_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -956,11 +617,354 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
|
||||
Padding padding_;
|
||||
bool use_cudnn_;
|
||||
TensorFormat data_format_;
|
||||
LaunchConv2DBackpropInputOp<Device, T> launcher_;
|
||||
bool cudnn_use_autotune_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
|
||||
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& out_backprop, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Padding& padding, Tensor* in_backprop,
|
||||
TensorFormat data_format) {
|
||||
using perftools::gputools::dnn::AlgorithmConfig;
|
||||
using perftools::gputools::dnn::AlgorithmType;
|
||||
using perftools::gputools::dnn::ProfileResult;
|
||||
|
||||
std::vector<int32> strides(4, 1);
|
||||
strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
|
||||
strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
|
||||
TensorShape input_shape = in_backprop->shape();
|
||||
|
||||
const TensorShape& filter_shape = filter.shape();
|
||||
ConvBackpropDimensions dims;
|
||||
OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
|
||||
"Conv2DSlowBackpropInput", /*num_spatial_dims=*/2,
|
||||
input_shape, filter_shape, out_backprop.shape(),
|
||||
strides, padding, data_format, &dims));
|
||||
|
||||
const int padding_rows =
|
||||
(padding == VALID)
|
||||
? 0
|
||||
: std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
|
||||
dims.spatial_dims[0].stride +
|
||||
dims.spatial_dims[0].filter_size -
|
||||
dims.spatial_dims[0].input_size);
|
||||
const int padding_cols =
|
||||
(padding == VALID)
|
||||
? 0
|
||||
: std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
|
||||
dims.spatial_dims[1].stride +
|
||||
dims.spatial_dims[1].filter_size -
|
||||
dims.spatial_dims[1].input_size);
|
||||
|
||||
// TODO(keveman): cuDNN only supports equal padding on both sides, so only
|
||||
// calling it when that is true. Remove this check when (if?) cuDNN starts
|
||||
// supporting different padding.
|
||||
bool rows_odd = (padding_rows % 2 != 0);
|
||||
bool cols_odd = (padding_cols % 2 != 0);
|
||||
|
||||
auto* stream = ctx->op_device_context()->stream();
|
||||
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
if (!use_cudnn) {
|
||||
ctx->SetStatus(errors::Unimplemented(
|
||||
"Conv2DBackpropInput for GPU is not currently supported "
|
||||
"without cudnn"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (dims.spatial_dims[0].filter_size == 1 &&
|
||||
dims.spatial_dims[1].filter_size == 1 &&
|
||||
dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
|
||||
data_format == FORMAT_NHWC) {
|
||||
// 1x1 filter, so call cublas directly.
|
||||
const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
|
||||
dims.spatial_dims[1].input_size;
|
||||
const uint64 k = dims.out_depth;
|
||||
const uint64 n = dims.in_depth;
|
||||
|
||||
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
|
||||
out_backprop.template flat<T>().size());
|
||||
auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
|
||||
filter.template flat<T>().size());
|
||||
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
|
||||
in_backprop->template flat<T>().size());
|
||||
|
||||
auto transpose = perftools::gputools::blas::Transpose::kTranspose;
|
||||
auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
|
||||
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
|
||||
a_ptr, k, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
} else if (dims.spatial_dims[0].filter_size ==
|
||||
dims.spatial_dims[0].input_size &&
|
||||
dims.spatial_dims[1].filter_size ==
|
||||
dims.spatial_dims[1].input_size &&
|
||||
padding == VALID && data_format == FORMAT_NHWC) {
|
||||
// The input data and filter have the same height/width, so call cublas
|
||||
// directly.
|
||||
const uint64 m = dims.batch_size;
|
||||
const uint64 k = dims.out_depth;
|
||||
const uint64 n = dims.spatial_dims[0].input_size *
|
||||
dims.spatial_dims[1].input_size * dims.in_depth;
|
||||
|
||||
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
|
||||
out_backprop.template flat<T>().size());
|
||||
auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
|
||||
filter.template flat<T>().size());
|
||||
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
|
||||
in_backprop->template flat<T>().size());
|
||||
|
||||
auto transpose = perftools::gputools::blas::Transpose::kTranspose;
|
||||
auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
|
||||
|
||||
bool blas_launch_status =
|
||||
stream
|
||||
->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
|
||||
a_ptr, k, 0.0f, &c_ptr, n)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
|
||||
", n=", n, ", k=", k));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
TensorShape compatible_input_shape;
|
||||
if (rows_odd || cols_odd) {
|
||||
// If a padding dimension is odd, we have one more element on the right
|
||||
// side or the bottom side. This is unsupported in cudnn. Therefore,
|
||||
// we pad that extra element and make it compatible.
|
||||
compatible_input_shape = ShapeFromFormat(
|
||||
data_format, dims.batch_size,
|
||||
dims.spatial_dims[0].input_size + rows_odd,
|
||||
dims.spatial_dims[1].input_size + cols_odd, dims.in_depth);
|
||||
} else {
|
||||
compatible_input_shape = input_shape;
|
||||
}
|
||||
|
||||
CHECK(padding_rows >= 0 && padding_cols >= 0)
|
||||
<< "Negative row or col paddings: (" << padding_rows << ", "
|
||||
<< padding_cols << ")";
|
||||
perftools::gputools::dnn::BatchDescriptor input_desc;
|
||||
input_desc.set_count(dims.batch_size)
|
||||
.set_height(GetTensorDim(compatible_input_shape, data_format, 'H'))
|
||||
.set_width(GetTensorDim(compatible_input_shape, data_format, 'W'))
|
||||
.set_feature_map_count(dims.in_depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::BatchDescriptor output_desc;
|
||||
output_desc.set_count(dims.batch_size)
|
||||
.set_height(dims.spatial_dims[0].output_size)
|
||||
.set_width(dims.spatial_dims[1].output_size)
|
||||
.set_feature_map_count(dims.out_depth)
|
||||
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
|
||||
perftools::gputools::dnn::FilterDescriptor filter_desc;
|
||||
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
|
||||
.set_input_filter_width(dims.spatial_dims[1].filter_size)
|
||||
.set_input_feature_map_count(dims.in_depth)
|
||||
.set_output_feature_map_count(dims.out_depth);
|
||||
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
|
||||
conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
|
||||
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
|
||||
.set_zero_padding_height(padding_rows / 2)
|
||||
.set_zero_padding_width(padding_cols / 2);
|
||||
|
||||
// NOTE(keveman):
|
||||
// cuDNN only supports the following layouts :
|
||||
// Input : B x D x R x C
|
||||
// Filter : OD x ID x R x C
|
||||
// Whereas, we have
|
||||
// Input : B x R x C x D
|
||||
// Filter : R x C x ID x OD
|
||||
// TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
|
||||
// The first TransformDepth performs
|
||||
// (B x R x C x D) => (B x D x R x C).
|
||||
// Since the tensor returned from cuDNN is B x D x R x C also,
|
||||
// the second TransformDepth performs
|
||||
// (B x D x R x C) => (B x R x C x D).
|
||||
Tensor transformed_filter;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||
TensorShape({dims.out_depth, dims.in_depth,
|
||||
dims.spatial_dims[0].filter_size,
|
||||
dims.spatial_dims[1].filter_size}),
|
||||
&transformed_filter));
|
||||
|
||||
functor::TransformFilter<GPUDevice, T, int, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
|
||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||
|
||||
Tensor transformed_out_backprop;
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
TensorShape nchw_shape = ShapeFromFormat(
|
||||
FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size,
|
||||
dims.spatial_dims[1].output_size, dims.out_depth);
|
||||
if (dims.out_depth > 1) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
|
||||
&transformed_out_backprop));
|
||||
functor::NHWCToNCHW<GPUDevice, T, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
|
||||
transformed_out_backprop.tensor<T, 4>());
|
||||
} else {
|
||||
// If depth <= 1, then just reshape.
|
||||
CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape));
|
||||
}
|
||||
} else {
|
||||
transformed_out_backprop = out_backprop;
|
||||
}
|
||||
|
||||
Tensor pre_transformed_in_backprop;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(
|
||||
FORMAT_NCHW,
|
||||
GetTensorDim(compatible_input_shape, data_format, 'N'),
|
||||
GetTensorDim(compatible_input_shape, data_format, 'H'),
|
||||
GetTensorDim(compatible_input_shape, data_format, 'W'),
|
||||
GetTensorDim(compatible_input_shape, data_format, 'C')),
|
||||
&pre_transformed_in_backprop));
|
||||
|
||||
auto out_backprop_ptr =
|
||||
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
|
||||
transformed_out_backprop.template flat<T>().size());
|
||||
auto filter_ptr =
|
||||
AsDeviceMemory(transformed_filter.template flat<T>().data(),
|
||||
transformed_filter.template flat<T>().size());
|
||||
auto in_backprop_ptr =
|
||||
AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
|
||||
pre_transformed_in_backprop.template flat<T>().size());
|
||||
|
||||
static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
|
||||
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
|
||||
);
|
||||
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
|
||||
int device_id = stream->parent()->device_ordinal();
|
||||
DataType dtype = out_backprop.dtype();
|
||||
ConvParameters conv_parameters = {
|
||||
dims.batch_size, // batch
|
||||
dims.in_depth, // in_depths
|
||||
{{input_desc.height(), // in_rows
|
||||
input_desc.width()}}, // in_cols
|
||||
dims.out_depth, // out_depths
|
||||
{{dims.spatial_dims[0].filter_size, // filter_rows
|
||||
dims.spatial_dims[1].filter_size}}, // filter_cols
|
||||
{{dims.spatial_dims[0].stride, // stride_rows
|
||||
dims.spatial_dims[1].stride}}, // stride_cols
|
||||
{{padding_rows, // padding_rows
|
||||
padding_cols}}, // padding_cols
|
||||
dtype, // tensor data type
|
||||
device_id, // device_id
|
||||
};
|
||||
AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
|
||||
conv_parameters, &algorithm_config)) {
|
||||
std::vector<AlgorithmType> algorithms;
|
||||
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
|
||||
ctx);
|
||||
ProfileResult profile_result;
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardDataWithAlgorithm(
|
||||
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(profile_algorithm), &profile_result)
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardDataWithAlgorithm(
|
||||
filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
|
||||
input_desc, &in_backprop_ptr, &scratch_allocator,
|
||||
algorithm_config, nullptr)
|
||||
.ok();
|
||||
|
||||
if (!cudnn_launch_status) {
|
||||
ctx->SetStatus(errors::Internal(
|
||||
"cuDNN Backward Data function launch failure : input shape(",
|
||||
input_shape.DebugString(), ") filter shape(",
|
||||
filter_shape.DebugString(), ")"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (rows_odd || cols_odd) {
|
||||
Tensor in_backprop_remove_padding;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
ShapeFromFormat(FORMAT_NCHW,
|
||||
GetTensorDim(input_shape, data_format, 'N'),
|
||||
GetTensorDim(input_shape, data_format, 'H'),
|
||||
GetTensorDim(input_shape, data_format, 'W'),
|
||||
GetTensorDim(input_shape, data_format, 'C')),
|
||||
&in_backprop_remove_padding));
|
||||
|
||||
// Remove the padding for odd rows or cols.
|
||||
functor::PadInput<GPUDevice, T, int, 4>()(
|
||||
ctx->template eigen_device<GPUDevice>(),
|
||||
To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
|
||||
.tensor<T, 4>()),
|
||||
{{0, 0}}, {{-rows_odd, -cols_odd}},
|
||||
To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW);
|
||||
|
||||
pre_transformed_in_backprop = in_backprop_remove_padding;
|
||||
}
|
||||
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
|
||||
functor::NCHWToNHWC<GPUDevice, T, 4>()(
|
||||
ctx->eigen_device<GPUDevice>(),
|
||||
toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
|
||||
in_backprop->tensor<T, 4>());
|
||||
} else {
|
||||
*in_backprop = pre_transformed_in_backprop;
|
||||
}
|
||||
}
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
|
@ -168,6 +168,43 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Forward declaration.
|
||||
class OpKernelContext;
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchConv2DBackpropInputOp {
|
||||
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& out_backprop, const Tensor& filter,
|
||||
int row_stride, int col_stride, const Padding& padding,
|
||||
Tensor* in_backprop, TensorFormat data_format);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchConv2DBackpropFilterOp {
|
||||
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& out_backprop, const Tensor& input,
|
||||
int row_stride, int col_stride, const Padding& padding,
|
||||
Tensor* filter_backprop, TensorFormat data_format);
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
template <typename T>
|
||||
struct LaunchConv2DBackpropInputOp<Eigen::GpuDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& input, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Padding& padding, Tensor* output,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& out_backprop, const Tensor& input,
|
||||
int row_stride, int col_stride, const Padding& padding,
|
||||
Tensor* filter_backprop, TensorFormat data_format);
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
// Information about a single spatial dimension for a convolution
|
||||
// backpropagation.
|
||||
struct ConvBackpropSpatialDimension {
|
||||
|
@ -58,10 +58,10 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
namespace {
|
||||
template <typename Device, typename T>
|
||||
struct LaunchGeneric {
|
||||
static void launch(OpKernelContext* ctx, const Tensor& input,
|
||||
const Tensor& filter, int row_stride, int col_stride,
|
||||
const Eigen::PaddingType& padding, Tensor* output,
|
||||
TensorFormat data_format) {
|
||||
void operator()(OpKernelContext* ctx, const Tensor& input,
|
||||
const Tensor& filter, int row_stride, int col_stride,
|
||||
const Padding& padding, Tensor* output,
|
||||
TensorFormat data_format) {
|
||||
CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
|
||||
"supports NHWC tensor format for now.";
|
||||
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
|
||||
@ -86,8 +86,7 @@ struct LaunchGeneric {
|
||||
filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
|
||||
dim_pair);
|
||||
} else if (filter.dim_size(0) == input.dim_size(1) &&
|
||||
filter.dim_size(1) == input.dim_size(2) &&
|
||||
padding == Eigen::PADDING_VALID) {
|
||||
filter.dim_size(1) == input.dim_size(2) && padding == VALID) {
|
||||
// If the input data and filter have the same height/width,
|
||||
// the 2D convolution is reduced to matrix multiplication.
|
||||
const int k = // Length of reduction dimension.
|
||||
@ -104,28 +103,26 @@ struct LaunchGeneric {
|
||||
functor::SpatialConvolution<Device, T>()(
|
||||
ctx->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
|
||||
padding);
|
||||
BrainPadding2EigenPadding(padding));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
class LaunchConv2DOp<CPUDevice, T> {
|
||||
public:
|
||||
void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& input, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Eigen::PaddingType& padding, Tensor* output,
|
||||
TensorFormat data_format) {
|
||||
struct LaunchConv2DOp<CPUDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& input, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Padding& padding, Tensor* output,
|
||||
TensorFormat data_format) {
|
||||
if (data_format != FORMAT_NHWC) {
|
||||
ctx->SetStatus(
|
||||
errors::Unimplemented("Generic conv implementation only supports "
|
||||
"NHWC tensor format for now."));
|
||||
return;
|
||||
}
|
||||
LaunchGeneric<CPUDevice, T>::launch(ctx, input, filter, row_stride,
|
||||
col_stride, padding, output,
|
||||
data_format);
|
||||
LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
|
||||
padding, output, data_format);
|
||||
}
|
||||
};
|
||||
|
||||
@ -387,9 +384,8 @@ class Conv2DOp : public BinaryOp<T> {
|
||||
return;
|
||||
}
|
||||
|
||||
launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter,
|
||||
stride_rows, stride_cols,
|
||||
BrainPadding2EigenPadding(padding_), output, data_format_);
|
||||
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
|
||||
stride_rows, stride_cols, padding_, output, data_format_);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -445,10 +441,10 @@ typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters,
|
||||
AutoTuneConv;
|
||||
|
||||
template <typename T>
|
||||
void LaunchConv2DOp<GPUDevice, T>::launch(
|
||||
void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& input_param, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Eigen::PaddingType& padding, Tensor* output,
|
||||
int col_stride, const Padding& padding, Tensor* output,
|
||||
TensorFormat data_format) {
|
||||
using perftools::gputools::dnn::AlgorithmConfig;
|
||||
using perftools::gputools::dnn::AlgorithmType;
|
||||
@ -492,8 +488,8 @@ void LaunchConv2DOp<GPUDevice, T>::launch(
|
||||
}
|
||||
return;
|
||||
} else if (filter.dim_size(0) == input.dim_size(1) &&
|
||||
filter.dim_size(1) == input.dim_size(2) &&
|
||||
padding == Eigen::PADDING_VALID && data_format == FORMAT_NHWC) {
|
||||
filter.dim_size(1) == input.dim_size(2) && padding == VALID &&
|
||||
data_format == FORMAT_NHWC) {
|
||||
// The input data and filter have the same height/width, so call cublas
|
||||
// directly.
|
||||
const uint64 m = input.dim_size(0);
|
||||
|
@ -32,14 +32,23 @@ namespace tensorflow {
|
||||
class OpKernelContext;
|
||||
|
||||
template <typename Device, typename T>
|
||||
class LaunchConv2DOp {
|
||||
public:
|
||||
void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& input, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Eigen::PaddingType& padding, Tensor* output,
|
||||
TensorFormat data_format);
|
||||
struct LaunchConv2DOp {
|
||||
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& input, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Padding& padding, Tensor* output,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
template <typename T>
|
||||
struct LaunchConv2DOp<Eigen::GpuDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& input, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Padding& padding, Tensor* output,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
// Used to keep track of persistent memory buffers used within the op.
|
||||
// It uses malloc and free to avoid the time cost of initializing the memory.
|
||||
template <class T, size_t size>
|
||||
@ -55,17 +64,6 @@ struct Im2ColBufferResource : public ResourceBase {
|
||||
string DebugString() { return "Im2ColBufferResource"; }
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
template <typename T>
|
||||
class LaunchConv2DOp<Eigen::GpuDevice, T> {
|
||||
public:
|
||||
void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
|
||||
const Tensor& input, const Tensor& filter, int row_stride,
|
||||
int col_stride, const Eigen::PaddingType& padding, Tensor* output,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_CONV_OPS_H
|
||||
|
@ -361,19 +361,15 @@ static void ComputeBackpropInput(const DepthwiseArgs& args,
|
||||
}
|
||||
}
|
||||
|
||||
// Kernels to compute the input backprop for depthwise convolution.
|
||||
template <typename Device, typename T>
|
||||
struct LaunchDepthwiseConvBackpropInputOp;
|
||||
|
||||
// Computes the depthwise conv2d backprop input of 'out_backprop' by
|
||||
// 'depthwise_filter' and stores the result in 'in_backprop'.
|
||||
template <typename T>
|
||||
struct LaunchDepthwiseConvBackpropInputOp<CPUDevice, T> {
|
||||
typedef typename Eigen::internal::packet_traits<T>::type Packet;
|
||||
|
||||
static void launch(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* out_backprop, const T* depthwise_filter,
|
||||
T* in_backprop, TensorFormat data_format) {
|
||||
void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* out_backprop, const T* depthwise_filter,
|
||||
T* in_backprop, TensorFormat data_format) {
|
||||
OP_REQUIRES(
|
||||
ctx, data_format == FORMAT_NHWC,
|
||||
errors::Unimplemented(
|
||||
@ -514,27 +510,8 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dBackpropInputGPULaunch {
|
||||
static void Run(const GPUDevice& d, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* filter, T* in_backprop,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, T> {
|
||||
static void launch(OpKernelContext* ctx, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* filter, T* in_backprop,
|
||||
TensorFormat data_format) {
|
||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||
DepthwiseConv2dBackpropInputGPULaunch<T>().Run(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
auto stream = ctx->op_device_context()->stream();
|
||||
OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for "
|
||||
"DepthwiseConv2dBackpropInp"
|
||||
"utGPULaunch failed"));
|
||||
}
|
||||
};
|
||||
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
|
||||
extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>;
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
@ -598,7 +575,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
|
||||
if (input_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
LaunchDepthwiseConvBackpropInputOp<Device, T>::launch(
|
||||
LaunchDepthwiseConvBackpropInputOp<Device, T>()(
|
||||
context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr,
|
||||
data_format_);
|
||||
}
|
||||
@ -744,9 +721,9 @@ template <typename T>
|
||||
struct LaunchDepthwiseConvBackpropFilterOp<CPUDevice, T> {
|
||||
typedef typename Eigen::internal::packet_traits<T>::type Packet;
|
||||
|
||||
static void launch(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* out_backprop, const T* input, T* filter_backprop,
|
||||
TensorFormat data_format) {
|
||||
void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* out_backprop, const T* input, T* filter_backprop,
|
||||
TensorFormat data_format) {
|
||||
OP_REQUIRES(
|
||||
ctx, data_format == FORMAT_NHWC,
|
||||
errors::Unimplemented(
|
||||
@ -907,35 +884,8 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dBackpropFilterGPULaunch {
|
||||
static void Run(const GPUDevice& d, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* input, T* filter_backprop,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T> {
|
||||
static void launch(OpKernelContext* ctx, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* input, T* filter_backprop,
|
||||
TensorFormat data_format) {
|
||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||
auto stream = ctx->op_device_context()->stream();
|
||||
|
||||
// Initialize the results to 0.
|
||||
int num_filter_backprop =
|
||||
args.filter_rows * args.filter_cols * args.out_depth;
|
||||
perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop,
|
||||
num_filter_backprop);
|
||||
stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T));
|
||||
|
||||
DepthwiseConv2dBackpropFilterGPULaunch<T>().Run(
|
||||
d, args, out_backprop, input, filter_backprop, data_format);
|
||||
OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for "
|
||||
"DepthwiseConv2dBackpropFil"
|
||||
"terGPULaunch failed"));
|
||||
}
|
||||
};
|
||||
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
|
||||
extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>;
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
@ -1001,7 +951,7 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
|
||||
if (filter_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
LaunchDepthwiseConvBackpropFilterOp<Device, T>::launch(
|
||||
LaunchDepthwiseConvBackpropFilterOp<Device, T>()(
|
||||
context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
|
||||
data_format_);
|
||||
}
|
||||
|
@ -54,9 +54,6 @@ namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchDepthwiseConvOp;
|
||||
|
||||
// Computes the vectorized product of 'input_buffer' and 'filter' and stores
|
||||
// result in 'output' at location specified by 'out_r' and 'out_c'.
|
||||
//
|
||||
@ -156,9 +153,9 @@ template <typename T>
|
||||
struct LaunchDepthwiseConvOp<CPUDevice, T> {
|
||||
typedef typename Eigen::internal::packet_traits<T>::type Packet;
|
||||
|
||||
static void launch(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* input, const T* depthwise_filter, T* output,
|
||||
TensorFormat data_format) {
|
||||
void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* input, const T* depthwise_filter, T* output,
|
||||
TensorFormat data_format) {
|
||||
OP_REQUIRES(
|
||||
ctx, data_format == FORMAT_NHWC,
|
||||
errors::Unimplemented(
|
||||
@ -248,27 +245,9 @@ extern template class LaunchConv2DOp<CPUDevice, float>;
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dGPULaunch {
|
||||
static void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input,
|
||||
const T* filter, T* output, TensorFormat data_format);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaunchDepthwiseConvOp<GPUDevice, T> {
|
||||
static void launch(OpKernelContext* ctx, const DepthwiseArgs args,
|
||||
const T* input, const T* filter, T* output,
|
||||
TensorFormat data_format) {
|
||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||
DepthwiseConv2dGPULaunch<T>().Run(d, args, input, filter, output,
|
||||
data_format);
|
||||
auto stream = ctx->op_device_context()->stream();
|
||||
OP_REQUIRES(
|
||||
ctx, stream->ok(),
|
||||
errors::Internal(
|
||||
"Launch of gpu kernel for DepthwiseConv2dGPULaunch failed"));
|
||||
}
|
||||
};
|
||||
// Extern template instantiated in depthwise_conv_op_gpu.cc.
|
||||
extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
|
||||
extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
|
||||
|
||||
// Extern template instantiated in conv_ops.cc.
|
||||
extern template class LaunchConv2DOp<GPUDevice, float>;
|
||||
@ -393,9 +372,8 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
|
||||
// If in_depth==1, this operation is just a standard convolution, so
|
||||
// invoke that op.
|
||||
if (std::is_same<T, float>::value && in_depth == 1) {
|
||||
launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter,
|
||||
stride_, stride_, BrainPadding2EigenPadding(padding_),
|
||||
output, data_format_);
|
||||
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
|
||||
stride_, stride_, padding_, output, data_format_);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -417,8 +395,8 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
|
||||
auto input_ptr = input.template flat<T>().data();
|
||||
auto filter_ptr = filter.template flat<T>().data();
|
||||
auto output_ptr = output->template flat<T>().data();
|
||||
LaunchDepthwiseConvOp<Device, T>::launch(
|
||||
context, args, input_ptr, filter_ptr, output_ptr, data_format_);
|
||||
LaunchDepthwiseConvOp<Device, T>()(context, args, input_ptr, filter_ptr,
|
||||
output_ptr, data_format_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -56,6 +56,53 @@ struct DepthwiseArgs {
|
||||
out_depth(0) {}
|
||||
};
|
||||
|
||||
// Forward declaration.
|
||||
class OpKernelContext;
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchDepthwiseConvOp {
|
||||
void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* input, const T* filter, T* output,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchDepthwiseConvBackpropInputOp {
|
||||
void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* out_backprop, const T* filter, T* in_backprop,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchDepthwiseConvBackpropFilterOp {
|
||||
void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* out_backprop, const T* input, T* filter_backprop,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
template <typename T>
|
||||
struct LaunchDepthwiseConvOp<Eigen::GpuDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, const DepthwiseArgs args,
|
||||
const T* input, const T* filter, T* output,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaunchDepthwiseConvBackpropInputOp<Eigen::GpuDevice, T> {
|
||||
void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* out_backprop, const T* filter, T* in_backprop,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaunchDepthwiseConvBackpropFilterOp<Eigen::GpuDevice, T> {
|
||||
void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args,
|
||||
const T* out_backprop, const T* input, T* filter_backprop,
|
||||
TensorFormat data_format);
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/depthwise_conv_op.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
@ -689,21 +690,27 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args,
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dGPULaunch {
|
||||
static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* input,
|
||||
const T* filter, T* output, TensorFormat data_format) {
|
||||
if (args.filter_rows == 3 && args.filter_cols == 3) {
|
||||
LaunchDepthwiseConv2dGPU<T, 3, 3>(d, args, input, filter, output,
|
||||
void LaunchDepthwiseConvOp<GPUDevice, T>::operator()(OpKernelContext* ctx,
|
||||
const DepthwiseArgs args,
|
||||
const T* input,
|
||||
const T* filter, T* output,
|
||||
TensorFormat data_format) {
|
||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||
if (args.filter_rows == 3 && args.filter_cols == 3) {
|
||||
LaunchDepthwiseConv2dGPU<T, 3, 3>(d, args, input, filter, output,
|
||||
data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dGPU<T, -1, -1>(d, args, input, filter, output,
|
||||
data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dGPU<T, -1, -1>(d, args, input, filter, output,
|
||||
data_format);
|
||||
}
|
||||
}
|
||||
};
|
||||
auto stream = ctx->op_device_context()->stream();
|
||||
OP_REQUIRES(ctx, stream->ok(),
|
||||
errors::Internal(
|
||||
"Launch of gpu kernel for DepthwiseConv2dGPULaunch failed"));
|
||||
}
|
||||
|
||||
template struct DepthwiseConv2dGPULaunch<float>;
|
||||
template struct DepthwiseConv2dGPULaunch<double>;
|
||||
template struct LaunchDepthwiseConvOp<GPUDevice, float>;
|
||||
template struct LaunchDepthwiseConvOp<GPUDevice, double>;
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. input.
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
@ -893,22 +900,26 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d,
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dBackpropInputGPULaunch {
|
||||
static void Run(const GpuDevice& d, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* filter, T* in_backprop,
|
||||
TensorFormat data_format) {
|
||||
if (args.filter_rows == 3 && args.filter_cols == 3) {
|
||||
LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
}
|
||||
void LaunchDepthwiseConvBackpropInputOp<GPUDevice, T>::operator()(
|
||||
OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
|
||||
const T* filter, T* in_backprop, TensorFormat data_format) {
|
||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||
if (args.filter_rows == 3 && args.filter_cols == 3) {
|
||||
LaunchDepthwiseConv2dBackpropInputGPU<T, 3, 3>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dBackpropInputGPU<T, -1, -1>(
|
||||
d, args, out_backprop, filter, in_backprop, data_format);
|
||||
}
|
||||
};
|
||||
auto stream = ctx->op_device_context()->stream();
|
||||
OP_REQUIRES(ctx, stream->ok(),
|
||||
errors::Internal("Launch of gpu kernel for "
|
||||
"DepthwiseConv2dBackpropInp"
|
||||
"utGPULaunch failed"));
|
||||
}
|
||||
|
||||
template struct DepthwiseConv2dBackpropInputGPULaunch<float>;
|
||||
template struct DepthwiseConv2dBackpropInputGPULaunch<double>;
|
||||
template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
|
||||
template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>;
|
||||
|
||||
// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
|
||||
template <typename T, int kKnownFilterWidth, int kKnownFilterHeight,
|
||||
@ -1580,21 +1591,33 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d,
|
||||
|
||||
// A simple launch pad to launch the Cuda kernel for depthwise convolution.
|
||||
template <typename T>
|
||||
struct DepthwiseConv2dBackpropFilterGPULaunch {
|
||||
static void Run(const GpuDevice& d, const DepthwiseArgs args,
|
||||
const T* out_backprop, const T* input, T* filter_backprop,
|
||||
TensorFormat data_format) {
|
||||
if (args.filter_rows == 3 && args.filter_cols == 3) {
|
||||
LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>(
|
||||
d, args, out_backprop, input, filter_backprop, data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>(
|
||||
d, args, out_backprop, input, filter_backprop, data_format);
|
||||
}
|
||||
}
|
||||
};
|
||||
void LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T>::operator()(
|
||||
OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop,
|
||||
const T* input, T* filter_backprop, TensorFormat data_format) {
|
||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||
auto stream = ctx->op_device_context()->stream();
|
||||
|
||||
template struct DepthwiseConv2dBackpropFilterGPULaunch<float>;
|
||||
template struct DepthwiseConv2dBackpropFilterGPULaunch<double>;
|
||||
// Initialize the results to 0.
|
||||
int num_filter_backprop =
|
||||
args.filter_rows * args.filter_cols * args.out_depth;
|
||||
perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop,
|
||||
num_filter_backprop);
|
||||
stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T));
|
||||
|
||||
if (args.filter_rows == 3 && args.filter_cols == 3) {
|
||||
LaunchDepthwiseConv2dBackpropFilterGPU<T, 3, 3>(
|
||||
d, args, out_backprop, input, filter_backprop, data_format);
|
||||
} else {
|
||||
LaunchDepthwiseConv2dBackpropFilterGPU<T, -1, -1>(
|
||||
d, args, out_backprop, input, filter_backprop, data_format);
|
||||
}
|
||||
OP_REQUIRES(ctx, stream->ok(),
|
||||
errors::Internal("Launch of gpu kernel for "
|
||||
"DepthwiseConv2dBackpropFil"
|
||||
"terGPULaunch failed"));
|
||||
}
|
||||
|
||||
template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
|
||||
template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>;
|
||||
} // namespace tensorflow
|
||||
#endif // GOOGLE_CUDA
|
||||
|
Loading…
Reference in New Issue
Block a user