Update Conv2DShape function to handle filters that have data NCHW_VECT_C layout.
PiperOrigin-RevId: 163746769
This commit is contained in:
parent
efb7fb8e58
commit
104f349e9e
tensorflow/core/framework
@ -297,17 +297,18 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
|
||||
const int rank = GetTensorDimsFromSpatialDims(2, data_format);
|
||||
ShapeHandle input_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
|
||||
// The filter of a 2D convolution is always 4D.
|
||||
// The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
|
||||
ShapeHandle filter_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
|
||||
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
|
||||
std::vector<int32> strides;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
||||
|
||||
if (strides.size() != rank) {
|
||||
// strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
|
||||
if (strides.size() != 4) {
|
||||
return errors::InvalidArgument("Conv2D on data format ", data_format_str,
|
||||
" requires the stride attribute to contain ",
|
||||
rank, " values, but got: ", strides.size());
|
||||
" requires the stride attribute to contain"
|
||||
" 4 values, but got: ",
|
||||
strides.size());
|
||||
}
|
||||
|
||||
int32 stride_rows, stride_cols;
|
||||
@ -326,15 +327,29 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
|
||||
&batch_size_dim, &input_spatial_dims,
|
||||
&input_depth_dim, c));
|
||||
|
||||
DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
|
||||
DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
|
||||
DimensionHandle output_depth_dim = c->Dim(filter_shape, 3);
|
||||
DimensionHandle output_depth_dim, filter_rows_dim, filter_cols_dim,
|
||||
filter_input_depth_dim;
|
||||
// If the input format is NCHW_VECT_C, the filter format is assumed to be
|
||||
// OIHW_VECT_I, otherwise it is assumed to be HWIO.
|
||||
if (data_format == FORMAT_NCHW_VECT_C) {
|
||||
output_depth_dim = c->Dim(filter_shape, 0);
|
||||
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(filter_shape, 1),
|
||||
c->Dim(filter_shape, 4),
|
||||
&filter_input_depth_dim));
|
||||
filter_rows_dim = c->Dim(filter_shape, 2);
|
||||
filter_cols_dim = c->Dim(filter_shape, 3);
|
||||
} else {
|
||||
filter_rows_dim = c->Dim(filter_shape, 0);
|
||||
filter_cols_dim = c->Dim(filter_shape, 1);
|
||||
filter_input_depth_dim = c->Dim(filter_shape, 2);
|
||||
output_depth_dim = c->Dim(filter_shape, 3);
|
||||
}
|
||||
|
||||
// Check that the input tensor and the filter tensor agree on the input
|
||||
// channel count.
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(input_depth_dim, c->Dim(filter_shape, 2), &unused));
|
||||
c->Merge(input_depth_dim, filter_input_depth_dim, &unused));
|
||||
|
||||
Padding padding;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
||||
|
@ -481,24 +481,24 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
|
||||
|
||||
// Tests for NCHW_VECT_C
|
||||
// 1x1 filter
|
||||
set_op({{1, 1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,2,2,4];[1,1,4,4]", "[d0_0,1,2,2,4]");
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]");
|
||||
|
||||
// 2x2 filter
|
||||
set_op({{1, 1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,2,2,4];[2,2,4,4]", "[d0_0,1,1,1,4]");
|
||||
set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]");
|
||||
|
||||
// 3x3 input, 1x1 filter, 2x2 stride
|
||||
set_op({{1, 1, 2, 2, 1}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,3,3,4];[1,1,4,8]", "[d0_0,2,2,2,4]");
|
||||
set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]");
|
||||
|
||||
// 3x3 input, 1x1 filter, 2x1 stride
|
||||
set_op({{1, 1, 2, 1, 1}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,3,3,4];[1,1,4,4]", "[d0_0,1,2,3,4]");
|
||||
set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]");
|
||||
|
||||
// 4x4 input, 2x1 filter, 1x2 stride
|
||||
set_op({{1, 1, 1, 2, 1}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,4,4,4];[2,1,4,4]", "[d0_0,1,3,2,4]");
|
||||
set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C");
|
||||
INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]");
|
||||
|
||||
// Some tests for "SAME" padding
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user