added runtimeshape check for tanh kernel
This commit is contained in:
parent
4e20264533
commit
72a30f0017
@ -298,7 +298,6 @@ void HardSwishFree(TfLiteContext* context, void* buffer) {
|
|||||||
delete static_cast<HardSwishData*>(buffer);
|
delete static_cast<HardSwishData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
|
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
|
||||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||||
@ -865,12 +864,10 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TanhParams params;
|
TanhParams params;
|
||||||
params.input_left_shift = data->input_left_shift;
|
params.input_left_shift = data->input_left_shift;
|
||||||
if (kernel_type == kReference || (data->input_multiplier > 0)) {
|
if (kernel_type == kReference || (data->input_multiplier > 0)) {
|
||||||
const int size =
|
|
||||||
MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
|
|
||||||
|
|
||||||
reference_integer_ops::Tanh(
|
reference_integer_ops::Tanh(
|
||||||
data->input_multiplier, data->input_left_shift, size,
|
data->input_multiplier, data->input_left_shift,
|
||||||
GetTensorData<int16_t>(input), GetTensorData<int16_t>(output));
|
GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||||
|
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||||
} else {
|
} else {
|
||||||
optimized_ops::Tanh(
|
optimized_ops::Tanh(
|
||||||
params, GetTensorShape(input), GetTensorData<int16_t>(input),
|
params, GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||||
|
@ -25,8 +25,8 @@ namespace reference_integer_ops {
|
|||||||
|
|
||||||
inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
|
inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
|
||||||
int32_t input_multiplier, int32_t input_shift,
|
int32_t input_multiplier, int32_t input_shift,
|
||||||
int32_t input_size, const int8_t* input_data,
|
const RuntimeShape& input_shape, const int8_t* input_data,
|
||||||
int8_t* output_data) {
|
const RuntimeShape& output_shape, int8_t* output_data) {
|
||||||
// Integer bits must be in sync with Prepare() function.
|
// Integer bits must be in sync with Prepare() function.
|
||||||
static constexpr int32_t kInputIntegerBits = 4;
|
static constexpr int32_t kInputIntegerBits = 4;
|
||||||
static constexpr int32_t kOutputScale = 7;
|
static constexpr int32_t kOutputScale = 7;
|
||||||
@ -34,7 +34,9 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
|
|||||||
static constexpr int32_t kMaxInt8 = std::numeric_limits<int8_t>::max();
|
static constexpr int32_t kMaxInt8 = std::numeric_limits<int8_t>::max();
|
||||||
using F4 = gemmlowp::FixedPoint<int32_t, kInputIntegerBits>;
|
using F4 = gemmlowp::FixedPoint<int32_t, kInputIntegerBits>;
|
||||||
|
|
||||||
for (int i = 0; i < input_size; ++i) {
|
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||||
|
|
||||||
|
for (int i = 0; i < flat_size; ++i) {
|
||||||
const int32_t input =
|
const int32_t input =
|
||||||
static_cast<int32_t>(input_data[i]) - input_zero_point;
|
static_cast<int32_t>(input_data[i]) - input_zero_point;
|
||||||
if (input <= -input_range_radius) {
|
if (input <= -input_range_radius) {
|
||||||
@ -58,14 +60,16 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline void Tanh(int32_t input_multiplier, int32_t input_left_shift,
|
inline void Tanh(int32_t input_multiplier, int32_t input_left_shift,
|
||||||
int32_t input_size, const int16_t* ptr_input_data,
|
const RuntimeShape& input_shape, const int16_t* ptr_input_data,
|
||||||
int16_t* ptr_output_data) {
|
const RuntimeShape& output_shape, int16_t* ptr_output_data) {
|
||||||
// We use the LUT for sigmoid and take into account, that
|
// We use the LUT for sigmoid and take into account, that
|
||||||
// tanh(x) = 2*sigmoid(2*x) - 1
|
// tanh(x) = 2*sigmoid(2*x) - 1
|
||||||
|
|
||||||
int32_t input_data_mul = (input_multiplier > 0) ? input_multiplier : 1;
|
int32_t input_data_mul = (input_multiplier > 0) ? input_multiplier : 1;
|
||||||
|
|
||||||
for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) {
|
int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||||
|
|
||||||
|
for (int i = 0; i < flat_size; ++i, ptr_input_data++, ptr_output_data++) {
|
||||||
int32_t input_data = (*ptr_input_data) * input_data_mul;
|
int32_t input_data = (*ptr_input_data) * input_data_mul;
|
||||||
|
|
||||||
if (input_left_shift == 1) {
|
if (input_left_shift == 1) {
|
||||||
|
@ -124,8 +124,9 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
case kTfLiteInt8: {
|
case kTfLiteInt8: {
|
||||||
reference_integer_ops::Tanh(
|
reference_integer_ops::Tanh(
|
||||||
data.input_zero_point, data.input_range_radius, data.input_multiplier,
|
data.input_zero_point, data.input_range_radius, data.input_multiplier,
|
||||||
data.input_left_shift, NumElements(input->dims),
|
data.input_left_shift, tflite::micro::GetTensorShape(input),
|
||||||
tflite::micro::GetTensorData<int8_t>(input),
|
tflite::micro::GetTensorData<int8_t>(input),
|
||||||
|
tflite::micro::GetTensorShape(output),
|
||||||
tflite::micro::GetTensorData<int8_t>(output));
|
tflite::micro::GetTensorData<int8_t>(output));
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
} break;
|
} break;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user