Support MediaPipe Convolution2DTransposeBias in XNNPACK delegate
PiperOrigin-RevId: 308960507 Change-Id: Ie3bfad63a2ff0934ba35f9f696be5f97ca5b2d9a
This commit is contained in:
parent
2bcda95578
commit
97e9e00993
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -397,6 +398,27 @@ class Subgraph {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus CheckTransposedConvolutionParams(
|
||||
TfLiteContext* context, const TfLiteTransposeConvParams* params,
|
||||
int node_index) {
|
||||
if (params->stride_width <= 0) {
|
||||
if (context != nullptr) {
|
||||
TF_LITE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
|
||||
params->stride_width, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
if (params->stride_height <= 0) {
|
||||
if (context != nullptr) {
|
||||
TF_LITE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
|
||||
params->stride_height, node_index);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus CheckFullyConnectedParams(
|
||||
TfLiteContext* context, const TfLiteFullyConnectedParams* params,
|
||||
int node_index) {
|
||||
@ -672,6 +694,18 @@ class Subgraph {
|
||||
context->tensors, softmax_params,
|
||||
xnnpack_tensors);
|
||||
}
|
||||
case kTfLiteBuiltinCustom: {
|
||||
if (strcmp(registration->custom_name, "Convolution2DTransposeBias") ==
|
||||
0) {
|
||||
const TfLiteTransposeConvParams* deconv_params =
|
||||
static_cast<const TfLiteTransposeConvParams*>(
|
||||
node->custom_initial_data);
|
||||
return VisitMediaPipeDeconvolutionNode(
|
||||
subgraph, context, node_index, node, context->tensors,
|
||||
deconv_params, xnnpack_tensors);
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
default:
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -1236,6 +1270,93 @@ class Subgraph {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitMediaPipeDeconvolutionNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
const TfLiteTransposeConvParams* deconv_params,
|
||||
const std::vector<uint32_t>& xnnpack_tensors) {
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
|
||||
|
||||
const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
|
||||
node->inputs->data[0]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
|
||||
const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
|
||||
node->inputs->data[1]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, filter_tensor, node->inputs->data[1], node_index));
|
||||
|
||||
const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, filter_tensor, node->inputs->data[2], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
|
||||
node->inputs->data[2]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
|
||||
logging_context, bias_tensor, node->inputs->data[2], node_index));
|
||||
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
|
||||
node->outputs->data[0]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
|
||||
const int output_channels = filter_tensor.dims->data[0];
|
||||
const int kernel_height = filter_tensor.dims->data[1];
|
||||
const int kernel_width = filter_tensor.dims->data[2];
|
||||
const int input_channels = filter_tensor.dims->data[3];
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CheckTransposedConvolutionParams(
|
||||
logging_context, deconv_params, node_index));
|
||||
|
||||
uint32_t flags = 0;
|
||||
TF_LITE_ENSURE_STATUS(CalculatePadding(
|
||||
logging_context, deconv_params->padding, &flags, node_index));
|
||||
|
||||
if (subgraph != nullptr) {
|
||||
const xnn_status status = xnn_define_deconvolution_2d(
|
||||
subgraph,
|
||||
/*padding_top=*/0,
|
||||
/*padding_right=*/0,
|
||||
/*padding_bottom=*/0,
|
||||
/*padding_left=*/0,
|
||||
/*adjustment_height=*/0,
|
||||
/*adjustment_width=*/0, static_cast<uint32_t>(kernel_height),
|
||||
static_cast<uint32_t>(kernel_width),
|
||||
static_cast<uint32_t>(deconv_params->stride_height),
|
||||
static_cast<uint32_t>(deconv_params->stride_width),
|
||||
/*dilation_height=*/1,
|
||||
/*dilation_width=*/1,
|
||||
/*groups=*/1,
|
||||
/*group_input_channels=*/input_channels,
|
||||
/*group_output_channels=*/output_channels,
|
||||
/*output_min=*/-std::numeric_limits<float>::infinity(),
|
||||
/*output_max=*/+std::numeric_limits<float>::infinity(),
|
||||
/*input_id=*/xnnpack_tensors[node->inputs->data[0]],
|
||||
/*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
|
||||
/*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
|
||||
/*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
|
||||
if (status != xnn_status_success) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
logging_context,
|
||||
"failed to delegate Convolution2DTransposeBias node #%d",
|
||||
node_index);
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
static TfLiteStatus VisitMulNode(
|
||||
xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index,
|
||||
TfLiteNode* node, const TfLiteTensor* tensors,
|
||||
|
||||
@ -156,11 +156,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
|
||||
tf_http_archive(
|
||||
name = "XNNPACK",
|
||||
sha256 = "f6eb0f1759eca187d922a72a3a12dfe1593bd09783aa4b67bee70630985eb832",
|
||||
strip_prefix = "XNNPACK-38c07ec51af0cbacb255922fb6219df80c06df59",
|
||||
sha256 = "41a0a396a5a9cb2171c1c7f6d7689316beaa6f638663161fc7f86450eba33070",
|
||||
strip_prefix = "XNNPACK-5871703602c459b98c12be301c01255ae68a45e2",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/38c07ec51af0cbacb255922fb6219df80c06df59.zip",
|
||||
"https://github.com/google/XNNPACK/archive/38c07ec51af0cbacb255922fb6219df80c06df59.zip",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/5871703602c459b98c12be301c01255ae68a45e2.zip",
|
||||
"https://github.com/google/XNNPACK/archive/5871703602c459b98c12be301c01255ae68a45e2.zip",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user