Enable Conv2D op conversion in dynamic shape mode
This commit is contained in:
parent
d5f62ec58a
commit
50605edea5
|
@ -2146,6 +2146,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||||
"Stride must be 1 for batch and channel dimensions, at ",
|
"Stride must be 1 for batch and channel dimensions, at ",
|
||||||
node_def.name());
|
node_def.name());
|
||||||
}
|
}
|
||||||
|
// Channel dim must be static for DepthwiseConv2dNative since we use that
|
||||||
|
// value for num_groups at build time.
|
||||||
|
if (!params->use_implicit_batch && tensor->getDimensions().d[c_index] == -1) {
|
||||||
|
return errors::InvalidArgument("Channel dimension must be static, at ",
|
||||||
|
node_def.name());
|
||||||
|
}
|
||||||
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
|
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
|
||||||
if (params->validation_only) return Status::OK();
|
if (params->validation_only) return Status::OK();
|
||||||
|
|
||||||
|
@ -2157,11 +2163,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||||
}
|
}
|
||||||
// Dimensions of transposed tensor.
|
// Dimensions of transposed tensor.
|
||||||
const auto tensor_dim = tensor->getDimensions();
|
const auto tensor_dim = tensor->getDimensions();
|
||||||
|
const int c_dim_size = tensor_dim.d[params->use_implicit_batch ? 0 : 1];
|
||||||
|
|
||||||
// group == 0 signifies that this is a depthwise convolution, so set
|
// group == 0 signifies that this is a depthwise convolution, so set
|
||||||
// num_groups to size of input's channel dim. For a non-depthwise conv,
|
// num_groups to size of input's channel dim. For a non-depthwise conv,
|
||||||
// num_groups will be 1.
|
// num_groups will be 1.
|
||||||
const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
|
const int num_groups = (group == 0) ? c_dim_size : group;
|
||||||
|
|
||||||
// For conv, TF weights are RSCK, and TRT expects KCRS.
|
// For conv, TF weights are RSCK, and TRT expects KCRS.
|
||||||
// For backprop, TF weights are RSKC, and TRT expects CKRS.
|
// For backprop, TF weights are RSKC, and TRT expects CKRS.
|
||||||
|
|
|
@ -4037,15 +4037,16 @@ TEST_F(OpConverterTest, ConvertSlice) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(OpConverterTest, ConvertConv2D) {
|
TEST_P(OpConverterTest1, ConvertConv2D) {
|
||||||
// Get nodedef for Conv2D layer.
|
// Get nodedef for Conv2D layer.
|
||||||
|
DataType tf_type = tf_dtype;
|
||||||
auto get_conv2d_nodedef =
|
auto get_conv2d_nodedef =
|
||||||
[](std::vector<int> strides = {1, 1, 1, 1}, string padding = "SAME",
|
[tf_type](std::vector<int> strides = {1, 1, 1, 1},
|
||||||
string data_format = "NCHW",
|
string padding = "SAME", string data_format = "NCHW",
|
||||||
std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
|
std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
|
||||||
Scope s = Scope::NewRootScope();
|
Scope s = Scope::NewRootScope();
|
||||||
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
auto input = ops::Placeholder(s.WithOpName("input"), tf_type);
|
||||||
auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT);
|
auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type);
|
||||||
ops::Conv2D::Attrs attrs =
|
ops::Conv2D::Attrs attrs =
|
||||||
ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
|
ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
|
||||||
auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides,
|
auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides,
|
||||||
|
@ -4067,7 +4068,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
// Filter is tensor, should fail.
|
// Filter is tensor, should fail.
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def = get_conv2d_nodedef();
|
NodeDef node_def = get_conv2d_nodedef();
|
||||||
AddTestTensor("input", {1, 2, 3});
|
AddTestTensor("input", {3, 1, 2, 1});
|
||||||
AddTestTensor("weights", {3, 3, 1, 1});
|
AddTestTensor("weights", {3, 3, 1, 1});
|
||||||
RunValidationAndConversion(
|
RunValidationAndConversion(
|
||||||
node_def, error::UNIMPLEMENTED,
|
node_def, error::UNIMPLEMENTED,
|
||||||
|
@ -4077,7 +4078,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
// Filter is not 4D, should fail.
|
// Filter is not 4D, should fail.
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def = get_conv2d_nodedef();
|
NodeDef node_def = get_conv2d_nodedef();
|
||||||
AddTestTensor("input", {1, 2, 3});
|
AddTestTensor("input", {1, 1, 2, 3});
|
||||||
AddTestWeights<float>("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
AddTestWeights<float>("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
RunValidationAndConversion(
|
RunValidationAndConversion(
|
||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
@ -4088,7 +4089,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def =
|
NodeDef node_def =
|
||||||
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1});
|
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1});
|
||||||
AddTestTensor("input", {1, 2, 3});
|
AddTestTensor("input", {1, 1, 2, 3});
|
||||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
RunValidationAndConversion(
|
RunValidationAndConversion(
|
||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
@ -4099,7 +4100,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def =
|
NodeDef node_def =
|
||||||
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1});
|
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1});
|
||||||
AddTestTensor("input", {1, 2, 3});
|
AddTestTensor("input", {1, 1, 2, 3});
|
||||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||||
"Dilation rate must be 1 for batch and channel "
|
"Dilation rate must be 1 for batch and channel "
|
||||||
|
@ -4110,7 +4111,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def =
|
NodeDef node_def =
|
||||||
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2});
|
get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2});
|
||||||
AddTestTensor("input", {2, 3, 1});
|
AddTestTensor("input", {1, 2, 3, 1});
|
||||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||||
"Dilation rate must be 1 for batch and channel "
|
"Dilation rate must be 1 for batch and channel "
|
||||||
|
@ -4121,7 +4122,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def =
|
NodeDef node_def =
|
||||||
get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
|
get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
|
||||||
AddTestTensor("input", {1, 2, 3});
|
AddTestTensor("input", {1, 1, 2, 3});
|
||||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
RunValidationAndConversion(
|
RunValidationAndConversion(
|
||||||
node_def, error::INVALID_ARGUMENT,
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
@ -4132,12 +4133,23 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
Reset();
|
Reset();
|
||||||
NodeDef node_def =
|
NodeDef node_def =
|
||||||
get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
|
get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1});
|
||||||
AddTestTensor("input", {1, 2, 3});
|
AddTestTensor("input", {1, 1, 2, 3});
|
||||||
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||||
RunValidationAndConversion(
|
RunValidationAndConversion(
|
||||||
node_def, error::UNIMPLEMENTED,
|
node_def, error::UNIMPLEMENTED,
|
||||||
"Stride must be 1 for batch and channel dimensions, at my_conv2d");
|
"Stride must be 1 for batch and channel dimensions, at my_conv2d");
|
||||||
}
|
}
|
||||||
|
if (trt_mode == TrtTestMode::kDynamicShape) {
|
||||||
|
Reset();
|
||||||
|
NodeDef node_def = get_conv2d_nodedef();
|
||||||
|
// Channel dim unknown, should fail.
|
||||||
|
AddTestTensorWithTFDims("input", {-1, -1, -1, -1},
|
||||||
|
TfDataTypeToTrt(tf_type));
|
||||||
|
AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1});
|
||||||
|
RunValidationAndConversion(
|
||||||
|
node_def, error::INVALID_ARGUMENT,
|
||||||
|
"Channel dimension must be static, at my_conv2d");
|
||||||
|
}
|
||||||
|
|
||||||
struct TestParams {
|
struct TestParams {
|
||||||
std::vector<int> input_dims;
|
std::vector<int> input_dims;
|
||||||
|
@ -4155,7 +4167,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
// Ok.
|
// Ok.
|
||||||
std::vector<TestParams> ok_params = {
|
std::vector<TestParams> ok_params = {
|
||||||
// Basic
|
// Basic
|
||||||
TestParams{/*input_dims=*/{1, 2, 3},
|
TestParams{/*input_dims=*/{1, 1, 2, 3},
|
||||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||||
/*filter_dims=*/{1, 2, 1, 1},
|
/*filter_dims=*/{1, 2, 1, 1},
|
||||||
/*filter=*/{-1, 1},
|
/*filter=*/{-1, 1},
|
||||||
|
@ -4163,10 +4175,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
/*padding=*/"VALID",
|
/*padding=*/"VALID",
|
||||||
/*data_format=*/"NCHW",
|
/*data_format=*/"NCHW",
|
||||||
/*dilations=*/{1, 1, 1, 1},
|
/*dilations=*/{1, 1, 1, 1},
|
||||||
/*expected_output_dims=*/{1, 2, 2},
|
/*expected_output_dims=*/{1, 1, 2, 2},
|
||||||
/*expected_output=*/{1, 1, 0, 1}},
|
/*expected_output=*/{1, 1, 0, 1}},
|
||||||
// SAME padding (Asymmetric)
|
// SAME padding (Asymmetric)
|
||||||
TestParams{/*input_dims=*/{1, 2, 3},
|
TestParams{/*input_dims=*/{1, 1, 2, 3},
|
||||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||||
/*filter_dims=*/{1, 2, 1, 1},
|
/*filter_dims=*/{1, 2, 1, 1},
|
||||||
/*filter=*/{-1, 1},
|
/*filter=*/{-1, 1},
|
||||||
|
@ -4174,10 +4186,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
/*padding=*/"SAME",
|
/*padding=*/"SAME",
|
||||||
/*data_format=*/"NCHW",
|
/*data_format=*/"NCHW",
|
||||||
/*dilations=*/{1, 1, 1, 1},
|
/*dilations=*/{1, 1, 1, 1},
|
||||||
/*expected_output_dims=*/{1, 2, 3},
|
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||||
/*expected_output=*/{1, 1, -2, 0, 1, -4}},
|
/*expected_output=*/{1, 1, -2, 0, 1, -4}},
|
||||||
// SAME padding (Symmetric)
|
// SAME padding (Symmetric)
|
||||||
TestParams{/*input_dims=*/{1, 2, 3},
|
TestParams{/*input_dims=*/{1, 1, 2, 3},
|
||||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||||
/*filter_dims=*/{1, 3, 1, 1},
|
/*filter_dims=*/{1, 3, 1, 1},
|
||||||
/*filter=*/{-1, 0, 1},
|
/*filter=*/{-1, 0, 1},
|
||||||
|
@ -4185,10 +4197,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
/*padding=*/"SAME",
|
/*padding=*/"SAME",
|
||||||
/*data_format=*/"NCHW",
|
/*data_format=*/"NCHW",
|
||||||
/*dilations=*/{1, 1, 1, 1},
|
/*dilations=*/{1, 1, 1, 1},
|
||||||
/*expected_output_dims=*/{1, 2, 3},
|
/*expected_output_dims=*/{1, 1, 2, 3},
|
||||||
/*expected_output=*/{1, 2, -1, 3, 1, -3}},
|
/*expected_output=*/{1, 2, -1, 3, 1, -3}},
|
||||||
// NHWC
|
// NHWC
|
||||||
TestParams{/*input_dims=*/{2, 3, 1},
|
TestParams{/*input_dims=*/{1, 2, 3, 1},
|
||||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||||
/*filter_dims=*/{1, 2, 1, 1},
|
/*filter_dims=*/{1, 2, 1, 1},
|
||||||
/*filter=*/{-1, 1},
|
/*filter=*/{-1, 1},
|
||||||
|
@ -4196,10 +4208,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
/*padding=*/"VALID",
|
/*padding=*/"VALID",
|
||||||
/*data_format=*/"NHWC",
|
/*data_format=*/"NHWC",
|
||||||
/*dilations=*/{1, 1, 1, 1},
|
/*dilations=*/{1, 1, 1, 1},
|
||||||
/*expected_output_dims=*/{2, 2, 1},
|
/*expected_output_dims=*/{1, 2, 2, 1},
|
||||||
/*expected_output=*/{1, 1, 0, 1}},
|
/*expected_output=*/{1, 1, 0, 1}},
|
||||||
// Dilated
|
// Dilated
|
||||||
TestParams{/*input_dims=*/{1, 2, 3},
|
TestParams{/*input_dims=*/{1, 1, 2, 3},
|
||||||
/*input=*/{0, 1, 2, 3, 3, 4},
|
/*input=*/{0, 1, 2, 3, 3, 4},
|
||||||
/*filter_dims=*/{1, 2, 1, 1},
|
/*filter_dims=*/{1, 2, 1, 1},
|
||||||
/*filter=*/{-1, 1},
|
/*filter=*/{-1, 1},
|
||||||
|
@ -4207,10 +4219,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
/*padding=*/"VALID",
|
/*padding=*/"VALID",
|
||||||
/*data_format=*/"NCHW",
|
/*data_format=*/"NCHW",
|
||||||
/*dilations=*/{1, 1, 1, 2},
|
/*dilations=*/{1, 1, 1, 2},
|
||||||
/*expected_output_dims=*/{1, 2, 1},
|
/*expected_output_dims=*/{1, 1, 2, 1},
|
||||||
/*expected_output=*/{2, 1}},
|
/*expected_output=*/{2, 1}},
|
||||||
// Strided
|
// Strided
|
||||||
TestParams{/*input_dims=*/{1, 2, 4},
|
TestParams{/*input_dims=*/{1, 1, 2, 4},
|
||||||
/*input=*/{0, 1, 2, 2, 3, 4, 4, 7},
|
/*input=*/{0, 1, 2, 2, 3, 4, 4, 7},
|
||||||
/*filter_dims=*/{1, 2, 1, 1},
|
/*filter_dims=*/{1, 2, 1, 1},
|
||||||
/*filter=*/{-1, 1},
|
/*filter=*/{-1, 1},
|
||||||
|
@ -4218,7 +4230,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
/*padding=*/"VALID",
|
/*padding=*/"VALID",
|
||||||
/*data_format=*/"NCHW",
|
/*data_format=*/"NCHW",
|
||||||
/*dilations=*/{1, 1, 1, 1},
|
/*dilations=*/{1, 1, 1, 1},
|
||||||
/*expected_output_dims=*/{1, 2, 2},
|
/*expected_output_dims=*/{1, 1, 2, 2},
|
||||||
/*expected_output=*/{1, 0, 1, 3}},
|
/*expected_output=*/{1, 0, 1, 3}},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -4227,23 +4239,22 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
||||||
NodeDef node_def =
|
NodeDef node_def =
|
||||||
get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding,
|
get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding,
|
||||||
ok_params[i].data_format, ok_params[i].dilations);
|
ok_params[i].data_format, ok_params[i].dilations);
|
||||||
AddTestTensor("input", ok_params[i].input_dims);
|
std::vector<int> partial_input_shape;
|
||||||
|
if (trt_mode == TrtTestMode::kDynamicShape) {
|
||||||
|
// The channel dim cannot have unknown size, fix that.
|
||||||
|
partial_input_shape.resize(ok_params[i].input_dims.size(), -1);
|
||||||
|
int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3;
|
||||||
|
partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id];
|
||||||
|
}
|
||||||
|
|
||||||
|
AddTestTensor("input", ok_params[i].input_dims, tf_dtype,
|
||||||
|
ok_params[i].input, partial_input_shape);
|
||||||
AddTestWeights<float>("weights", ok_params[i].filter_dims,
|
AddTestWeights<float>("weights", ok_params[i].filter_dims,
|
||||||
ok_params[i].filter);
|
ok_params[i].filter);
|
||||||
RunValidationAndConversion(node_def);
|
|
||||||
TRT_TensorOrWeights output;
|
|
||||||
TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output));
|
|
||||||
ASSERT_TRUE(output.is_tensor());
|
|
||||||
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
|
|
||||||
output.tensor()->getDimensions());
|
|
||||||
|
|
||||||
const DataVec input_data{{"input", AsTensor<float>(ok_params[i].input)}};
|
TestOpConverter("my_conv2d", node_def, ok_params[i].expected_output_dims,
|
||||||
DataVec output_data{
|
Status::OK(), Status::OK(),
|
||||||
{"my_conv2d",
|
ElementsAreArray(ok_params[i].expected_output));
|
||||||
ConstructTensor<float>(ok_params[i].expected_output.size())}};
|
|
||||||
TF_EXPECT_OK(BuildAndRun(input_data, &output_data));
|
|
||||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
|
|
||||||
ElementsAreArray(ok_params[i].expected_output));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue