diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index e791ff9ff60..132c4d6dd68 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -2146,6 +2146,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, "Stride must be 1 for batch and channel dimensions, at ", node_def.name()); } + // Channel dim must be static for DepthwiseConv2dNative since we use that + // value for num_groups at build time. + if (!params->use_implicit_batch && tensor->getDimensions().d[c_index] == -1) { + return errors::InvalidArgument("Channel dimension must be static, at ", + node_def.name()); + } const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); if (params->validation_only) return Status::OK(); @@ -2157,11 +2163,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, } // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); + const int c_dim_size = tensor_dim.d[params->use_implicit_batch ? 0 : 1]; // group == 0 signifies that this is a depthwise convolution, so set // num_groups to size of input's channel dim. For a non-depthwise conv, // num_groups will be 1. - const int num_groups = (group == 0) ? tensor_dim.d[0] : group; + const int num_groups = (group == 0) ? c_dim_size : group; // For conv, TF weights are RSCK, and TRT expects KCRS. // For backprop, TF weights are RSKC, and TRT expects CKRS. diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index d4badd1cc03..57b2e13fad0 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -4037,15 +4037,16 @@ TEST_F(OpConverterTest, ConvertSlice) { } } -TEST_F(OpConverterTest, ConvertConv2D) { +TEST_P(OpConverterTest1, ConvertConv2D) { // Get nodedef for Conv2D layer. + DataType tf_type = tf_dtype; auto get_conv2d_nodedef = - [](std::vector strides = {1, 1, 1, 1}, string padding = "SAME", - string data_format = "NCHW", - std::vector dilations = {1, 1, 1, 1}) -> NodeDef { + [tf_type](std::vector strides = {1, 1, 1, 1}, + string padding = "SAME", string data_format = "NCHW", + std::vector dilations = {1, 1, 1, 1}) -> NodeDef { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type); ops::Conv2D::Attrs attrs = ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, @@ -4067,7 +4068,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Filter is tensor, should fail. Reset(); NodeDef node_def = get_conv2d_nodedef(); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {3, 1, 2, 1}); AddTestTensor("weights", {3, 3, 1, 1}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -4077,7 +4078,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Filter is not 4D, should fail. Reset(); NodeDef node_def = get_conv2d_nodedef(); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4088,7 +4089,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4099,7 +4100,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "Dilation rate must be 1 for batch and channel " @@ -4110,7 +4111,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2}); - AddTestTensor("input", {2, 3, 1}); + AddTestTensor("input", {1, 2, 3, 1}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "Dilation rate must be 1 for batch and channel " @@ -4121,7 +4122,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4132,12 +4133,23 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "Stride must be 1 for batch and channel dimensions, at my_conv2d"); } + if (trt_mode == TrtTestMode::kDynamicShape) { + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + // Channel dim unknown, should fail. + AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, + TfDataTypeToTrt(tf_type)); + AddTestWeights("weights", {1, 2, 1, 1}, {-1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Channel dimension must be static, at my_conv2d"); + } struct TestParams { std::vector input_dims; @@ -4155,7 +4167,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Ok. std::vector ok_params = { // Basic - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4163,10 +4175,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 2}, + /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 1, 0, 1}}, // SAME padding (Asymmetric) - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4174,10 +4186,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 3}, + /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/{1, 1, -2, 0, 1, -4}}, // SAME padding (Symmetric) - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 3, 1, 1}, /*filter=*/{-1, 0, 1}, @@ -4185,10 +4197,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 3}, + /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/{1, 2, -1, 3, 1, -3}}, // NHWC - TestParams{/*input_dims=*/{2, 3, 1}, + TestParams{/*input_dims=*/{1, 2, 3, 1}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4196,10 +4208,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NHWC", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{2, 2, 1}, + /*expected_output_dims=*/{1, 2, 2, 1}, /*expected_output=*/{1, 1, 0, 1}}, // Dilated - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4207,10 +4219,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 2}, - /*expected_output_dims=*/{1, 2, 1}, + /*expected_output_dims=*/{1, 1, 2, 1}, /*expected_output=*/{2, 1}}, // Strided - TestParams{/*input_dims=*/{1, 2, 4}, + TestParams{/*input_dims=*/{1, 1, 2, 4}, /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4218,7 +4230,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 2}, + /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 0, 1, 3}}, }; @@ -4227,23 +4239,22 @@ TEST_F(OpConverterTest, ConvertConv2D) { NodeDef node_def = get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, ok_params[i].dilations); - AddTestTensor("input", ok_params[i].input_dims); + std::vector partial_input_shape; + if (trt_mode == TrtTestMode::kDynamicShape) { + // The channel dim cannot have unknown size, fix that. + partial_input_shape.resize(ok_params[i].input_dims.size(), -1); + int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3; + partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id]; + } + + AddTestTensor("input", ok_params[i].input_dims, tf_dtype, + ok_params[i].input, partial_input_shape); AddTestWeights("weights", ok_params[i].filter_dims, ok_params[i].filter); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, - output.tensor()->getDimensions()); - const DataVec input_data{{"input", AsTensor(ok_params[i].input)}}; - DataVec output_data{ - {"my_conv2d", - ConstructTensor(ok_params[i].expected_output.size())}}; - TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAreArray(ok_params[i].expected_output)); + TestOpConverter("my_conv2d", node_def, ok_params[i].expected_output_dims, + Status::OK(), Status::OK(), + ElementsAreArray(ok_params[i].expected_output)); } }