Port the depthwise conv kernel to the new TfLiteEvalTensor API.

PiperOrigin-RevId: 322860697
Change-Id: I7858a70fd0b756ae269c3df63312f09aa3643d5d
This commit is contained in:
Nick Kreeger 2020-07-23 14:00:26 -07:00 committed by TensorFlower Gardener
parent 2587c2a1f2
commit f300cac524
3 changed files with 60 additions and 56 deletions

View File

@ -158,6 +158,7 @@ tflite_micro_cc_test(
"depthwise_conv_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/micro:op_resolvers",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
@ -165,8 +166,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
void EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, const OpData& data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias, TfLiteEvalTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
@ -185,17 +186,22 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
op_params.float_activation_max = output_activation_max;
tflite::reference_ops::DepthwiseConv(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter),
GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output));
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
}
void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params,
const OpData& data, const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
const OpData& data, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
DepthwiseParams op_params;
op_params.padding_type = PaddingType::kSame;
op_params.padding_values.width = data.padding.width;
@ -214,17 +220,21 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
reference_integer_ops::DepthwiseConvPerChannel(
op_params, data.per_channel_output_multiplier,
data.per_channel_output_shift, GetTensorShape(input),
GetTensorData<int8>(input), GetTensorShape(filter),
GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int8>(output));
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8>(output));
}
void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
TfLiteDepthwiseConvParams* params, const OpData& data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
const int32_t input_offset = -data.input_zero_point;
const int32_t filter_offset = -data.filter_zero_point;
const int32_t output_offset = data.output_zero_point;
@ -249,10 +259,14 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
op_params.output_shift = -data.output_shift;
tflite::reference_ops::DepthwiseConv(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(filter), GetTensorData<uint8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<uint8_t>(output));
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<uint8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@ -263,11 +277,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
const TfLiteTensor* bias =
(NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kFilterTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
? tflite::micro::GetEvalInput(context, node, kBiasTensor)
: nullptr;
// TODO(aselle): Consider whether float conv and quantized conv should be
// separate ops to avoid dispatch overhead here.

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/micro/testing/test_utils.h"
@ -42,13 +42,16 @@ TfLiteStatus ValidateDepthwiseConvGoldens(
const T* expected_output_data, int output_length,
TfLiteDepthwiseConvParams* conv_params, float tolerance, int tensors_size,
TfLiteTensor* tensors) {
TfLiteContext context;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
int inputs_array_data[] = {3, 0, 1, 2};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
::tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
const TfLiteRegistration registration =
tflite::ops::micro::Register_DEPTHWISE_CONV_2D();
micro::KernelRunner runner(
registration, tensors, tensors_size, inputs_array, outputs_array,
reinterpret_cast<void*>(conv_params), micro_test::reporter);
int input_depth = tensors[0].dims->data[3];
int output_depth = tensors[1].dims->data[3];
@ -60,32 +63,13 @@ TfLiteStatus ValidateDepthwiseConvGoldens(
conv_params->depth_multiplier = depth_mul;
const char* init_data = reinterpret_cast<const char*>(conv_params);
size_t init_data_size = 0;
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, init_data, init_data_size);
}
int inputs_array_data[] = {3, 0, 1, 2};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(conv_params);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
if (registration->prepare) {
TF_LITE_ENSURE_OK(context, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_ENSURE_OK(context, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
// TODO(b/154240825): Use a test macro here which fails and returns.
TfLiteStatus status = runner.InitAndPrepare(init_data);
if (status != kTfLiteOk) {
return status;
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
const T* output_data = tflite::GetTensorData<T>(&tensors[kOutputTensorIndex]);
for (int i = 0; i < output_length; ++i) {