From ac26a80b0f004067b24e7a4a021baab2ba17a188 Mon Sep 17 00:00:00 2001 From: Sachin Joglekar Date: Mon, 9 Mar 2020 09:15:46 -0700 Subject: [PATCH] Adds support for depth_multiplier case when DepthwiseConv resolves to Conv PiperOrigin-RevId: 299856795 Change-Id: I53195af9fbffc6215c58feb4ae3cfb654f6a03d2 --- .../hexagon/builders/conv_2d_builder.cc | 70 ++++++++++++---- .../hexagon/builders/tests/conv_test.cc | 82 +++++++++++++++++++ .../experimental/delegates/hexagon/utils.cc | 15 +++- 3 files changed, 147 insertions(+), 20 deletions(-) diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.cc index 809d6e7d7dc..85957706d57 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/conv_2d_builder.cc @@ -125,6 +125,9 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, // Input data tensor. const auto& data_tensor = context->tensors[inputs->data[0]]; + int input_batch_size, input_height_size, input_width_size, input_depth_size; + GetDims(&input_batch_size, &input_height_size, &input_width_size, + &input_depth_size, data_tensor.dims); TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( data_tensor, &data_min_, &data_max_, std::numeric_limits::min(), std::numeric_limits::max())); @@ -139,6 +142,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, int stride_height = 0; int stride_width = 0; bool is_dilated_depthwise_conv = false; + int channel_multiplier = 1; if (op_node_.op_type == OP_Supernode_8x8p32to8) { const TfLiteConvParams* conv_params = reinterpret_cast(builtin_data_); @@ -153,6 +157,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, stride_width = conv_params->stride_width; padding_type = conv_params->padding; activation = conv_params->activation; + channel_multiplier = conv_params->depth_multiplier; // We only support dilation for DepthwiseConv. if (conv_params->dilation_height_factor > 1 || conv_params->dilation_width_factor > 1) { @@ -176,22 +181,41 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, // Transpose NHWC -> HWCN GetDims(&weights_batch_size, &weights_height_size, &weights_width_size, &weights_depth_size, weights_tensor.dims); - weight_shape_ = {weights_height_size, weights_width_size, weights_depth_size, - weights_batch_size}; - RuntimeShape nhwc_shape({weights_batch_size, weights_height_size, - weights_width_size, weights_depth_size}); - RuntimeShape hwcn_shape({weights_height_size, weights_width_size, - weights_depth_size, weights_batch_size}); - std::vector hwcn(NumElements(&weights_tensor)); - TransposeParams transpose_params; - transpose_params.perm_count = 4; - transpose_params.perm[0] = 1; - transpose_params.perm[1] = 2; - transpose_params.perm[2] = 3; - transpose_params.perm[3] = 0; - optimized_ops::Transpose(transpose_params, nhwc_shape, - weights_tensor.data.uint8, hwcn_shape, - hwcn.data()); + OpBuilder* const_weights_node = nullptr; + if (op_node_.op_type == OP_Supernode_8x8p32to8) { + // Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC. + // Transpose NHWC -> HWCN + weight_shape_ = {weights_height_size, weights_width_size, + weights_depth_size, weights_batch_size}; + RuntimeShape nhwc_shape({weights_batch_size, weights_height_size, + weights_width_size, weights_depth_size}); + RuntimeShape hwcn_shape({weights_height_size, weights_width_size, + weights_depth_size, weights_batch_size}); + std::vector hwcn(NumElements(&weights_tensor)); + TransposeParams transpose_params; + transpose_params.perm_count = 4; + transpose_params.perm[0] = 1; + transpose_params.perm[1] = 2; + transpose_params.perm[2] = 3; + transpose_params.perm[3] = 0; + optimized_ops::Transpose(transpose_params, nhwc_shape, + weights_tensor.data.uint8, hwcn_shape, + hwcn.data()); + const_weights_node = graph_builder_->AddConstNodeWithData( + weight_shape_.data(), (char*)hwcn.data(), + hwcn.size() * sizeof(hwcn[0])); + } else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) { + // Hexagon treats depthwise conv like tf.nn.depthwise_conv2d, where the + // expected filter shape is [fh,fw,din,dmul]. + // The data itself will remain the same, since TFLite's representation is + // just a 'flattening' of Hexagon's version. + const int channel_multiplier = weights_depth_size / input_depth_size; + weight_shape_ = {weights_height_size, weights_width_size, input_depth_size, + channel_multiplier}; + const_weights_node = graph_builder_->AddConstNodeWithData( + weight_shape_.data(), weights_tensor.data.raw, + NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0])); + } // Quantization params for Weights tensor. TF_LITE_ENSURE_STATUS( ComputeMinAndMaxQuantValues(weights_tensor, &weights_min_, &weights_max_, @@ -201,8 +225,6 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, quant_bound_shape.data(), (char*)&weights_min_, sizeof(weights_min_)); auto* weights_max_const = graph_builder_->AddConstNodeWithData( quant_bound_shape.data(), (char*)&weights_max_, sizeof(weights_max_)); - auto* const_weights_node = graph_builder_->AddConstNodeWithData( - weight_shape_.data(), (char*)hwcn.data(), hwcn.size() * sizeof(hwcn[0])); graph_builder_->AddTensorWithID(inputs->data[1], const_weights_node->GetID(), 0); @@ -256,6 +278,18 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, auto* bias_max_const = graph_builder_->AddConstNodeWithData( quant_bound_shape.data(), (char*)&bias_max_, sizeof(bias_max_)); + // TODO(b/143759564): Simplify this method when depth_multiplier support needs + // generalizing. + if (channel_multiplier > 1 && input_depth_size == 1) { + // Depthwise Conv with input_depth == 1 & channel_multiplier > 1 is + // equivalent to Conv. + SetOpType(OP_Supernode_8x8p32to8); + } else if (channel_multiplier > 1) { + TF_LITE_KERNEL_LOG( + context, "depth_multiplier > 1 not supported with input_depth > 1"); + return kTfLiteError; + } + TensorID output, output_min, output_max; if (is_dilated_depthwise_conv) { // For dilated Depthwise Conv, we convert this node into SpaceToBatchND, and diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/conv_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/conv_test.cc index f1e1686b27a..ba4b57001fb 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/conv_test.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/conv_test.cc @@ -256,4 +256,86 @@ TEST(QuantizedConvolutionOpModel, SimpleConvTestReLU6Activation) { 1e-5))); } +// Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a +// Conv op. +TEST(QuantizedConvolutionOpModel, DepthwiseConvWithMultiplier_InputDepth1) { + QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D, + {TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64}, + {TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}, + Padding_VALID); + // clang-format off + m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0}); + m.SetFilter({1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 6, 7, 8, 9, 10, + 6, 7, 8, 9, 10, + 6, 7, 8, 9, 10, + 6, 7, 8, 9, 10, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5}); + // clang-format on + m.SetBias({1, 2, 3}); + + // Reference output. + m.Invoke(); + auto reference_output = m.GetDequantizedOutput(); + + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(reference_output, 1e-5))); +} + +// Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a +// Conv op. +TEST(QuantizedConvolutionOpModel, + DepthwiseConvWithMultiplier_InputDepth1_RELU) { + QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D, + {TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64}, + {TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}, + Padding_VALID, /**dilation_factor**/ 1, + /**stride**/ 2, ActivationFunctionType_RELU6); + // clang-format off + m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0}); + m.SetFilter({1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 6, 7, 8, 9, 10, + 6, 7, 8, 9, 10, + 6, 7, 8, 9, 10, + 6, 7, 8, 9, 10, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5, + 1, 2, 3, 4, 5}); + // clang-format on + m.SetBias({1, 2, 3}); + + // Reference output. + m.Invoke(); + auto reference_output = m.GetDequantizedOutput(); + + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(reference_output, 1e-5))); +} + } // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/utils.cc b/tensorflow/lite/experimental/delegates/hexagon/utils.cc index 55736373da3..feff2080eaa 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/utils.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/utils.cc @@ -188,6 +188,8 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, if (!InputsWithCorrectTypes(node, context, {kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32})) return false; + + // Check dilation. const TfLiteDepthwiseConvParams* conv_params = reinterpret_cast( node->builtin_data); @@ -198,10 +200,19 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, if (conv_params->stride_height != 1 || conv_params->stride_width != 1) return false; } + + // We currently only support depth_multiplier > 1 when: + // 1. dilation_factor == 1 AND + // 2. input_depth == 1 + // TODO(b/143759564): Add support for general case. + const auto& input = context->tensors[node->inputs->data[0]]; + const bool supported_depth_multiplier = + conv_params->depth_multiplier == 1 || + (!dilation && input.dims->size == 4 && input.dims->data[3] == 1); + return (IsActivationReluOrNone(conv_params->activation) && conv_params->stride_height <= 3 && - conv_params->stride_width <= 3 && - conv_params->depth_multiplier == 1); + conv_params->stride_width <= 3 && supported_depth_multiplier); } case kTfLiteBuiltinReshape: { if (node->inputs->size > 2 ||