added runtimeshape check for tanh kernel
This commit is contained in:
parent
4e20264533
commit
72a30f0017
@ -298,7 +298,6 @@ void HardSwishFree(TfLiteContext* context, void* buffer) {
|
||||
delete static_cast<HardSwishData*>(buffer);
|
||||
}
|
||||
|
||||
|
||||
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
@ -865,12 +864,10 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TanhParams params;
|
||||
params.input_left_shift = data->input_left_shift;
|
||||
if (kernel_type == kReference || (data->input_multiplier > 0)) {
|
||||
const int size =
|
||||
MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
|
||||
|
||||
reference_integer_ops::Tanh(
|
||||
data->input_multiplier, data->input_left_shift, size,
|
||||
GetTensorData<int16_t>(input), GetTensorData<int16_t>(output));
|
||||
data->input_multiplier, data->input_left_shift,
|
||||
GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int16_t>(output));
|
||||
} else {
|
||||
optimized_ops::Tanh(
|
||||
params, GetTensorShape(input), GetTensorData<int16_t>(input),
|
||||
|
@ -25,8 +25,8 @@ namespace reference_integer_ops {
|
||||
|
||||
inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
|
||||
int32_t input_multiplier, int32_t input_shift,
|
||||
int32_t input_size, const int8_t* input_data,
|
||||
int8_t* output_data) {
|
||||
const RuntimeShape& input_shape, const int8_t* input_data,
|
||||
const RuntimeShape& output_shape, int8_t* output_data) {
|
||||
// Integer bits must be in sync with Prepare() function.
|
||||
static constexpr int32_t kInputIntegerBits = 4;
|
||||
static constexpr int32_t kOutputScale = 7;
|
||||
@ -34,7 +34,9 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
|
||||
static constexpr int32_t kMaxInt8 = std::numeric_limits<int8_t>::max();
|
||||
using F4 = gemmlowp::FixedPoint<int32_t, kInputIntegerBits>;
|
||||
|
||||
for (int i = 0; i < input_size; ++i) {
|
||||
const int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
const int32_t input =
|
||||
static_cast<int32_t>(input_data[i]) - input_zero_point;
|
||||
if (input <= -input_range_radius) {
|
||||
@ -58,14 +60,16 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
|
||||
}
|
||||
|
||||
inline void Tanh(int32_t input_multiplier, int32_t input_left_shift,
|
||||
int32_t input_size, const int16_t* ptr_input_data,
|
||||
int16_t* ptr_output_data) {
|
||||
const RuntimeShape& input_shape, const int16_t* ptr_input_data,
|
||||
const RuntimeShape& output_shape, int16_t* ptr_output_data) {
|
||||
// We use the LUT for sigmoid and take into account, that
|
||||
// tanh(x) = 2*sigmoid(2*x) - 1
|
||||
|
||||
int32_t input_data_mul = (input_multiplier > 0) ? input_multiplier : 1;
|
||||
|
||||
for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) {
|
||||
int flat_size = MatchingFlatSize(input_shape, output_shape);
|
||||
|
||||
for (int i = 0; i < flat_size; ++i, ptr_input_data++, ptr_output_data++) {
|
||||
int32_t input_data = (*ptr_input_data) * input_data_mul;
|
||||
|
||||
if (input_left_shift == 1) {
|
||||
|
@ -124,8 +124,9 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteInt8: {
|
||||
reference_integer_ops::Tanh(
|
||||
data.input_zero_point, data.input_range_radius, data.input_multiplier,
|
||||
data.input_left_shift, NumElements(input->dims),
|
||||
data.input_left_shift, tflite::micro::GetTensorShape(input),
|
||||
tflite::micro::GetTensorData<int8_t>(input),
|
||||
tflite::micro::GetTensorShape(output),
|
||||
tflite::micro::GetTensorData<int8_t>(output));
|
||||
return kTfLiteOk;
|
||||
} break;
|
||||
|
Loading…
x
Reference in New Issue
Block a user