Adds support for depth_multiplier case when DepthwiseConv resolves to Conv
PiperOrigin-RevId: 299856795 Change-Id: I53195af9fbffc6215c58feb4ae3cfb654f6a03d2
This commit is contained in:
parent
1d27131082
commit
ac26a80b0f
@ -125,6 +125,9 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
|
||||
// Input data tensor.
|
||||
const auto& data_tensor = context->tensors[inputs->data[0]];
|
||||
int input_batch_size, input_height_size, input_width_size, input_depth_size;
|
||||
GetDims(&input_batch_size, &input_height_size, &input_width_size,
|
||||
&input_depth_size, data_tensor.dims);
|
||||
TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
|
||||
data_tensor, &data_min_, &data_max_, std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max()));
|
||||
@ -139,6 +142,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
int stride_height = 0;
|
||||
int stride_width = 0;
|
||||
bool is_dilated_depthwise_conv = false;
|
||||
int channel_multiplier = 1;
|
||||
if (op_node_.op_type == OP_Supernode_8x8p32to8) {
|
||||
const TfLiteConvParams* conv_params =
|
||||
reinterpret_cast<const TfLiteConvParams*>(builtin_data_);
|
||||
@ -153,6 +157,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
stride_width = conv_params->stride_width;
|
||||
padding_type = conv_params->padding;
|
||||
activation = conv_params->activation;
|
||||
channel_multiplier = conv_params->depth_multiplier;
|
||||
// We only support dilation for DepthwiseConv.
|
||||
if (conv_params->dilation_height_factor > 1 ||
|
||||
conv_params->dilation_width_factor > 1) {
|
||||
@ -176,22 +181,41 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
// Transpose NHWC -> HWCN
|
||||
GetDims(&weights_batch_size, &weights_height_size, &weights_width_size,
|
||||
&weights_depth_size, weights_tensor.dims);
|
||||
weight_shape_ = {weights_height_size, weights_width_size, weights_depth_size,
|
||||
weights_batch_size};
|
||||
RuntimeShape nhwc_shape({weights_batch_size, weights_height_size,
|
||||
weights_width_size, weights_depth_size});
|
||||
RuntimeShape hwcn_shape({weights_height_size, weights_width_size,
|
||||
weights_depth_size, weights_batch_size});
|
||||
std::vector<uint8_t> hwcn(NumElements(&weights_tensor));
|
||||
TransposeParams transpose_params;
|
||||
transpose_params.perm_count = 4;
|
||||
transpose_params.perm[0] = 1;
|
||||
transpose_params.perm[1] = 2;
|
||||
transpose_params.perm[2] = 3;
|
||||
transpose_params.perm[3] = 0;
|
||||
optimized_ops::Transpose<uint8_t>(transpose_params, nhwc_shape,
|
||||
weights_tensor.data.uint8, hwcn_shape,
|
||||
hwcn.data());
|
||||
OpBuilder* const_weights_node = nullptr;
|
||||
if (op_node_.op_type == OP_Supernode_8x8p32to8) {
|
||||
// Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC.
|
||||
// Transpose NHWC -> HWCN
|
||||
weight_shape_ = {weights_height_size, weights_width_size,
|
||||
weights_depth_size, weights_batch_size};
|
||||
RuntimeShape nhwc_shape({weights_batch_size, weights_height_size,
|
||||
weights_width_size, weights_depth_size});
|
||||
RuntimeShape hwcn_shape({weights_height_size, weights_width_size,
|
||||
weights_depth_size, weights_batch_size});
|
||||
std::vector<uint8_t> hwcn(NumElements(&weights_tensor));
|
||||
TransposeParams transpose_params;
|
||||
transpose_params.perm_count = 4;
|
||||
transpose_params.perm[0] = 1;
|
||||
transpose_params.perm[1] = 2;
|
||||
transpose_params.perm[2] = 3;
|
||||
transpose_params.perm[3] = 0;
|
||||
optimized_ops::Transpose<uint8_t>(transpose_params, nhwc_shape,
|
||||
weights_tensor.data.uint8, hwcn_shape,
|
||||
hwcn.data());
|
||||
const_weights_node = graph_builder_->AddConstNodeWithData(
|
||||
weight_shape_.data(), (char*)hwcn.data(),
|
||||
hwcn.size() * sizeof(hwcn[0]));
|
||||
} else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) {
|
||||
// Hexagon treats depthwise conv like tf.nn.depthwise_conv2d, where the
|
||||
// expected filter shape is [fh,fw,din,dmul].
|
||||
// The data itself will remain the same, since TFLite's representation is
|
||||
// just a 'flattening' of Hexagon's version.
|
||||
const int channel_multiplier = weights_depth_size / input_depth_size;
|
||||
weight_shape_ = {weights_height_size, weights_width_size, input_depth_size,
|
||||
channel_multiplier};
|
||||
const_weights_node = graph_builder_->AddConstNodeWithData(
|
||||
weight_shape_.data(), weights_tensor.data.raw,
|
||||
NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0]));
|
||||
}
|
||||
// Quantization params for Weights tensor.
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ComputeMinAndMaxQuantValues(weights_tensor, &weights_min_, &weights_max_,
|
||||
@ -201,8 +225,6 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
quant_bound_shape.data(), (char*)&weights_min_, sizeof(weights_min_));
|
||||
auto* weights_max_const = graph_builder_->AddConstNodeWithData(
|
||||
quant_bound_shape.data(), (char*)&weights_max_, sizeof(weights_max_));
|
||||
auto* const_weights_node = graph_builder_->AddConstNodeWithData(
|
||||
weight_shape_.data(), (char*)hwcn.data(), hwcn.size() * sizeof(hwcn[0]));
|
||||
graph_builder_->AddTensorWithID(inputs->data[1], const_weights_node->GetID(),
|
||||
0);
|
||||
|
||||
@ -256,6 +278,18 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
|
||||
auto* bias_max_const = graph_builder_->AddConstNodeWithData(
|
||||
quant_bound_shape.data(), (char*)&bias_max_, sizeof(bias_max_));
|
||||
|
||||
// TODO(b/143759564): Simplify this method when depth_multiplier support needs
|
||||
// generalizing.
|
||||
if (channel_multiplier > 1 && input_depth_size == 1) {
|
||||
// Depthwise Conv with input_depth == 1 & channel_multiplier > 1 is
|
||||
// equivalent to Conv.
|
||||
SetOpType(OP_Supernode_8x8p32to8);
|
||||
} else if (channel_multiplier > 1) {
|
||||
TF_LITE_KERNEL_LOG(
|
||||
context, "depth_multiplier > 1 not supported with input_depth > 1");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
TensorID output, output_min, output_max;
|
||||
if (is_dilated_depthwise_conv) {
|
||||
// For dilated Depthwise Conv, we convert this node into SpaceToBatchND, and
|
||||
|
@ -256,4 +256,86 @@ TEST(QuantizedConvolutionOpModel, SimpleConvTestReLU6Activation) {
|
||||
1e-5)));
|
||||
}
|
||||
|
||||
// Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a
|
||||
// Conv op.
|
||||
TEST(QuantizedConvolutionOpModel, DepthwiseConvWithMultiplier_InputDepth1) {
|
||||
QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
{TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64},
|
||||
{TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -127, 128},
|
||||
Padding_VALID);
|
||||
// clang-format off
|
||||
m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
m.SetFilter({1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
6, 7, 8, 9, 10,
|
||||
6, 7, 8, 9, 10,
|
||||
6, 7, 8, 9, 10,
|
||||
6, 7, 8, 9, 10,
|
||||
6, 7, 8, 9, 10,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5});
|
||||
// clang-format on
|
||||
m.SetBias({1, 2, 3});
|
||||
|
||||
// Reference output.
|
||||
m.Invoke();
|
||||
auto reference_output = m.GetDequantizedOutput();
|
||||
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(reference_output, 1e-5)));
|
||||
}
|
||||
|
||||
// Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a
|
||||
// Conv op.
|
||||
TEST(QuantizedConvolutionOpModel,
|
||||
DepthwiseConvWithMultiplier_InputDepth1_RELU) {
|
||||
QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D,
|
||||
{TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64},
|
||||
{TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64},
|
||||
{TensorType_UINT8, {}, -127, 128},
|
||||
Padding_VALID, /**dilation_factor**/ 1,
|
||||
/**stride**/ 2, ActivationFunctionType_RELU6);
|
||||
// clang-format off
|
||||
m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 1, 1, 1, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
m.SetFilter({1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
6, 7, 8, 9, 10,
|
||||
6, 7, 8, 9, 10,
|
||||
6, 7, 8, 9, 10,
|
||||
6, 7, 8, 9, 10,
|
||||
6, 7, 8, 9, 10,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5,
|
||||
1, 2, 3, 4, 5});
|
||||
// clang-format on
|
||||
m.SetBias({1, 2, 3});
|
||||
|
||||
// Reference output.
|
||||
m.Invoke();
|
||||
auto reference_output = m.GetDequantizedOutput();
|
||||
|
||||
m.ApplyDelegateAndInvoke();
|
||||
EXPECT_THAT(m.GetDequantizedOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(reference_output, 1e-5)));
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -188,6 +188,8 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
||||
if (!InputsWithCorrectTypes(node, context,
|
||||
{kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32}))
|
||||
return false;
|
||||
|
||||
// Check dilation.
|
||||
const TfLiteDepthwiseConvParams* conv_params =
|
||||
reinterpret_cast<const TfLiteDepthwiseConvParams*>(
|
||||
node->builtin_data);
|
||||
@ -198,10 +200,19 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
|
||||
if (conv_params->stride_height != 1 || conv_params->stride_width != 1)
|
||||
return false;
|
||||
}
|
||||
|
||||
// We currently only support depth_multiplier > 1 when:
|
||||
// 1. dilation_factor == 1 AND
|
||||
// 2. input_depth == 1
|
||||
// TODO(b/143759564): Add support for general case.
|
||||
const auto& input = context->tensors[node->inputs->data[0]];
|
||||
const bool supported_depth_multiplier =
|
||||
conv_params->depth_multiplier == 1 ||
|
||||
(!dilation && input.dims->size == 4 && input.dims->data[3] == 1);
|
||||
|
||||
return (IsActivationReluOrNone(conv_params->activation) &&
|
||||
conv_params->stride_height <= 3 &&
|
||||
conv_params->stride_width <= 3 &&
|
||||
conv_params->depth_multiplier == 1);
|
||||
conv_params->stride_width <= 3 && supported_depth_multiplier);
|
||||
}
|
||||
case kTfLiteBuiltinReshape: {
|
||||
if (node->inputs->size > 2 ||
|
||||
|
Loading…
Reference in New Issue
Block a user