diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 38024fcf68b..9df5cbdec06 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -297,17 +297,18 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { const int rank = GetTensorDimsFromSpatialDims(2, data_format); ShapeHandle input_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); - // The filter of a 2D convolution is always 4D. + // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C). ShapeHandle filter_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); - + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); - if (strides.size() != rank) { + // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C). + if (strides.size() != 4) { return errors::InvalidArgument("Conv2D on data format ", data_format_str, - " requires the stride attribute to contain ", - rank, " values, but got: ", strides.size()); + " requires the stride attribute to contain" + " 4 values, but got: ", + strides.size()); } int32 stride_rows, stride_cols; @@ -326,15 +327,29 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { &batch_size_dim, &input_spatial_dims, &input_depth_dim, c)); - DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0); - DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1); - DimensionHandle output_depth_dim = c->Dim(filter_shape, 3); + DimensionHandle output_depth_dim, filter_rows_dim, filter_cols_dim, + filter_input_depth_dim; + // If the input format is NCHW_VECT_C, the filter format is assumed to be + // OIHW_VECT_I, otherwise it is assumed to be HWIO. + if (data_format == FORMAT_NCHW_VECT_C) { + output_depth_dim = c->Dim(filter_shape, 0); + TF_RETURN_IF_ERROR(c->Multiply(c->Dim(filter_shape, 1), + c->Dim(filter_shape, 4), + &filter_input_depth_dim)); + filter_rows_dim = c->Dim(filter_shape, 2); + filter_cols_dim = c->Dim(filter_shape, 3); + } else { + filter_rows_dim = c->Dim(filter_shape, 0); + filter_cols_dim = c->Dim(filter_shape, 1); + filter_input_depth_dim = c->Dim(filter_shape, 2); + output_depth_dim = c->Dim(filter_shape, 3); + } // Check that the input tensor and the filter tensor agree on the input // channel count. DimensionHandle unused; TF_RETURN_IF_ERROR( - c->Merge(input_depth_dim, c->Dim(filter_shape, 2), &unused)); + c->Merge(input_depth_dim, filter_input_depth_dim, &unused)); Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 37e211ad683..416478f8542 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -481,24 +481,24 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) { // Tests for NCHW_VECT_C // 1x1 filter - set_op({{1, 1, 1, 1, 1}}, "VALID", "NCHW_VECT_C"); - INFER_OK(op, "[1,1,2,2,4];[1,1,4,4]", "[d0_0,1,2,2,4]"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C"); + INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]"); // 2x2 filter - set_op({{1, 1, 1, 1, 1}}, "VALID", "NCHW_VECT_C"); - INFER_OK(op, "[1,1,2,2,4];[2,2,4,4]", "[d0_0,1,1,1,4]"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C"); + INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]"); // 3x3 input, 1x1 filter, 2x2 stride - set_op({{1, 1, 2, 2, 1}}, "VALID", "NCHW_VECT_C"); - INFER_OK(op, "[1,1,3,3,4];[1,1,4,8]", "[d0_0,2,2,2,4]"); + set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C"); + INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]"); // 3x3 input, 1x1 filter, 2x1 stride - set_op({{1, 1, 2, 1, 1}}, "VALID", "NCHW_VECT_C"); - INFER_OK(op, "[1,1,3,3,4];[1,1,4,4]", "[d0_0,1,2,3,4]"); + set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C"); + INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]"); // 4x4 input, 2x1 filter, 1x2 stride - set_op({{1, 1, 1, 2, 1}}, "VALID", "NCHW_VECT_C"); - INFER_OK(op, "[1,1,4,4,4];[2,1,4,4]", "[d0_0,1,3,2,4]"); + set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C"); + INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]"); // Some tests for "SAME" padding