From 104f349e9eda70a21533294d2fb09f0bf0828834 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Mon, 31 Jul 2017 14:17:21 -0700
Subject: [PATCH] Update Conv2DShape function to handle filters that have data
 NCHW_VECT_C layout.

PiperOrigin-RevId: 163746769
---
 tensorflow/core/framework/common_shape_fns.cc | 35 +++++++++++++------
 .../core/framework/common_shape_fns_test.cc   | 20 +++++------
 2 files changed, 35 insertions(+), 20 deletions(-)

diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 38024fcf68b..9df5cbdec06 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -297,17 +297,18 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
   const int rank = GetTensorDimsFromSpatialDims(2, data_format);
   ShapeHandle input_shape;
   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
-  // The filter of a 2D convolution is always 4D.
+  // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
   ShapeHandle filter_shape;
-  TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
-
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
   std::vector<int32> strides;
   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
 
-  if (strides.size() != rank) {
+  // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
+  if (strides.size() != 4) {
     return errors::InvalidArgument("Conv2D on data format ", data_format_str,
-                                   " requires the stride attribute to contain ",
-                                   rank, " values, but got: ", strides.size());
+                                   " requires the stride attribute to contain"
+                                   " 4 values, but got: ",
+                                   strides.size());
   }
 
   int32 stride_rows, stride_cols;
@@ -326,15 +327,29 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
                                          &batch_size_dim, &input_spatial_dims,
                                          &input_depth_dim, c));
 
-  DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
-  DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
-  DimensionHandle output_depth_dim = c->Dim(filter_shape, 3);
+  DimensionHandle output_depth_dim, filter_rows_dim, filter_cols_dim,
+      filter_input_depth_dim;
+  // If the input format is NCHW_VECT_C, the filter format is assumed to be
+  // OIHW_VECT_I, otherwise it is assumed to be HWIO.
+  if (data_format == FORMAT_NCHW_VECT_C) {
+    output_depth_dim = c->Dim(filter_shape, 0);
+    TF_RETURN_IF_ERROR(c->Multiply(c->Dim(filter_shape, 1),
+                                   c->Dim(filter_shape, 4),
+                                   &filter_input_depth_dim));
+    filter_rows_dim = c->Dim(filter_shape, 2);
+    filter_cols_dim = c->Dim(filter_shape, 3);
+  } else {
+    filter_rows_dim = c->Dim(filter_shape, 0);
+    filter_cols_dim = c->Dim(filter_shape, 1);
+    filter_input_depth_dim = c->Dim(filter_shape, 2);
+    output_depth_dim = c->Dim(filter_shape, 3);
+  }
 
   // Check that the input tensor and the filter tensor agree on the input
   // channel count.
   DimensionHandle unused;
   TF_RETURN_IF_ERROR(
-      c->Merge(input_depth_dim, c->Dim(filter_shape, 2), &unused));
+      c->Merge(input_depth_dim, filter_input_depth_dim, &unused));
 
   Padding padding;
   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index 37e211ad683..416478f8542 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -481,24 +481,24 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
 
   // Tests for NCHW_VECT_C
   // 1x1 filter
-  set_op({{1, 1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
-  INFER_OK(op, "[1,1,2,2,4];[1,1,4,4]", "[d0_0,1,2,2,4]");
+  set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
+  INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]");
 
   // 2x2 filter
-  set_op({{1, 1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
-  INFER_OK(op, "[1,1,2,2,4];[2,2,4,4]", "[d0_0,1,1,1,4]");
+  set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
+  INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]");
 
   // 3x3 input, 1x1 filter, 2x2 stride
-  set_op({{1, 1, 2, 2, 1}}, "VALID", "NCHW_VECT_C");
-  INFER_OK(op, "[1,1,3,3,4];[1,1,4,8]", "[d0_0,2,2,2,4]");
+  set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C");
+  INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]");
 
   // 3x3 input, 1x1 filter, 2x1 stride
-  set_op({{1, 1, 2, 1, 1}}, "VALID", "NCHW_VECT_C");
-  INFER_OK(op, "[1,1,3,3,4];[1,1,4,4]", "[d0_0,1,2,3,4]");
+  set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C");
+  INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]");
 
   // 4x4 input, 2x1 filter, 1x2 stride
-  set_op({{1, 1, 1, 2, 1}}, "VALID", "NCHW_VECT_C");
-  INFER_OK(op, "[1,1,4,4,4];[2,1,4,4]", "[d0_0,1,3,2,4]");
+  set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C");
+  INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]");
 
   // Some tests for "SAME" padding