Update Conv2DShape function to handle filters that have data NCHW_VECT_C layout.

PiperOrigin-RevId: 163746769
This commit is contained in:
A. Unique TensorFlower 2017-07-31 14:17:21 -07:00 committed by TensorFlower Gardener
parent efb7fb8e58
commit 104f349e9e
2 changed files with 35 additions and 20 deletions

View File

@ -297,17 +297,18 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
const int rank = GetTensorDimsFromSpatialDims(2, data_format);
ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
// The filter of a 2D convolution is always 4D.
// The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
ShapeHandle filter_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
std::vector<int32> strides;
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
if (strides.size() != rank) {
// strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
if (strides.size() != 4) {
return errors::InvalidArgument("Conv2D on data format ", data_format_str,
" requires the stride attribute to contain ",
rank, " values, but got: ", strides.size());
" requires the stride attribute to contain"
" 4 values, but got: ",
strides.size());
}
int32 stride_rows, stride_cols;
@ -326,15 +327,29 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
&batch_size_dim, &input_spatial_dims,
&input_depth_dim, c));
DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
DimensionHandle output_depth_dim = c->Dim(filter_shape, 3);
DimensionHandle output_depth_dim, filter_rows_dim, filter_cols_dim,
filter_input_depth_dim;
// If the input format is NCHW_VECT_C, the filter format is assumed to be
// OIHW_VECT_I, otherwise it is assumed to be HWIO.
if (data_format == FORMAT_NCHW_VECT_C) {
output_depth_dim = c->Dim(filter_shape, 0);
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(filter_shape, 1),
c->Dim(filter_shape, 4),
&filter_input_depth_dim));
filter_rows_dim = c->Dim(filter_shape, 2);
filter_cols_dim = c->Dim(filter_shape, 3);
} else {
filter_rows_dim = c->Dim(filter_shape, 0);
filter_cols_dim = c->Dim(filter_shape, 1);
filter_input_depth_dim = c->Dim(filter_shape, 2);
output_depth_dim = c->Dim(filter_shape, 3);
}
// Check that the input tensor and the filter tensor agree on the input
// channel count.
DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->Merge(input_depth_dim, c->Dim(filter_shape, 2), &unused));
c->Merge(input_depth_dim, filter_input_depth_dim, &unused));
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));

View File

@ -481,24 +481,24 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
// Tests for NCHW_VECT_C
// 1x1 filter
set_op({{1, 1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,2,2,4];[1,1,4,4]", "[d0_0,1,2,2,4]");
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]");
// 2x2 filter
set_op({{1, 1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,2,2,4];[2,2,4,4]", "[d0_0,1,1,1,4]");
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]");
// 3x3 input, 1x1 filter, 2x2 stride
set_op({{1, 1, 2, 2, 1}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,3,3,4];[1,1,4,8]", "[d0_0,2,2,2,4]");
set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]");
// 3x3 input, 1x1 filter, 2x1 stride
set_op({{1, 1, 2, 1, 1}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,3,3,4];[1,1,4,4]", "[d0_0,1,2,3,4]");
set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]");
// 4x4 input, 2x1 filter, 1x2 stride
set_op({{1, 1, 1, 2, 1}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,4,4,4];[2,1,4,4]", "[d0_0,1,3,2,4]");
set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C");
INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]");
// Some tests for "SAME" padding