diff --git a/tensorflow/lite/kernels/basic_rnn.cc b/tensorflow/lite/kernels/basic_rnn.cc index a2c38b3b7d8..894316b2049 100644 --- a/tensorflow/lite/kernels/basic_rnn.cc +++ b/tensorflow/lite/kernels/basic_rnn.cc @@ -95,9 +95,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size_array)); - bool is_hybrid = - input->type == kTfLiteFloat32 && (input_weights->type == kTfLiteUInt8 || - input_weights->type == kTfLiteInt8); + const bool is_hybrid = IsHybridOp(input, input_weights); // Allocate temporary tensors to store quantized values of input and // hidden_state tensors. diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index 31c6e3f44c8..57746868b1e 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -509,8 +509,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, fw_output, fw_output_size)); // The weights are of consistent type, so it suffices to check one. - const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8 || - fw_input_to_output_weights->type == kTfLiteInt8); + const bool is_hybrid_op = IsHybridOp(input, fw_input_to_output_weights); TfLiteIntArrayFree(node->temporaries); if (is_hybrid_op) { diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc index 0adf574bb06..75dbdd3fe16 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc @@ -168,11 +168,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { bw_aux_input_weights->dims->data[1]); } - const bool is_hybrid_op = ((fw_input_weights->type == kTfLiteUInt8 || - fw_input_weights->type == kTfLiteInt8) && - input->type == kTfLiteFloat32); - - if (is_hybrid_op) { + if (IsHybridOp(input, fw_input_weights)) { int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 423832c047c..94c9842b474 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -84,6 +84,13 @@ inline void SetTensorToDynamic(TfLiteTensor* tensor) { } } +// Determines whether it is a hybrid op - one that has float inputs and +// quantized weights. +inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) { + return ((weight->type == kTfLiteUInt8 || weight->type == kTfLiteInt8) && + input->type == kTfLiteFloat32); +} + // Check dimensionality match and populate OpData for Conv and DepthwiseConv. TfLiteStatus PopulateConvolutionQuantizationParams( TfLiteContext* context, const TfLiteTensor* input, diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index ea22ed56941..a0f8ac9455e 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -381,10 +381,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, output, output_size)); // The weights are of consistent type, so it suffices to check one. - // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. - const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 || - input_to_output_weights->type == kTfLiteInt8) && - input->type == kTfLiteFloat32); + const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights); TfLiteIntArrayFree(node->temporaries); if (is_hybrid_op) { diff --git a/tensorflow/lite/kernels/svdf.cc b/tensorflow/lite/kernels/svdf.cc index d8fc7ce1cea..ae04c96967c 100644 --- a/tensorflow/lite/kernels/svdf.cc +++ b/tensorflow/lite/kernels/svdf.cc @@ -176,9 +176,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, output, output_size_array)); // The weights are of consistent type, so it suffices to check one. - const bool is_hybrid_op = (input->type == kTfLiteFloat32 && - (weights_feature->type == kTfLiteUInt8 || - weights_feature->type == kTfLiteInt8)); + const bool is_hybrid_op = IsHybridOp(input, weights_feature); // Resize scratch. TfLiteIntArrayFree(node->temporaries); diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index 8c2d0d57c7b..f1793c13a72 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -304,14 +304,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size)); - // The weights are of consistent type, so it suffices to check one. - // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. - const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 || - input_to_output_weights->type == kTfLiteInt8) && - input->type == kTfLiteFloat32); - TfLiteIntArrayFree(node->temporaries); - if (is_hybrid_op) { + if (IsHybridOp(input, input_to_output_weights)) { node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); } else { node->temporaries = TfLiteIntArrayCreate(1); @@ -338,7 +332,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, scratch_buffer_size)); - if (is_hybrid_op) { + if (IsHybridOp(input, input_to_output_weights)) { // Allocate temporary tensors to store quantized values of input, // activation_state and cell_state tensors. node->temporaries->data[kInputQuantized] = diff --git a/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc index 3854695d0bf..3000c3cd42f 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc @@ -96,9 +96,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size_array)); - const bool is_hybrid = - input->type == kTfLiteFloat32 && (input_weights->type == kTfLiteUInt8 || - input_weights->type == kTfLiteInt8); + const bool is_hybrid = IsHybridOp(input, input_weights); // Allocate temporary tensors to store quantized values of input and // hidden_state tensors.