Merge pull request #39204 from tfeher:trt_conv2D_dynamic_shape

PiperOrigin-RevId: 313205518
Change-Id: I06df91df11f6009740bbe466bc22afbeb29e9981
This commit is contained in:
TensorFlower Gardener 2020-05-26 09:35:38 -07:00
commit 00664cef68
2 changed files with 58 additions and 40 deletions

View File

@ -2146,6 +2146,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
"Stride must be 1 for batch and channel dimensions, at ",
node_def.name());
}
// Channel dim must be static for DepthwiseConv2dNative since we use that
// value for num_groups at build time.
if (!params->use_implicit_batch && tensor->getDimensions().d[c_index] == -1) {
return errors::InvalidArgument("Channel dimension must be static, at ",
node_def.name());
}
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
if (params->validation_only) return Status::OK();
@ -2157,11 +2163,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
}
// Dimensions of transposed tensor.
const auto tensor_dim = tensor->getDimensions();
const int c_dim_size = tensor_dim.d[params->use_implicit_batch ? 0 : 1];
// group == 0 signifies that this is a depthwise convolution, so set
// num_groups to size of input's channel dim. For a non-depthwise conv,
// num_groups will be 1.
const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
const int num_groups = (group == 0) ? c_dim_size : group;
// For conv, TF weights are RSCK, and TRT expects KCRS.
// For backprop, TF weights are RSKC, and TRT expects CKRS.

View File

@ -4037,15 +4037,16 @@ TEST_F(OpConverterTest, ConvertSlice) {
}
}
TEST_F(OpConverterTest, ConvertConv2D) {
TEST_P(OpConverterTest1, ConvertConv2D) {
// Get nodedef for Conv2D layer.
DataType tf_type = tf_dtype;
auto get_conv2d_nodedef =
[](std::vector<int> strides = {1, 1, 1, 1}, string padding = "SAME",
string data_format = "NCHW",
std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
[tf_type](std::vector<int> strides = {1, 1, 1, 1},
string padding = "SAME", string data_format = "NCHW",
std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT);
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type);
ops::Conv2D::Attrs attrs =
ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides,
@ -4067,7 +4068,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
// Filter is tensor, should fail.
Reset();
NodeDef node_def = get_conv2d_nodedef();
AddTestTensor("input", {1, 2, 3});
AddTestTensor("input", {3, 1, 2, 1});
AddTestTensor("weights", {3, 3, 1, 1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
@ -4077,7 +4078,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
// Filter is not 4D, should fail.
Reset();
NodeDef node_def = get_conv2d_nodedef();
AddTestTensor("input", {1, 2, 3});
AddTestTensor("input", {1, 1, 2, 3});
AddTestWeights<float>("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
@ -4088,7 +4089,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
Reset();
NodeDef node_def =
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1});
AddTestTensor("input", {1, 2, 3});
AddTestTensor("input", {1, 1, 2, 3});
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
@ -4099,7 +4100,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
Reset();
NodeDef node_def =
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1});
AddTestTensor("input", {1, 2, 3});
AddTestTensor("input", {1, 1, 2, 3});
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"Dilation rate must be 1 for batch and channel "
@ -4110,7 +4111,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
Reset();
NodeDef node_def =
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2});
AddTestTensor("input", {2, 3, 1});
AddTestTensor("input", {1, 2, 3, 1});
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"Dilation rate must be 1 for batch and channel "
@ -4121,7 +4122,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
Reset();
NodeDef node_def =
get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
AddTestTensor("input", {1, 2, 3});
AddTestTensor("input", {1, 1, 2, 3});
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
@ -4132,12 +4133,23 @@ TEST_F(OpConverterTest, ConvertConv2D) {
Reset();
NodeDef node_def =
get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
AddTestTensor("input", {1, 2, 3});
AddTestTensor("input", {1, 1, 2, 3});
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Stride must be 1 for batch and channel dimensions, at my_conv2d");
}
if (trt_mode == TrtTestMode::kDynamicShape) {
Reset();
NodeDef node_def = get_conv2d_nodedef();
// Channel dim unknown, should fail.
AddTestTensorWithTFDims("input", {-1, -1, -1, -1},
TfDataTypeToTrt(tf_type));
AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Channel dimension must be static, at my_conv2d");
}
struct TestParams {
std::vector<int> input_dims;
@ -4155,7 +4167,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
// Ok.
std::vector<TestParams> ok_params = {
// Basic
TestParams{/*input_dims=*/{1, 2, 3},
TestParams{/*input_dims=*/{1, 1, 2, 3},
/*input=*/{0, 1, 2, 3, 3, 4},
/*filter_dims=*/{1, 2, 1, 1},
/*filter=*/{-1, 1},
@ -4163,10 +4175,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
/*padding=*/"VALID",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 1},
/*expected_output_dims=*/{1, 2, 2},
/*expected_output_dims=*/{1, 1, 2, 2},
/*expected_output=*/{1, 1, 0, 1}},
// SAME padding (Asymmetric)
TestParams{/*input_dims=*/{1, 2, 3},
TestParams{/*input_dims=*/{1, 1, 2, 3},
/*input=*/{0, 1, 2, 3, 3, 4},
/*filter_dims=*/{1, 2, 1, 1},
/*filter=*/{-1, 1},
@ -4174,10 +4186,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
/*padding=*/"SAME",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 1},
/*expected_output_dims=*/{1, 2, 3},
/*expected_output_dims=*/{1, 1, 2, 3},
/*expected_output=*/{1, 1, -2, 0, 1, -4}},
// SAME padding (Symmetric)
TestParams{/*input_dims=*/{1, 2, 3},
TestParams{/*input_dims=*/{1, 1, 2, 3},
/*input=*/{0, 1, 2, 3, 3, 4},
/*filter_dims=*/{1, 3, 1, 1},
/*filter=*/{-1, 0, 1},
@ -4185,10 +4197,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
/*padding=*/"SAME",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 1},
/*expected_output_dims=*/{1, 2, 3},
/*expected_output_dims=*/{1, 1, 2, 3},
/*expected_output=*/{1, 2, -1, 3, 1, -3}},
// NHWC
TestParams{/*input_dims=*/{2, 3, 1},
TestParams{/*input_dims=*/{1, 2, 3, 1},
/*input=*/{0, 1, 2, 3, 3, 4},
/*filter_dims=*/{1, 2, 1, 1},
/*filter=*/{-1, 1},
@ -4196,10 +4208,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
/*padding=*/"VALID",
/*data_format=*/"NHWC",
/*dilations=*/{1, 1, 1, 1},
/*expected_output_dims=*/{2, 2, 1},
/*expected_output_dims=*/{1, 2, 2, 1},
/*expected_output=*/{1, 1, 0, 1}},
// Dilated
TestParams{/*input_dims=*/{1, 2, 3},
TestParams{/*input_dims=*/{1, 1, 2, 3},
/*input=*/{0, 1, 2, 3, 3, 4},
/*filter_dims=*/{1, 2, 1, 1},
/*filter=*/{-1, 1},
@ -4207,10 +4219,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
/*padding=*/"VALID",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 2},
/*expected_output_dims=*/{1, 2, 1},
/*expected_output_dims=*/{1, 1, 2, 1},
/*expected_output=*/{2, 1}},
// Strided
TestParams{/*input_dims=*/{1, 2, 4},
TestParams{/*input_dims=*/{1, 1, 2, 4},
/*input=*/{0, 1, 2, 2, 3, 4, 4, 7},
/*filter_dims=*/{1, 2, 1, 1},
/*filter=*/{-1, 1},
@ -4218,7 +4230,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
/*padding=*/"VALID",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 1},
/*expected_output_dims=*/{1, 2, 2},
/*expected_output_dims=*/{1, 1, 2, 2},
/*expected_output=*/{1, 0, 1, 3}},
};
@ -4227,23 +4239,22 @@ TEST_F(OpConverterTest, ConvertConv2D) {
NodeDef node_def =
get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding,
ok_params[i].data_format, ok_params[i].dilations);
AddTestTensor("input", ok_params[i].input_dims);
std::vector<int> partial_input_shape;
if (trt_mode == TrtTestMode::kDynamicShape) {
// The channel dim cannot have unknown size, fix that.
partial_input_shape.resize(ok_params[i].input_dims.size(), -1);
int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3;
partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id];
}
AddTestTensor("input", ok_params[i].input_dims, tf_dtype,
ok_params[i].input, partial_input_shape);
AddTestWeights<float>("weights", ok_params[i].filter_dims,
ok_params[i].filter);
RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
output.tensor()->getDimensions());
const DataVec input_data{{"input", AsTensor<float>(ok_params[i].input)}};
DataVec output_data{
{"my_conv2d",
ConstructTensor<float>(ok_params[i].expected_output.size())}};
TF_EXPECT_OK(BuildAndRun(input_data, &output_data));
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
ElementsAreArray(ok_params[i].expected_output));
TestOpConverter("my_conv2d", node_def, ok_params[i].expected_output_dims,
Status::OK(), Status::OK(),
ElementsAreArray(ok_params[i].expected_output));
}
}