Support MediaPipe Convolution2DTransposeBias in XNNPACK delegate

PiperOrigin-RevId: 308960507
Change-Id: Ie3bfad63a2ff0934ba35f9f696be5f97ca5b2d9a
This commit is contained in:
Marat Dukhan 2020-04-28 21:56:34 -07:00 committed by TensorFlower Gardener
parent 2bcda95578
commit 97e9e00993
2 changed files with 125 additions and 4 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <string>
@ -397,6 +398,27 @@ class Subgraph {
return kTfLiteOk;
}
static TfLiteStatus CheckTransposedConvolutionParams(
TfLiteContext* context, const TfLiteTransposeConvParams* params,
int node_index) {
if (params->stride_width <= 0) {
if (context != nullptr) {
TF_LITE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
params->stride_width, node_index);
}
return kTfLiteError;
}
if (params->stride_height <= 0) {
if (context != nullptr) {
TF_LITE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
params->stride_height, node_index);
}
return kTfLiteError;
}
return kTfLiteOk;
}
static TfLiteStatus CheckFullyConnectedParams(
TfLiteContext* context, const TfLiteFullyConnectedParams* params,
int node_index) {
@ -672,6 +694,18 @@ class Subgraph {
context->tensors, softmax_params,
xnnpack_tensors);
}
case kTfLiteBuiltinCustom: {
if (strcmp(registration->custom_name, "Convolution2DTransposeBias") ==
0) {
const TfLiteTransposeConvParams* deconv_params =
static_cast<const TfLiteTransposeConvParams*>(
node->custom_initial_data);
return VisitMediaPipeDeconvolutionNode(
subgraph, context, node_index, node, context->tensors,
deconv_params, xnnpack_tensors);
}
return kTfLiteError;
}
default:
return kTfLiteError;
}
@ -1236,6 +1270,93 @@ class Subgraph {
return kTfLiteOk;
}
static TfLiteStatus VisitMediaPipeDeconvolutionNode(
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
TfLiteNode* node, const TfLiteTensor* tensors,
const TfLiteTransposeConvParams* deconv_params,
const std::vector<uint32_t>& xnnpack_tensors) {
TF_LITE_ENSURE_STATUS(
CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
logging_context, input_tensor, node->inputs->data[0], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
node->inputs->data[0]));
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
logging_context, input_tensor, node->inputs->data[0], node_index));
const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
logging_context, filter_tensor, node->inputs->data[1], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
node->inputs->data[1]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, filter_tensor, node->inputs->data[1], node_index));
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
logging_context, filter_tensor, node->inputs->data[2], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
node->inputs->data[2]));
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
logging_context, bias_tensor, node->inputs->data[2], node_index));
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
logging_context, output_tensor, node->outputs->data[0], node_index));
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
node->outputs->data[0]));
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
logging_context, output_tensor, node->outputs->data[0], node_index));
const int output_channels = filter_tensor.dims->data[0];
const int kernel_height = filter_tensor.dims->data[1];
const int kernel_width = filter_tensor.dims->data[2];
const int input_channels = filter_tensor.dims->data[3];
TF_LITE_ENSURE_STATUS(CheckTransposedConvolutionParams(
logging_context, deconv_params, node_index));
uint32_t flags = 0;
TF_LITE_ENSURE_STATUS(CalculatePadding(
logging_context, deconv_params->padding, &flags, node_index));
if (subgraph != nullptr) {
const xnn_status status = xnn_define_deconvolution_2d(
subgraph,
/*padding_top=*/0,
/*padding_right=*/0,
/*padding_bottom=*/0,
/*padding_left=*/0,
/*adjustment_height=*/0,
/*adjustment_width=*/0, static_cast<uint32_t>(kernel_height),
static_cast<uint32_t>(kernel_width),
static_cast<uint32_t>(deconv_params->stride_height),
static_cast<uint32_t>(deconv_params->stride_width),
/*dilation_height=*/1,
/*dilation_width=*/1,
/*groups=*/1,
/*group_input_channels=*/input_channels,
/*group_output_channels=*/output_channels,
/*output_min=*/-std::numeric_limits<float>::infinity(),
/*output_max=*/+std::numeric_limits<float>::infinity(),
/*input_id=*/xnnpack_tensors[node->inputs->data[0]],
/*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
/*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
/*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
if (status != xnn_status_success) {
TF_LITE_KERNEL_LOG(
logging_context,
"failed to delegate Convolution2DTransposeBias node #%d",
node_index);
return kTfLiteError;
}
}
return kTfLiteOk;
}
static TfLiteStatus VisitMulNode(
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
TfLiteNode* node, const TfLiteTensor* tensors,

View File

@ -156,11 +156,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "XNNPACK",
sha256 = "f6eb0f1759eca187d922a72a3a12dfe1593bd09783aa4b67bee70630985eb832",
strip_prefix = "XNNPACK-38c07ec51af0cbacb255922fb6219df80c06df59",
sha256 = "41a0a396a5a9cb2171c1c7f6d7689316beaa6f638663161fc7f86450eba33070",
strip_prefix = "XNNPACK-5871703602c459b98c12be301c01255ae68a45e2",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/38c07ec51af0cbacb255922fb6219df80c06df59.zip",
"https://github.com/google/XNNPACK/archive/38c07ec51af0cbacb255922fb6219df80c06df59.zip",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/5871703602c459b98c12be301c01255ae68a45e2.zip",
"https://github.com/google/XNNPACK/archive/5871703602c459b98c12be301c01255ae68a45e2.zip",
],
)