Tidy up Conv2D converter
This commit is contained in:
parent
1051c37705
commit
db95d28671
@ -491,6 +491,7 @@ cuda_py_tests(
|
|||||||
"test/binary_tensor_weight_broadcast_test.py",
|
"test/binary_tensor_weight_broadcast_test.py",
|
||||||
"test/concatenation_test.py",
|
"test/concatenation_test.py",
|
||||||
"test/const_broadcast_test.py",
|
"test/const_broadcast_test.py",
|
||||||
|
"test/conv2d_test.py",
|
||||||
"test/identity_output_test.py",
|
"test/identity_output_test.py",
|
||||||
"test/manual_test.py",
|
"test/manual_test.py",
|
||||||
"test/memory_alignment_test.py",
|
"test/memory_alignment_test.py",
|
||||||
|
@ -62,14 +62,14 @@ limitations under the License.
|
|||||||
|
|
||||||
#define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \
|
#define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \
|
||||||
do { \
|
do { \
|
||||||
if ((status) == false) { \
|
if ((status) == false) { \
|
||||||
TFTRT_INTERNAL_ERROR_AT_NODE(node); \
|
TFTRT_INTERNAL_ERROR_AT_NODE(node); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
|
#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
|
||||||
do { \
|
do { \
|
||||||
if ((ptr) == nullptr) { \
|
if ((ptr) == nullptr) { \
|
||||||
TFTRT_INTERNAL_ERROR_AT_NODE(node); \
|
TFTRT_INTERNAL_ERROR_AT_NODE(node); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
@ -1567,63 +1567,19 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
|
|||||||
node_def.name());
|
node_def.name());
|
||||||
}
|
}
|
||||||
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
|
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
|
||||||
VLOG(2) << "weight shape: " << weights_rsck.DebugString();
|
|
||||||
if (weights_rsck.shape_.nbDims != 4) {
|
if (weights_rsck.shape_.nbDims != 4) {
|
||||||
return tensorflow::errors::Internal(
|
return tensorflow::errors::Internal(
|
||||||
"Conv2D expects kernel of dimension 4, at: " + node_def.name());
|
"Conv2D expects kernel of dimension 4, at: " + node_def.name());
|
||||||
}
|
}
|
||||||
if (params->validation_only) return tensorflow::Status::OK();
|
|
||||||
|
|
||||||
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
|
||||||
TFAttrs attrs(node_def);
|
TFAttrs attrs(node_def);
|
||||||
|
|
||||||
int c_index = 1;
|
|
||||||
int h_index = 2;
|
|
||||||
int w_index = 3;
|
|
||||||
auto data_format = attrs.get<string>("data_format");
|
auto data_format = attrs.get<string>("data_format");
|
||||||
if (data_format == "NHWC") {
|
int c_index = (data_format == "NHWC") ? 3 : 1;
|
||||||
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
int h_index = (data_format == "NHWC") ? 1 : 2;
|
||||||
const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 1, 2}, &tensor));
|
int w_index = (data_format == "NHWC") ? 2 : 3;
|
||||||
c_index = 3;
|
|
||||||
h_index = 1;
|
|
||||||
w_index = 2;
|
|
||||||
// TODO(jie): transpose it
|
|
||||||
}
|
|
||||||
|
|
||||||
// tensor after transpose (NCHW)
|
|
||||||
const auto tensor_dim = tensor->getDimensions();
|
|
||||||
|
|
||||||
int num_groups = group;
|
|
||||||
if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution
|
|
||||||
VLOG(2) << "groups count: " << num_groups;
|
|
||||||
|
|
||||||
if (params->converter->precision_mode() == FP16MODE) {
|
|
||||||
weights_rsck =
|
|
||||||
ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights());
|
|
||||||
}
|
|
||||||
|
|
||||||
TRT_ShapedWeights weights =
|
|
||||||
params->weight_store->GetTempWeights(weights_rsck);
|
|
||||||
ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
|
|
||||||
TRT_ShapedWeights biases(weights.type_);
|
|
||||||
const int noutput = weights.shape_.d[0] * num_groups;
|
|
||||||
nvinfer1::DimsHW kernel_size;
|
|
||||||
kernel_size.h() = weights.shape_.d[2];
|
|
||||||
kernel_size.w() = weights.shape_.d[3];
|
|
||||||
VLOG(2) << "RSCK: " << weights.DebugString();
|
|
||||||
VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w();
|
|
||||||
|
|
||||||
// TODO(jie): stride. (NHWC/NCHW)
|
|
||||||
const auto tf_stride = attrs.get<std::vector<int>>("strides");
|
|
||||||
VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index;
|
|
||||||
VLOG(2) << "stride: " << tf_stride[0] << tf_stride[1] << tf_stride[2]
|
|
||||||
<< tf_stride[3];
|
|
||||||
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
|
|
||||||
|
|
||||||
auto tf_dilations = attrs.get<std::vector<int>>("dilations");
|
auto tf_dilations = attrs.get<std::vector<int>>("dilations");
|
||||||
if ((int)tf_dilations.size() != 4) {
|
if (tf_dilations.size() != 4) {
|
||||||
return tensorflow::errors::InvalidArgument(
|
return tensorflow::errors::InvalidArgument(
|
||||||
"Convolution dilations field must specify 4 dimensions " +
|
"Convolution dilations field must specify 4 dimensions, at ",
|
||||||
node_def.name());
|
node_def.name());
|
||||||
}
|
}
|
||||||
if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
|
if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) {
|
||||||
@ -1632,13 +1588,47 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
|
|||||||
node_def.name());
|
node_def.name());
|
||||||
}
|
}
|
||||||
nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]);
|
nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]);
|
||||||
|
const auto tf_stride = attrs.get<std::vector<int>>("strides");
|
||||||
|
if (tf_stride[0] != 1 || tf_stride[c_index] != 1) {
|
||||||
|
return tensorflow::errors::Unimplemented(
|
||||||
|
"Stride must be 1 for batch and channel dimensions, at ",
|
||||||
|
node_def.name());
|
||||||
|
}
|
||||||
|
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
|
||||||
|
if (params->validation_only) return tensorflow::Status::OK();
|
||||||
|
|
||||||
|
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
|
||||||
|
|
||||||
|
// Transpose to NCHW (NCHW is required for IConvLayer).
|
||||||
|
const bool need_transpose = (data_format == "NHWC");
|
||||||
|
if (need_transpose) {
|
||||||
|
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
||||||
|
const_cast<nvinfer1::ITensor*>(tensor), {0, 3, 1, 2}, &tensor));
|
||||||
|
}
|
||||||
|
// Dimensions of transposed tensor.
|
||||||
|
const auto tensor_dim = tensor->getDimensions();
|
||||||
|
|
||||||
|
// This is a depthwise convolution when num_groups is 0. Otherwise, num_groups
|
||||||
|
// will be 1.
|
||||||
|
int num_groups = group;
|
||||||
|
if (num_groups == 0) num_groups = tensor_dim.d[0];
|
||||||
|
|
||||||
|
if (params->converter->precision_mode() == FP16MODE) {
|
||||||
|
weights_rsck =
|
||||||
|
ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights());
|
||||||
|
}
|
||||||
|
TRT_ShapedWeights weights =
|
||||||
|
params->weight_store->GetTempWeights(weights_rsck);
|
||||||
|
ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
|
||||||
|
TRT_ShapedWeights biases(weights.type_);
|
||||||
|
const int noutput = weights.shape_.d[0] * num_groups;
|
||||||
|
nvinfer1::DimsHW kernel_size;
|
||||||
|
kernel_size.h() = weights.shape_.d[2];
|
||||||
|
kernel_size.w() = weights.shape_.d[3];
|
||||||
|
|
||||||
|
// Add padding.
|
||||||
std::vector<std::pair<int, int>> padding;
|
std::vector<std::pair<int, int>> padding;
|
||||||
// TODO(jie): padding.
|
|
||||||
if (attrs.get<string>("padding") == "SAME") {
|
if (attrs.get<string>("padding") == "SAME") {
|
||||||
// This is NCHW tensor with no batch dimension.
|
|
||||||
// 1 -> h
|
|
||||||
// 2 -> w
|
|
||||||
nvinfer1::DimsHW effective_kernel_size = kernel_size;
|
nvinfer1::DimsHW effective_kernel_size = kernel_size;
|
||||||
effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1);
|
effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1);
|
||||||
effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1);
|
effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1);
|
||||||
@ -1648,13 +1638,9 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
|
|||||||
} else {
|
} else {
|
||||||
padding = {{0, 0}, {0, 0}};
|
padding = {{0, 0}, {0, 0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (padding[0].first != padding[0].second ||
|
if (padding[0].first != padding[0].second ||
|
||||||
padding[1].first != padding[1].second) {
|
padding[1].first != padding[1].second) {
|
||||||
// TODO(jie): handle asymmetric padding
|
// Handle asymmetric padding.
|
||||||
VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
|
|
||||||
<< padding[1].first << padding[1].second;
|
|
||||||
VLOG(2) << "TENSOR before: " << DebugString(tensor->getDimensions());
|
|
||||||
auto pad_layer = params->converter->network()->addPadding(
|
auto pad_layer = params->converter->network()->addPadding(
|
||||||
*const_cast<nvinfer1::ITensor*>(tensor),
|
*const_cast<nvinfer1::ITensor*>(tensor),
|
||||||
nvinfer1::DimsHW(padding[0].first, padding[1].first),
|
nvinfer1::DimsHW(padding[0].first, padding[1].first),
|
||||||
@ -1664,25 +1650,23 @@ tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
|
|||||||
const_cast<nvinfer1::ITensor*>(tensor), pad_layer->getOutput(0));
|
const_cast<nvinfer1::ITensor*>(tensor), pad_layer->getOutput(0));
|
||||||
padding = {{0, 0}, {0, 0}};
|
padding = {{0, 0}, {0, 0}};
|
||||||
tensor = pad_layer->getOutput(0);
|
tensor = pad_layer->getOutput(0);
|
||||||
VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add convolution.
|
||||||
nvinfer1::IConvolutionLayer* layer =
|
nvinfer1::IConvolutionLayer* layer =
|
||||||
params->converter->network()->addConvolution(
|
params->converter->network()->addConvolution(
|
||||||
*const_cast<nvinfer1::ITensor*>(tensor), noutput, kernel_size,
|
*const_cast<nvinfer1::ITensor*>(tensor), noutput, kernel_size,
|
||||||
weights.GetTrtWeights(), biases.GetTrtWeights());
|
weights.GetTrtWeights(), biases.GetTrtWeights());
|
||||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||||
|
|
||||||
layer->setStride(stride);
|
layer->setStride(stride);
|
||||||
layer->setPadding({padding[0].first, padding[1].first});
|
layer->setPadding({padding[0].first, padding[1].first});
|
||||||
layer->setName(node_def.name().c_str());
|
layer->setName(node_def.name().c_str());
|
||||||
layer->setNbGroups(num_groups);
|
layer->setNbGroups(num_groups);
|
||||||
layer->setDilation(dilation);
|
layer->setDilation(dilation);
|
||||||
const nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
const nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
||||||
VLOG(2) << "TENSOR out: " << DebugString(output_tensor->getDimensions());
|
|
||||||
VLOG(2) << "data_format: " << data_format;
|
// Restore transpose.
|
||||||
if (data_format == "NHWC") {
|
if (need_transpose) {
|
||||||
// TODO(jie): transpose it back!
|
|
||||||
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
|
||||||
const_cast<nvinfer1::ITensor*>(output_tensor), {0, 2, 3, 1},
|
const_cast<nvinfer1::ITensor*>(output_tensor), {0, 2, 3, 1},
|
||||||
&output_tensor));
|
&output_tensor));
|
||||||
|
@ -88,7 +88,7 @@ class Conv2DNCHWTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
|
|
||||||
def ExpectedEnginesToBuild(self, run_params):
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
"""Return the expected engines to build."""
|
"""Return the expected engines to build."""
|
||||||
return ["my_trt_op_0"]
|
return ["TRTEngineOp_0"]
|
||||||
|
|
||||||
|
|
||||||
class Conv2DStridedNCHWTest(trt_test.TfTrtIntegrationTestBase):
|
class Conv2DStridedNCHWTest(trt_test.TfTrtIntegrationTestBase):
|
||||||
@ -128,7 +128,7 @@ class Conv2DStridedNCHWTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
|
|
||||||
def ExpectedEnginesToBuild(self, run_params):
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
"""Return the expected engines to build."""
|
"""Return the expected engines to build."""
|
||||||
return ["my_trt_op_0"]
|
return ["TRTEngineOp_0"]
|
||||||
|
|
||||||
|
|
||||||
class Conv2DNHWCTest(trt_test.TfTrtIntegrationTestBase):
|
class Conv2DNHWCTest(trt_test.TfTrtIntegrationTestBase):
|
||||||
@ -165,7 +165,7 @@ class Conv2DNHWCTest(trt_test.TfTrtIntegrationTestBase):
|
|||||||
|
|
||||||
def ExpectedEnginesToBuild(self, run_params):
|
def ExpectedEnginesToBuild(self, run_params):
|
||||||
"""Return the expected engines to build."""
|
"""Return the expected engines to build."""
|
||||||
return ["my_trt_op_0"]
|
return ["TRTEngineOp_0"]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user