Add prelu op for micro
PiperOrigin-RevId: 259473219
This commit is contained in:
parent
1de23834be
commit
8281648f9c
@ -19,6 +19,7 @@ cc_library(
|
||||
"elementwise.cc",
|
||||
"fully_connected.cc",
|
||||
"pooling.cc",
|
||||
"prelu.cc",
|
||||
"softmax.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -59,6 +60,7 @@ cc_library(
|
||||
"fully_connected.cc",
|
||||
"pooling.cc",
|
||||
"portable_optimized/depthwise_conv.cc",
|
||||
"prelu.cc",
|
||||
"softmax.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -179,3 +181,16 @@ tflite_micro_cc_test(
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
||||
tflite_micro_cc_test(
|
||||
name = "prelu_test",
|
||||
srcs = [
|
||||
"prelu_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":all_ops_resolver",
|
||||
"//tensorflow/lite/c:c_api_internal",
|
||||
"//tensorflow/lite/experimental/micro:micro_framework",
|
||||
"//tensorflow/lite/experimental/micro/testing:micro_test",
|
||||
],
|
||||
)
|
||||
|
@ -23,6 +23,7 @@ TfLiteRegistration* Register_CONV_2D();
|
||||
TfLiteRegistration* Register_AVERAGE_POOL_2D();
|
||||
TfLiteRegistration* Register_MAX_POOL_2D();
|
||||
TfLiteRegistration* Register_ABS();
|
||||
TfLiteRegistration* Register_PRELU();
|
||||
|
||||
AllOpsResolver::AllOpsResolver() {
|
||||
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
|
||||
@ -34,6 +35,7 @@ AllOpsResolver::AllOpsResolver() {
|
||||
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
|
||||
AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D());
|
||||
AddBuiltin(BuiltinOperator_ABS, Register_ABS());
|
||||
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
|
114
tensorflow/lite/experimental/micro/kernels/prelu.cc
Normal file
114
tensorflow/lite/experimental/micro/kernels/prelu.cc
Normal file
@ -0,0 +1,114 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/reference/prelu.h"
|
||||
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace micro {
|
||||
namespace activations {
|
||||
|
||||
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
inline void BroadcastPrelu4DSlowFloat(
|
||||
const RuntimeShape& unextended_input1_shape, const float* input1_data,
|
||||
const RuntimeShape& unextended_input2_shape, const float* input2_data,
|
||||
const RuntimeShape& unextended_output_shape, float* output_data) {
|
||||
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape output_shape =
|
||||
RuntimeShape::ExtendedShape(4, unextended_output_shape);
|
||||
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
|
||||
unextended_input2_shape, &desc1, &desc2);
|
||||
|
||||
for (int b = 0; b < output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < output_shape.Dims(3); ++c) {
|
||||
auto out_idx = Offset(output_shape, b, y, x, c);
|
||||
auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
|
||||
auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
|
||||
auto in1_val = input1_data[in1_idx];
|
||||
auto in2_val = input2_data[in2_idx];
|
||||
output_data[out_idx] = in1_val >= 0.0 ? in1_val : in1_val * in2_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
const TfLiteTensor* alpha = GetInput(context, node, 1);
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
int32_t output_multiplier = 0;
|
||||
int output_shift = 0;
|
||||
if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
|
||||
double real_multiplier =
|
||||
input->params.scale * alpha->params.scale / output->params.scale;
|
||||
QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
|
||||
&output_shift);
|
||||
}
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
BroadcastPrelu4DSlowFloat(
|
||||
GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(alpha), GetTensorData<float>(alpha),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
case kTfLiteUInt8: {
|
||||
PreluParams op_params;
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.alpha_offset = -alpha->params.zero_point;
|
||||
op_params.output_offset = output->params.zero_point;
|
||||
op_params.output_multiplier = output_multiplier;
|
||||
op_params.output_shift = output_shift;
|
||||
reference_ops::BroadcastPrelu4DSlow(
|
||||
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
|
||||
GetTensorShape(alpha), GetTensorData<uint8_t>(alpha),
|
||||
GetTensorShape(output), GetTensorData<uint8_t>(output));
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Only float32 and uint8 are supported currently, got %d.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace activations
|
||||
|
||||
TfLiteRegistration* Register_PRELU() {
|
||||
static TfLiteRegistration r = {nullptr, nullptr, activations::PreluPrepare,
|
||||
activations::PreluEval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace micro
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
204
tensorflow/lite/experimental/micro/kernels/prelu_test.cc
Normal file
204
tensorflow/lite/experimental/micro/kernels/prelu_test.cc
Normal file
@ -0,0 +1,204 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
|
||||
#include "tensorflow/lite/experimental/micro/simple_tensor_allocator.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||
#include "tensorflow/lite/experimental/micro/testing/test_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
void TestPreluFloat(std::initializer_list<int> input_dims_data,
|
||||
std::initializer_list<float> input_data,
|
||||
std::initializer_list<int> alpha_dims_data,
|
||||
std::initializer_list<float> alpha_data,
|
||||
std::initializer_list<float> expected_output_data,
|
||||
std::initializer_list<int> output_dims_data,
|
||||
float* output_data) {
|
||||
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
|
||||
TfLiteIntArray* alpha_dims = IntArrayFromInitializer(alpha_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
|
||||
const int output_dims_count = ElementCount(*output_dims);
|
||||
constexpr int inputs_size = 2;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateFloatTensor(input_data, input_dims, "input_tensor"),
|
||||
CreateFloatTensor(alpha_data, alpha_dims, "alpha_tensor"),
|
||||
CreateFloatTensor(output_data, output_dims, "output_tensor"),
|
||||
};
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, &context);
|
||||
::tflite::ops::micro::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_PRELU, 1);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
size_t init_data_size = 0;
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, init_data_size);
|
||||
}
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
|
||||
1e-5f);
|
||||
}
|
||||
}
|
||||
|
||||
void TestPreluQuantized(std::initializer_list<int> input_dims_data,
|
||||
std::initializer_list<uint8_t> input_data,
|
||||
float input_min, float input_max,
|
||||
std::initializer_list<int> alpha_dims_data,
|
||||
std::initializer_list<uint8_t> alpha_data,
|
||||
float alpha_min, float alpha_max,
|
||||
std::initializer_list<uint8_t> expected_output_data,
|
||||
std::initializer_list<int> output_dims_data,
|
||||
float output_min, float output_max,
|
||||
uint8_t* output_data) {
|
||||
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
|
||||
TfLiteIntArray* alpha_dims = IntArrayFromInitializer(alpha_dims_data);
|
||||
TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
|
||||
const int output_dims_count = ElementCount(*output_dims);
|
||||
constexpr int inputs_size = 2;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
TfLiteTensor tensors[tensors_size] = {
|
||||
CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min,
|
||||
input_max),
|
||||
CreateQuantizedTensor(alpha_data, alpha_dims, "alpha_tensor", alpha_min,
|
||||
alpha_max),
|
||||
CreateQuantizedTensor(output_data, output_dims, "output_tensor",
|
||||
output_min, output_max),
|
||||
};
|
||||
TfLiteContext context;
|
||||
PopulateContext(tensors, tensors_size, &context);
|
||||
::tflite::ops::micro::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_PRELU, 1);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
size_t init_data_size = 0;
|
||||
void* user_data = nullptr;
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, init_data_size);
|
||||
}
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
if (registration->free) {
|
||||
registration->free(&context, user_data);
|
||||
}
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(FloatPreluActivationsOpTest) {
|
||||
const int output_dims_count = 12;
|
||||
float output_data[output_dims_count];
|
||||
tflite::testing::TestPreluFloat({1, 2, 2, 3}, // input shape
|
||||
{
|
||||
0.0f, 0.0f, 0.0f, // Row 1, Column 1
|
||||
1.0f, 1.0f, 1.0f, // Row 1, Column 2
|
||||
-1.0f, -1.0f, -1.0f, // Row 2, Column 1
|
||||
-2.0f, -2.0f, -2.0f, // Row 1, Column 2
|
||||
},
|
||||
{1, 1, 3}, // alpha shape
|
||||
{0.0f, 1.0f, 2.0f}, // alpha values
|
||||
{
|
||||
0.0f, 0.0f, 0.0f, // Row 1, Column 1
|
||||
1.0f, 1.0f, 1.0f, // Row 1, Column 2
|
||||
0.0f, -1.0f, -2.0f, // Row 2, Column 1
|
||||
0.0f, -2.0f, -4.0f, // Row 1, Column 2
|
||||
},
|
||||
{1, 2, 2, 3}, // output shape
|
||||
output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(QuantizedPreluActivationsOpTest) {
|
||||
using tflite::testing::F2Q;
|
||||
const float kMin = -1;
|
||||
const float kMax = 127.f / 128.f;
|
||||
const float kAlphaMin = -0.5f;
|
||||
const float kAlphaMax = 0.5f;
|
||||
const int output_dims_count = 12;
|
||||
uint8_t output_data[output_dims_count];
|
||||
tflite::testing::TestPreluQuantized(
|
||||
{1, 2, 2, 3}, // input shape
|
||||
{F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax),
|
||||
F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax),
|
||||
F2Q(-1.0f, kMin, kMax), F2Q(-1.0f, kMin, kMax), F2Q(-1.0f, kMin, kMax),
|
||||
F2Q(-0.25f, kMin, kMax), F2Q(-0.25f, kMin, kMax),
|
||||
F2Q(-0.25f, kMin, kMax)},
|
||||
kMin, kMax, {1, 1, 3}, // alpha shape
|
||||
{F2Q(0.0f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(-0.5f, kMin, kMax)},
|
||||
kMin, kMax,
|
||||
{F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax), F2Q(0.0f, kMin, kMax),
|
||||
F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax),
|
||||
F2Q(0.0f, kMin, kMax), F2Q(-0.5f, kMin, kMax), F2Q(0.5f, kMin, kMax),
|
||||
F2Q(0.0f, kMin, kMax), F2Q(-0.125f, kMin, kMax),
|
||||
F2Q(0.125f, kMin, kMax)},
|
||||
{1, 2, 2, 3}, // output shape
|
||||
kMin, kMax, output_data);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TESTS_END
|
@ -112,6 +112,7 @@ tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h \
|
||||
tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h \
|
||||
tensorflow/lite/kernels/internal/reference/fully_connected.h \
|
||||
tensorflow/lite/kernels/internal/reference/pooling.h \
|
||||
tensorflow/lite/kernels/internal/reference/prelu.h \
|
||||
tensorflow/lite/kernels/internal/reference/softmax.h \
|
||||
tensorflow/lite/kernels/internal/round.h \
|
||||
tensorflow/lite/kernels/internal/tensor_ctypes.h \
|
||||
|
@ -365,6 +365,7 @@ cc_library(
|
||||
"reference/integer_ops/softmax.h",
|
||||
"reference/integer_ops/tanh.h",
|
||||
"reference/pooling.h",
|
||||
"reference/prelu.h",
|
||||
"reference/reference_ops.h",
|
||||
"reference/softmax.h",
|
||||
"reference/strided_slice.h",
|
||||
@ -405,6 +406,7 @@ cc_library(
|
||||
"reference/fully_connected.h",
|
||||
"reference/legacy_reference_ops.h",
|
||||
"reference/pooling.h",
|
||||
"reference/prelu.h",
|
||||
"reference/reference_ops.h",
|
||||
"reference/softmax.h",
|
||||
"reference/strided_slice.h",
|
||||
|
77
tensorflow/lite/kernels/internal/reference/prelu.h
Normal file
77
tensorflow/lite/kernels/internal/reference/prelu.h
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
namespace reference_ops {
|
||||
|
||||
// Broadcast prelu to output_shape for quantized uint8 data.
|
||||
inline void BroadcastPrelu4DSlow(const PreluParams& params,
|
||||
const RuntimeShape& input_shape,
|
||||
const uint8* input_data,
|
||||
const RuntimeShape& alpha_shape,
|
||||
const uint8* alpha_data,
|
||||
const RuntimeShape& output_shape,
|
||||
uint8* output_data) {
|
||||
TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(alpha_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape extended_output_shape =
|
||||
RuntimeShape::ExtendedShape(4, output_shape);
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(input_shape, alpha_shape, &desc1, &desc2);
|
||||
|
||||
for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
|
||||
int output_index = Offset(extended_output_shape, b, y, x, c);
|
||||
int input_index = SubscriptToIndex(desc1, b, y, x, c);
|
||||
const int32 input_value =
|
||||
params.input_offset + input_data[input_index];
|
||||
if (input_value >= 0) {
|
||||
output_data[output_index] = input_data[input_index];
|
||||
} else {
|
||||
auto alpha_index = SubscriptToIndex(desc2, b, y, x, c);
|
||||
const int32 alpha_value =
|
||||
params.alpha_offset + alpha_data[alpha_index];
|
||||
const int32 unclamped_output =
|
||||
params.output_offset +
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
input_value * alpha_value, params.output_multiplier,
|
||||
params.output_shift);
|
||||
const int32 quantized_min = std::numeric_limits<uint8_t>::min();
|
||||
const int32 quantized_max = std::numeric_limits<uint8_t>::max();
|
||||
const int32 clamped_output = std::min(
|
||||
quantized_max, std::max(quantized_min, unclamped_output));
|
||||
output_data[output_index] = static_cast<uint8>(clamped_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/reference/conv.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/pooling.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/prelu.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
|
||||
#include "tensorflow/lite/kernels/internal/round.h"
|
||||
@ -4403,53 +4404,6 @@ inline void ResizeNearestNeighbor(
|
||||
}
|
||||
}
|
||||
|
||||
inline void BroadcastPrelu4DSlow(const PreluParams& params,
|
||||
const RuntimeShape& input_shape,
|
||||
const uint8* input_data,
|
||||
const RuntimeShape& alpha_shape,
|
||||
const uint8* alpha_data,
|
||||
const RuntimeShape& output_shape,
|
||||
uint8* output_data) {
|
||||
TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(alpha_shape.DimensionsCount(), 4);
|
||||
TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
|
||||
const RuntimeShape extended_output_shape =
|
||||
RuntimeShape::ExtendedShape(4, output_shape);
|
||||
NdArrayDesc<4> desc1;
|
||||
NdArrayDesc<4> desc2;
|
||||
NdArrayDescsForElementwiseBroadcast(input_shape, alpha_shape, &desc1, &desc2);
|
||||
|
||||
for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
|
||||
for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
|
||||
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
|
||||
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
|
||||
int output_index = Offset(extended_output_shape, b, y, x, c);
|
||||
int input_index = SubscriptToIndex(desc1, b, y, x, c);
|
||||
const int32 input_value =
|
||||
params.input_offset + input_data[input_index];
|
||||
if (input_value >= 0) {
|
||||
output_data[output_index] = input_data[input_index];
|
||||
} else {
|
||||
auto alpha_index = SubscriptToIndex(desc2, b, y, x, c);
|
||||
const int32 alpha_value =
|
||||
params.alpha_offset + alpha_data[alpha_index];
|
||||
const int32 unclamped_output =
|
||||
params.output_offset +
|
||||
MultiplyByQuantizedMultiplierSmallerThanOneExp(
|
||||
input_value * alpha_value, params.output_multiplier,
|
||||
params.output_shift);
|
||||
const int32 quantized_min = std::numeric_limits<uint8_t>::min();
|
||||
const int32 quantized_max = std::numeric_limits<uint8_t>::max();
|
||||
const int32 clamped_output = std::min(
|
||||
quantized_max, std::max(quantized_min, unclamped_output));
|
||||
output_data[output_index] = static_cast<uint8>(clamped_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Fill(const RuntimeShape& value_shape, const T* value_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
|
Loading…
Reference in New Issue
Block a user