Adds support for depth_multiplier case when DepthwiseConv resolves to Conv
PiperOrigin-RevId: 299856795 Change-Id: I53195af9fbffc6215c58feb4ae3cfb654f6a03d2
This commit is contained in:
parent
1d27131082
commit
ac26a80b0f
@ -125,6 +125,9 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
|||||||
|
|
||||||
// Input data tensor.
|
// Input data tensor.
|
||||||
const auto& data_tensor = context->tensors[inputs->data[0]];
|
const auto& data_tensor = context->tensors[inputs->data[0]];
|
||||||
|
int input_batch_size, input_height_size, input_width_size, input_depth_size;
|
||||||
|
GetDims(&input_batch_size, &input_height_size, &input_width_size,
|
||||||
|
&input_depth_size, data_tensor.dims);
|
||||||
TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
|
TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
|
||||||
data_tensor, &data_min_, &data_max_, std::numeric_limits<uint8_t>::min(),
|
data_tensor, &data_min_, &data_max_, std::numeric_limits<uint8_t>::min(),
|
||||||
std::numeric_limits<uint8_t>::max()));
|
std::numeric_limits<uint8_t>::max()));
|
||||||
@ -139,6 +142,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
|||||||
int stride_height = 0;
|
int stride_height = 0;
|
||||||
int stride_width = 0;
|
int stride_width = 0;
|
||||||
bool is_dilated_depthwise_conv = false;
|
bool is_dilated_depthwise_conv = false;
|
||||||
|
int channel_multiplier = 1;
|
||||||
if (op_node_.op_type == OP_Supernode_8x8p32to8) {
|
if (op_node_.op_type == OP_Supernode_8x8p32to8) {
|
||||||
const TfLiteConvParams* conv_params =
|
const TfLiteConvParams* conv_params =
|
||||||
reinterpret_cast<const TfLiteConvParams*>(builtin_data_);
|
reinterpret_cast<const TfLiteConvParams*>(builtin_data_);
|
||||||
@ -153,6 +157,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
|||||||
stride_width = conv_params->stride_width;
|
stride_width = conv_params->stride_width;
|
||||||
padding_type = conv_params->padding;
|
padding_type = conv_params->padding;
|
||||||
activation = conv_params->activation;
|
activation = conv_params->activation;
|
||||||
|
channel_multiplier = conv_params->depth_multiplier;
|
||||||
// We only support dilation for DepthwiseConv.
|
// We only support dilation for DepthwiseConv.
|
||||||
if (conv_params->dilation_height_factor > 1 ||
|
if (conv_params->dilation_height_factor > 1 ||
|
||||||
conv_params->dilation_width_factor > 1) {
|
conv_params->dilation_width_factor > 1) {
|
||||||
@ -176,8 +181,12 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
|||||||
// Transpose NHWC -> HWCN
|
// Transpose NHWC -> HWCN
|
||||||
GetDims(&weights_batch_size, &weights_height_size, &weights_width_size,
|
GetDims(&weights_batch_size, &weights_height_size, &weights_width_size,
|
||||||
&weights_depth_size, weights_tensor.dims);
|
&weights_depth_size, weights_tensor.dims);
|
||||||
weight_shape_ = {weights_height_size, weights_width_size, weights_depth_size,
|
OpBuilder* const_weights_node = nullptr;
|
||||||
weights_batch_size};
|
if (op_node_.op_type == OP_Supernode_8x8p32to8) {
|
||||||
|
// Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC.
|
||||||
|
// Transpose NHWC -> HWCN
|
||||||
|
weight_shape_ = {weights_height_size, weights_width_size,
|
||||||
|
weights_depth_size, weights_batch_size};
|
||||||
RuntimeShape nhwc_shape({weights_batch_size, weights_height_size,
|
RuntimeShape nhwc_shape({weights_batch_size, weights_height_size,
|
||||||
weights_width_size, weights_depth_size});
|
weights_width_size, weights_depth_size});
|
||||||
RuntimeShape hwcn_shape({weights_height_size, weights_width_size,
|
RuntimeShape hwcn_shape({weights_height_size, weights_width_size,
|
||||||
@ -192,6 +201,21 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
|||||||
optimized_ops::Transpose<uint8_t>(transpose_params, nhwc_shape,
|
optimized_ops::Transpose<uint8_t>(transpose_params, nhwc_shape,
|
||||||
weights_tensor.data.uint8, hwcn_shape,
|
weights_tensor.data.uint8, hwcn_shape,
|
||||||
hwcn.data());
|
hwcn.data());
|
||||||
|
const_weights_node = graph_builder_->AddConstNodeWithData(
|
||||||
|
weight_shape_.data(), (char*)hwcn.data(),
|
||||||
|
hwcn.size() * sizeof(hwcn[0]));
|
||||||
|
} else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) {
|
||||||
|
// Hexagon treats depthwise conv like tf.nn.depthwise_conv2d, where the
|
||||||
|
// expected filter shape is [fh,fw,din,dmul].
|
||||||
|
// The data itself will remain the same, since TFLite's representation is
|
||||||
|
// just a 'flattening' of Hexagon's version.
|
||||||
|
const int channel_multiplier = weights_depth_size / input_depth_size;
|
||||||
|
weight_shape_ = {weights_height_size, weights_width_size, input_depth_size,
|
||||||
|
channel_multiplier};
|
||||||
|
const_weights_node = graph_builder_->AddConstNodeWithData(
|
||||||
|
weight_shape_.data(), weights_tensor.data.raw,
|
||||||
|
NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0]));
|
||||||
|
}
|
||||||
// Quantization params for Weights tensor.
|
// Quantization params for Weights tensor.
|
||||||
TF_LITE_ENSURE_STATUS(
|
TF_LITE_ENSURE_STATUS(
|
||||||
ComputeMinAndMaxQuantValues(weights_tensor, &weights_min_, &weights_max_,
|
ComputeMinAndMaxQuantValues(weights_tensor, &weights_min_, &weights_max_,
|
||||||
@ -201,8 +225,6 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
|||||||
quant_bound_shape.data(), (char*)&weights_min_, sizeof(weights_min_));
|
quant_bound_shape.data(), (char*)&weights_min_, sizeof(weights_min_));
|
||||||
auto* weights_max_const = graph_builder_->AddConstNodeWithData(
|
auto* weights_max_const = graph_builder_->AddConstNodeWithData(
|
||||||
quant_bound_shape.data(), (char*)&weights_max_, sizeof(weights_max_));
|
quant_bound_shape.data(), (char*)&weights_max_, sizeof(weights_max_));
|
||||||
auto* const_weights_node = graph_builder_->AddConstNodeWithData(
|
|
||||||
weight_shape_.data(), (char*)hwcn.data(), hwcn.size() * sizeof(hwcn[0]));
|
|
||||||
graph_builder_->AddTensorWithID(inputs->data[1], const_weights_node->GetID(),
|
graph_builder_->AddTensorWithID(inputs->data[1], const_weights_node->GetID(),
|
||||||
0);
|
0);
|
||||||
|
|
||||||
@ -256,6 +278,18 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
|||||||
auto* bias_max_const = graph_builder_->AddConstNodeWithData(
|
auto* bias_max_const = graph_builder_->AddConstNodeWithData(
|
||||||
quant_bound_shape.data(), (char*)&bias_max_, sizeof(bias_max_));
|
quant_bound_shape.data(), (char*)&bias_max_, sizeof(bias_max_));
|
||||||
|
|
||||||
|
// TODO(b/143759564): Simplify this method when depth_multiplier support needs
|
||||||
|
// generalizing.
|
||||||
|
if (channel_multiplier > 1 && input_depth_size == 1) {
|
||||||
|
// Depthwise Conv with input_depth == 1 & channel_multiplier > 1 is
|
||||||
|
// equivalent to Conv.
|
||||||
|
SetOpType(OP_Supernode_8x8p32to8);
|
||||||
|
} else if (channel_multiplier > 1) {
|
||||||
|
TF_LITE_KERNEL_LOG(
|
||||||
|
context, "depth_multiplier > 1 not supported with input_depth > 1");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
TensorID output, output_min, output_max;
|
TensorID output, output_min, output_max;
|
||||||
if (is_dilated_depthwise_conv) {
|
if (is_dilated_depthwise_conv) {
|
||||||
// For dilated Depthwise Conv, we convert this node into SpaceToBatchND, and
|
// For dilated Depthwise Conv, we convert this node into SpaceToBatchND, and
|
||||||
|
@ -256,4 +256,86 @@ TEST(QuantizedConvolutionOpModel, SimpleConvTestReLU6Activation) {
|
|||||||
1e-5)));
|
1e-5)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a
|
||||||
|
// Conv op.
|
||||||
|
TEST(QuantizedConvolutionOpModel, DepthwiseConvWithMultiplier_InputDepth1) {
|
||||||
|
QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||||
|
{TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64},
|
||||||
|
{TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64},
|
||||||
|
{TensorType_UINT8, {}, -127, 128},
|
||||||
|
Padding_VALID);
|
||||||
|
// clang-format off
|
||||||
|
m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
m.SetFilter({1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5});
|
||||||
|
// clang-format on
|
||||||
|
m.SetBias({1, 2, 3});
|
||||||
|
|
||||||
|
// Reference output.
|
||||||
|
m.Invoke();
|
||||||
|
auto reference_output = m.GetDequantizedOutput();
|
||||||
|
|
||||||
|
m.ApplyDelegateAndInvoke();
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(reference_output, 1e-5)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a
|
||||||
|
// Conv op.
|
||||||
|
TEST(QuantizedConvolutionOpModel,
|
||||||
|
DepthwiseConvWithMultiplier_InputDepth1_RELU) {
|
||||||
|
QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||||
|
{TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64},
|
||||||
|
{TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64},
|
||||||
|
{TensorType_UINT8, {}, -127, 128},
|
||||||
|
Padding_VALID, /**dilation_factor**/ 1,
|
||||||
|
/**stride**/ 2, ActivationFunctionType_RELU6);
|
||||||
|
// clang-format off
|
||||||
|
m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
m.SetFilter({1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
6, 7, 8, 9, 10,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 4, 5});
|
||||||
|
// clang-format on
|
||||||
|
m.SetBias({1, 2, 3});
|
||||||
|
|
||||||
|
// Reference output.
|
||||||
|
m.Invoke();
|
||||||
|
auto reference_output = m.GetDequantizedOutput();
|
||||||
|
|
||||||
|
m.ApplyDelegateAndInvoke();
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(reference_output, 1e-5)));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -188,6 +188,8 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
|||||||
if (!InputsWithCorrectTypes(node, context,
|
if (!InputsWithCorrectTypes(node, context,
|
||||||
{kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32}))
|
{kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32}))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
// Check dilation.
|
||||||
const TfLiteDepthwiseConvParams* conv_params =
|
const TfLiteDepthwiseConvParams* conv_params =
|
||||||
reinterpret_cast<const TfLiteDepthwiseConvParams*>(
|
reinterpret_cast<const TfLiteDepthwiseConvParams*>(
|
||||||
node->builtin_data);
|
node->builtin_data);
|
||||||
@ -198,10 +200,19 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
|||||||
if (conv_params->stride_height != 1 || conv_params->stride_width != 1)
|
if (conv_params->stride_height != 1 || conv_params->stride_width != 1)
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We currently only support depth_multiplier > 1 when:
|
||||||
|
// 1. dilation_factor == 1 AND
|
||||||
|
// 2. input_depth == 1
|
||||||
|
// TODO(b/143759564): Add support for general case.
|
||||||
|
const auto& input = context->tensors[node->inputs->data[0]];
|
||||||
|
const bool supported_depth_multiplier =
|
||||||
|
conv_params->depth_multiplier == 1 ||
|
||||||
|
(!dilation && input.dims->size == 4 && input.dims->data[3] == 1);
|
||||||
|
|
||||||
return (IsActivationReluOrNone(conv_params->activation) &&
|
return (IsActivationReluOrNone(conv_params->activation) &&
|
||||||
conv_params->stride_height <= 3 &&
|
conv_params->stride_height <= 3 &&
|
||||||
conv_params->stride_width <= 3 &&
|
conv_params->stride_width <= 3 && supported_depth_multiplier);
|
||||||
conv_params->depth_multiplier == 1);
|
|
||||||
}
|
}
|
||||||
case kTfLiteBuiltinReshape: {
|
case kTfLiteBuiltinReshape: {
|
||||||
if (node->inputs->size > 2 ||
|
if (node->inputs->size > 2 ||
|
||||||
|
Loading…
Reference in New Issue
Block a user