Adds support for depth_multiplier case when DepthwiseConv resolves to Conv

PiperOrigin-RevId: 299856795
Change-Id: I53195af9fbffc6215c58feb4ae3cfb654f6a03d2
This commit is contained in:
Sachin Joglekar 2020-03-09 09:15:46 -07:00 committed by TensorFlower Gardener
parent 1d27131082
commit ac26a80b0f
3 changed files with 147 additions and 20 deletions

View File

@ -125,6 +125,9 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
// Input data tensor.
const auto& data_tensor = context->tensors[inputs->data[0]];
int input_batch_size, input_height_size, input_width_size, input_depth_size;
GetDims(&input_batch_size, &input_height_size, &input_width_size,
&input_depth_size, data_tensor.dims);
TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
data_tensor, &data_min_, &data_max_, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max()));
@ -139,6 +142,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
int stride_height = 0;
int stride_width = 0;
bool is_dilated_depthwise_conv = false;
int channel_multiplier = 1;
if (op_node_.op_type == OP_Supernode_8x8p32to8) {
const TfLiteConvParams* conv_params =
reinterpret_cast<const TfLiteConvParams*>(builtin_data_);
@ -153,6 +157,7 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
stride_width = conv_params->stride_width;
padding_type = conv_params->padding;
activation = conv_params->activation;
channel_multiplier = conv_params->depth_multiplier;
// We only support dilation for DepthwiseConv.
if (conv_params->dilation_height_factor > 1 ||
conv_params->dilation_width_factor > 1) {
@ -176,22 +181,41 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
// Transpose NHWC -> HWCN
GetDims(&weights_batch_size, &weights_height_size, &weights_width_size,
&weights_depth_size, weights_tensor.dims);
weight_shape_ = {weights_height_size, weights_width_size, weights_depth_size,
weights_batch_size};
RuntimeShape nhwc_shape({weights_batch_size, weights_height_size,
weights_width_size, weights_depth_size});
RuntimeShape hwcn_shape({weights_height_size, weights_width_size,
weights_depth_size, weights_batch_size});
std::vector<uint8_t> hwcn(NumElements(&weights_tensor));
TransposeParams transpose_params;
transpose_params.perm_count = 4;
transpose_params.perm[0] = 1;
transpose_params.perm[1] = 2;
transpose_params.perm[2] = 3;
transpose_params.perm[3] = 0;
optimized_ops::Transpose<uint8_t>(transpose_params, nhwc_shape,
weights_tensor.data.uint8, hwcn_shape,
hwcn.data());
OpBuilder* const_weights_node = nullptr;
if (op_node_.op_type == OP_Supernode_8x8p32to8) {
// Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC.
// Transpose NHWC -> HWCN
weight_shape_ = {weights_height_size, weights_width_size,
weights_depth_size, weights_batch_size};
RuntimeShape nhwc_shape({weights_batch_size, weights_height_size,
weights_width_size, weights_depth_size});
RuntimeShape hwcn_shape({weights_height_size, weights_width_size,
weights_depth_size, weights_batch_size});
std::vector<uint8_t> hwcn(NumElements(&weights_tensor));
TransposeParams transpose_params;
transpose_params.perm_count = 4;
transpose_params.perm[0] = 1;
transpose_params.perm[1] = 2;
transpose_params.perm[2] = 3;
transpose_params.perm[3] = 0;
optimized_ops::Transpose<uint8_t>(transpose_params, nhwc_shape,
weights_tensor.data.uint8, hwcn_shape,
hwcn.data());
const_weights_node = graph_builder_->AddConstNodeWithData(
weight_shape_.data(), (char*)hwcn.data(),
hwcn.size() * sizeof(hwcn[0]));
} else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) {
// Hexagon treats depthwise conv like tf.nn.depthwise_conv2d, where the
// expected filter shape is [fh,fw,din,dmul].
// The data itself will remain the same, since TFLite's representation is
// just a 'flattening' of Hexagon's version.
const int channel_multiplier = weights_depth_size / input_depth_size;
weight_shape_ = {weights_height_size, weights_width_size, input_depth_size,
channel_multiplier};
const_weights_node = graph_builder_->AddConstNodeWithData(
weight_shape_.data(), weights_tensor.data.raw,
NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0]));
}
// Quantization params for Weights tensor.
TF_LITE_ENSURE_STATUS(
ComputeMinAndMaxQuantValues(weights_tensor, &weights_min_, &weights_max_,
@ -201,8 +225,6 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
quant_bound_shape.data(), (char*)&weights_min_, sizeof(weights_min_));
auto* weights_max_const = graph_builder_->AddConstNodeWithData(
quant_bound_shape.data(), (char*)&weights_max_, sizeof(weights_max_));
auto* const_weights_node = graph_builder_->AddConstNodeWithData(
weight_shape_.data(), (char*)hwcn.data(), hwcn.size() * sizeof(hwcn[0]));
graph_builder_->AddTensorWithID(inputs->data[1], const_weights_node->GetID(),
0);
@ -256,6 +278,18 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
auto* bias_max_const = graph_builder_->AddConstNodeWithData(
quant_bound_shape.data(), (char*)&bias_max_, sizeof(bias_max_));
// TODO(b/143759564): Simplify this method when depth_multiplier support needs
// generalizing.
if (channel_multiplier > 1 && input_depth_size == 1) {
// Depthwise Conv with input_depth == 1 & channel_multiplier > 1 is
// equivalent to Conv.
SetOpType(OP_Supernode_8x8p32to8);
} else if (channel_multiplier > 1) {
TF_LITE_KERNEL_LOG(
context, "depth_multiplier > 1 not supported with input_depth > 1");
return kTfLiteError;
}
TensorID output, output_min, output_max;
if (is_dilated_depthwise_conv) {
// For dilated Depthwise Conv, we convert this node into SpaceToBatchND, and

View File

@ -256,4 +256,86 @@ TEST(QuantizedConvolutionOpModel, SimpleConvTestReLU6Activation) {
1e-5)));
}
// Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a
// Conv op.
TEST(QuantizedConvolutionOpModel, DepthwiseConvWithMultiplier_InputDepth1) {
QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D,
{TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64},
{TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64},
{TensorType_UINT8, {}, -127, 128},
Padding_VALID);
// clang-format off
m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0});
m.SetFilter({1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
6, 7, 8, 9, 10,
6, 7, 8, 9, 10,
6, 7, 8, 9, 10,
6, 7, 8, 9, 10,
6, 7, 8, 9, 10,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5});
// clang-format on
m.SetBias({1, 2, 3});
// Reference output.
m.Invoke();
auto reference_output = m.GetDequantizedOutput();
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(reference_output, 1e-5)));
}
// Depthwise Conv with multiplier > 1 but input depth==1 should resolve into a
// Conv op.
TEST(QuantizedConvolutionOpModel,
DepthwiseConvWithMultiplier_InputDepth1_RELU) {
QuantizedConvolutionOpModel m(BuiltinOperator_DEPTHWISE_CONV_2D,
{TensorType_UINT8, {1, 6, 6, 1}, -63.5, 64},
{TensorType_UINT8, {1, 5, 5, 3}, -63.5, 64},
{TensorType_UINT8, {}, -127, 128},
Padding_VALID, /**dilation_factor**/ 1,
/**stride**/ 2, ActivationFunctionType_RELU6);
// clang-format off
m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0});
m.SetFilter({1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
6, 7, 8, 9, 10,
6, 7, 8, 9, 10,
6, 7, 8, 9, 10,
6, 7, 8, 9, 10,
6, 7, 8, 9, 10,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5,
1, 2, 3, 4, 5});
// clang-format on
m.SetBias({1, 2, 3});
// Reference output.
m.Invoke();
auto reference_output = m.GetDequantizedOutput();
m.ApplyDelegateAndInvoke();
EXPECT_THAT(m.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(reference_output, 1e-5)));
}
} // namespace tflite

View File

@ -188,6 +188,8 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
if (!InputsWithCorrectTypes(node, context,
{kTfLiteUInt8, kTfLiteUInt8, kTfLiteInt32}))
return false;
// Check dilation.
const TfLiteDepthwiseConvParams* conv_params =
reinterpret_cast<const TfLiteDepthwiseConvParams*>(
node->builtin_data);
@ -198,10 +200,19 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration,
if (conv_params->stride_height != 1 || conv_params->stride_width != 1)
return false;
}
// We currently only support depth_multiplier > 1 when:
// 1. dilation_factor == 1 AND
// 2. input_depth == 1
// TODO(b/143759564): Add support for general case.
const auto& input = context->tensors[node->inputs->data[0]];
const bool supported_depth_multiplier =
conv_params->depth_multiplier == 1 ||
(!dilation && input.dims->size == 4 && input.dims->data[3] == 1);
return (IsActivationReluOrNone(conv_params->activation) &&
conv_params->stride_height <= 3 &&
conv_params->stride_width <= 3 &&
conv_params->depth_multiplier == 1);
conv_params->stride_width <= 3 && supported_depth_multiplier);
}
case kTfLiteBuiltinReshape: {
if (node->inputs->size > 2 ||