Add explicit padding to depthwise conv2d.

Roll-forward of 52d33f117e.

It was rolled back in 092ae742c3 because it broke an MLIR test. I fixed this by modifying tf_generated_ops.td. Other than modifying tf_generated_ops.td and syncing, I made no other changes.

PiperOrigin-RevId: 302720901
Change-Id: I409a5be7a2fcda8deae69b988fdebf8be4c19488
This commit is contained in:
Reed Wanderman-Milne 2020-03-24 12:22:27 -07:00 committed by TensorFlower Gardener
parent f19161ecb7
commit 8052d528e9
18 changed files with 639 additions and 92 deletions

View File

@ -1746,7 +1746,8 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
TF_FpTensor:$filter,
I64ArrayAttr:$strides,
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding,
DefaultValuedAttr<I64ArrayAttr, "{}">:$explicit_paddings,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<I64ArrayAttr, "{1, 1, 1, 1}">:$dilations
);

View File

@ -1,8 +1,4 @@
op {
graph_op_name: "DepthwiseConv2dNative"
deprecation_message: "Use nn.depthwise_conv2d instead"
endpoint {
name: "nn.depthwise_conv2d_native"
deprecation_version: 2
}
visibility: HIDDEN
}

View File

@ -1,11 +1,4 @@
op {
graph_op_name: "DepthwiseConv2dNativeBackpropFilter"
endpoint {
name: "nn.depthwise_conv2d_native_backprop_filter"
deprecated: true
deprecation_version: 2
}
endpoint {
name: "nn.depthwise_conv2d_backprop_filter"
}
visibility: HIDDEN
}

View File

@ -1,11 +1,4 @@
op {
graph_op_name: "DepthwiseConv2dNativeBackpropInput"
endpoint {
name: "nn.depthwise_conv2d_native_backprop_input"
deprecated: true
deprecation_version: 2
}
endpoint {
name: "nn.depthwise_conv2d_backprop_input"
}
visibility: HIDDEN
}

View File

@ -29,9 +29,9 @@ namespace tensorflow {
namespace shape_inference {
// The V2 version computes windowed output size with arbitrary dilation_rate,
// while the original version only handles the cases where dilation_rates equal
// to 1.
// The V2 version computes windowed output size with arbitrary dilation_rate and
// explicit padding, while the original version only handles the cases where
// dilation_rates equal to 1 and the padding is SAME or VALID.
Status GetWindowedOutputSizeFromDimsV2(
shape_inference::InferenceContext* c,
shape_inference::DimensionHandle input_size,
@ -822,7 +822,10 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
namespace {
Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c,
bool supports_explicit_padding) {
ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
ShapeHandle filter_shape;
@ -850,13 +853,17 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
dilations.size());
}
string data_format;
Status s = c->GetAttr("data_format", &data_format);
string data_format_str;
Status s = c->GetAttr("data_format", &data_format_str);
TensorFormat data_format;
if (!s.ok() || !FormatFromString(data_format_str, &data_format)) {
data_format = FORMAT_NHWC;
}
int32 stride_rows;
int32 stride_cols;
int32 dilation_rows;
int32 dilation_cols;
if (s.ok() && data_format == "NCHW") {
if (data_format == FORMAT_NCHW) {
// Canonicalize input shape to NHWC so the shape inference code below can
// process it.
input_shape =
@ -892,20 +899,41 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
std::vector<int64> explicit_paddings;
if (supports_explicit_padding) {
Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
// Use the default value, which is an empty list, if the attribute is not
// found. Otherwise return the error to the caller.
if (!status.ok() && !errors::IsNotFound(status)) {
return status;
}
TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
/*num_dims=*/4, data_format));
} else {
DCHECK(padding != Padding::EXPLICIT);
}
// TODO(mrry,shlens): Raise an error if the stride would cause
// information in the input to be ignored. This will require a change
// in the kernel implementation.
DimensionHandle output_rows, output_cols;
int64 pad_rows_before = -1, pad_rows_after = -1;
int64 pad_cols_before = -1, pad_cols_after = -1;
if (padding == Padding::EXPLICIT) {
GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
&pad_rows_before, &pad_rows_after);
GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
&pad_cols_before, &pad_cols_after);
}
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1,
-1, &output_rows));
c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding,
pad_rows_before, pad_rows_after, &output_rows));
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1,
-1, &output_cols));
c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding,
pad_cols_before, pad_cols_after, &output_cols));
ShapeHandle output_shape;
if (data_format == "NCHW") {
if (data_format == FORMAT_NCHW) {
output_shape =
c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
} else {
@ -916,6 +944,17 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
}; // namespace
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
return DepthwiseConv2DNativeShapeImpl(c, false);
}
Status DepthwiseConv2DNativeShapeWithExplicitPadding(
shape_inference::InferenceContext* c) {
return DepthwiseConv2DNativeShapeImpl(c, true);
}
Status AvgPoolShape(shape_inference::InferenceContext* c) {
string data_format_str;
TensorFormat data_format;

View File

@ -129,7 +129,13 @@ Status Conv2DShape(shape_inference::InferenceContext* c);
// Shape function for Conv3D-like operations.
Status Conv3DShape(shape_inference::InferenceContext* c);
// Shape function for DepthwiseConv2D-like operations.
// Shape function for DepthwiseConv2D-like operations that support explicit
// padding.
Status DepthwiseConv2DNativeShapeWithExplicitPadding(
shape_inference::InferenceContext* c);
// Shape function for DepthwiseConv2D-like operations that do not support
// explicit padding.
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
// Shape function for AvgPool-like operations.

View File

@ -116,13 +116,20 @@ typedef Eigen::GpuDevice GPUDevice;
errors::InvalidArgument( \
label, ": depth_multiplier * in_depth not equal to out_depth")); \
const auto stride = stride_; \
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; \
OP_REQUIRES_OK(context, \
GetWindowedOutputSize(input_rows, filter_rows, stride, \
padding_, &out_rows, &pad_rows)); \
OP_REQUIRES_OK(context, \
GetWindowedOutputSize(input_cols, filter_cols, stride, \
padding_, &out_cols, &pad_cols)); \
int64 out_rows = 0, out_cols = 0, pad_top = 0, pad_bottom = 0, pad_left = 0, \
pad_right = 0; \
if (padding_ == Padding::EXPLICIT) { \
GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H', &pad_top, \
&pad_bottom); \
GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left, \
&pad_right); \
} \
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( \
input_rows, filter_rows, stride_, padding_, \
&out_rows, &pad_top, &pad_bottom)); \
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( \
input_cols, filter_cols, stride_, padding_, \
&out_cols, &pad_left, &pad_right)); \
OP_REQUIRES( \
context, output_rows == out_rows, \
errors::InvalidArgument( \
@ -142,8 +149,8 @@ typedef Eigen::GpuDevice GPUDevice;
args.filter_cols = filter_cols; \
args.depth_multiplier = depth_multiplier; \
args.stride = stride; \
args.pad_rows = pad_rows; \
args.pad_cols = pad_cols; \
args.pad_rows = pad_top; \
args.pad_cols = pad_left; \
args.out_rows = out_rows; \
args.out_cols = out_cols; \
args.out_depth = out_depth; \
@ -151,7 +158,7 @@ typedef Eigen::GpuDevice GPUDevice;
<< input_rows << ", " << input_cols << ", " << in_depth \
<< "]; Filter: [" << filter_rows << ", " << filter_cols << ", " \
<< in_depth << ", " << depth_multiplier << "]; stride = " << stride \
<< ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols \
<< ", pad_rows = " << pad_top << ", pad_cols = " << pad_left \
<< ", output: [" << batch << ", " << out_rows << ", " << out_cols \
<< ", " << out_depth << "]";
@ -566,6 +573,10 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
OP_REQUIRES_OK(context,
context->GetAttr("explicit_paddings", &explicit_paddings_));
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
/*num_dims=*/4, data_format_));
// For in_depth == 1 and grouped convolutions.
use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
@ -628,7 +639,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
<< filter_cols << ", " << in_depth << ", " << depth_multiplier
<< "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
<< ", " << out_depth << "], stride = " << stride_
<< ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
<< ", pad_rows = " << pad_top << ", pad_cols = " << pad_left
<< ", Use cuDNN: " << use_cudnn;
if (use_cudnn) {
@ -652,8 +663,8 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
// conv is supported.
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop,
reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
stride_, stride_, padding_, /*explicit_paddings=*/{},
in_backprop, data_format_);
stride_, stride_, padding_, explicit_paddings_, in_backprop,
data_format_);
return;
}
@ -671,6 +682,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
private:
std::vector<int32> strides_;
Padding padding_;
std::vector<int64> explicit_paddings_;
TensorFormat data_format_;
int64 stride_;
@ -1055,6 +1067,10 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
OP_REQUIRES_OK(context,
context->GetAttr("explicit_paddings", &explicit_paddings_));
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
/*num_dims=*/4, data_format_));
// For in_depth == 1 and grouped convolutions.
use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
@ -1123,7 +1139,7 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
<< filter_cols << ", " << in_depth << ", " << depth_multiplier
<< "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
<< ", " << out_depth << "], stride = " << stride_
<< ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
<< ", pad_rows = " << pad_top << ", pad_cols = " << pad_left
<< ", Use cuDNN: " << use_cudnn;
if (use_cudnn) {
@ -1148,8 +1164,7 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
// conv is supported.
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
/*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
padding_, /*explicit_paddings=*/{}, &reshaped_filter,
data_format_);
padding_, explicit_paddings_, &reshaped_filter, data_format_);
return;
}
@ -1167,6 +1182,7 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
private:
std::vector<int32> strides_;
Padding padding_;
std::vector<int64> explicit_paddings_;
TensorFormat data_format_;
int64 stride_;

View File

@ -293,6 +293,10 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
OP_REQUIRES_OK(context,
context->GetAttr("explicit_paddings", &explicit_paddings_));
OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
/*num_dims=*/4, data_format_));
// For in_depth == 1 and grouped convolutions.
use_cudnn_ = CanUseCudnn() && std::is_same<Device, GPUDevice>::value;
@ -357,13 +361,20 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
// The first dimension for input is batch.
const int32 batch = input.dim_size(0);
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
OP_REQUIRES_OK(context,
GetWindowedOutputSize(input_rows, filter_rows, stride_,
padding_, &out_rows, &pad_rows));
OP_REQUIRES_OK(context,
GetWindowedOutputSize(input_cols, filter_cols, stride_,
padding_, &out_cols, &pad_cols));
int64 out_rows = 0, out_cols = 0, pad_top = 0, pad_bottom = 0, pad_left = 0,
pad_right = 0;
if (padding_ == Padding::EXPLICIT) {
GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H', &pad_top,
&pad_bottom);
GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left,
&pad_right);
}
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
input_rows, filter_rows, stride_, padding_,
&out_rows, &pad_top, &pad_bottom));
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
input_cols, filter_cols, stride_, padding_,
&out_cols, &pad_left, &pad_right));
TensorShape out_shape =
ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
OP_REQUIRES(
@ -398,7 +409,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
<< filter_cols << ", " << in_depth << ", " << depth_multiplier
<< "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
<< ", " << out_depth << "], stride = " << stride_
<< ", pad_rows = " << pad_rows << ", pad_cols = " << pad_cols
<< ", pad_top = " << pad_top << ", pad_left = " << pad_left
<< ", Use cuDNN: " << use_cudnn;
if (use_cudnn) {
@ -422,7 +433,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
// conv is supported.
launcher_(context, use_cudnn_, cudnn_use_autotune_, input,
reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
stride_, stride_, padding_, /*explicit_paddings=*/{}, output,
stride_, stride_, padding_, explicit_paddings_, output,
data_format_);
return;
}
@ -436,8 +447,8 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
args.filter_cols = filter_cols;
args.depth_multiplier = depth_multiplier;
args.stride = stride_;
args.pad_rows = pad_rows;
args.pad_cols = pad_cols;
args.pad_rows = pad_top;
args.pad_cols = pad_left;
args.out_rows = out_rows;
args.out_cols = out_cols;
args.out_depth = out_depth;
@ -455,6 +466,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
private:
std::vector<int32> strides_;
Padding padding_;
std::vector<int64> explicit_paddings_;
TensorFormat data_format_;
int64 stride_; // in height/width dimension.

View File

@ -32,8 +32,8 @@ struct DepthwiseArgs {
int filter_cols;
int depth_multiplier;
int stride;
int pad_rows;
int pad_cols;
int pad_rows; // Amount of padding to the top of the input
int pad_cols; // Amount of padding to the left of the input
// Output layer dimensions
int out_rows;

View File

@ -558,10 +558,11 @@ REGISTER_OP("DepthwiseConv2dNative")
.Output("output: T")
.Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetPaddingAttrStringWithExplicit())
.Attr(GetExplicitPaddingsAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding);
REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
.Input("input_sizes: int32")
@ -570,7 +571,8 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
.Output("output: T")
.Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetPaddingAttrStringWithExplicit())
.Attr(GetExplicitPaddingsAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
@ -588,7 +590,8 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
.Output("output: T")
.Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetPaddingAttrStringWithExplicit())
.Attr(GetExplicitPaddingsAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {

View File

@ -31,12 +31,12 @@ class NodeDef;
// Padding: the padding we apply to the input tensor along the rows and columns
// dimensions. This is usually used to make sure that the spatial dimensions do
// not shrink when we progress with convolutions. Two types of padding are
// not shrink when we progress with convolutions. Three types of padding are
// supported:
// VALID: No padding is carried out.
// SAME: The pad value is computed so that the output will have the same
// dimensions as the input.
// EXPLICIT: The user specifies the pad values in the explicit_padding
// EXPLICIT: The user specifies the pad values in the explicit_paddings
// attribute.
// The padded area is zero-filled.
enum Padding {

View File

@ -2941,6 +2941,7 @@ cuda_py_test(
size = "medium", # http://b/30603882
timeout = "long",
srcs = ["depthwise_conv_op_test.py"],
shard_count = 3,
# TODO(b/118842098): Re-enable this test in Kokoro.
tags = ["no_oss"],
deps = [

View File

@ -166,6 +166,43 @@ def ConfigsToTest():
]
def ConfigsToTestExplicit():
"""Iterator for different convolution shapes, strides and explicit paddings.
Returns:
List of tuples (input_size, filter_size, out_size, stride, padding,
dilations), the depthwise convolution parameters.
"""
def Config(input_size, filter_size, out_size, stride=1, padding=None,
dilations=None):
return input_size, filter_size, out_size, stride, padding, dilations
return [
Config([4, 5, 5, 48], [1, 1, 48, 2], [4, 8, 12, 96],
padding=[[1, 2], [3, 4]]),
Config([4, 1, 1, 3], [3, 3, 3, 2], [4, 29, 39, 6],
padding=[[10, 20], [15, 25]]),
Config([4, 9, 27, 8], [3, 3, 8, 1], [4, 14, 31, 8],
padding=[[3, 4], [4, 2]]),
Config([4, 31, 31, 7], [3, 3, 7, 1], [4, 29, 29, 7],
padding=[[0, 0], [0, 0]]),
Config([3, 299, 299, 3], [3, 2, 3, 8], [3, 150, 153, 24], 2,
padding=[[1, 2], [3, 5]]),
Config([5, 183, 183, 1], [5, 5, 1, 2], [5, 62, 60, 2], 3,
padding=[[3, 2], [1, 0]]),
Config([5, 29, 31, 1], [5, 4, 1, 2], [5, 26, 23, 2],
padding=[[3, 2], [1, 0]], dilations=[2, 3]),
# These cases test the kernels in depthwise_conv_op_gpu.h which are used
# if the input size is small.
Config([4, 5, 5, 48], [3, 3, 48, 1], [4, 5, 5, 48],
padding=[[0, 2], [0, 2]]),
Config([1, 8, 7, 2], [8, 7, 2, 1], [1, 8, 7, 2],
padding=[[0, 7], [3, 3]]),
Config([2, 4, 3, 2], [3, 2, 2, 1], [2, 4, 3, 2],
padding=[[2, 0], [1, 0]]),
]
def CheckGradConfigsToTest():
"""Iterator for different convolution shapes, strides and paddings.
@ -194,6 +231,39 @@ def CheckGradConfigsToTest():
]
def CheckGradConfigsToTestExplicit():
"""Iterator for different convolution shapes, strides and explicit paddings.
compute_gradient_error() is very expensive. So the configs should be
relatively small.
Returns:
List of tuples (input_size, filter_size, out_size, stride, padding,
dilations), the depthwise convolution parameters.
"""
def Config(input_size, filter_size, out_size, stride=1, padding=None,
dilations=None):
return input_size, filter_size, out_size, stride, padding, dilations
return [
Config([2, 5, 8, 1], [4, 4, 1, 2], [2, 3, 10, 2],
padding=[[0, 1], [2, 3]]),
Config([4, 5, 5, 1], [2, 2, 1, 2], [4, 4, 5, 2], 2,
padding=[[3, 1], [5, 0]]),
Config([2, 4, 4, 2], [3, 1, 2, 2], [2, 7, 11, 4],
padding=[[4, 1], [3, 4]]),
Config([1, 15, 15, 2], [1, 3, 2, 1], [1, 18, 23, 2],
padding=[[3, 0], [2, 8]]),
Config([2, 15, 16, 1], [3, 3, 1, 2], [2, 5, 8, 2], 3,
padding=[[0, 0], [10, 0]]),
Config([2, 5, 8, 1], [3, 4, 1, 2], [2, 5, 10, 2],
padding=[[3, 1], [2, 3]], dilations=[2, 1]),
# These cases test the kernels in depthwise_conv_op_gpu.h which are used
# if the input size is small.
Config([2, 4, 3, 2], [3, 2, 2, 1], [2, 4, 3, 2],
padding=[[2, 0], [1, 0]]),
]
class DepthwiseConv2DTest(test.TestCase):
# This tests depthwise_conv2d and depthwise_conv2d_native
@ -235,6 +305,8 @@ class DepthwiseConv2DTest(test.TestCase):
x2 = np.array(x2).reshape(filter_in_sizes)
# Compute reference result
strides = [1, stride, stride, 1]
if isinstance(padding, list):
padding = [(0, 0)] + padding + [(0, 0)]
np_result = _DepthwiseConv2dNumpy(x1, x2, strides, padding, "NHWC",
dilations)
@ -255,6 +327,8 @@ class DepthwiseConv2DTest(test.TestCase):
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
t1 = array_ops.transpose(t1, [0, 3, 1, 2])
strides = [1, 1, stride, stride]
if isinstance(padding, list):
padding = [padding[0], padding[3], padding[1], padding[2]]
# depthwise_conv2d_native does not support dilations except on TPUs.
if dilations is None:
@ -384,6 +458,23 @@ class DepthwiseConv2DTest(test.TestCase):
data_format="NCHW",
dilations=dilations)
@test_util.run_v1_only("b/120545219")
def testDepthwiseConv2DExplicit(self):
for index, (input_size, filter_size, _, stride,
padding, dilations) in enumerate(ConfigsToTestExplicit()):
tf_logging.info(
"Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
"%s", index, input_size, filter_size, stride, padding)
# double datatype is currently not supported for convolution ops
# on the ROCm platform
optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"]
for data_type in [dtypes.float16, dtypes.float32] + optional_float64:
for data_format in data_formats:
self._VerifyValues(
input_size, filter_size, stride, padding, data_type, use_gpu=True,
data_format=data_format, dilations=dilations)
# This is testing against hand calculated results.
def _VerifyHandValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
@ -530,6 +621,8 @@ class DepthwiseConv2DTest(test.TestCase):
native_input = input_tensor
strides = [1, stride, stride, 1]
if isinstance(padding, list):
padding = [(0, 0)] + padding + [(0, 0)]
if data_format == "NCHW":
# Transpose from NHWC input to NCHW
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
@ -541,6 +634,8 @@ class DepthwiseConv2DTest(test.TestCase):
output_shape[0], output_shape[3], output_shape[1], output_shape[2]
]
strides = [1, 1, stride, stride]
if isinstance(padding, list):
padding = [padding[0], padding[3], padding[1], padding[2]]
with sess.graph._kernel_label_map({
"DepthwiseConv2dNative": "cudnn_grouped_convolution",
@ -666,6 +761,32 @@ class DepthwiseConv2DTest(test.TestCase):
data_format="NCHW",
dilations=dilations)
@test_util.run_v1_only("b/120545219")
def testDepthwiseConv2DInputGradExplicit(self):
for index, (input_size, filter_size, output_size, stride, padding,
dilations) in enumerate(CheckGradConfigsToTestExplicit()):
tf_logging.info(
"Testing DepthwiseConv2DInputGradExplicit, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
# double datatype is currently not supported for convolution ops
# on the ROCm platform
optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"]
for data_type in [dtypes.float16, dtypes.float32] + optional_float64:
for data_format in data_formats:
self._ConstructAndTestGradient(
input_size,
filter_size,
output_size,
stride,
padding,
data_type,
test_input=True,
use_gpu=True,
data_format=data_format,
dilations=dilations)
@test_util.run_v1_only("b/120545219")
@test_util.run_cuda_only
def testDepthwiseConv2DFilterGradCudnn(self):
@ -750,10 +871,38 @@ class DepthwiseConv2DTest(test.TestCase):
data_format="NCHW",
dilations=dilations)
@test_util.run_v1_only("b/120545219")
def testDepthwiseConv2DFilterGradExplicit(self):
for index, (input_size, filter_size, output_size, stride, padding,
dilations) in enumerate(CheckGradConfigsToTestExplicit()):
tf_logging.info(
"Testing DepthwiseConv2DFilterGradExplicit, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
# double datatype is currently not supported for convolution ops
# on the ROCm platform
optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"]
for data_type in [dtypes.float16, dtypes.float32] + optional_float64:
for data_format in data_formats:
self._ConstructAndTestGradient(
input_size,
filter_size,
output_size,
stride,
padding,
data_type,
test_input=False,
use_gpu=True,
data_format=data_format,
dilations=dilations)
def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
stride, padding, dtype):
x1 = np.random.rand(*filter_sizes).astype(dtype)
x2 = np.random.rand(*output_sizes).astype(dtype)
if isinstance(padding, list):
padding = [(0, 0)] + padding + [(0, 0)]
def _GetVal(use_gpu):
with self.cached_session(use_gpu=use_gpu):
@ -788,10 +937,30 @@ class DepthwiseConv2DTest(test.TestCase):
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
padding, "float64")
def testDepthwiseConv2DInputGradExplicitCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding, dilations) in enumerate(ConfigsToTestExplicit()):
if dilations:
continue
tf_logging.info(
"Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
padding, "float32")
# double datatype is currently not supported for convolution ops
# on the ROCm platform
if test.is_built_with_rocm():
continue
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
padding, "float64")
def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
stride, padding, dtype):
x0 = np.random.rand(*input_sizes).astype(dtype)
x2 = np.random.rand(*output_sizes).astype(dtype)
if isinstance(padding, list):
padding = [(0, 0)] + padding + [(0, 0)]
def _GetVal(use_gpu):
with self.cached_session(use_gpu=use_gpu):
@ -826,6 +995,24 @@ class DepthwiseConv2DTest(test.TestCase):
self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
padding, "float64")
def testDepthwiseConv2DFilterGradExplicitCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding, dilations) in enumerate(ConfigsToTestExplicit()):
if dilations:
continue
tf_logging.info(
"Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
padding, "float32")
# double datatype is currently not supported for convolution ops
# on the ROCm platform
if test.is_built_with_rocm():
continue
self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
padding, "float64")
if __name__ == "__main__":
test.main()

View File

@ -104,20 +104,22 @@ def _DepthwiseConv2dNativeBackpropInputGrad(op, grad):
"""
return [
None,
nn_ops.depthwise_conv2d_native_backprop_filter(
gen_nn_ops.depthwise_conv2d_native_backprop_filter(
grad,
array_ops.shape(op.inputs[1]),
op.inputs[2],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format")),
nn_ops.depthwise_conv2d_native(
gen_nn_ops.depthwise_conv2d_native(
grad,
op.inputs[1],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format"))
]
@ -125,20 +127,22 @@ def _DepthwiseConv2dNativeBackpropInputGrad(op, grad):
@ops.RegisterGradient("DepthwiseConv2dNativeBackpropFilter")
def _DepthwiseConv2dNativeBackpropFilterGrad(op, grad):
return [
nn_ops.depthwise_conv2d_native_backprop_input(
gen_nn_ops.depthwise_conv2d_native_backprop_input(
array_ops.shape(op.inputs[0]),
grad,
op.inputs[2],
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format")), None,
nn_ops.depthwise_conv2d_native(
gen_nn_ops.depthwise_conv2d_native(
op.inputs[0],
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format"))
]
@ -606,21 +610,23 @@ def _Conv2DGrad(op, grad):
@ops.RegisterGradient("DepthwiseConv2dNative")
def _DepthwiseConv2dNativeGrad(op, grad):
return [
nn_ops.depthwise_conv2d_native_backprop_input(
gen_nn_ops.depthwise_conv2d_native_backprop_input(
array_ops.shape(op.inputs[0]),
op.inputs[1],
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format")),
nn_ops.depthwise_conv2d_native_backprop_filter(
gen_nn_ops.depthwise_conv2d_native_backprop_filter(
op.inputs[0],
array_ops.shape(op.inputs[1]),
grad,
dilations=op.get_attr("dilations"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
explicit_paddings=op.get_attr("explicit_paddings"),
data_format=op.get_attr("data_format"))
]

View File

@ -740,14 +740,50 @@ def depthwise_conv2d(input,
convolution, in which case all values in the `strides` tensor must be equal
to 1.
Usage Example:
>>> x = np.array([
... [1., 2.],
... [3., 4.],
... [5., 6.]
... ], dtype=np.float32).reshape((1, 3, 2, 1))
>>> kernel = np.array([
... [1., 2.],
... [3., 4]
... ], dtype=np.float32).reshape((2, 1, 1, 2))
>>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
... padding='VALID').numpy()
array([[[[10., 14.],
[14., 20.]],
[[18., 26.],
[22., 32.]]]], dtype=float32)
>>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
... padding=[[0, 0], [1, 0], [1, 0], [0, 0]]
... ).numpy()
array([[[[ 0., 0.],
[ 3., 4.],
[ 6., 8.]],
[[ 0., 0.],
[10., 14.],
[14., 20.]],
[[ 0., 0.],
[18., 26.],
[22., 32.]]]], dtype=float32)
Args:
input: 4-D with shape according to `data_format`.
filter: 4-D with shape
`[filter_height, filter_width, in_channels, channel_multiplier]`.
strides: 1-D of size 4. The stride of the sliding window for each
dimension of `input`.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the "returns" section of `tf.nn.convolution` for details.
padding: Controls how to pad the image before applying the convolution. Can
be the string `"SAME"` or `"VALID"` indicating the type of padding
algorithm to use, or a list indicating the explicit paddings at the start
and end of each dimension. When explicit padding is used and data_format
is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
[pad_left, pad_right], [0, 0]]`. When explicit padding used and
data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`.
rate: 1-D of size 2. The dilation rate in which we sample input values
across the `height` and `width` dimensions in atrous convolution. If it is
greater than 1, then all values of strides must be 1.
@ -830,14 +866,49 @@ def depthwise_conv2d_v2(input,
convolution, in which case all values in the `strides` tensor must be equal
to 1.
Usage Example:
>>> x = np.array([
... [1., 2.],
... [3., 4.],
... [5., 6.]
... ], dtype=np.float32).reshape((1, 3, 2, 1))
>>> kernel = np.array([
... [1., 2.],
... [3., 4]
... ], dtype=np.float32).reshape((2, 1, 1, 2))
>>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
... padding='VALID').numpy()
array([[[[10., 14.],
[14., 20.]],
[[18., 26.],
[22., 32.]]]], dtype=float32)
>>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
... padding=[[0, 0], [1, 0], [1, 0], [0, 0]]).numpy()
array([[[[ 0., 0.],
[ 3., 4.],
[ 6., 8.]],
[[ 0., 0.],
[10., 14.],
[14., 20.]],
[[ 0., 0.],
[18., 26.],
[22., 32.]]]], dtype=float32)
Args:
input: 4-D with shape according to `data_format`.
filter: 4-D with shape
`[filter_height, filter_width, in_channels, channel_multiplier]`.
strides: 1-D of size 4. The stride of the sliding window for each
dimension of `input`.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the "returns" section of `tf.nn.convolution` for details.
padding: Controls how to pad the image before applying the convolution. Can
be the string `"SAME"` or `"VALID"` indicating the type of padding
algorithm to use, or a list indicating the explicit paddings at the start
and end of each dimension. When explicit padding is used and data_format
is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
[pad_left, pad_right], [0, 0]]`. When explicit padding used and
data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`.
data_format: The data format for input. Either "NHWC" (default) or "NCHW".
dilations: 1-D of size 2. The dilation rate in which we sample input values
across the `height` and `width` dimensions in atrous convolution. If it is

View File

@ -557,6 +557,8 @@ class _WithSpaceToBatch(object):
self.call = build_op(num_spatial_dims, padding)
return
padding, explicit_paddings = convert_padding(padding)
# We have two padding contributions. The first is used for converting "SAME"
# to "VALID". The second is required so that the height and width of the
# zero-padded value tensor are multiples of rate.
@ -577,6 +579,14 @@ class _WithSpaceToBatch(object):
self.base_paddings = None
elif padding == "VALID":
self.base_paddings = np.zeros([num_spatial_dims, 2], np.int32)
elif padding == "EXPLICIT":
base_paddings = (np.array(explicit_paddings)
.reshape([num_spatial_dims + 2, 2]))
# Remove batch and channel dimensions
if data_format is not None and data_format.startswith("NC"):
self.base_paddings = base_paddings[2:]
else:
self.base_paddings = base_paddings[1:-1]
else:
raise ValueError("Invalid padding method %r" % padding)
@ -1528,7 +1538,7 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
name=name)
def _convert_padding(padding):
def convert_padding(padding):
"""Converts Python padding to C++ padding for ops which take EXPLICIT padding.
Args:
@ -1857,7 +1867,7 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin
... [[1], [3], [2], [2], [3]],
... [[1], [1], [3], [3], [0]],
... [[2], [2], [0], [1], [1]],
... [[0], [0], [3], [1], [2]], ]])
... [[0], [0], [3], [1], [2]], ]])
>>> kernel_in = np.array([
... [ [[2, 0.1]], [[3, 0.2]] ],
... [ [[0, 0.3]],[[1, 0.4]] ], ])
@ -1996,7 +2006,7 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value
"""
filter = deprecation.deprecated_argument_lookup(
"filters", filters, "filter", filter)
padding, explicit_paddings = _convert_padding(padding)
padding, explicit_paddings = convert_padding(padding)
if data_format is None:
data_format = "NHWC"
channel_index = 1 if data_format.startswith("NC") else 3
@ -2068,7 +2078,7 @@ def conv2d_backprop_filter( # pylint: disable=redefined-builtin,dangerous-defau
Returns:
A `Tensor`. Has the same type as `input`.
"""
padding, explicit_paddings = _convert_padding(padding)
padding, explicit_paddings = convert_padding(padding)
return gen_nn_ops.conv2d_backprop_filter(
input, filter_sizes, out_backprop, strides, padding, use_cudnn_on_gpu,
explicit_paddings, data_format, dilations, name)
@ -2132,7 +2142,7 @@ def conv2d_backprop_input( # pylint: disable=redefined-builtin,dangerous-defaul
"""
filter = deprecation.deprecated_argument_lookup(
"filters", filters, "filter", filter)
padding, explicit_paddings = _convert_padding(padding)
padding, explicit_paddings = convert_padding(padding)
return gen_nn_ops.conv2d_backprop_input(
input_sizes, filter, out_backprop, strides, padding, use_cudnn_on_gpu,
explicit_paddings, data_format, dilations, name)
@ -2449,6 +2459,219 @@ def atrous_conv2d_transpose(value,
input=value, crops=batch_to_space_crop, block_size=rate)
@tf_export(v1=["nn.depthwise_conv2d_native"])
@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native")
def depthwise_conv2d_native( # pylint: disable=redefined-builtin,dangerous-default-value
input,
filter,
strides,
padding,
data_format="NHWC",
dilations=[1, 1, 1, 1],
name=None):
r"""Computes a 2-D depthwise convolution.
Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
and a filter / kernel tensor of shape
`[filter_height, filter_width, in_channels, channel_multiplier]`, containing
`in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies
a different filter to each input channel (expanding from 1 channel to
`channel_multiplier` channels for each), then concatenates the results
together. Thus, the output has `in_channels * channel_multiplier` channels.
```
for k in 0..in_channels-1
for q in 0..channel_multiplier-1
output[b, i, j, k * channel_multiplier + q] =
sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
filter[di, dj, k, q]
```
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
Args:
input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
`float32`, `float64`.
filter: A `Tensor`. Must have the same type as `input`.
strides: A list of `ints`. 1-D of length 4. The stride of the sliding
window for each dimension of `input`.
padding: Controls how to pad the image before applying the convolution. Can
be the string `"SAME"` or `"VALID"` indicating the type of padding
algorithm to use, or a list indicating the explicit paddings at the start
and end of each dimension. When explicit padding is used and data_format
is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
[pad_left, pad_right], [0, 0]]`. When explicit padding used and
data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`.
data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
`"NHWC"`. Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of: [batch, height,
width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
tensor of length 4. The dilation factor for each dimension of `input`. If
set to k > 1, there will be k-1 skipped cells between each filter element
on that dimension. The dimension order is determined by the value of
`data_format`, see above for details. Dilations in the batch and depth
dimensions must be 1.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `input`.
"""
padding, explicit_paddings = convert_padding(padding)
return gen_nn_ops.depthwise_conv2d_native(
input,
filter,
strides,
padding,
explicit_paddings=explicit_paddings,
data_format=data_format,
dilations=dilations,
name=name)
@tf_export(
"nn.depthwise_conv2d_backprop_input",
v1=[
"nn.depthwise_conv2d_native_backprop_input",
"nn.depthwise_conv2d_backprop_input"
])
@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_input")
def depthwise_conv2d_native_backprop_input( # pylint: disable=redefined-builtin,dangerous-default-value
input_sizes,
filter,
out_backprop,
strides,
padding,
data_format="NHWC",
dilations=[1, 1, 1, 1],
name=None):
r"""Computes the gradients of depthwise convolution with respect to the input.
Args:
input_sizes: A `Tensor` of type `int32`. An integer vector representing the
shape of `input`, based on `data_format`. For example, if `data_format`
is 'NHWC' then `input` is a 4-D `[batch, height, width, channels]` tensor.
filter: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
`float32`, `float64`. 4-D with shape `[filter_height, filter_width,
in_channels, depthwise_multiplier]`.
out_backprop: A `Tensor`. Must have the same type as `filter`. 4-D with
shape based on `data_format`. For example, if `data_format` is 'NHWC'
then out_backprop shape is `[batch, out_height, out_width, out_channels]`.
Gradients w.r.t. the output of the convolution.
strides: A list of `ints`. The stride of the sliding window for each
dimension of the input of the convolution.
padding: Controls how to pad the image before applying the convolution. Can
be the string `"SAME"` or `"VALID"` indicating the type of padding
algorithm to use, or a list indicating the explicit paddings at the start
and end of each dimension. When explicit padding is used and data_format
is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
[pad_left, pad_right], [0, 0]]`. When explicit padding used and
data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`.
data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
`"NHWC"`. Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of: [batch, height,
width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
tensor of length 4. The dilation factor for each dimension of `input`. If
set to k > 1, there will be k-1 skipped cells between each filter element
on that dimension. The dimension order is determined by the value of
`data_format`, see above for details. Dilations in the batch and depth
dimensions must be 1.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `filter`.
"""
padding, explicit_paddings = convert_padding(padding)
return gen_nn_ops.depthwise_conv2d_native_backprop_input(
input_sizes,
filter,
out_backprop,
strides,
padding,
explicit_paddings=explicit_paddings,
data_format=data_format,
dilations=dilations,
name=name)
@tf_export(
"nn.depthwise_conv2d_backprop_filter",
v1=[
"nn.depthwise_conv2d_native_backprop_filter",
"nn.depthwise_conv2d_backprop_filter"
])
@deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_filter")
def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builtin,dangerous-default-value
input,
filter_sizes,
out_backprop,
strides,
padding,
data_format="NHWC",
dilations=[1, 1, 1, 1],
name=None):
r"""Computes the gradients of depthwise convolution with respect to the filter.
Args:
input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`,
`float32`, `float64`. 4-D with shape based on `data_format`. For example,
if `data_format` is 'NHWC' then `input` is a 4-D `[batch, in_height,
in_width, in_channels]` tensor.
filter_sizes: A `Tensor` of type `int32`. An integer vector representing the
tensor shape of `filter`, where `filter` is a 4-D `[filter_height,
filter_width, in_channels, depthwise_multiplier]` tensor.
out_backprop: A `Tensor`. Must have the same type as `input`. 4-D with shape
based on `data_format`. For example, if `data_format` is 'NHWC' then
out_backprop shape is `[batch, out_height, out_width, out_channels]`.
Gradients w.r.t. the output of the convolution.
strides: A list of `ints`. The stride of the sliding window for each
dimension of the input of the convolution.
padding: Controls how to pad the image before applying the convolution. Can
be the string `"SAME"` or `"VALID"` indicating the type of padding
algorithm to use, or a list indicating the explicit paddings at the start
and end of each dimension. When explicit padding is used and data_format
is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
[pad_left, pad_right], [0, 0]]`. When explicit padding used and
data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]]`.
data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to
`"NHWC"`. Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of: [batch, height,
width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`. 1-D
tensor of length 4. The dilation factor for each dimension of `input`. If
set to k > 1, there will be k-1 skipped cells between each filter element
on that dimension. The dimension order is determined by the value of
`data_format`, see above for details. Dilations in the batch and depth
dimensions must be 1.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `input`.
"""
padding, explicit_paddings = convert_padding(padding)
return gen_nn_ops.depthwise_conv2d_native_backprop_filter(
input,
filter_sizes,
out_backprop,
strides,
padding,
explicit_paddings=explicit_paddings,
data_format=data_format,
dilations=dilations,
name=name)
@tf_export("nn.conv3d", v1=[])
def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring
filters,

View File

@ -1098,15 +1098,15 @@ tf_module {
}
member_method {
name: "DepthwiseConv2dNative"
argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "DepthwiseConv2dNativeBackpropFilter"
argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "DepthwiseConv2dNativeBackpropInput"
argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "Dequantize"

View File

@ -1098,15 +1098,15 @@ tf_module {
}
member_method {
name: "DepthwiseConv2dNative"
argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "DepthwiseConv2dNativeBackpropFilter"
argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "DepthwiseConv2dNativeBackpropInput"
argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'explicit_paddings\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "Dequantize"