Add generic fallback optimized implementations for dilated DepthwiseConv.

PiperOrigin-RevId: 213350122
This commit is contained in:
Suharsh Sivakumar 2018-09-17 15:32:12 -07:00 committed by TensorFlower Gardener
parent aec9a70770
commit 3365cd1cc7
10 changed files with 281 additions and 145 deletions

View File

@ -509,6 +509,7 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)

View File

@ -184,17 +184,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
const Dims<4>&, const float*, const Dims<4>&, int, int,
int, int, int, int, int, float, float, float*,
const Dims<4>&);
KernelType effective_kernel_type;
// TODO(suharshs): Currently only the reference implementation supports
// dilations.
if ((params->dilation_width_factor != 1) ||
(params->dilation_height_factor != 1)) {
effective_kernel_type = kReference;
} else {
effective_kernel_type = kernel_type;
}
if (effective_kernel_type == kReference) {
if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
@ -224,17 +214,7 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int, int, int, int, int, int, int, int32, int32, int,
int32, int32, uint8*, const Dims<4>&);
KernelType effective_kernel_type;
// TODO(suharshs): Currently only the reference implementation supports
// dilations.
if ((params->dilation_width_factor != 1) ||
(params->dilation_height_factor != 1)) {
effective_kernel_type = kReference;
} else {
effective_kernel_type = kernel_type;
}
if (effective_kernel_type == kReference) {
if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;

View File

@ -14,12 +14,24 @@ limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <gtest/gtest.h>
#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
namespace ops {
namespace builtin {
TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF();
TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT();
TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT();
} // namespace builtin
} // namespace ops
namespace {
using ::testing::ElementsAreArray;
@ -28,9 +40,11 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
public:
// TODO(ahentz): Also test different activation types, bias, padding types,
// stride values.
BaseDepthwiseConvolutionOpModel(const TensorData& input,
BaseDepthwiseConvolutionOpModel(TfLiteRegistration* registration,
const TensorData& input,
const TensorData& filter,
const TensorData& output,
Padding padding_type,
int dilation_factor = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@ -56,11 +70,14 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
SetBuiltinOp(
BuiltinOperator_DEPTHWISE_CONV_2D,
BuiltinOptions_DepthwiseConv2DOptions,
CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
CreateDepthwiseConv2DOptions(builder_, padding_type, 1, 1, depth_mul,
ActivationFunctionType_NONE,
dilation_factor, dilation_factor)
.Union());
resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_DEPTHWISE_CONV_2D, registration);
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
}
@ -86,10 +103,25 @@ class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
TEST(DepthwiseConvolutionOpTest, SimpleTest) {
DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
{"Reference", ops::builtin::Register_DEPTHWISE_CONVOLUTION_REF()},
{"GenericOptimized",
ops::builtin::Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT()},
{"NeonOptimized", ops::builtin::Register_DEPTHWISE_CONVOLUTION_NEON_OPT()},
});
class DepthwiseConvolutionOpTest : public SingleOpTest {
protected:
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
return *kKernelMap;
}
};
TEST_P(DepthwiseConvolutionOpTest, SimpleTest) {
DepthwiseConvolutionOpModel m(GetRegistration(),
{TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_FLOAT32, {1, 2, 2, 4}},
{TensorType_FLOAT32, {}});
{TensorType_FLOAT32, {}}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@ -112,7 +144,7 @@ TEST(DepthwiseConvolutionOpTest, SimpleTest) {
}));
}
TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
const int depth = 1;
const int image_width = 9;
const int image_height = 9;
@ -121,10 +153,11 @@ TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
const int filter_count = 1;
const int dilation_factor = 3;
DepthwiseConvolutionOpModel m(
GetRegistration(),
{TensorType_FLOAT32,
{image_batch_count, image_height, image_width, depth}},
{TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
{TensorType_FLOAT32, {}}, dilation_factor);
{TensorType_FLOAT32, {}}, Padding_VALID, dilation_factor);
// The image matrix is:
// | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
@ -164,6 +197,41 @@ TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}
TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
const int depth = 1;
const int image_width = 3;
const int image_height = 3;
const int image_batch_count = 1;
const int filter_size = 2;
const int filter_count = 1;
const int dilation_factor = 2;
DepthwiseConvolutionOpModel m(
GetRegistration(),
{TensorType_FLOAT32,
{image_batch_count, image_height, image_width, depth}},
{TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
{TensorType_FLOAT32, {}}, Padding_SAME, dilation_factor);
// The image matrix is:
// | 1 | 1 | 1 |
// | 1 | 1 | 1 |
// | 1 | 1 | 1 |
m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
// The filter matrix is:
// | 1 | 2 |
// | 3 | 4 |
m.SetFilter({1, 2, 3, 4});
// No bias for this test.
m.SetBias({0});
m.Invoke();
// Output:
// | 4 | 7 | 3 |
// | 6 |10 | 4 |
// | 2 | 3 | 1 |
EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
}
class QuantizedDepthwiseConvolutionOpModel
: public BaseDepthwiseConvolutionOpModel {
public:
@ -188,13 +256,20 @@ class QuantizedDepthwiseConvolutionOpModel
}
};
class QuantizedDepthwiseConvolutionOpTest : public SingleOpTest {
protected:
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
return *kKernelMap;
}
};
// In this test we set the input and output scales so that the results match
// exactly the 'non-quantized' version.
TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
QuantizedDepthwiseConvolutionOpModel m(
{TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64},
{TensorType_UINT8, {}, -127, 128});
{TensorType_UINT8, {}, -127, 128}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@ -224,15 +299,16 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
}));
}
TEST(QuantizedDepthwiseConvolutionOpTest,
SimpleTestQuantizedFilterMultiplierGreaterThan1) {
TEST_P(QuantizedDepthwiseConvolutionOpTest,
SimpleTestQuantizedFilterMultiplierGreaterThan1) {
QuantizedDepthwiseConvolutionOpModel quant_op(
{TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
{TensorType_UINT8, {}, -127, 128});
DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_UINT8, {}, -127, 128}, Padding_VALID);
DepthwiseConvolutionOpModel float_op(GetRegistration(),
{TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_FLOAT32, {1, 2, 2, 4}},
{TensorType_FLOAT32, {}});
{TensorType_FLOAT32, {}}, Padding_VALID);
std::initializer_list<float> input = {
1, 2, 7, 8, // column 1
@ -261,7 +337,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest,
ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
}
TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
const int depth = 1;
const int image_width = 9;
const int image_height = 9;
@ -270,6 +346,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
const int filter_count = 1;
const int dilation_factor = 3;
QuantizedDepthwiseConvolutionOpModel m(
GetRegistration(),
{TensorType_UINT8,
{image_batch_count, image_height, image_width, depth},
0,
@ -278,7 +355,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
{depth, filter_size, filter_size, filter_count},
0,
255},
{TensorType_UINT8, {}, 0, 255}, dilation_factor);
{TensorType_UINT8, {}, 0, 255}, Padding_VALID, dilation_factor);
// The image matrix is:
// | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
@ -319,6 +396,55 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}
TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
const int depth = 1;
const int image_width = 3;
const int image_height = 3;
const int image_batch_count = 1;
const int filter_size = 2;
const int filter_count = 1;
const int dilation_factor = 2;
QuantizedDepthwiseConvolutionOpModel m(
GetRegistration(),
{TensorType_UINT8,
{image_batch_count, image_height, image_width, depth},
0,
255},
{TensorType_UINT8,
{depth, filter_size, filter_size, filter_count},
0,
255},
{TensorType_UINT8, {}, 0, 255}, Padding_SAME, dilation_factor);
// The image matrix is:
// | 1 | 1 | 1 |
// | 1 | 1 | 1 |
// | 1 | 1 | 1 |
m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
// The filter matrix is:
// | 1 | 2 |
// | 3 | 4 |
m.SetFilter({1, 2, 3, 4});
// No bias for this test.
m.SetBias({0});
m.Invoke();
// Output:
// | 4 | 7 | 3 |
// | 6 |10 | 4 |
// | 2 | 3 | 1 |
EXPECT_THAT(m.GetDequantizedOutput(),
ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
}
INSTANTIATE_TEST_CASE_P(
DepthwiseConvolutionOpTest, DepthwiseConvolutionOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
INSTANTIATE_TEST_CASE_P(
QuantizedDepthwiseConvolutionOpTest, QuantizedDepthwiseConvolutionOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
} // namespace
} // namespace tflite

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@ -28,23 +29,29 @@ namespace tflite {
namespace {
// Runs the DepthwiseConv and compares against the reference implementation.
template <FusedActivationFunctionType Ac>
void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
int stride, int pad_width, int pad_height,
int depth_multiplier, const Dims<4>& output_dims) {
int stride, int dilation_width_factor,
int dilation_height_factor, int pad_width,
int pad_height, int depth_multiplier,
float output_activation_min,
float output_activation_max,
const Dims<4>& output_dims) {
const int output_buffer_size = RequiredBufferSizeForDims(output_dims);
std::vector<float> output_data(output_buffer_size);
std::vector<float> reference_output_data(output_buffer_size);
reference_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
filter_dims, bias_data, bias_dims, stride,
pad_width, pad_height, depth_multiplier,
reference_output_data.data(), output_dims);
optimized_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
filter_dims, bias_data, bias_dims, stride,
pad_width, pad_height, depth_multiplier,
output_data.data(), output_dims);
reference_ops::DepthwiseConv(
input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
stride, stride, dilation_width_factor, dilation_height_factor, pad_width,
pad_height, depth_multiplier, output_activation_min,
output_activation_max, reference_output_data.data(), output_dims);
optimized_ops::DepthwiseConv(
input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
stride, stride, dilation_width_factor, dilation_height_factor, pad_width,
pad_height, depth_multiplier, output_activation_min,
output_activation_max, output_data.data(), output_dims);
double sum_abs_diff = 0;
float max_abs_val = 0;
for (int i = 0; i < output_buffer_size; i++) {
@ -59,27 +66,6 @@ void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
void TestOneDepthwiseConv(FusedActivationFunctionType Ac,
const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
int stride, int pad_width, int pad_height,
int depth_multiplier, const Dims<4>& output_dims) {
#define TOCO_HANDLE_CASE(AC_TYPE) \
if (AC_TYPE == Ac) { \
TestOneDepthwiseConv<AC_TYPE>(input_data, input_dims, filter_data, \
filter_dims, bias_data, bias_dims, stride, \
pad_width, pad_height, depth_multiplier, \
output_dims); \
return; \
}
TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1)
TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6)
#undef TOCO_HANDLE_CASE
}
// This function picks some random DepthwiseConv params, which may or may not
// be legal. If they're not legal, it returns false. If they're legal,
// it runs the DepthwiseConv test and returns true. This allows the caller
@ -99,6 +85,16 @@ bool TryTestOneDepthwiseConv() {
const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
const int output_depth = input_depth * depth_multiplier;
const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
float output_activation_min, output_activation_max;
FusedActivationFunctionType ac =
RandomElement(std::vector<FusedActivationFunctionType>(
{FusedActivationFunctionType::kNone,
FusedActivationFunctionType::kRelu,
FusedActivationFunctionType::kRelu1,
FusedActivationFunctionType::kRelu6}));
GetActivationMinMax(ac, &output_activation_min, &output_activation_max);
// The optimized DepthwiseConv implementation currently uses a fixed-size
// accumulator buffer on the stack, with that size. This currently means
// that it does not support larger output depths. It CHECK's for it,
@ -109,10 +105,6 @@ bool TryTestOneDepthwiseConv() {
if (output_depth > kMaxSupportedOutputDepth) {
return false;
}
const auto ac = RandomElement(std::vector<FusedActivationFunctionType>(
{FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu,
FusedActivationFunctionType::kRelu6,
FusedActivationFunctionType::kRelu1}));
Dims<4> input_dims_inference =
MakeDimsForInference(input_depth, input_width, input_height, batch);
Dims<4> output_dims_inference;
@ -120,7 +112,8 @@ bool TryTestOneDepthwiseConv() {
const auto padding_type =
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
filter_height, stride, padding_type,
filter_height, stride, dilation_width_factor,
dilation_height_factor, padding_type,
&output_dims_inference, &pad_width, &pad_height)) {
return false;
}
@ -140,10 +133,12 @@ bool TryTestOneDepthwiseConv() {
FillRandom(&input_data, -input_amplitude, input_amplitude);
FillRandom(&filter_data, -filter_amplitude, filter_amplitude);
FillRandom(&bias_data, -bias_amplitude, bias_amplitude);
TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference,
TestOneDepthwiseConv(input_data.data(), input_dims_inference,
filter_data.data(), filter_dims_inference,
bias_data.data(), bias_dims_inference, stride, pad_width,
pad_height, depth_multiplier, output_dims_inference);
bias_data.data(), bias_dims_inference, stride,
dilation_width_factor, dilation_height_factor, pad_width,
pad_height, depth_multiplier, output_activation_min,
output_activation_max, output_dims_inference);
return true;
}

View File

@ -199,6 +199,7 @@ void TestOneDepthwiseConv(
bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
int input_height, int filter_width, int filter_height,
int depth_multiplier, int stride,
int dilation_width_factor, int dilation_height_factor,
PaddingType padding_type) {
const int output_depth = input_depth * depth_multiplier;
// The optimized DepthwiseConv implementation currently uses a fixed-size
@ -231,7 +232,8 @@ bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
Dims<4> output_dims_inference;
int pad_width, pad_height;
if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
filter_height, stride, padding_type,
filter_height, stride, dilation_width_factor,
dilation_height_factor, padding_type,
&output_dims_inference, &pad_width, &pad_height)) {
return false;
}
@ -274,12 +276,15 @@ bool TryTestOneDepthwiseConv() {
const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10);
const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
const auto padding_type =
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
filter_width, filter_height, depth_multiplier,
stride, padding_type);
stride, dilation_width_factor,
dilation_height_factor, padding_type);
}
// Tests parameters for the 3x3 filter kernel.
@ -292,6 +297,9 @@ bool TryTestOneDepthwiseConv3x3Filter() {
const int filter_height = 3;
const int depth_multiplier = 1;
const int stride = UniformRandomInt(1, 2);
// We don't support dilations in the 3x3 filter.
const int dilation_width_factor = 1;
const int dilation_height_factor = 1;
// Although the kernel supports only kValid padding, we test that kSame
// is using the correct code path.
const auto padding_type =
@ -299,7 +307,8 @@ bool TryTestOneDepthwiseConv3x3Filter() {
return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
filter_width, filter_height, depth_multiplier,
stride, padding_type);
stride, dilation_width_factor,
dilation_height_factor, padding_type);
}
void TestOneDepthwiseConv() {

View File

@ -761,7 +761,8 @@ struct FloatDepthwiseConvKernel<true, 4, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
void FloatDepthwiseConvAccumRow(int stride, int dilation_factor,
int input_depth, int input_width,
const float* input_data, int pad_width,
int depth_multiplier, int filter_width,
const float* filter_data,
@ -835,10 +836,10 @@ void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
inline void FloatDepthwiseConvAccumRowGeneric(
int stride, int input_depth, int input_width, const float* input_data,
int pad_width, int depth_multiplier, int filter_width,
const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
int output_depth, float* acc_buffer) {
int stride, int dilation_factor, int input_depth, int input_width,
const float* input_data, int pad_width, int depth_multiplier,
int filter_width, const float* filter_data, int out_x_buffer_start,
int out_x_buffer_end, int output_depth, float* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@ -860,6 +861,7 @@ inline void FloatDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
<< "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@ -869,14 +871,17 @@ inline void FloatDepthwiseConvAccumRowGeneric(
const float* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
const int out_x_loop_end =
std::min(out_x_buffer_end,
(pad_width + input_width - filter_x + stride - 1) / stride);
out_x_buffer_start,
(pad_width - dilation_factor * filter_x + stride - 1) / stride);
const int out_x_loop_end = std::min(
out_x_buffer_end,
(pad_width + input_width - dilation_factor * filter_x + stride - 1) /
stride);
float* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
const int in_x_origin =
(out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const float* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@ -921,14 +926,14 @@ inline void DepthwiseConv(
const int depth_multiplier = params.depth_multiplier;
const float output_activation_min = params.float_activation_min;
const float output_activation_max = params.float_activation_max;
const int dilation_width_factor = params.dilation_width_factor;
const int dilation_height_factor = params.dilation_height_factor;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
// TODO(suharshs): Optimized implementation of dilation depthwise conv need to
// be implemented.
TFLITE_DCHECK_EQ(params.dilation_width_factor, 1);
TFLITE_DCHECK_EQ(params.dilation_height_factor, 1);
const bool has_dilation = (params.dilation_width_factor != 1) ||
(params.dilation_height_factor != 1);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
@ -961,7 +966,7 @@ inline void DepthwiseConv(
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
depth_multiplier == FIXED_DEPTH_MULTIPLIER && !has_dilation) { \
row_accum_func = \
FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@ -1014,9 +1019,13 @@ inline void DepthwiseConv(
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_start =
std::max(0, (-in_y_origin + dilation_height_factor - 1) /
dilation_height_factor);
const int filter_y_end =
std::min(filter_height, input_height - in_y_origin);
std::min(filter_height,
(input_height - in_y_origin + dilation_height_factor - 1) /
dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@ -1032,9 +1041,9 @@ inline void DepthwiseConv(
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
const int in_y = in_y_origin + filter_y;
const int in_y = in_y_origin + dilation_height_factor * filter_y;
row_accum_func(
stride_width, input_depth, input_width,
stride_width, dilation_width_factor, input_depth, input_width,
input_data + in_y * input_height_stride + b * input_batch_stride,
pad_width, depth_multiplier, filter_width,
filter_data + filter_y * filter_height_stride, out_x_buffer_start,
@ -1096,11 +1105,6 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
float output_activation_min,
float output_activation_max, float* output_data,
const Dims<4>& output_dims) {
// TODO(suharshs): Optimized implementation of dilation depthwise conv need to
// be implemented.
TFLITE_DCHECK_EQ(dilation_width_factor, 1);
TFLITE_DCHECK_EQ(dilation_height_factor, 1);
tflite::DepthwiseParams op_params;
// Padding type is ignored, but still set.
op_params.padding_type = PaddingType::kSame;

View File

@ -1466,11 +1466,14 @@ struct QuantizedDepthwiseConvKernel<false, 12, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
void QuantizedDepthwiseConvAccumRow(
int stride, int input_depth, int input_width, const uint8* input_data,
int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
int out_x_buffer_end, int output_depth, int32* acc_buffer) {
void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor,
int input_depth, int input_width,
const uint8* input_data, int16 input_offset,
int pad_width, int depth_multiplier,
int filter_width, const uint8* filter_data,
int16 filter_offset, int out_x_buffer_start,
int out_x_buffer_end, int output_depth,
int32* acc_buffer) {
#ifdef GEMMLOWP_PROFILING
gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
#endif
@ -1537,10 +1540,11 @@ void QuantizedDepthwiseConvAccumRow(
// generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
inline void QuantizedDepthwiseConvAccumRowGeneric(
int stride, int input_depth, int input_width, const uint8* input_data,
int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
int out_x_buffer_end, int output_depth, int32* acc_buffer) {
int stride, int dilation_factor, int input_depth, int input_width,
const uint8* input_data, int16 input_offset, int pad_width,
int depth_multiplier, int filter_width, const uint8* filter_data,
int16 filter_offset, int out_x_buffer_start, int out_x_buffer_end,
int output_depth, int32* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@ -1562,6 +1566,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
<< "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@ -1571,14 +1576,17 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
const uint8* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
const int out_x_loop_end =
std::min(out_x_buffer_end,
(pad_width + input_width - filter_x + stride - 1) / stride);
out_x_buffer_start,
(pad_width - dilation_factor * filter_x + stride - 1) / stride);
const int out_x_loop_end = std::min(
out_x_buffer_end,
(pad_width + input_width - dilation_factor * filter_x + stride - 1) /
stride);
int32* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
const int in_x_origin =
(out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const uint8* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@ -1688,15 +1696,11 @@ inline void DepthwiseConv(
const int32 output_offset = params.output_offset;
const int32 output_multiplier = params.output_multiplier;
const int output_shift = params.output_shift;
const int dilation_width_factor = params.dilation_width_factor;
const int dilation_height_factor = params.dilation_height_factor;
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
// TODO(suharshs): Optimized implementation of dilation depthwise conv need to
// be implemented.
TFLITE_DCHECK_EQ(params.dilation_width_factor, 1);
TFLITE_DCHECK_EQ(params.dilation_height_factor, 1);
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
const int batches = MatchingDim(input_shape, 0, output_shape, 0);
const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
@ -1714,14 +1718,18 @@ inline void DepthwiseConv(
TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
const bool has_dilation =
(dilation_width_factor != 1) || (dilation_height_factor != 1);
// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
// Jetson TX-2. This compiler does not support the offsetof() macro.
#if defined(__aarch64__) && !defined(GOOGLE_L4T)
// Call kernel optimized for depthwise convolutions using 3x3 filters if
// parameters are supported.
if (Fast3x3FilterKernelSupported(
input_shape, filter_shape, stride_width, stride_height, pad_width,
pad_height, depth_multiplier, output_shape, output_shift)) {
if (Fast3x3FilterKernelSupported(input_shape, filter_shape, stride_width,
stride_height, has_dilation, pad_width,
pad_height, depth_multiplier, output_shape,
output_shift)) {
DepthwiseConv3x3Filter(params, input_shape, input_data, filter_shape,
filter_data, bias_shape, bias_data, output_shape,
output_data);
@ -1748,7 +1756,7 @@ inline void DepthwiseConv(
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
depth_multiplier == FIXED_DEPTH_MULTIPLIER && !has_dilation) { \
row_accum_func = \
QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@ -1808,9 +1816,13 @@ inline void DepthwiseConv(
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_start =
std::max(0, (-in_y_origin + dilation_height_factor - 1) /
dilation_height_factor);
const int filter_y_end =
std::min(filter_height, input_height - in_y_origin);
std::min(filter_height,
(input_height - in_y_origin + dilation_height_factor - 1) /
dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@ -1826,9 +1838,9 @@ inline void DepthwiseConv(
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
const int in_y = in_y_origin + filter_y;
const int in_y = in_y_origin + dilation_height_factor * filter_y;
row_accum_func(
stride_width, input_depth, input_width,
stride_width, dilation_width_factor, input_depth, input_width,
input_data + in_y * input_height_stride + b * input_batch_stride,
input_offset, pad_width, depth_multiplier, filter_width,
filter_data + filter_y * filter_height_stride, filter_offset,

View File

@ -3176,8 +3176,8 @@ inline void DepthwiseConvHandlePadding(const uint8* input_data,
inline bool Fast3x3FilterKernelSupported(
const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
int32 stride_width, int32 stride_height, int32 pad_width, int32 pad_height,
int32 depth_multiplier, const RuntimeShape& output_shape,
int32 stride_width, int32 stride_height, bool has_dilation, int32 pad_width,
int32 pad_height, int32 depth_multiplier, const RuntimeShape& output_shape,
int32 output_shift) {
const int32 input_height = input_shape.Dims(1);
const int32 input_width = input_shape.Dims(2);
@ -3193,7 +3193,7 @@ inline bool Fast3x3FilterKernelSupported(
(stride_height == 1 || stride_height == 2) &&
(stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
(pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
(input_depth % 8) == 0 && (output_shift > 0);
(input_depth % 8) == 0 && (output_shift > 0) && !has_dilation;
if (!supported) {
return false;

View File

@ -43,17 +43,21 @@ Dims<4> MakeDimsForInference(int depth, int width, int height, int batch) {
// this is a copied from an internal function in propagate_fixed_sizes.cc
bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
int filter_height, int stride, PaddingType padding_type,
int filter_height, int stride, int dilation_width_factor,
int dilation_height_factor, PaddingType padding_type,
Dims<4>* output_dims, int* pad_width, int* pad_height) {
const int input_width = ArraySize(input_dims, 1);
const int input_height = ArraySize(input_dims, 2);
const int batch = ArraySize(input_dims, 3);
int dilated_filter_width = dilation_width_factor * (filter_width - 1) + 1;
int dilated_filter_height = dilation_height_factor * (filter_height - 1) + 1;
int output_height = 0;
int output_width = 0;
if (padding_type == PaddingType::kValid) {
output_height = (input_height + stride - filter_height) / stride;
output_width = (input_width + stride - filter_width) / stride;
output_height = (input_height + stride - dilated_filter_height) / stride;
output_width = (input_width + stride - dilated_filter_width) / stride;
} else if (padding_type == PaddingType::kSame) {
output_height = (input_height + stride - 1) / stride;
output_width = (input_width + stride - 1) / stride;
@ -65,9 +69,13 @@ bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
return false;
}
*pad_height =
((output_height - 1) * stride + filter_height - input_height) / 2;
*pad_width = ((output_width - 1) * stride + filter_width - input_width) / 2;
*pad_height = std::max(
0, ((output_height - 1) * stride + dilated_filter_height - input_height) /
2);
*pad_width = std::max(
0,
((output_width - 1) * stride + dilated_filter_width - input_width) / 2);
*output_dims =
MakeDimsForInference(output_depth, output_width, output_height, batch);
return true;

View File

@ -31,7 +31,8 @@ Dims<4> MakeDimsForInference(int depth, int width, int height, int batch);
// Computes output and padding dimensions.
bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
int filter_height, int stride, PaddingType padding_type,
int filter_height, int stride, int dilation_width_factor,
int dilation_height_factor, PaddingType padding_type,
Dims<4>* output_dims, int* pad_width, int* pad_height);
// Returns a mt19937 random engine.