Adds support for depth_multiplier case when DepthwiseConv resolves to Conv
PiperOrigin-RevId: 299856795 Change-Id: I53195af9fbffc6215c58feb4ae3cfb654f6a03d2
This commit is contained in:
		
							parent
							
								
									1d27131082
								
							
						
					
					
						commit
						ac26a80b0f
					
				| @ -125,6 +125,9 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, | |||||||
| 
 | 
 | ||||||
|   // Input data tensor.
 |   // Input data tensor.
 | ||||||
|   const auto& data_tensor = context->tensors[inputs->data[0]]; |   const auto& data_tensor = context->tensors[inputs->data[0]]; | ||||||
|  |   int input_batch_size, input_height_size, input_width_size, input_depth_size; | ||||||
|  |   GetDims(&input_batch_size, &input_height_size, &input_width_size, | ||||||
|  |           &input_depth_size, data_tensor.dims); | ||||||
|   TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( |   TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( | ||||||
|       data_tensor, &data_min_, &data_max_, std::numeric_limits<uint8_t>::min(), |       data_tensor, &data_min_, &data_max_, std::numeric_limits<uint8_t>::min(), | ||||||
|       std::numeric_limits<uint8_t>::max())); |       std::numeric_limits<uint8_t>::max())); | ||||||
| @ -139,6 +142,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, | |||||||
|   int stride_height = 0; |   int stride_height = 0; | ||||||
|   int stride_width = 0; |   int stride_width = 0; | ||||||
|   bool is_dilated_depthwise_conv = false; |   bool is_dilated_depthwise_conv = false; | ||||||
|  |   int channel_multiplier = 1; | ||||||
|   if (op_node_.op_type == OP_Supernode_8x8p32to8) { |   if (op_node_.op_type == OP_Supernode_8x8p32to8) { | ||||||
|     const TfLiteConvParams* conv_params = |     const TfLiteConvParams* conv_params = | ||||||
|         reinterpret_cast<const TfLiteConvParams*>(builtin_data_); |         reinterpret_cast<const TfLiteConvParams*>(builtin_data_); | ||||||
| @ -153,6 +157,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, | |||||||
|     stride_width = conv_params->stride_width; |     stride_width = conv_params->stride_width; | ||||||
|     padding_type = conv_params->padding; |     padding_type = conv_params->padding; | ||||||
|     activation = conv_params->activation; |     activation = conv_params->activation; | ||||||
|  |     channel_multiplier = conv_params->depth_multiplier; | ||||||
|     // We only support dilation for DepthwiseConv.
 |     // We only support dilation for DepthwiseConv.
 | ||||||
|     if (conv_params->dilation_height_factor > 1 || |     if (conv_params->dilation_height_factor > 1 || | ||||||
|         conv_params->dilation_width_factor > 1) { |         conv_params->dilation_width_factor > 1) { | ||||||
| @ -176,22 +181,41 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, | |||||||
|   // Transpose NHWC -> HWCN
 |   // Transpose NHWC -> HWCN
 | ||||||
|   GetDims(&weights_batch_size, &weights_height_size, &weights_width_size, |   GetDims(&weights_batch_size, &weights_height_size, &weights_width_size, | ||||||
|           &weights_depth_size, weights_tensor.dims); |           &weights_depth_size, weights_tensor.dims); | ||||||
|   weight_shape_ = {weights_height_size, weights_width_size, weights_depth_size, |   OpBuilder* const_weights_node = nullptr; | ||||||
|                    weights_batch_size}; |   if (op_node_.op_type == OP_Supernode_8x8p32to8) { | ||||||
|   RuntimeShape nhwc_shape({weights_batch_size, weights_height_size, |     // Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC.
 | ||||||
|                            weights_width_size, weights_depth_size}); |     // Transpose NHWC -> HWCN
 | ||||||
|   RuntimeShape hwcn_shape({weights_height_size, weights_width_size, |     weight_shape_ = {weights_height_size, weights_width_size, | ||||||
|                            weights_depth_size, weights_batch_size}); |                      weights_depth_size, weights_batch_size}; | ||||||
|   std::vector<uint8_t> hwcn(NumElements(&weights_tensor)); |     RuntimeShape nhwc_shape({weights_batch_size, weights_height_size, | ||||||
|   TransposeParams transpose_params; |                              weights_width_size, weights_depth_size}); | ||||||
|   transpose_params.perm_count = 4; |     RuntimeShape hwcn_shape({weights_height_size, weights_width_size, | ||||||
|   transpose_params.perm[0] = 1; |                              weights_depth_size, weights_batch_size}); | ||||||
|   transpose_params.perm[1] = 2; |     std::vector<uint8_t> hwcn(NumElements(&weights_tensor)); | ||||||
|   transpose_params.perm[2] = 3; |     TransposeParams transpose_params; | ||||||
|   transpose_params.perm[3] = 0; |     transpose_params.perm_count = 4; | ||||||
|   optimized_ops::Transpose<uint8_t>(transpose_params, nhwc_shape, |     transpose_params.perm[0] = 1; | ||||||
|                                     weights_tensor.data.uint8, hwcn_shape, |     transpose_params.perm[1] = 2; | ||||||
|                                     hwcn.data()); |     transpose_params.perm[2] = 3; | ||||||
|  |     transpose_params.perm[3] = 0; | ||||||
|  |     optimized_ops::Transpose<uint8_t>(transpose_params, nhwc_shape, | ||||||
|  |                                       weights_tensor.data.uint8, hwcn_shape, | ||||||
|  |                                       hwcn.data()); | ||||||
|  |     const_weights_node = graph_builder_->AddConstNodeWithData( | ||||||
|  |         weight_shape_.data(), (char*)hwcn.data(), | ||||||
|  |         hwcn.size() * sizeof(hwcn[0])); | ||||||
|  |   } else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) { | ||||||
|  |     // Hexagon treats depthwise conv like tf.nn.depthwise_conv2d, where the
 | ||||||
|  |     // expected filter shape is [fh,fw,din,dmul].
 | ||||||
|  |     // The data itself will remain the same, since TFLite's representation is
 | ||||||
|  |     // just a 'flattening' of Hexagon's version.
 | ||||||
|  |     const int channel_multiplier = weights_depth_size / input_depth_size; | ||||||
|  |     weight_shape_ = {weights_height_size, weights_width_size, input_depth_size, | ||||||
|  |                      channel_multiplier}; | ||||||
|  |     const_weights_node = graph_builder_->AddConstNodeWithData( | ||||||
|  |         weight_shape_.data(), weights_tensor.data.raw, | ||||||
|  |         NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0])); | ||||||
|  |   } | ||||||
|   // Quantization params for Weights tensor.
 |   // Quantization params for Weights tensor.
 | ||||||
|   TF_LITE_ENSURE_STATUS( |   TF_LITE_ENSURE_STATUS( | ||||||
|       ComputeMinAndMaxQuantValues(weights_tensor, &weights_min_, &weights_max_, |       ComputeMinAndMaxQuantValues(weights_tensor, &weights_min_, &weights_max_, | ||||||
| @ -201,8 +225,6 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, | |||||||
|       quant_bound_shape.data(), (char*)&weights_min_, sizeof(weights_min_)); |       quant_bound_shape.data(), (char*)&weights_min_, sizeof(weights_min_)); | ||||||
|   auto* weights_max_const = graph_builder_->AddConstNodeWithData( |   auto* weights_max_const = graph_builder_->AddConstNodeWithData( | ||||||
|       quant_bound_shape.data(), (char*)&weights_max_, sizeof(weights_max_)); |       quant_bound_shape.data(), (char*)&weights_max_, sizeof(weights_max_)); | ||||||
|   auto* const_weights_node = graph_builder_->AddConstNodeWithData( |  | ||||||
|       weight_shape_.data(), (char*)hwcn.data(), hwcn.size() * sizeof(hwcn[0])); |  | ||||||
|   graph_builder_->AddTensorWithID(inputs->data[1], const_weights_node->GetID(), |   graph_builder_->AddTensorWithID(inputs->data[1], const_weights_node->GetID(), | ||||||
|                                   0); |                                   0); | ||||||
| 
 | 
 | ||||||
| @ -256,6 +278,18 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, | |||||||
|   auto* bias_max_const = graph_builder_->AddConstNodeWithData( |   auto* bias_max_const = graph_builder_->AddConstNodeWithData( | ||||||
|       quant_bound_shape.data(), (char*)&bias_max_, sizeof(bias_max_)); |       quant_bound_shape.data(), (char*)&bias_max_, sizeof(bias_max_)); | ||||||
| 
 | 
 | ||||||
|  |   // TODO(b/143759564): Simplify this method when depth_multiplier support needs
 | ||||||
|  |   // generalizing.
 | ||||||
|  |   if (channel_multiplier > 1 && input_depth_size == 1) { | ||||||
|  |     // Depthwise Conv with input_depth == 1 & channel_multiplier > 1 is
 | ||||||
|  |     // equivalent to Conv.
 | ||||||
|  |     SetOpType(OP_Supernode_8x8p32to8); | ||||||
|  |   } else if (channel_multiplier > 1) { | ||||||
|  |     TF_LITE_KERNEL_LOG( | ||||||
|  |         context, "depth_multiplier > 1 not supported with input_depth > 1"); | ||||||
|  |     return kTfLiteError; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   TensorID output, output_min, output_max; |   TensorID output, output_min, output_max; | ||||||
|   if (is_dilated_depthwise_conv) { |   if (is_dilated_depthwise_conv) { | ||||||
|     // For dilated Depthwise Conv, we convert this node into SpaceToBatchND, and
 |     // For dilated Depthwise Conv, we convert this node into SpaceToBatchND, and
 | ||||||
|  | |||||||
| @ -256,4 +256,86 @@ TEST(QuantizedConvolutionOpModel, SimpleConvTestReLU6Activation) { | |||||||
|                                             1e-5))); |                                             1e-5))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a
 | ||||||
|  | // Conv op.
 | ||||||
|  | TEST(QuantizedConvolutionOpModel, DepthwiseConvWithMultiplier_InputDepth1) { | ||||||
|  |   QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D, | ||||||
|  |                                 {TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64}, | ||||||
|  |                                 {TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64}, | ||||||
|  |                                 {TensorType_UINT8, {}, -127, 128}, | ||||||
|  |                                 Padding_VALID); | ||||||
|  |   // clang-format off
 | ||||||
|  |   m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0, | ||||||
|  |               0, 0, 0, 1, 1, 1, 0, 0, 0, | ||||||
|  |               0, 0, 0, 1, 1, 1, 0, 0, 0, | ||||||
|  |               0, 0, 0, 0, 0, 0, 0, 0, 0}); | ||||||
|  |   m.SetFilter({1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5}); | ||||||
|  |   // clang-format on
 | ||||||
|  |   m.SetBias({1, 2, 3}); | ||||||
|  | 
 | ||||||
|  |   // Reference output.
 | ||||||
|  |   m.Invoke(); | ||||||
|  |   auto reference_output = m.GetDequantizedOutput(); | ||||||
|  | 
 | ||||||
|  |   m.ApplyDelegateAndInvoke(); | ||||||
|  |   EXPECT_THAT(m.GetDequantizedOutput(), | ||||||
|  |               ElementsAreArray(ArrayFloatNear(reference_output, 1e-5))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a
 | ||||||
|  | // Conv op.
 | ||||||
|  | TEST(QuantizedConvolutionOpModel, | ||||||
|  |      DepthwiseConvWithMultiplier_InputDepth1_RELU) { | ||||||
|  |   QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D, | ||||||
|  |                                 {TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64}, | ||||||
|  |                                 {TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64}, | ||||||
|  |                                 {TensorType_UINT8, {}, -127, 128}, | ||||||
|  |                                 Padding_VALID, /**dilation_factor**/ 1, | ||||||
|  |                                 /**stride**/ 2, ActivationFunctionType_RELU6); | ||||||
|  |   // clang-format off
 | ||||||
|  |   m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0, | ||||||
|  |               0, 0, 0, 1, 1, 1, 0, 0, 0, | ||||||
|  |               0, 0, 0, 1, 1, 1, 0, 0, 0, | ||||||
|  |               0, 0, 0, 0, 0, 0, 0, 0, 0}); | ||||||
|  |   m.SetFilter({1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                6, 7, 8, 9, 10, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5, | ||||||
|  |                1, 2, 3, 4, 5}); | ||||||
|  |   // clang-format on
 | ||||||
|  |   m.SetBias({1, 2, 3}); | ||||||
|  | 
 | ||||||
|  |   // Reference output.
 | ||||||
|  |   m.Invoke(); | ||||||
|  |   auto reference_output = m.GetDequantizedOutput(); | ||||||
|  | 
 | ||||||
|  |   m.ApplyDelegateAndInvoke(); | ||||||
|  |   EXPECT_THAT(m.GetDequantizedOutput(), | ||||||
|  |               ElementsAreArray(ArrayFloatNear(reference_output, 1e-5))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace tflite
 | }  // namespace tflite
 | ||||||
|  | |||||||
| @ -188,6 +188,8 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, | |||||||
|       if (!InputsWithCorrectTypes(node, context, |       if (!InputsWithCorrectTypes(node, context, | ||||||
|                                   {kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32})) |                                   {kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32})) | ||||||
|         return false; |         return false; | ||||||
|  | 
 | ||||||
|  |       // Check dilation.
 | ||||||
|       const TfLiteDepthwiseConvParams* conv_params = |       const TfLiteDepthwiseConvParams* conv_params = | ||||||
|           reinterpret_cast<const TfLiteDepthwiseConvParams*>( |           reinterpret_cast<const TfLiteDepthwiseConvParams*>( | ||||||
|               node->builtin_data); |               node->builtin_data); | ||||||
| @ -198,10 +200,19 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, | |||||||
|         if (conv_params->stride_height != 1 || conv_params->stride_width != 1) |         if (conv_params->stride_height != 1 || conv_params->stride_width != 1) | ||||||
|           return false; |           return false; | ||||||
|       } |       } | ||||||
|  | 
 | ||||||
|  |       // We currently only support depth_multiplier > 1 when:
 | ||||||
|  |       // 1. dilation_factor == 1 AND
 | ||||||
|  |       // 2. input_depth == 1
 | ||||||
|  |       // TODO(b/143759564): Add support for general case.
 | ||||||
|  |       const auto& input = context->tensors[node->inputs->data[0]]; | ||||||
|  |       const bool supported_depth_multiplier = | ||||||
|  |           conv_params->depth_multiplier == 1 || | ||||||
|  |           (!dilation && input.dims->size == 4 && input.dims->data[3] == 1); | ||||||
|  | 
 | ||||||
|       return (IsActivationReluOrNone(conv_params->activation) && |       return (IsActivationReluOrNone(conv_params->activation) && | ||||||
|               conv_params->stride_height <= 3 && |               conv_params->stride_height <= 3 && | ||||||
|               conv_params->stride_width <= 3 && |               conv_params->stride_width <= 3 && supported_depth_multiplier); | ||||||
|               conv_params->depth_multiplier == 1); |  | ||||||
|     } |     } | ||||||
|     case kTfLiteBuiltinReshape: { |     case kTfLiteBuiltinReshape: { | ||||||
|       if (node->inputs->size > 2 || |       if (node->inputs->size > 2 || | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user