Add prelu op for micro

PiperOrigin-RevId: 259473219
This commit is contained in:
A. Unique TensorFlower 2019-07-22 22:46:03 -07:00 committed by TensorFlower Gardener
parent 1de23834be
commit 8281648f9c
8 changed files with 416 additions and 47 deletions

View File

@ -19,6 +19,7 @@ cc_library(
"elementwise.cc",
"fully_connected.cc",
"pooling.cc",
"prelu.cc",
"softmax.cc",
],
hdrs = [
@ -59,6 +60,7 @@ cc_library(
"fully_connected.cc",
"pooling.cc",
"portable_optimized/depthwise_conv.cc",
"prelu.cc",
"softmax.cc",
],
hdrs = [
@ -179,3 +181,16 @@ tflite_micro_cc_test(
"//tensorflow/lite/experimental/micro/testing:micro_test",
],
)
tflite_micro_cc_test(
name = "prelu_test",
srcs = [
"prelu_test.cc",
],
deps = [
":all_ops_resolver",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/testing:micro_test",
],
)

View File

@ -23,6 +23,7 @@ TfLiteRegistration* Register_CONV_2D();
TfLiteRegistration* Register_AVERAGE_POOL_2D();
TfLiteRegistration* Register_MAX_POOL_2D();
TfLiteRegistration* Register_ABS();
TfLiteRegistration* Register_PRELU();
AllOpsResolver::AllOpsResolver() {
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
@ -34,6 +35,7 @@ AllOpsResolver::AllOpsResolver() {
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D());
AddBuiltin(BuiltinOperator_ABS, Register_ABS());
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
}
} // namespace micro

View File

@ -0,0 +1,114 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/prelu.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace micro {
namespace activations {
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
inline void BroadcastPrelu4DSlowFloat(
const RuntimeShape& unextended_input1_shape, const float* input1_data,
const RuntimeShape& unextended_input2_shape, const float* input2_data,
const RuntimeShape& unextended_output_shape, float* output_data) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
for (int x = 0; x < output_shape.Dims(2); ++x) {
for (int c = 0; c < output_shape.Dims(3); ++c) {
auto out_idx = Offset(output_shape, b, y, x, c);
auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
auto in1_val = input1_data[in1_idx];
auto in2_val = input2_data[in2_idx];
output_data[out_idx] = in1_val >= 0.0 ? in1_val : in1_val * in2_val;
}
}
}
}
}
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
TfLiteTensor* output = GetOutput(context, node, 0);
int32_t output_multiplier = 0;
int output_shift = 0;
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
double real_multiplier =
input->params.scale * alpha->params.scale / output->params.scale;
QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
&output_shift);
}
switch (input->type) {
case kTfLiteFloat32: {
BroadcastPrelu4DSlowFloat(
GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(alpha), GetTensorData<float>(alpha),
GetTensorShape(output), GetTensorData<float>(output));
return kTfLiteOk;
} break;
case kTfLiteUInt8: {
PreluParams op_params;
op_params.input_offset = -input->params.zero_point;
op_params.alpha_offset = -alpha->params.zero_point;
op_params.output_offset = output->params.zero_point;
op_params.output_multiplier = output_multiplier;
op_params.output_shift = output_shift;
reference_ops::BroadcastPrelu4DSlow(
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(alpha), GetTensorData<uint8_t>(alpha),
GetTensorShape(output), GetTensorData<uint8_t>(output));
return kTfLiteOk;
} break;
default:
context->ReportError(
context, "Only float32 and uint8 are supported currently, got %d.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
} // namespace activations
TfLiteRegistration* Register_PRELU() {
static TfLiteRegistration r = {nullptr, nullptr, activations::PreluPrepare,
activations::PreluEval};
return &r;
}
} // namespace micro
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,204 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
namespace tflite {
namespace testing {
namespace {
void TestPreluFloat(std::initializer_list<int> input_dims_data,
std::initializer_list<float> input_data,
std::initializer_list<int> alpha_dims_data,
std::initializer_list<float> alpha_data,
std::initializer_list<float> expected_output_data,
std::initializer_list<int> output_dims_data,
float* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* alpha_dims = IntArrayFromInitializer(alpha_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
const int output_dims_count = ElementCount(*output_dims);
constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateFloatTensor(input_data, input_dims, "input_tensor"),
CreateFloatTensor(alpha_data, alpha_dims, "alpha_tensor"),
CreateFloatTensor(output_data, output_dims, "output_tensor"),
};
TfLiteContext context;
PopulateContext(tensors, tensors_size, &context);
::tflite::ops::micro::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_PRELU, 1);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
size_t init_data_size = 0;
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, init_data_size);
}
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
1e-5f);
}
}
void TestPreluQuantized(std::initializer_list<int> input_dims_data,
std::initializer_list<uint8_t> input_data,
float input_min, float input_max,
std::initializer_list<int> alpha_dims_data,
std::initializer_list<uint8_t> alpha_data,
float alpha_min, float alpha_max,
std::initializer_list<uint8_t> expected_output_data,
std::initializer_list<int> output_dims_data,
float output_min, float output_max,
uint8_t* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* alpha_dims = IntArrayFromInitializer(alpha_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
const int output_dims_count = ElementCount(*output_dims);
constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
TfLiteTensor tensors[tensors_size] = {
CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min,
input_max),
CreateQuantizedTensor(alpha_data, alpha_dims, "alpha_tensor", alpha_min,
alpha_max),
CreateQuantizedTensor(output_data, output_dims, "output_tensor",
output_min, output_max),
};
TfLiteContext context;
PopulateContext(tensors, tensors_size, &context);
::tflite::ops::micro::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_PRELU, 1);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
size_t init_data_size = 0;
void* user_data = nullptr;
if (registration->init) {
user_data = registration->init(&context, nullptr, init_data_size);
}
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
if (registration->free) {
registration->free(&context, user_data);
}
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
}
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(FloatPreluActivationsOpTest) {
const int output_dims_count = 12;
float output_data[output_dims_count];
tflite::testing::TestPreluFloat({1, 2, 2, 3}, // input shape
{
0.0f, 0.0f, 0.0f, // Row 1, Column 1
1.0f, 1.0f, 1.0f, // Row 1, Column 2
-1.0f, -1.0f, -1.0f, // Row 2, Column 1
-2.0f, -2.0f, -2.0f, // Row 1, Column 2
},
{1, 1, 3}, // alpha shape
{0.0f, 1.0f, 2.0f}, // alpha values
{
0.0f, 0.0f, 0.0f, // Row 1, Column 1
1.0f, 1.0f, 1.0f, // Row 1, Column 2
0.0f, -1.0f, -2.0f, // Row 2, Column 1
0.0f, -2.0f, -4.0f, // Row 1, Column 2
},
{1, 2, 2, 3}, // output shape
output_data);
}
TF_LITE_MICRO_TEST(QuantizedPreluActivationsOpTest) {
using tflite::testing::F2Q;
const float kMin = -1;
const float kMax = 127.f / 128.f;
const float kAlphaMin = -0.5f;
const float kAlphaMax = 0.5f;
const int output_dims_count = 12;
uint8_t output_data[output_dims_count];
tflite::testing::TestPreluQuantized(
{1, 2, 2, 3}, // input shape
{F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax),
F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax),
F2Q(-1.0f, kMin, kMax), F2Q(-1.0f, kMin, kMax), F2Q(-1.0f, kMin, kMax),
F2Q(-0.25f, kMin, kMax), F2Q(-0.25f, kMin, kMax),
F2Q(-0.25f, kMin, kMax)},
kMin, kMax, {1, 1, 3}, // alpha shape
{F2Q(0.0f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(-0.5f, kMin, kMax)},
kMin, kMax,
{F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax),
F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax),
F2Q(0.0f, kMin, kMax), F2Q(-0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax),
F2Q(0.0f, kMin, kMax), F2Q(-0.125f, kMin, kMax),
F2Q(0.125f, kMin, kMax)},
{1, 2, 2, 3}, // output shape
kMin, kMax, output_data);
}
TF_LITE_MICRO_TESTS_END

View File

@ -112,6 +112,7 @@ tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h \
tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h \
tensorflow/lite/kernels/internal/reference/fully_connected.h \
tensorflow/lite/kernels/internal/reference/pooling.h \
tensorflow/lite/kernels/internal/reference/prelu.h \
tensorflow/lite/kernels/internal/reference/softmax.h \
tensorflow/lite/kernels/internal/round.h \
tensorflow/lite/kernels/internal/tensor_ctypes.h \

View File

@ -365,6 +365,7 @@ cc_library(
"reference/integer_ops/softmax.h",
"reference/integer_ops/tanh.h",
"reference/pooling.h",
"reference/prelu.h",
"reference/reference_ops.h",
"reference/softmax.h",
"reference/strided_slice.h",
@ -405,6 +406,7 @@ cc_library(
"reference/fully_connected.h",
"reference/legacy_reference_ops.h",
"reference/pooling.h",
"reference/prelu.h",
"reference/reference_ops.h",
"reference/softmax.h",
"reference/strided_slice.h",

View File

@ -0,0 +1,77 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
// Broadcast prelu to output_shape for quantized uint8 data.
inline void BroadcastPrelu4DSlow(const PreluParams& params,
const RuntimeShape& input_shape,
const uint8* input_data,
const RuntimeShape& alpha_shape,
const uint8* alpha_data,
const RuntimeShape& output_shape,
uint8* output_data) {
TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(alpha_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input_shape, alpha_shape, &desc1, &desc2);
for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
int output_index = Offset(extended_output_shape, b, y, x, c);
int input_index = SubscriptToIndex(desc1, b, y, x, c);
const int32 input_value =
params.input_offset + input_data[input_index];
if (input_value >= 0) {
output_data[output_index] = input_data[input_index];
} else {
auto alpha_index = SubscriptToIndex(desc2, b, y, x, c);
const int32 alpha_value =
params.alpha_offset + alpha_data[alpha_index];
const int32 unclamped_output =
params.output_offset +
MultiplyByQuantizedMultiplierSmallerThanOneExp(
input_value * alpha_value, params.output_multiplier,
params.output_shift);
const int32 quantized_min = std::numeric_limits<uint8_t>::min();
const int32 quantized_max = std::numeric_limits<uint8_t>::max();
const int32 clamped_output = std::min(
quantized_max, std::max(quantized_min, unclamped_output));
output_data[output_index] = static_cast<uint8>(clamped_output);
}
}
}
}
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
#include "tensorflow/lite/kernels/internal/reference/pooling.h"
#include "tensorflow/lite/kernels/internal/reference/prelu.h"
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
#include "tensorflow/lite/kernels/internal/round.h"
@ -4403,53 +4404,6 @@ inline void ResizeNearestNeighbor(
}
}
inline void BroadcastPrelu4DSlow(const PreluParams& params,
const RuntimeShape& input_shape,
const uint8* input_data,
const RuntimeShape& alpha_shape,
const uint8* alpha_data,
const RuntimeShape& output_shape,
uint8* output_data) {
TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(alpha_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input_shape, alpha_shape, &desc1, &desc2);
for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
int output_index = Offset(extended_output_shape, b, y, x, c);
int input_index = SubscriptToIndex(desc1, b, y, x, c);
const int32 input_value =
params.input_offset + input_data[input_index];
if (input_value >= 0) {
output_data[output_index] = input_data[input_index];
} else {
auto alpha_index = SubscriptToIndex(desc2, b, y, x, c);
const int32 alpha_value =
params.alpha_offset + alpha_data[alpha_index];
const int32 unclamped_output =
params.output_offset +
MultiplyByQuantizedMultiplierSmallerThanOneExp(
input_value * alpha_value, params.output_multiplier,
params.output_shift);
const int32 quantized_min = std::numeric_limits<uint8_t>::min();
const int32 quantized_max = std::numeric_limits<uint8_t>::max();
const int32 clamped_output = std::min(
quantized_max, std::max(quantized_min, unclamped_output));
output_data[output_index] = static_cast<uint8>(clamped_output);
}
}
}
}
}
}
template <typename T>
void Fill(const RuntimeShape& value_shape, const T* value_data,
const RuntimeShape& output_shape, T* output_data) {