added runtimeshape check for tanh kernel

This commit is contained in:
Basit Ayantunde 2020-09-09 16:15:11 +01:00
parent 4e20264533
commit 72a30f0017
3 changed files with 15 additions and 13 deletions

View File

@ -298,7 +298,6 @@ void HardSwishFree(TfLiteContext* context, void* buffer) {
delete static_cast<HardSwishData*>(buffer);
}
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
TfLiteTensor* output = GetOutput(context, node, 0);
@ -865,12 +864,10 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
TanhParams params;
params.input_left_shift = data->input_left_shift;
if (kernel_type == kReference || (data->input_multiplier > 0)) {
const int size =
MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
reference_integer_ops::Tanh(
data->input_multiplier, data->input_left_shift, size,
GetTensorData<int16_t>(input), GetTensorData<int16_t>(output));
data->input_multiplier, data->input_left_shift,
GetTensorShape(input), GetTensorData<int16_t>(input),
GetTensorShape(output), GetTensorData<int16_t>(output));
} else {
optimized_ops::Tanh(
params, GetTensorShape(input), GetTensorData<int16_t>(input),

View File

@ -25,8 +25,8 @@ namespace reference_integer_ops {
inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
int32_t input_multiplier, int32_t input_shift,
int32_t input_size, const int8_t* input_data,
int8_t* output_data) {
const RuntimeShape& input_shape, const int8_t* input_data,
const RuntimeShape& output_shape, int8_t* output_data) {
// Integer bits must be in sync with Prepare() function.
static constexpr int32_t kInputIntegerBits = 4;
static constexpr int32_t kOutputScale = 7;
@ -34,7 +34,9 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
static constexpr int32_t kMaxInt8 = std::numeric_limits<int8_t>::max();
using F4 = gemmlowp::FixedPoint<int32_t, kInputIntegerBits>;
for (int i = 0; i < input_size; ++i) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const int32_t input =
static_cast<int32_t>(input_data[i]) - input_zero_point;
if (input <= -input_range_radius) {
@ -58,14 +60,16 @@ inline void Tanh(int32_t input_zero_point, int32_t input_range_radius,
}
inline void Tanh(int32_t input_multiplier, int32_t input_left_shift,
int32_t input_size, const int16_t* ptr_input_data,
int16_t* ptr_output_data) {
const RuntimeShape& input_shape, const int16_t* ptr_input_data,
const RuntimeShape& output_shape, int16_t* ptr_output_data) {
// We use the LUT for sigmoid and take into account, that
// tanh(x) = 2*sigmoid(2*x) - 1
int32_t input_data_mul = (input_multiplier > 0) ? input_multiplier : 1;
for (int i = 0; i < input_size; ++i, ptr_input_data++, ptr_output_data++) {
int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i, ptr_input_data++, ptr_output_data++) {
int32_t input_data = (*ptr_input_data) * input_data_mul;
if (input_left_shift == 1) {

View File

@ -124,8 +124,9 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt8: {
reference_integer_ops::Tanh(
data.input_zero_point, data.input_range_radius, data.input_multiplier,
data.input_left_shift, NumElements(input->dims),
data.input_left_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
} break;