diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 925ef82ddcb..7c061b3a734 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -397,6 +398,27 @@ class Subgraph { return kTfLiteOk; } + static TfLiteStatus CheckTransposedConvolutionParams( + TfLiteContext* context, const TfLiteTransposeConvParams* params, + int node_index) { + if (params->stride_width <= 0) { + if (context != nullptr) { + TF_LITE_KERNEL_LOG(context, "invalid stride width %d in node #%d", + params->stride_width, node_index); + } + return kTfLiteError; + } + if (params->stride_height <= 0) { + if (context != nullptr) { + TF_LITE_KERNEL_LOG(context, "invalid stride height %d in node #%d", + params->stride_height, node_index); + } + return kTfLiteError; + } + + return kTfLiteOk; + } + static TfLiteStatus CheckFullyConnectedParams( TfLiteContext* context, const TfLiteFullyConnectedParams* params, int node_index) { @@ -672,6 +694,18 @@ class Subgraph { context->tensors, softmax_params, xnnpack_tensors); } + case kTfLiteBuiltinCustom: { + if (strcmp(registration->custom_name, "Convolution2DTransposeBias") == + 0) { + const TfLiteTransposeConvParams* deconv_params = + static_cast( + node->custom_initial_data); + return VisitMediaPipeDeconvolutionNode( + subgraph, context, node_index, node, context->tensors, + deconv_params, xnnpack_tensors); + } + return kTfLiteError; + } default: return kTfLiteError; } @@ -1236,6 +1270,93 @@ class Subgraph { return kTfLiteOk; } + static TfLiteStatus VisitMediaPipeDeconvolutionNode( + xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index, + TfLiteNode* node, const TfLiteTensor* tensors, + const TfLiteTransposeConvParams* deconv_params, + const std::vector& xnnpack_tensors) { + TF_LITE_ENSURE_STATUS( + CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index)); + + const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, input_tensor, node->inputs->data[0], node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4, + node->inputs->data[0])); + TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( + logging_context, input_tensor, node->inputs->data[0], node_index)); + + const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, filter_tensor, node->inputs->data[1], node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4, + node->inputs->data[1])); + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, filter_tensor, node->inputs->data[1], node_index)); + + const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, filter_tensor, node->inputs->data[2], node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1, + node->inputs->data[2])); + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, bias_tensor, node->inputs->data[2], node_index)); + + const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, output_tensor, node->outputs->data[0], node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4, + node->outputs->data[0])); + TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( + logging_context, output_tensor, node->outputs->data[0], node_index)); + + const int output_channels = filter_tensor.dims->data[0]; + const int kernel_height = filter_tensor.dims->data[1]; + const int kernel_width = filter_tensor.dims->data[2]; + const int input_channels = filter_tensor.dims->data[3]; + + TF_LITE_ENSURE_STATUS(CheckTransposedConvolutionParams( + logging_context, deconv_params, node_index)); + + uint32_t flags = 0; + TF_LITE_ENSURE_STATUS(CalculatePadding( + logging_context, deconv_params->padding, &flags, node_index)); + + if (subgraph != nullptr) { + const xnn_status status = xnn_define_deconvolution_2d( + subgraph, + /*padding_top=*/0, + /*padding_right=*/0, + /*padding_bottom=*/0, + /*padding_left=*/0, + /*adjustment_height=*/0, + /*adjustment_width=*/0, static_cast(kernel_height), + static_cast(kernel_width), + static_cast(deconv_params->stride_height), + static_cast(deconv_params->stride_width), + /*dilation_height=*/1, + /*dilation_width=*/1, + /*groups=*/1, + /*group_input_channels=*/input_channels, + /*group_output_channels=*/output_channels, + /*output_min=*/-std::numeric_limits::infinity(), + /*output_max=*/+std::numeric_limits::infinity(), + /*input_id=*/xnnpack_tensors[node->inputs->data[0]], + /*filter_id=*/xnnpack_tensors[node->inputs->data[1]], + /*bias_id=*/xnnpack_tensors[node->inputs->data[2]], + /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags); + if (status != xnn_status_success) { + TF_LITE_KERNEL_LOG( + logging_context, + "failed to delegate Convolution2DTransposeBias node #%d", + node_index); + return kTfLiteError; + } + } + + return kTfLiteOk; + } + static TfLiteStatus VisitMulNode( xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index b613bc569ff..9d7fc8be5e8 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -156,11 +156,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "XNNPACK", - sha256 = "f6eb0f1759eca187d922a72a3a12dfe1593bd09783aa4b67bee70630985eb832", - strip_prefix = "XNNPACK-38c07ec51af0cbacb255922fb6219df80c06df59", + sha256 = "41a0a396a5a9cb2171c1c7f6d7689316beaa6f638663161fc7f86450eba33070", + strip_prefix = "XNNPACK-5871703602c459b98c12be301c01255ae68a45e2", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/38c07ec51af0cbacb255922fb6219df80c06df59.zip", - "https://github.com/google/XNNPACK/archive/38c07ec51af0cbacb255922fb6219df80c06df59.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/5871703602c459b98c12be301c01255ae68a45e2.zip", + "https://github.com/google/XNNPACK/archive/5871703602c459b98c12be301c01255ae68a45e2.zip", ], )