diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index 654ccbc27ec..b8f9a97c551 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -298,7 +298,6 @@ void HardSwishFree(TfLiteContext* context, void* buffer) { delete static_cast(buffer); } - TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(GenericPrepare(context, node)); TfLiteTensor* output = GetOutput(context, node, 0); @@ -865,12 +864,10 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { TanhParams params; params.input_left_shift = data->input_left_shift; if (kernel_type == kReference || (data->input_multiplier > 0)) { - const int size = - MatchingFlatSize(GetTensorShape(input), GetTensorShape(output)); - reference_integer_ops::Tanh( - data->input_multiplier, data->input_left_shift, size, - GetTensorData(input), GetTensorData(output)); + data->input_multiplier, data->input_left_shift, + GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); } else { optimized_ops::Tanh( params, GetTensorShape(input), GetTensorData(input), diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h b/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h index baae65ab30e..81ff34fef63 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h @@ -25,8 +25,8 @@ namespace reference_integer_ops { inline void Tanh(int32_t input_zero_point, int32_t input_range_radius, int32_t input_multiplier, int32_t input_shift, - int32_t input_size, const int8_t* input_data, - int8_t* output_data) { + const RuntimeShape& input_shape, const int8_t* input_data, + const RuntimeShape& output_shape, int8_t* output_data) { // Integer bits must be in sync with Prepare() function. static constexpr int32_t kInputIntegerBits = 4; static constexpr int32_t kOutputScale = 7; @@ -34,7 +34,9 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius, static constexpr int32_t kMaxInt8 = std::numeric_limits::max(); using F4 = gemmlowp::FixedPoint; - for (int i = 0; i < input_size; ++i) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + + for (int i = 0; i < flat_size; ++i) { const int32_t input = static_cast(input_data[i]) - input_zero_point; if (input <= -input_range_radius) { @@ -58,14 +60,16 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius, } inline void Tanh(int32_t input_multiplier, int32_t input_left_shift, - int32_t input_size, const int16_t* ptr_input_data, - int16_t* ptr_output_data) { + const RuntimeShape& input_shape, const int16_t* ptr_input_data, + const RuntimeShape& output_shape, int16_t* ptr_output_data) { // We use the LUT for sigmoid and take into account, that // tanh(x) = 2*sigmoid(2*x) - 1 int32_t input_data_mul = (input_multiplier > 0) ? input_multiplier : 1; - for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) { + int flat_size = MatchingFlatSize(input_shape, output_shape); + + for (int i = 0; i < flat_size; ++i, ptr_input_data++, ptr_output_data++) { int32_t input_data = (*ptr_input_data) * input_data_mul; if (input_left_shift == 1) { diff --git a/tensorflow/lite/micro/kernels/tanh.cc b/tensorflow/lite/micro/kernels/tanh.cc index 5fa32f8f7ce..0f257dfe56b 100644 --- a/tensorflow/lite/micro/kernels/tanh.cc +++ b/tensorflow/lite/micro/kernels/tanh.cc @@ -124,8 +124,9 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt8: { reference_integer_ops::Tanh( data.input_zero_point, data.input_range_radius, data.input_multiplier, - data.input_left_shift, NumElements(input->dims), + data.input_left_shift, tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); return kTfLiteOk; } break;