Merge pull request #39204 from tfeher:trt_conv2D_dynamic_shape
PiperOrigin-RevId: 313205518 Change-Id: I06df91df11f6009740bbe466bc22afbeb29e9981
This commit is contained in:
commit
00664cef68
|
@ -2146,6 +2146,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
|||
"Stride must be 1 for batch and channel dimensions, at ",
|
||||
node_def.name());
|
||||
}
|
||||
// Channel dim must be static for DepthwiseConv2dNative since we use that
|
||||
// value for num_groups at build time.
|
||||
if (!params->use_implicit_batch && tensor->getDimensions().d[c_index] == -1) {
|
||||
return errors::InvalidArgument("Channel dimension must be static, at ",
|
||||
node_def.name());
|
||||
}
|
||||
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
|
@ -2157,11 +2163,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
|||
}
|
||||
// Dimensions of transposed tensor.
|
||||
const auto tensor_dim = tensor->getDimensions();
|
||||
const int c_dim_size = tensor_dim.d[params->use_implicit_batch ? 0 : 1];
|
||||
|
||||
// group == 0 signifies that this is a depthwise convolution, so set
|
||||
// num_groups to size of input's channel dim. For a non-depthwise conv,
|
||||
// num_groups will be 1.
|
||||
const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
|
||||
const int num_groups = (group == 0) ? c_dim_size : group;
|
||||
|
||||
// For conv, TF weights are RSCK, and TRT expects KCRS.
|
||||
// For backprop, TF weights are RSKC, and TRT expects CKRS.
|
||||
|
|
|
@ -4037,15 +4037,16 @@ TEST_F(OpConverterTest, ConvertSlice) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertConv2D) {
|
||||
TEST_P(OpConverterTest1, ConvertConv2D) {
|
||||
// Get nodedef for Conv2D layer.
|
||||
DataType tf_type = tf_dtype;
|
||||
auto get_conv2d_nodedef =
|
||||
[](std::vector<int> strides = {1, 1, 1, 1}, string padding = "SAME",
|
||||
string data_format = "NCHW",
|
||||
std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
|
||||
[tf_type](std::vector<int> strides = {1, 1, 1, 1},
|
||||
string padding = "SAME", string data_format = "NCHW",
|
||||
std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||
auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT);
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
|
||||
auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type);
|
||||
ops::Conv2D::Attrs attrs =
|
||||
ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
|
||||
auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides,
|
||||
|
@ -4067,7 +4068,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
// Filter is tensor, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = get_conv2d_nodedef();
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestTensor("input", {3, 1, 2, 1});
|
||||
AddTestTensor("weights", {3, 3, 1, 1});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
|
@ -4077,7 +4078,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
// Filter is not 4D, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = get_conv2d_nodedef();
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestTensor("input", {1, 1, 2, 3});
|
||||
AddTestWeights<float>("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
|
@ -4088,7 +4089,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
Reset();
|
||||
NodeDef node_def =
|
||||
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1});
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestTensor("input", {1, 1, 2, 3});
|
||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
|
@ -4099,7 +4100,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
Reset();
|
||||
NodeDef node_def =
|
||||
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1});
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestTensor("input", {1, 1, 2, 3});
|
||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"Dilation rate must be 1 for batch and channel "
|
||||
|
@ -4110,7 +4111,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
Reset();
|
||||
NodeDef node_def =
|
||||
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2});
|
||||
AddTestTensor("input", {2, 3, 1});
|
||||
AddTestTensor("input", {1, 2, 3, 1});
|
||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"Dilation rate must be 1 for batch and channel "
|
||||
|
@ -4121,7 +4122,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
Reset();
|
||||
NodeDef node_def =
|
||||
get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestTensor("input", {1, 1, 2, 3});
|
||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
|
@ -4132,12 +4133,23 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
Reset();
|
||||
NodeDef node_def =
|
||||
get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
|
||||
AddTestTensor("input", {1, 2, 3});
|
||||
AddTestTensor("input", {1, 1, 2, 3});
|
||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Stride must be 1 for batch and channel dimensions, at my_conv2d");
|
||||
}
|
||||
if (trt_mode == TrtTestMode::kDynamicShape) {
|
||||
Reset();
|
||||
NodeDef node_def = get_conv2d_nodedef();
|
||||
// Channel dim unknown, should fail.
|
||||
AddTestTensorWithTFDims("input", {-1, -1, -1, -1},
|
||||
TfDataTypeToTrt(tf_type));
|
||||
AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Channel dimension must be static, at my_conv2d");
|
||||
}
|
||||
|
||||
struct TestParams {
|
||||
std::vector<int> input_dims;
|
||||
|
@ -4155,7 +4167,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
// Ok.
|
||||
std::vector<TestParams> ok_params = {
|
||||
// Basic
|
||||
TestParams{/*input_dims=*/{1, 2, 3},
|
||||
TestParams{/*input_dims=*/{1, 1, 2, 3},
|
||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||
/*filter_dims=*/{1, 2, 1, 1},
|
||||
/*filter=*/{-1, 1},
|
||||
|
@ -4163,10 +4175,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
/*padding=*/"VALID",
|
||||
/*data_format=*/"NCHW",
|
||||
/*dilations=*/{1, 1, 1, 1},
|
||||
/*expected_output_dims=*/{1, 2, 2},
|
||||
/*expected_output_dims=*/{1, 1, 2, 2},
|
||||
/*expected_output=*/{1, 1, 0, 1}},
|
||||
// SAME padding (Asymmetric)
|
||||
TestParams{/*input_dims=*/{1, 2, 3},
|
||||
TestParams{/*input_dims=*/{1, 1, 2, 3},
|
||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||
/*filter_dims=*/{1, 2, 1, 1},
|
||||
/*filter=*/{-1, 1},
|
||||
|
@ -4174,10 +4186,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
/*padding=*/"SAME",
|
||||
/*data_format=*/"NCHW",
|
||||
/*dilations=*/{1, 1, 1, 1},
|
||||
/*expected_output_dims=*/{1, 2, 3},
|
||||
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||
/*expected_output=*/{1, 1, -2, 0, 1, -4}},
|
||||
// SAME padding (Symmetric)
|
||||
TestParams{/*input_dims=*/{1, 2, 3},
|
||||
TestParams{/*input_dims=*/{1, 1, 2, 3},
|
||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||
/*filter_dims=*/{1, 3, 1, 1},
|
||||
/*filter=*/{-1, 0, 1},
|
||||
|
@ -4185,10 +4197,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
/*padding=*/"SAME",
|
||||
/*data_format=*/"NCHW",
|
||||
/*dilations=*/{1, 1, 1, 1},
|
||||
/*expected_output_dims=*/{1, 2, 3},
|
||||
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||
/*expected_output=*/{1, 2, -1, 3, 1, -3}},
|
||||
// NHWC
|
||||
TestParams{/*input_dims=*/{2, 3, 1},
|
||||
TestParams{/*input_dims=*/{1, 2, 3, 1},
|
||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||
/*filter_dims=*/{1, 2, 1, 1},
|
||||
/*filter=*/{-1, 1},
|
||||
|
@ -4196,10 +4208,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
/*padding=*/"VALID",
|
||||
/*data_format=*/"NHWC",
|
||||
/*dilations=*/{1, 1, 1, 1},
|
||||
/*expected_output_dims=*/{2, 2, 1},
|
||||
/*expected_output_dims=*/{1, 2, 2, 1},
|
||||
/*expected_output=*/{1, 1, 0, 1}},
|
||||
// Dilated
|
||||
TestParams{/*input_dims=*/{1, 2, 3},
|
||||
TestParams{/*input_dims=*/{1, 1, 2, 3},
|
||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||
/*filter_dims=*/{1, 2, 1, 1},
|
||||
/*filter=*/{-1, 1},
|
||||
|
@ -4207,10 +4219,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
/*padding=*/"VALID",
|
||||
/*data_format=*/"NCHW",
|
||||
/*dilations=*/{1, 1, 1, 2},
|
||||
/*expected_output_dims=*/{1, 2, 1},
|
||||
/*expected_output_dims=*/{1, 1, 2, 1},
|
||||
/*expected_output=*/{2, 1}},
|
||||
// Strided
|
||||
TestParams{/*input_dims=*/{1, 2, 4},
|
||||
TestParams{/*input_dims=*/{1, 1, 2, 4},
|
||||
/*input=*/{0, 1, 2, 2, 3, 4, 4, 7},
|
||||
/*filter_dims=*/{1, 2, 1, 1},
|
||||
/*filter=*/{-1, 1},
|
||||
|
@ -4218,7 +4230,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
/*padding=*/"VALID",
|
||||
/*data_format=*/"NCHW",
|
||||
/*dilations=*/{1, 1, 1, 1},
|
||||
/*expected_output_dims=*/{1, 2, 2},
|
||||
/*expected_output_dims=*/{1, 1, 2, 2},
|
||||
/*expected_output=*/{1, 0, 1, 3}},
|
||||
};
|
||||
|
||||
|
@ -4227,23 +4239,22 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||
NodeDef node_def =
|
||||
get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding,
|
||||
ok_params[i].data_format, ok_params[i].dilations);
|
||||
AddTestTensor("input", ok_params[i].input_dims);
|
||||
std::vector<int> partial_input_shape;
|
||||
if (trt_mode == TrtTestMode::kDynamicShape) {
|
||||
// The channel dim cannot have unknown size, fix that.
|
||||
partial_input_shape.resize(ok_params[i].input_dims.size(), -1);
|
||||
int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3;
|
||||
partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id];
|
||||
}
|
||||
|
||||
AddTestTensor("input", ok_params[i].input_dims, tf_dtype,
|
||||
ok_params[i].input, partial_input_shape);
|
||||
AddTestWeights<float>("weights", ok_params[i].filter_dims,
|
||||
ok_params[i].filter);
|
||||
RunValidationAndConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
|
||||
output.tensor()->getDimensions());
|
||||
|
||||
const DataVec input_data{{"input", AsTensor<float>(ok_params[i].input)}};
|
||||
DataVec output_data{
|
||||
{"my_conv2d",
|
||||
ConstructTensor<float>(ok_params[i].expected_output.size())}};
|
||||
TF_EXPECT_OK(BuildAndRun(input_data, &output_data));
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
|
||||
ElementsAreArray(ok_params[i].expected_output));
|
||||
TestOpConverter("my_conv2d", node_def, ok_params[i].expected_output_dims,
|
||||
Status::OK(), Status::OK(),
|
||||
ElementsAreArray(ok_params[i].expected_output));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue