diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h index d7848f9bcc0..1bd7959c4db 100644 --- a/tensorflow/lite/c/builtin_op_data.h +++ b/tensorflow/lite/c/builtin_op_data.h @@ -124,21 +124,33 @@ typedef struct { typedef struct { int rank; TfLiteFusedActivation activation; + + // Parameter for SVDF version 4. + bool asymmetric_quantize_inputs; } TfLiteSVDFParams; typedef struct { TfLiteFusedActivation activation; + + // Parameter for RNN version 3. + bool asymmetric_quantize_inputs; } TfLiteRNNParams; typedef struct { bool time_major; TfLiteFusedActivation activation; + + // Parameter for Sequence RNN version 3. + bool asymmetric_quantize_inputs; } TfLiteSequenceRNNParams; typedef struct { bool time_major; TfLiteFusedActivation activation; bool merge_outputs; + + // Parameter for Bidirectional RNN verison 3. + bool asymmetric_quantize_inputs; } TfLiteBidirectionalSequenceRNNParams; typedef enum { @@ -158,6 +170,11 @@ typedef struct { // tensors are the same. Furthermore, all but the last dimension of the input // and output shapes will be equal. bool keep_num_dims; + + // Parameters for FullyConnected version 7 or above. + // If set to true and the weights are quantized, then non constant inputs + // are quantized at evaluation time with asymmetric quantization. + bool asymmetric_quantize_inputs; } TfLiteFullyConnectedParams; typedef enum { @@ -228,6 +245,9 @@ typedef struct { // Parameters for LSTM version 2. // kTfLiteLSTMBasicKernel is only supported in version 2 or above. TfLiteLSTMKernelType kernel_type; + + // Parameters for LSTM version 4. + bool asymmetric_quantize_inputs; } TfLiteLSTMParams; typedef struct { @@ -238,6 +258,9 @@ typedef struct { // If set to true then the first dimension is time, otherwise batch. bool time_major; + + // Parameter for unidirectional sequence RNN version 3. + bool asymmetric_quantize_inputs; } TfLiteUnidirectionalSequenceLSTMParams; typedef struct { @@ -253,6 +276,10 @@ typedef struct { // Parameters supported by version 2: // If set to true then the first dimension is time, otherwise batch. bool time_major; + + // Parameters supported by version 4: + // If set to true, then hybrid ops use asymmetric quantization for inputs. + bool asymmetric_quantize_inputs; } TfLiteBidirectionalSequenceLSTMParams; typedef struct { diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 83b4159cce0..61a946feecb 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -269,6 +269,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->rank = svdf_params->rank(); params->activation = parse_activation(svdf_params->fused_activation_function()); + params->asymmetric_quantize_inputs = + svdf_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -280,6 +282,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->activation = parse_activation(sequence_rnn_params->fused_activation_function()); params->time_major = sequence_rnn_params->time_major(); + params->asymmetric_quantize_inputs = + sequence_rnn_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -293,6 +297,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, bidi_sequence_rnn_params->fused_activation_function()); params->time_major = bidi_sequence_rnn_params->time_major(); params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); + params->asymmetric_quantize_inputs = + bidi_sequence_rnn_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -302,6 +308,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { params->activation = parse_activation(rnn_params->fused_activation_function()); + params->asymmetric_quantize_inputs = + rnn_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -323,6 +331,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->activation = parse_activation( fully_connected_params->fused_activation_function()); params->keep_num_dims = fully_connected_params->keep_num_dims(); + params->asymmetric_quantize_inputs = + fully_connected_params->asymmetric_quantize_inputs(); switch (fully_connected_params->weights_format()) { case FullyConnectedOptionsWeightsFormat_DEFAULT: params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; @@ -440,6 +450,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, lstm_params->kernel_type()); return kTfLiteError; } + params->asymmetric_quantize_inputs = + lstm_params->asymmetric_quantize_inputs(); } else { TF_LITE_REPORT_ERROR(error_reporter, "No valid LSTM builtin options exist"); @@ -458,6 +470,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->cell_clip = seq_lstm_params->cell_clip(); params->proj_clip = seq_lstm_params->proj_clip(); params->time_major = seq_lstm_params->time_major(); + params->asymmetric_quantize_inputs = + seq_lstm_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; @@ -473,6 +487,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->proj_clip = bidi_lstm_params->proj_clip(); params->merge_outputs = bidi_lstm_params->merge_outputs(); params->time_major = bidi_lstm_params->time_major(); + params->asymmetric_quantize_inputs = + bidi_lstm_params->asymmetric_quantize_inputs(); } *builtin_data = reinterpret_cast(params.release()); break; diff --git a/tensorflow/lite/kernels/basic_rnn.cc b/tensorflow/lite/kernels/basic_rnn.cc index f21b8a910dd..920e8cd223a 100644 --- a/tensorflow/lite/kernels/basic_rnn.cc +++ b/tensorflow/lite/kernels/basic_rnn.cc @@ -26,6 +26,15 @@ namespace ops { namespace builtin { namespace rnn { +namespace { + +struct OpData { + int scratch_tensor_index; + bool compute_row_sums = false; +}; + +} // namespace + constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kRecurrentWeightsTensor = 2; @@ -36,13 +45,14 @@ constexpr int kHiddenStateTensor = 4; constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, /*tensors_to_add=*/6, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { @@ -89,10 +99,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate temporary tensors to store quantized values of input and // hidden_state tensors. if (is_hybrid) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + auto* op_data = reinterpret_cast(node->user_data); + op_data->compute_row_sums = true; TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(3); - node->temporaries->data[0] = *scratch_tensor_index; + node->temporaries = TfLiteIntArrayCreate(6); + node->temporaries->data[0] = op_data->scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = input_weights->type; input_quantized->allocation_type = kTfLiteArenaRw; @@ -101,7 +112,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, input_quantized_size)); } - node->temporaries->data[1] = *scratch_tensor_index + 1; + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, /*index=*/1); hidden_state_quantized->type = input_weights->type; @@ -114,7 +125,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } - node->temporaries->data[2] = *scratch_tensor_index + 2; + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); scaling_factors->type = kTfLiteFloat32; scaling_factors->allocation_type = kTfLiteArenaRw; @@ -125,8 +136,43 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, scaling_factors_size)); } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {num_units, batch_size}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); + accum_scratch_size->data[0] = accum_scratch_dims[0]; + accum_scratch_size->data[1] = accum_scratch_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, + accum_scratch_size)); + } + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); + zero_points->type = kTfLiteInt32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {batch_size}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = batch_size; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_dims[2] = {2, num_units}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); + row_sums_size->data[0] = row_sums_dims[0]; + row_sums_size->data[1] = row_sums_dims[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } - return kTfLiteOk; } @@ -165,7 +211,9 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input, TfLiteTensor* input_scratch, TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, - TfLiteTensor* hidden_state, TfLiteTensor* output) { + TfLiteTensor* hidden_state, TfLiteTensor* output, + TfLiteTensor* zero_points, TfLiteTensor* accum_scratch, + TfLiteTensor* row_sums, bool* compute_row_sums) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; @@ -190,26 +238,34 @@ TfLiteStatus EvalHybrid(const TfLiteTensor* input, int8_t* quantized_hidden_state_ptr = GetTensorData(hidden_state_scratch); float* scaling_factors_ptr = GetTensorData(scaling_factors); - + int32_t* accum_scratch_ptr = GetTensorData(accum_scratch); + int32_t* zero_points_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + zero_points_ptr = GetTensorData(zero_points); + row_sums_ptr = GetTensorData(row_sums); + } kernel_utils::RnnBatchStep( input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, output_batch_leading_dim, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, - hidden_state_ptr_batch, output_ptr_batch); + hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, accum_scratch_ptr, + row_sums_ptr, compute_row_sums); return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - + auto* op_data = reinterpret_cast(node->user_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* recurrent_weights = GetInput(context, node, kRecurrentWeightsTensor); const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); TfLiteTensor* hidden_state = - GetVariableInput(context, node, kHiddenStateTensor); + &context->tensors[node->inputs->data[kHiddenStateTensor]]; TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // We already checked that weight types are consistent, so branch on one. @@ -223,9 +279,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + TfLiteTensor* accum_scratch = GetTemporary(context, node, 3); + TfLiteTensor* zero_points = GetTemporary(context, node, 4); + TfLiteTensor* row_sums = GetTemporary(context, node, 5); return EvalHybrid(input, input_weights, recurrent_weights, bias, params, input_quantized, hidden_state_quantized, - scaling_factors, hidden_state, output); + scaling_factors, hidden_state, output, zero_points, + accum_scratch, row_sums, &op_data->compute_row_sums); } default: context->ReportError(context, "Type %d not currently supported.", diff --git a/tensorflow/lite/kernels/basic_rnn_test.cc b/tensorflow/lite/kernels/basic_rnn_test.cc index b9c251ce044..f7cbaa5a814 100644 --- a/tensorflow/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/lite/kernels/basic_rnn_test.cc @@ -175,7 +175,8 @@ class RNNOpModel : public SingleOpModel { public: RNNOpModel(int batches, int units, int size, const TensorType& weights = TensorType_FLOAT32, - const TensorType& recurrent_weights = TensorType_FLOAT32) + const TensorType& recurrent_weights = TensorType_FLOAT32, + bool asymmetric_quantize_inputs = false) : batches_(batches), units_(units), input_size_(size) { input_ = AddInput(TensorType_FLOAT32); weights_ = AddInput(weights); @@ -183,9 +184,10 @@ class RNNOpModel : public SingleOpModel { bias_ = AddInput(TensorType_FLOAT32); hidden_state_ = AddInput(TensorType_FLOAT32, true); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_RNN, BuiltinOptions_RNNOptions, - CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + SetBuiltinOp(BuiltinOperator_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU, + asymmetric_quantize_inputs) + .Union()); BuildInterpreter({{batches_, input_size_}, // input tensor {units_, input_size_}, // weights tensor {units_, units_}, // recurrent weights tensor @@ -233,8 +235,10 @@ class RNNOpModel : public SingleOpModel { // The hybrid model has quantized weights and recurrent_weights. class HybridRNNOpModel : public RNNOpModel { public: - HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type) - : RNNOpModel(batches, units, size, tensor_type, tensor_type) { + HybridRNNOpModel(int batches, int units, int size, TensorType tensor_type, + bool asymmetric_quantize_inputs) + : RNNOpModel(batches, units, size, tensor_type, tensor_type, + asymmetric_quantize_inputs) { tensor_type_ = tensor_type; } @@ -282,8 +286,10 @@ TEST(RnnOpTest, BlackBoxTest) { } } -TEST(HybridRnnOpTest, BlackBoxTestUint8) { - HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8); +class HybridRnnOpTest : public ::testing::TestWithParam {}; + +TEST_P(HybridRnnOpTest, BlackBoxTestUint8) { + HybridRNNOpModel rnn(2, 16, 8, TensorType_UINT8, GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -310,8 +316,8 @@ TEST(HybridRnnOpTest, BlackBoxTestUint8) { } } -TEST(HybridRnnOpTest, BlackBoxTestInt8) { - HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8); +TEST_P(HybridRnnOpTest, BlackBoxTestInt8) { + HybridRNNOpModel rnn(2, 16, 8, TensorType_INT8, GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -338,5 +344,8 @@ TEST(HybridRnnOpTest, BlackBoxTestInt8) { } } +INSTANTIATE_TEST_SUITE_P(HybridRnnOpTest, HybridRnnOpTest, + ::testing::ValuesIn({false, true})); + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index 33c43aacbc7..3a780eed0a0 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -139,18 +139,28 @@ enum TemporaryTensor { kProductScalingFactors = 8, kRecoveredCellWeights = 9, kAccumScratchBuffer = 10, - kAuxInputQuantized = 11, // Optional, quantized tensor for auxiliary input. - kNumTemporaryTensors + kZeroPoints = 11, + kFwRowSums = 12, + kBwRowSums = 13, + kAuxInputQuantized = 14, // Optional, quantized tensor for auxiliary input. + kNumTemporaryTensors = 15 +}; + +struct OpData { + int scratch_tensor_index; + bool compute_fw_row_sums = false; + bool compute_bw_row_sums = false; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, kNumTemporaryTensors, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } // Check that input tensor dimensions matches with each other. @@ -385,7 +395,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // Resize the output and scratch tensors based on the sizes of the input // tensors. Also check that the size of the input tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + auto* op_data = reinterpret_cast(node->user_data); const auto* params = reinterpret_cast( node->builtin_data); @@ -522,7 +532,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers. } // Create a scratch buffer tensor. - node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index; + node->temporaries->data[kFwScratchBuffer] = op_data->scratch_tensor_index; TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, kFwScratchBuffer); fw_scratch_buffer->type = input->type; @@ -581,7 +591,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Create a scratch buffer tensor. node->temporaries->data[kBwScratchBuffer] = - *(scratch_tensor_index) + kBwScratchBuffer; + op_data->scratch_tensor_index + kBwScratchBuffer; TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, kBwScratchBuffer); bw_scratch_buffer->type = input->type; @@ -606,10 +616,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, bw_scratch_buffer_size)); if (is_hybrid_op) { + // Compute the row sums for cached zero_point offset calculation. + op_data->compute_fw_row_sums = true; + op_data->compute_bw_row_sums = true; // Allocate temporary tensors to store quantized values of input, aux_input // (if present), activation_state and cell_state tensors. node->temporaries->data[kInputQuantized] = - *scratch_tensor_index + kInputQuantized; + op_data->scratch_tensor_index + kInputQuantized; TfLiteTensor* input_quantized = GetTemporary(context, node, kInputQuantized); input_quantized->type = fw_input_to_output_weights->type; @@ -621,7 +634,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } node->temporaries->data[kFwActivationStateQuantized] = - *scratch_tensor_index + kFwActivationStateQuantized; + op_data->scratch_tensor_index + kFwActivationStateQuantized; TfLiteTensor* fw_activation_state_quantized = GetTemporary(context, node, kFwActivationStateQuantized); fw_activation_state_quantized->type = fw_input_to_output_weights->type; @@ -635,7 +648,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { fw_activation_state_quantized_size)); } node->temporaries->data[kBwActivationStateQuantized] = - *scratch_tensor_index + kBwActivationStateQuantized; + op_data->scratch_tensor_index + kBwActivationStateQuantized; TfLiteTensor* bw_activation_state_quantized = GetTemporary(context, node, kBwActivationStateQuantized); bw_activation_state_quantized->type = fw_input_to_output_weights->type; @@ -649,7 +662,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { bw_activation_state_quantized_size)); } node->temporaries->data[kFwCellStateQuantized] = - *scratch_tensor_index + kFwCellStateQuantized; + op_data->scratch_tensor_index + kFwCellStateQuantized; TfLiteTensor* fw_cell_state_quantized = GetTemporary(context, node, kFwCellStateQuantized); fw_cell_state_quantized->type = fw_input_to_output_weights->type; @@ -663,7 +676,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { fw_cell_state_quantized_size)); } node->temporaries->data[kBwCellStateQuantized] = - *scratch_tensor_index + kBwCellStateQuantized; + op_data->scratch_tensor_index + kBwCellStateQuantized; TfLiteTensor* bw_cell_state_quantized = GetTemporary(context, node, kBwCellStateQuantized); bw_cell_state_quantized->type = fw_input_to_output_weights->type; @@ -683,7 +696,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // different matrices (which requires multiplying the scaling factors with // the scaling factor of the matrix). node->temporaries->data[kScalingFactors] = - *scratch_tensor_index + kScalingFactors; + op_data->scratch_tensor_index + kScalingFactors; TfLiteTensor* scaling_factors = GetTemporary(context, node, kScalingFactors); scaling_factors->type = kTfLiteFloat32; @@ -696,7 +709,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scaling_factors_size)); } node->temporaries->data[kProductScalingFactors] = - *scratch_tensor_index + kProductScalingFactors; + op_data->scratch_tensor_index + kProductScalingFactors; TfLiteTensor* prod_scaling_factors = GetTemporary(context, node, kProductScalingFactors); prod_scaling_factors->type = kTfLiteFloat32; @@ -713,7 +726,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate a temporary tensor to store the recovered cell weights. Since // this is used for diagonal matrices, only need to store n_cell values. node->temporaries->data[kRecoveredCellWeights] = - *scratch_tensor_index + kRecoveredCellWeights; + op_data->scratch_tensor_index + kRecoveredCellWeights; TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, kRecoveredCellWeights); recovered_cell_weights->type = kTfLiteFloat32; @@ -730,7 +743,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate a temporary tensor to store the accumulated int32 values. node->temporaries->data[kAccumScratchBuffer] = - *scratch_tensor_index + kAccumScratchBuffer; + op_data->scratch_tensor_index + kAccumScratchBuffer; TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratchBuffer); accum_scratch->type = kTfLiteInt32; @@ -750,11 +763,72 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context, context->ResizeTensor(context, accum_scratch, accum_size)); } + // Allocate temporary tensors for storing zero-points. + node->temporaries->data[kZeroPoints] = + op_data->scratch_tensor_index + kZeroPoints; + TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); + zero_points->type = kTfLiteFloat32; + zero_points->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + + // Allocate temporary tensors for caching row sums for hybrid zero-point + // calculations. + int fw_row_sums_rows = fw_use_cifg ? 6 : 8; + if (has_aux_input) { + fw_row_sums_rows += fw_use_cifg ? 3 : 4; + } + const TfLiteTensor* fw_projection_weights = + GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor); + if (fw_projection_weights != nullptr) { + fw_row_sums_rows += ceil(n_fw_output / n_fw_cell); + } + node->temporaries->data[kFwRowSums] = + op_data->scratch_tensor_index + kFwRowSums; + TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); + fw_row_sums->type = kTfLiteInt32; + fw_row_sums->allocation_type = kTfLiteArenaRwPersistent; + int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell}; + if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) { + TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2); + fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0]; + fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums, + fw_hybrid_scratch_size)); + } + + int bw_row_sums_rows = bw_use_cifg ? 6 : 8; + if (has_aux_input) { + bw_row_sums_rows += bw_use_cifg ? 3 : 4; + } + const TfLiteTensor* bw_projection_weights = + GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor); + if (bw_projection_weights != nullptr) { + bw_row_sums_rows += ceil(n_bw_output / n_bw_cell); + } + node->temporaries->data[kBwRowSums] = + op_data->scratch_tensor_index + kBwRowSums; + TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); + bw_row_sums->type = kTfLiteInt32; + bw_row_sums->allocation_type = kTfLiteArenaRwPersistent; + int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell}; + if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) { + TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2); + bw_row_sums_size->data[0] = bw_row_sums_dims[0]; + bw_row_sums_size->data[1] = bw_row_sums_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums, + bw_row_sums_size)); + } + // Only allocate a temporary tensor for quantized auxiliary input if we are // actually going to use it. if (has_aux_input) { node->temporaries->data[kAuxInputQuantized] = - *scratch_tensor_index + kAuxInputQuantized; + op_data->scratch_tensor_index + kAuxInputQuantized; TfLiteTensor* aux_input_quantized = GetTemporary(context, node, kAuxInputQuantized); aux_input_quantized->type = fw_input_to_output_weights->type; @@ -775,7 +849,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast( node->builtin_data); - + auto* op_data = reinterpret_cast(node->user_data); // Input tensor. const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -909,7 +983,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Populate a TfLiteLSTMParams struct for the evaluation functions. TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip, - params->proj_clip, kTfLiteLSTMFullKernel}; + params->proj_clip, kTfLiteLSTMFullKernel, + params->asymmetric_quantize_inputs}; const int bw_output_offset = params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; @@ -1003,7 +1078,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { : nullptr; TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratchBuffer); - + TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); + TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); + TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); + const int fw_row_sums_size = fw_row_sums->dims->data[0]; + const int bw_row_sums_size = bw_row_sums->dims->data[0]; TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( input, fw_input_to_input_weights, fw_input_to_forget_weights, fw_input_to_cell_weights, fw_input_to_output_weights, @@ -1025,6 +1104,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recovered_cell_weights, input_quantized, aux_input_quantized, fw_activation_state_quantized, fw_cell_state_quantized, fw_activation_state, fw_cell_state, accum_scratch, fw_output, + zero_points, fw_row_sums, fw_row_sums_size, + &op_data->compute_fw_row_sums, CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, fw_pass_status); @@ -1049,6 +1130,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { recovered_cell_weights, input_quantized, aux_input_quantized, bw_activation_state_quantized, bw_cell_state_quantized, bw_activation_state, bw_cell_state, accum_scratch, actual_bw_output, + zero_points, bw_row_sums, bw_row_sums_size, + &op_data->compute_bw_row_sums, CpuBackendContext::GetFromContext(context)); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc index 12b33c9661d..c468c4c09fb 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm_test.cc @@ -40,7 +40,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bool use_projection_bias, bool merge_outputs, bool use_aux_input, float cell_clip, float proj_clip, bool quantize_weights, bool time_major, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + bool asymmetric_quantize_inputs = false) : n_batch_(n_batch), n_input_(n_input), n_fw_cell_(n_cell), @@ -207,12 +208,13 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bw_aux_input_to_output_weights_ = AddNullInput(); } - SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, - BuiltinOptions_BidirectionalSequenceLSTMOptions, - CreateBidirectionalSequenceLSTMOptions( - builder_, ActivationFunctionType_TANH, cell_clip, - proj_clip, merge_outputs, time_major) - .Union()); + SetBuiltinOp( + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOptions_BidirectionalSequenceLSTMOptions, + CreateBidirectionalSequenceLSTMOptions( + builder_, ActivationFunctionType_TANH, cell_clip, proj_clip, + merge_outputs, time_major, asymmetric_quantize_inputs) + .Union()); BuildInterpreter(input_shapes); } @@ -424,11 +426,14 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bool quantize_weights_; }; -// Declare LSTMOpTest as a parameterized test, where the parameter is a boolean -// indicating whether to use quantization or not. -class LSTMOpTest : public ::testing::TestWithParam {}; +// Declare LSTMOpTest as a parameterized test. +class LSTMOpTest + : public ::testing::TestWithParam<::testing::tuple> {}; -INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, ::testing::Bool()); +INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, LSTMOpTest, + ::testing::Combine( + /*quantize_weights*/ ::testing::Bool(), + /*asymmetric_quantize_inputs*/ ::testing::Bool())); TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { const int n_batch = 1; @@ -437,7 +442,9 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { const int n_cell = 4; const int n_output = 4; const int sequence_length = 3; - const bool quantize_weights = GetParam(); + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, @@ -509,7 +516,8 @@ TEST_P(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // aux_bw_input_to_forget tensor {0}, // aux_bw_input_to_cell tensor {0}, // aux_bw_input_to_output tensor - }); + }, + asymmetric_quantize_inputs); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, @@ -600,7 +608,9 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) { const int n_cell = 4; const int n_output = 4; const int sequence_length = 3; - const bool quantize_weights = GetParam(); + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, @@ -672,7 +682,8 @@ TEST_P(LSTMOpTest, BlackBoxTestMergedOutput) { {0}, // aux_bw_input_to_forget tensor {0}, // aux_bw_input_to_cell tensor {0}, // aux_bw_input_to_output tensor - }); + }, + asymmetric_quantize_inputs); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, @@ -2631,7 +2642,9 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) { const int n_cell = 4; const int n_output = 4; const int sequence_length = 3; - const bool quantize_weights = GetParam(); + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, @@ -2703,7 +2716,8 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInputZeroAuxWeight) { {n_cell, n_input}, // aux_bw_input_to_forget tensor {n_cell, n_input}, // aux_bw_input_to_cell tensor {n_cell, n_input}, // aux_bw_input_to_output tensor - }); + }, + asymmetric_quantize_inputs); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, @@ -2802,7 +2816,9 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) { const int n_cell = 4; const int n_output = 4; const int sequence_length = 3; - const bool quantize_weights = GetParam(); + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, @@ -2874,7 +2890,8 @@ TEST_P(LSTMOpTest, BlackBoxTestWithAuxInput) { {n_cell, n_input}, // aux_bw_input_to_forget tensor {n_cell, n_input}, // aux_bw_input_to_cell tensor {n_cell, n_input}, // aux_bw_input_to_output tensor - }); + }, + asymmetric_quantize_inputs); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc index db456d539b9..58a2ef9c1ea 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc @@ -27,6 +27,16 @@ namespace ops { namespace builtin { namespace bidirectional_sequence_rnn { +namespace { + +struct OpData { + int scratch_tensor_index; + bool fw_compute_row_sums = false; + bool bw_compute_row_sums = false; +}; + +} // namespace + // LINT.IfChange constexpr int kInputTensor = 0; @@ -58,18 +68,23 @@ enum TemporaryTensor { kFwHiddenStateQuantized = 1, kBwHiddenStateQuantized = 2, kScalingFactors = 3, - kAuxInputQuantized = 4, - kNumTemporaryTensors = 5 + kAccumScratch = 4, + kZeroPoints = 5, + kFwRowSums = 6, + kBwRowSums = 7, + kAuxInputQuantized = 8, + kNumTemporaryTensors = 9 }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, kNumTemporaryTensors, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { @@ -157,8 +172,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } if (IsHybridOp(input, fw_input_weights)) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); - + OpData* op_data = reinterpret_cast(node->user_data); + op_data->fw_compute_row_sums = true; + op_data->bw_compute_row_sums = true; TfLiteIntArrayFree(node->temporaries); if (has_aux_input) { node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); @@ -168,7 +184,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } node->temporaries->data[kInputQuantized] = - *scratch_tensor_index + kInputQuantized; + op_data->scratch_tensor_index + kInputQuantized; TfLiteTensor* input_quantized = GetTemporary(context, node, kInputQuantized); input_quantized->type = fw_input_weights->type; @@ -180,7 +196,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } node->temporaries->data[kFwHiddenStateQuantized] = - *scratch_tensor_index + kFwHiddenStateQuantized; + op_data->scratch_tensor_index + kFwHiddenStateQuantized; TfLiteTensor* fw_hidden_state_quantized = GetTemporary(context, node, kFwHiddenStateQuantized); fw_hidden_state_quantized->type = fw_input_weights->type; @@ -195,7 +211,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } node->temporaries->data[kBwHiddenStateQuantized] = - *scratch_tensor_index + kBwHiddenStateQuantized; + op_data->scratch_tensor_index + kBwHiddenStateQuantized; TfLiteTensor* bw_hidden_state_quantized = GetTemporary(context, node, kBwHiddenStateQuantized); bw_hidden_state_quantized->type = fw_input_weights->type; @@ -211,7 +227,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate temporary tensors to store scaling factors of quantization. node->temporaries->data[kScalingFactors] = - *scratch_tensor_index + kScalingFactors; + op_data->scratch_tensor_index + kScalingFactors; TfLiteTensor* scaling_factors = GetTemporary(context, node, kScalingFactors); scaling_factors->type = kTfLiteFloat32; @@ -223,10 +239,66 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, scaling_factors_size)); } - + node->temporaries->data[kAccumScratch] = + op_data->scratch_tensor_index + kAccumScratch; + TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units), + batch_size}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); + accum_scratch_size->data[0] = accum_scratch_dims[0]; + accum_scratch_size->data[1] = accum_scratch_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, + accum_scratch_size)); + } + node->temporaries->data[kZeroPoints] = + op_data->scratch_tensor_index + kZeroPoints; + TfLiteTensor* zero_points = + GetTemporary(context, node, /*index=*/kZeroPoints); + zero_points->type = kTfLiteInt32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {batch_size}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = batch_size; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + const int num_row_sums = has_aux_input ? 3 : 2; + node->temporaries->data[kFwRowSums] = + op_data->scratch_tensor_index + kFwRowSums; + TfLiteTensor* fw_row_sums = + GetTemporary(context, node, /*index=*/kFwRowSums); + fw_row_sums->type = kTfLiteInt32; + fw_row_sums->allocation_type = kTfLiteArenaRwPersistent; + int fw_row_sums_dims[2] = {num_row_sums, fw_num_units}; + if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) { + TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2); + fw_row_sums_size->data[0] = fw_row_sums_dims[0]; + fw_row_sums_size->data[1] = fw_row_sums_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums, + fw_row_sums_size)); + } + node->temporaries->data[kBwRowSums] = + op_data->scratch_tensor_index + kBwRowSums; + TfLiteTensor* bw_row_sums = GetTemporary(context, node, + /*index=*/kBwRowSums); + bw_row_sums->type = kTfLiteInt32; + bw_row_sums->allocation_type = kTfLiteArenaRwPersistent; + int bw_row_sums_dims[2] = {num_row_sums, bw_num_units}; + if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) { + TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2); + bw_row_sums_size->data[0] = bw_row_sums_dims[0]; + bw_row_sums_size->data[1] = bw_row_sums_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums, + bw_row_sums_size)); + } if (has_aux_input) { node->temporaries->data[kAuxInputQuantized] = - *scratch_tensor_index + kAuxInputQuantized; + op_data->scratch_tensor_index + kAuxInputQuantized; TfLiteTensor* aux_input_quantized = GetTemporary(context, node, kAuxInputQuantized); aux_input_quantized->type = fw_input_weights->type; @@ -418,7 +490,10 @@ TfLiteStatus EvalHybrid( TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, - TfLiteTensor* bw_output) { + TfLiteTensor* bw_output, TfLiteTensor* zero_points, + TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums, + TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums, + bool* bw_compute_row_sums) { const bool time_major = params->time_major; const int batch_size = (time_major) ? input->dims->data[1] : input->dims->data[0]; @@ -464,11 +539,20 @@ TfLiteStatus EvalHybrid( int8_t* bw_quantized_hidden_state_ptr = GetTensorData(bw_hidden_state_quantized); float* scaling_factors_ptr = GetTensorData(scaling_factors); - + int32_t* accum_scratch_ptr = GetTensorData(accum_scratch); + int32_t* zero_points_ptr = nullptr; + int32_t* fw_row_sums_ptr = nullptr; + int32_t* bw_row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + zero_points_ptr = GetTensorData(zero_points); + fw_row_sums_ptr = GetTensorData(fw_row_sums); + bw_row_sums_ptr = GetTensorData(bw_row_sums); + } const int fw_output_step = params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; const int bw_output_step = params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; + if (time_major) { for (int t = 0; t < max_time; t++) { // Forward cell. @@ -491,7 +575,9 @@ TfLiteStatus EvalHybrid( fw_num_units, batch_size, fw_output_step, params->activation, quantized_input_ptr, aux_quantized_input_ptr, fw_quantized_hidden_state_ptr, scaling_factors_ptr, - fw_hidden_state_ptr_batch, output_ptr_batch); + fw_hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums); } // Backward cell. float* bw_hidden_state_ptr_batch = GetTensorData(bw_hidden_state); @@ -516,7 +602,9 @@ TfLiteStatus EvalHybrid( bw_num_units, batch_size, bw_output_step, params->activation, quantized_input_ptr, aux_quantized_input_ptr, bw_quantized_hidden_state_ptr, scaling_factors_ptr, - bw_hidden_state_ptr_batch, output_ptr_batch); + bw_hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums); } } } else { @@ -545,7 +633,9 @@ TfLiteStatus EvalHybrid( fw_num_units, /*batch_size=*/1, fw_output_step, params->activation, quantized_input_ptr, aux_quantized_input_ptr, fw_quantized_hidden_state_ptr, scaling_factors_ptr, - fw_hidden_state_ptr_batch, output_ptr_batch); + fw_hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums); } // Backward cell. float* bw_hidden_state_ptr_batch = @@ -574,7 +664,9 @@ TfLiteStatus EvalHybrid( bw_num_units, /*batch_size=*/1, bw_output_step, params->activation, quantized_input_ptr, aux_quantized_input_ptr, bw_quantized_hidden_state_ptr, scaling_factors_ptr, - bw_hidden_state_ptr_batch, output_ptr_batch); + bw_hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums); } } } @@ -656,17 +748,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, kBwHiddenStateQuantized); TfLiteTensor* scaling_factors = GetTemporary(context, node, kScalingFactors); + TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); + TfLiteTensor* accum_scratch = GetTemporary(context, node, kAccumScratch); + TfLiteTensor* fw_row_sums = GetTemporary(context, node, kFwRowSums); + TfLiteTensor* bw_row_sums = GetTemporary(context, node, kBwRowSums); TfLiteTensor* aux_input_quantized = use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) : nullptr; - - return EvalHybrid(input, bw_input, fw_input_weights, fw_recurrent_weights, - fw_bias, bw_input_weights, bw_recurrent_weights, - bw_bias, real_aux_input, fw_aux_input_weights, - bw_aux_input_weights, params, scaling_factors, - input_quantized, aux_input_quantized, - fw_hidden_state_quantized, fw_hidden_state, fw_output, - bw_hidden_state_quantized, bw_hidden_state, bw_output); + auto* op_data = reinterpret_cast(node->user_data); + return EvalHybrid( + input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias, + bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input, + fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors, + input_quantized, aux_input_quantized, fw_hidden_state_quantized, + fw_hidden_state, fw_output, bw_hidden_state_quantized, + bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums, + bw_row_sums, &op_data->fw_compute_row_sums, + &op_data->bw_compute_row_sums); } default: context->ReportError(context, "Type not currently supported."); diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc index 34441e2b300..4a7cc9a016d 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -662,20 +662,24 @@ class BidirectionalRNNOpModel : public SingleOpModel { BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, int bw_units, int input_size, int aux_input_size, AuxInputMode aux_input_mode, bool time_major, - bool merge_outputs) + bool merge_outputs, bool quantize_weights = false, + bool asymmetric_quantize_weights = false) : batches_(batches), sequence_len_(sequence_len), fw_units_(fw_units), bw_units_(bw_units), input_size_(input_size), - aux_input_size_(aux_input_size) { + aux_input_size_(aux_input_size), + quantize_weights_(quantize_weights) { + const TensorType tensor_type = + quantize_weights ? TensorType_UINT8 : TensorType_FLOAT32; input_ = AddInput(TensorType_FLOAT32); - fw_weights_ = AddInput(TensorType_FLOAT32); - fw_recurrent_weights_ = AddInput(TensorType_FLOAT32); + fw_weights_ = AddInput(tensor_type); + fw_recurrent_weights_ = AddInput(tensor_type); fw_bias_ = AddInput(TensorType_FLOAT32); fw_hidden_state_ = AddInput(TensorType_FLOAT32, true); - bw_weights_ = AddInput(TensorType_FLOAT32); - bw_recurrent_weights_ = AddInput(TensorType_FLOAT32); + bw_weights_ = AddInput(tensor_type); + bw_recurrent_weights_ = AddInput(tensor_type); bw_bias_ = AddInput(TensorType_FLOAT32); bw_hidden_state_ = AddInput(TensorType_FLOAT32, true); @@ -697,8 +701,8 @@ class BidirectionalRNNOpModel : public SingleOpModel { } if (aux_input_mode == AuxInputMode::kCrossLinking) { - aux_fw_weights_ = AddInput(TensorType_FLOAT32); - aux_bw_weights_ = AddInput(TensorType_FLOAT32); + aux_fw_weights_ = AddInput(tensor_type); + aux_bw_weights_ = AddInput(tensor_type); aux_fw_weights_shape = {fw_units, aux_input_size_}; aux_bw_weights_shape = {bw_units, aux_input_size_}; @@ -712,12 +716,12 @@ class BidirectionalRNNOpModel : public SingleOpModel { bw_output_ = AddOutput(TensorType_FLOAT32); } - SetBuiltinOp( - BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, - BuiltinOptions_BidirectionalSequenceRNNOptions, - CreateBidirectionalSequenceRNNOptions( - builder_, time_major, ActivationFunctionType_RELU, merge_outputs) - .Union()); + SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_BidirectionalSequenceRNNOptions, + CreateBidirectionalSequenceRNNOptions( + builder_, time_major, ActivationFunctionType_RELU, + merge_outputs, asymmetric_quantize_weights) + .Union()); BuildInterpreter({ input_shape, // input @@ -744,19 +748,35 @@ class BidirectionalRNNOpModel : public SingleOpModel { } void SetFwWeights(const std::vector& f) { - PopulateTensor(fw_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(fw_weights_, f); + } else { + PopulateTensor(fw_weights_, f); + } } void SetBwWeights(const std::vector& f) { - PopulateTensor(bw_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(bw_weights_, f); + } else { + PopulateTensor(bw_weights_, f); + } } void SetFwRecurrentWeights(const std::vector& f) { - PopulateTensor(fw_recurrent_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(fw_recurrent_weights_, f); + } else { + PopulateTensor(fw_recurrent_weights_, f); + } } void SetBwRecurrentWeights(const std::vector& f) { - PopulateTensor(bw_recurrent_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(bw_recurrent_weights_, f); + } else { + PopulateTensor(bw_recurrent_weights_, f); + } } void SetInput(std::initializer_list data) { @@ -772,11 +792,19 @@ class BidirectionalRNNOpModel : public SingleOpModel { } void SetAuxFwWeights(const std::vector& f) { - PopulateTensor(aux_fw_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(aux_fw_weights_, f); + } else { + PopulateTensor(aux_fw_weights_, f); + } } void SetAuxBwWeights(const std::vector& f) { - PopulateTensor(aux_bw_weights_, f); + if (quantize_weights_) { + SymmetricQuantizeAndPopulate(aux_bw_weights_, f); + } else { + PopulateTensor(aux_bw_weights_, f); + } } std::vector GetFwOutput() { return ExtractVector(fw_output_); } @@ -811,17 +839,31 @@ class BidirectionalRNNOpModel : public SingleOpModel { int bw_units_; int input_size_; int aux_input_size_; + bool quantize_weights_; }; +// Declare LSTMOpTest as a parameterized test. +class BidirectionalRNNOpTest + : public ::testing::TestWithParam<::testing::tuple> {}; + +INSTANTIATE_TEST_SUITE_P(QuantizationOrNot, BidirectionalRNNOpTest, + ::testing::Combine( + /*quantize_weights*/ ::testing::Bool(), + /*asymmetric_quantize_inputs*/ ::testing::Bool())); + // TODO(mirkov): add another test which directly compares to TF once TOCO // supports the conversion from dynamic_rnn with BasicRNNCell. -TEST(BidirectionalRNNOpTest, BlackBoxTest) { +TEST_P(BidirectionalRNNOpTest, BlackBoxTest) { + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, /*input_size=*/8, /*aux_input_size=*/0, /*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*time_major=*/false, - /*merge_outputs=*/false); + /*merge_outputs=*/false, quantize_weights, + asymmetric_quantize_inputs); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -843,7 +885,9 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { std::vector fw_expected; fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); - EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); + EXPECT_THAT(rnn.GetFwOutput(), + ElementsAreArray(ArrayFloatNear( + fw_expected, quantize_weights ? 1.42e-2 : 1e-5))); float* golden_bw_start = rnn_golden_bw_output; float* golden_bw_end = @@ -851,17 +895,23 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { std::vector bw_expected; bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); - EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); + EXPECT_THAT(rnn.GetBwOutput(), + ElementsAreArray(ArrayFloatNear( + bw_expected, quantize_weights ? 1.42e-2 : 1e-5))); } // Same as BlackBox test, but input is reshuffled to time_major format. -TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { +TEST_P(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, /*input_size=*/8, /*aux_input_size=*/0, /*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*time_major=*/true, - /*merge_outputs=*/false); + /*merge_outputs=*/false, quantize_weights, + asymmetric_quantize_inputs); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -889,17 +939,26 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); } - EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); + constexpr float kHybridTolerance = 3.57e-1; + constexpr float kFloatTolerance = 1e-5; + EXPECT_THAT( + rnn.GetFwOutput(), + ElementsAreArray(ArrayFloatNear( + fw_expected, quantize_weights ? kHybridTolerance : kFloatTolerance))); } // Same as BlackBox test, yet with merged outputs. -TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { +TEST_P(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { + auto params = GetParam(); + const bool quantize_weights = std::get<0>(params); + const bool asymmetric_quantize_inputs = std::get<1>(params); BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, /*input_size=*/8, /*aux_input_size=*/0, /*aux_input_mode=*/AuxInputMode::kNoAuxInput, /*time_major=*/false, - /*merge_outputs=*/true); + /*merge_outputs=*/true, quantize_weights, + asymmetric_quantize_inputs); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -929,7 +988,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { } } EXPECT_THAT(rnn.GetFwOutput(), - ElementsAreArray(ArrayFloatNear(merged_expected))); + ElementsAreArray(ArrayFloatNear( + merged_expected, quantize_weights ? 1.42e-2 : 1e-5))); } // Same as BlackBox test, but input is reshuffled to time_major format. diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index fc6f1991fd3..5faf13303d8 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -71,6 +71,7 @@ struct OpData { int32_t output_activation_max; // The index of the temporary tensor where the quantized inputs are cached. int scratch_tensor_index; + bool compute_row_sums = false; }; constexpr int kInputTensor = 0; @@ -131,7 +132,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Instead, we allocate a new object to carry information from Prepare() to // Eval(). auto* op_data = new OpData(); - context->AddTensors(context, /*tensors_to_add=*/3, + context->AddTensors(context, /*tensors_to_add=*/5, &op_data->scratch_tensor_index); return op_data; } @@ -144,7 +145,6 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); - // Check we have all the inputs and outputs we need. TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3); // Shuffled formats need a workspace to store the shuffled input activations. @@ -208,7 +208,8 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteFloat32 && (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) { TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(3); + data->compute_row_sums = true; + node->temporaries = TfLiteIntArrayCreate(5); node->temporaries->data[0] = data->scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); @@ -245,6 +246,28 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } + + node->temporaries->data[3] = data->scratch_tensor_index + 3; + TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3); + input_offsets->type = kTfLiteInt32; + input_offsets->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) { + TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1); + input_offsets_size->data[0] = batch_size; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets, + input_offsets_size)); + } + node->temporaries->data[4] = data->scratch_tensor_index + 4; + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_dims[1] = {num_units}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1); + row_sums_size->data[0] = row_sums_dims[0]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } // Resize output. @@ -332,7 +355,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* input_quantized, - TfLiteTensor* scaling_factors, TfLiteTensor* output) { + TfLiteTensor* scaling_factors, + TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, + TfLiteTensor* input_offsets, TfLiteTensor* output) { int total_input_size = 1; for (int i = 0; i < input->dims->size; i++) { total_input_size *= input->dims->data[i]; @@ -363,32 +388,39 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, // Quantize input from float to uint8 + quantization params (scaling factor). float unused_min, unused_max; float* scaling_factors_ptr = GetTensorData(scaling_factors); + int32_t* input_offset_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + input_offset_ptr = GetTensorData(input_offsets); + row_sums_ptr = GetTensorData(row_sums); + } int8_t* quant_data = GetTensorData(input_quantized); const int8_t* filter_data = GetTensorData(filter); - + const float* input_ptr = GetTensorData(input); // Quantize each batch independently. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; - tensor_utils::SymmetricQuantizeFloats( - GetTensorData(input) + offset, input_size, quant_data + offset, - &unused_min, &unused_max, &scaling_factors_ptr[b]); + if (params->asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr + offset, input_size, quant_data + offset, + &scaling_factors_ptr[b], &input_offset_ptr[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + input_ptr + offset, input_size, quant_data + offset, &unused_min, + &unused_max, &scaling_factors_ptr[b]); + } // Incorporate scaling of the filter. scaling_factors_ptr[b] *= filter->params.scale; } // Compute output += weight * quantized_input -#ifdef TFLITE_WITH_RUY_GEMV - TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2); int32_t* scratch = GetTensorData(accum_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate( filter_data, num_units, input_size, quant_data, scaling_factors_ptr, - batch_size, scratch, GetTensorData(output), + batch_size, GetTensorData(output), /*per_channel_scale=*/nullptr, + input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums, CpuBackendContext::GetFromContext(context)); -#else - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - filter_data, num_units, input_size, quant_data, scaling_factors_ptr, - batch_size, GetTensorData(output)); -#endif + // Apply activation function to floats. tensor_utils::ApplyActivationToVector( GetTensorData(output), batch_size * num_units, params->activation, @@ -461,8 +493,12 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, if (input->type == kTfLiteFloat32) { TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/2); + TfLiteTensor* input_offsets = GetTemporary(context, node, /*index=*/3); + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/4); return EvalHybrid(context, node, params, data, input, filter, bias, - input_quantized, scaling_factors, output); + input_quantized, scaling_factors, accum_scratch, row_sums, + input_offsets, output); } else { FullyConnectedParams op_params; op_params.input_offset = input_offset; @@ -590,7 +626,6 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, FullyConnectedParams op_params; op_params.float_activation_min = output_activation_min; op_params.float_activation_max = output_activation_max; - reference_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), GetTensorData(filter), diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index 1f671cae0fc..fbc02dd741d 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -286,7 +286,8 @@ class HybridFullyConnectedOpModel : public SingleOpModel { public: HybridFullyConnectedOpModel(int units, int batches, const TensorData& input, const TensorData& weights, - const TensorData& output = {TensorType_FLOAT32}) + const TensorData& output = {TensorType_FLOAT32}, + bool asymmetric_inputs = false) : batches_(batches), units_(units) { int total_input_size = 1; for (size_t i = 0; i < input.shape.size(); ++i) { @@ -302,10 +303,13 @@ class HybridFullyConnectedOpModel : public SingleOpModel { output_ = AddOutput(output); - SetBuiltinOp( - BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) - .Union()); + auto options = CreateFullyConnectedOptions( + builder_, ActivationFunctionType_RELU, + tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, + false, asymmetric_inputs) + .Union(); + SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED, + BuiltinOptions_FullyConnectedOptions, options); resolver_ = absl::make_unique( BuiltinOperator_FULLY_CONNECTED, ops::builtin::Register_FULLY_CONNECTED_PIE()); @@ -867,6 +871,66 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) { /*max_abs_error=*/1.3f))); } +TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) { + HybridFullyConnectedOpModel m( + /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}, + /*weights=*/ + {TensorType_UINT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, {TensorType_FLOAT32}, + /*asymmetric_quantize_input*/ true); // Hybrid asymmetric + + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 24, 25, 26, // + 58, 59, 60, // + }, + /*max_abs_error=*/0.64f))); +} + +TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) { + HybridFullyConnectedOpModel m( + /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}, + /*weights=*/{TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, + {TensorType_FLOAT32}, + /*asymmetric_quantize_input*/ true); + + m.SetSignedWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 24, 25, 26, // + 58, 59, 60, // + }, + /*max_abs_error=*/1.3f))); +} + TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) { // Note that it is not required that the first dimension be the number of // batches. All we care is that the input can be evenly distributed in diff --git a/tensorflow/lite/kernels/internal/kernel_utils.cc b/tensorflow/lite/kernels/internal/kernel_utils.cc index 21c058c394b..f34cee02f4d 100644 --- a/tensorflow/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/lite/kernels/internal/kernel_utils.cc @@ -123,7 +123,9 @@ void RnnBatchStep( int num_units, int batch_size, int output_batch_leading_dim, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, - float* hidden_state_ptr_batch, float* output_ptr_batch) { + float* hidden_state_ptr_batch, float* output_ptr_batch, + bool asymmetric_quantize_inputs, int32_t* zero_points, + int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) { RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale, /*aux_input_ptr_batch=*/nullptr, /*aux_input_weights_ptr=*/nullptr, @@ -133,7 +135,29 @@ void RnnBatchStep( output_batch_leading_dim, activation, quantized_input_ptr_batch, /*aux_quantized_input_ptr_batch=*/nullptr, quantized_hidden_state_ptr_batch, scaling_factors, - hidden_state_ptr_batch, output_ptr_batch); + hidden_state_ptr_batch, output_ptr_batch, + asymmetric_quantize_inputs, zero_points, accum_scratch, row_sums, + compute_row_sums); +} + +void ComputeMatrixSums(int32_t* input_row_sums, int32_t* aux_input_row_sums, + int32_t* recurrent_row_sums, int32_t* row_sums, + const float* aux_input_ptr_batch, int num_units, + int input_size, int aux_input_size, + const int8_t* input_weights_ptr, + const int8_t* aux_input_weights_ptr, + const int8_t* recurrent_weights_ptr) { + memset(input_row_sums, 0, sizeof(int32_t) * num_units); + tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums, num_units, + input_size); + if (aux_input_ptr_batch) { + memset(aux_input_row_sums, 0, sizeof(int32_t) * num_units); + tensor_utils::ReductionSumVector(aux_input_weights_ptr, aux_input_row_sums, + num_units, aux_input_size); + } + memset(recurrent_row_sums, 0, sizeof(int32_t) * num_units); + tensor_utils::ReductionSumVector(recurrent_weights_ptr, recurrent_row_sums, + num_units, num_units); } void RnnBatchStep( @@ -146,9 +170,31 @@ void RnnBatchStep( TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, - float* hidden_state_ptr_batch, float* output_ptr_batch) { + float* hidden_state_ptr_batch, float* output_ptr_batch, + bool asymmetric_quantize_inputs, int32_t* zero_points, + int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums) { // Since the output batch rows may not be contiguous (output_batch_leading_dim // != n_output), we unroll the batched operations where this is the case. + + int32_t* input_row_sums = nullptr; + int32_t* aux_input_row_sums = nullptr; + int32_t* recurrent_row_sums = nullptr; + if (asymmetric_quantize_inputs) { + input_row_sums = row_sums; + aux_input_row_sums = row_sums; + if (aux_input_ptr_batch) { + aux_input_row_sums += num_units; + } + recurrent_row_sums = aux_input_row_sums + num_units; + if (*compute_row_sums) { + ComputeMatrixSums(input_row_sums, aux_input_row_sums, recurrent_row_sums, + row_sums, aux_input_ptr_batch, num_units, input_size, + aux_input_size, input_weights_ptr, + aux_input_weights_ptr, recurrent_weights_ptr); + *compute_row_sums = false; + } + } + if (output_batch_leading_dim == num_units) { // Output = bias tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, @@ -163,17 +209,25 @@ void RnnBatchStep( // whichever is faster. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; - tensor_utils::SymmetricQuantizeFloats( - input_ptr_batch + offset, input_size, - quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } scaling_factors[b] *= input_weights_scale; } - // Output += input * input_weights tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_weights_ptr, num_units, input_size, quantized_input_ptr_batch, - scaling_factors, batch_size, output_ptr_batch); + scaling_factors, batch_size, output_ptr_batch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch, + input_row_sums, compute_row_sums, /*context=*/nullptr); } if (aux_input_ptr_batch && @@ -182,10 +236,17 @@ void RnnBatchStep( float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * aux_input_size; - tensor_utils::SymmetricQuantizeFloats( - aux_input_ptr_batch + offset, aux_input_size, - aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + aux_input_ptr_batch + offset, aux_input_size, + aux_quantized_input_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + aux_input_ptr_batch + offset, aux_input_size, + aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } scaling_factors[b] *= aux_input_weights_scale; } @@ -193,7 +254,9 @@ void RnnBatchStep( tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_weights_ptr, num_units, aux_input_size, aux_quantized_input_ptr_batch, scaling_factors, batch_size, - output_ptr_batch); + output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch, aux_input_row_sums, compute_row_sums, + /*context=*/nullptr); } // Save quantization and matmul computation for all zero input. @@ -203,10 +266,17 @@ void RnnBatchStep( float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * num_units; - tensor_utils::SymmetricQuantizeFloats( - hidden_state_ptr_batch + offset, num_units, - quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + hidden_state_ptr_batch + offset, num_units, + quantized_hidden_state_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + hidden_state_ptr_batch + offset, num_units, + quantized_hidden_state_ptr_batch + offset, &unused_min, + &unused_max, &scaling_factors[b]); + } scaling_factors[b] *= recurrent_weights_scale; } @@ -214,7 +284,9 @@ void RnnBatchStep( tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_weights_ptr, num_units, num_units, quantized_hidden_state_ptr_batch, scaling_factors, batch_size, - output_ptr_batch); + output_ptr_batch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch, recurrent_row_sums, compute_row_sums, + /*context=*/nullptr); } // Output = activation(Output) and update hidden_state @@ -238,10 +310,17 @@ void RnnBatchStep( // whichever is faster. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; - tensor_utils::SymmetricQuantizeFloats( - input_ptr_batch + offset, input_size, - quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, input_size, + quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } scaling_factors[b] *= input_weights_scale; } @@ -250,7 +329,9 @@ void RnnBatchStep( tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_weights_ptr, num_units, input_size, quantized_input_ptr_batch + k * input_size, &scaling_factors[k], - /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim); + /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim, + /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch, + input_row_sums, compute_row_sums, /*context=*/nullptr); } } @@ -260,10 +341,17 @@ void RnnBatchStep( float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * aux_input_size; - tensor_utils::SymmetricQuantizeFloats( - aux_input_ptr_batch + offset, aux_input_size, - aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + aux_input_ptr_batch + offset, aux_input_size, + aux_quantized_input_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + aux_input_ptr_batch + offset, aux_input_size, + aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } scaling_factors[b] *= aux_input_weights_scale; } @@ -273,7 +361,9 @@ void RnnBatchStep( aux_input_weights_ptr, num_units, aux_input_size, aux_quantized_input_ptr_batch + k * aux_input_size, &scaling_factors[k], - /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim); + /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim, + /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch, + aux_input_row_sums, compute_row_sums, /*context=*/nullptr); } } @@ -284,10 +374,17 @@ void RnnBatchStep( float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * num_units; - tensor_utils::SymmetricQuantizeFloats( - hidden_state_ptr_batch + offset, num_units, - quantized_hidden_state_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + hidden_state_ptr_batch + offset, num_units, + quantized_hidden_state_ptr_batch + offset, &scaling_factors[b], + &zero_points[b]); + } else { + tensor_utils::SymmetricQuantizeFloats( + hidden_state_ptr_batch + offset, num_units, + quantized_hidden_state_ptr_batch + offset, &unused_min, + &unused_max, &scaling_factors[b]); + } scaling_factors[b] *= recurrent_weights_scale; } @@ -296,8 +393,10 @@ void RnnBatchStep( tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_weights_ptr, num_units, num_units, quantized_hidden_state_ptr_batch + k * num_units, - &scaling_factors[k], - /*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim); + &scaling_factors[k], /*n_batch=*/1, + output_ptr_batch + k * output_batch_leading_dim, + /*per_channel_scale=*/nullptr, zero_points + k, accum_scratch, + recurrent_row_sums, compute_row_sums, /*context=*/nullptr); } } diff --git a/tensorflow/lite/kernels/internal/kernel_utils.h b/tensorflow/lite/kernels/internal/kernel_utils.h index ebb91678fec..2f551570e17 100644 --- a/tensorflow/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/lite/kernels/internal/kernel_utils.h @@ -70,7 +70,9 @@ void RnnBatchStep( int num_units, int batch_size, int output_batch_leading_dim, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, - float* hidden_state_ptr_batch, float* output_ptr_batch); + float* hidden_state_ptr_batch, float* output_ptr_batch, + bool asymmetric_quantize_inputs, int32_t* zero_points, + int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums); void RnnBatchStep( const float* input_ptr_batch, const int8_t* input_weights_ptr, @@ -82,7 +84,9 @@ void RnnBatchStep( TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, - float* hidden_state_ptr_batch, float* output_ptr_batch); + float* hidden_state_ptr_batch, float* output_ptr_batch, + bool asymmetric_quantize_inputs, int32_t* zero_points, + int32_t* accum_scratch, int32_t* row_sums, bool* compute_row_sums); } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 86e2f9fa96a..af9ffba2c7c 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -1310,6 +1310,13 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( const int postamble_half_start = m_cols & ~(kWeightsPerNeonLane - 1); const int postamble_start = m_cols & ~((kWeightsPerNeonLane >> 1) - 1); + int32_t* row_sums_ptr = row_sums; + if (row_sums == nullptr) { + row_sums_ptr = static_cast(malloc(sizeof(int32_t) * m_rows)); + memset(row_sums_ptr, 0, sizeof(int32_t) * m_rows); + NeonReductionSumVector(matrix, row_sums_ptr, m_rows, m_cols); + } + for (int batch = 0; batch < n_batch; ++batch) { const float batch_scaling_factor = scaling_factors[batch]; const int batch_input_offset = input_offset[batch]; @@ -1327,10 +1334,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( // Initialize the dot product sum for the row to 0. int32x4_t dotprod_32x4 = vmovq_n_s32(0); - int32x4_t row_sum_32x4; - if (row_sums == nullptr) { - row_sum_32x4 = vmovq_n_s32(0); - } // Prefetch the row to cache. __builtin_prefetch(row_ptr, 0 /* prefetch for read */, 3 /* temporal locality */); @@ -1358,10 +1361,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( prod_16x8 = vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16)); dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8); - if (row_sums == nullptr) { - const int16x8_t row_sum_16x8 = vpaddlq_s8(s2_8x16); - row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8); - } } // for col // Half iteration dealing only 8 elements @@ -1375,29 +1374,24 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl( const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + col)); const int16x8_t prod_16x8 = vmull_s8(s1_8x8, s2_8x8); dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8); - if (row_sums == nullptr) { - const int16x8_t row_sum_16x8 = vmovl_s8(s2_8x8); - row_sum_32x4 = vpadalq_s16(row_sum_32x4, row_sum_16x8); - } col += (kWeightsPerNeonLane >> 1); } int32_t dotprod = AccumulateNeonLane(dotprod_32x4); - int32_t row_sum = row_sums == nullptr ? AccumulateNeonLane(row_sum_32x4) - : row_sums[row]; // Postamble loop. for (; col < m_cols; ++col) { dotprod += row_ptr[col] * aligned_vec[col]; - if (row_sums == nullptr) { - row_sum += row_ptr[col]; - } } // for col - dotprod -= row_sum * batch_input_offset; + dotprod -= row_sums_ptr[row] * batch_input_offset; *result += dotprod * scale; ++result; } // for row } // for batch + + if (row_sums == nullptr) { + free(row_sums_ptr); + } if (unaligned) { free(aligned_row_free); } @@ -1410,6 +1404,20 @@ void NeonMatrixBatchVectorMultiplyAccumulate( int n_batch, float* __restrict__ result, const float* per_channel_scale, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, bool* compute_row_sums, CpuBackendContext* context) { + if (input_offset == nullptr) { +#ifdef TFLITE_WITH_RUY_GEMV + if (context) { + NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, + scaling_factors, n_batch, scratch, + result, context); + return; + } +#endif + NeonMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, + scaling_factors, n_batch, result); + return; + } + if (compute_row_sums == nullptr || *compute_row_sums) { memset(row_sums, 0, sizeof(int32_t) * m_rows); NeonReductionSumVector(matrix, row_sums, m_rows, m_cols); @@ -1419,7 +1427,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate( } #ifdef TFLITE_WITH_RUY_GEMV - if (m_rows % 4 == 0) { + if (context != nullptr && m_rows % 4 == 0) { const int32_t* bias = static_cast(nullptr); NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows, 0, scratch, context); @@ -1463,9 +1471,9 @@ void NeonMatrixBatchVectorMultiplyAccumulate( for (; i < total_size; i++) { const float batch_scaling_factor = scaling_factors[i / m_rows]; const int32_t zero_point = input_offset[i / m_rows]; - int32_t x = *(scratch_ptr++); - x -= row_sums[i % m_rows] * zero_point; - *result += x * batch_scaling_factor; + int32_t dotprod = *(scratch_ptr++); + dotprod -= row_sums[i % m_rows] * zero_point; + *result += dotprod * batch_scaling_factor; ++result; } return; diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc index fe970dd8b39..7fb69e7b4f4 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc @@ -167,6 +167,11 @@ void SseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ scaling_factors, int n_batch, float* __restrict__ result, const float* __restrict__ per_channel_scale, const int32_t* __restrict__ input_offset) { + if (input_offset == nullptr) { + SseMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, + scaling_factors, n_batch, result); + return; + } static constexpr std::intptr_t kBlockSize = 16; for (std::intptr_t batch = 0; batch < n_batch; ++batch) { const float batch_scaling_factor = scaling_factors[batch]; diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h index fa6f2c7a8db..1d0d2273e93 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h @@ -59,9 +59,10 @@ void MatrixBatchVectorMultiplyAccumulate( int n_batch, float* __restrict__ result, const float* per_channel_scale, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, bool* compute_row_sums, CpuBackendContext* context) { - NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, - vectors, scaling_factors, n_batch, result, per_channel_scale, - input_offset, scratch, row_sums, compute_row_sums, context); + PortableMatrixBatchVectorMultiplyAccumulate( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, + per_channel_scale, input_offset, scratch, row_sums, compute_row_sums, + context); } void MatrixBatchVectorMultiplyAccumulate( diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index 9c58415d6dc..19c74973aeb 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -196,6 +196,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate( int n_batch, float* __restrict__ result, const float* per_channel_scale, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, bool* compute_row_sums, CpuBackendContext* context) { + if (input_offset == nullptr) { + PortableMatrixBatchVectorMultiplyAccumulate( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result); + return; + } if (!compute_row_sums || *compute_row_sums) { memset(row_sums, 0, sizeof(int32_t) * m_rows); PortableReductionSumVector(matrix, row_sums, m_rows, m_cols); diff --git a/tensorflow/lite/kernels/internal/reference/svdf.h b/tensorflow/lite/kernels/internal/reference/svdf.h index 10c2e2cd849..18e4e079293 100644 --- a/tensorflow/lite/kernels/internal/reference/svdf.h +++ b/tensorflow/lite/kernels/internal/reference/svdf.h @@ -223,7 +223,8 @@ inline void EvalHybridSVDF( const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time, const TfLiteTensor* bias, const TfLiteSVDFParams* params, TfLiteTensor* scratch, TfLiteTensor* scaling_factors, - TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) { + TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output, + TfLiteTensor* zero_points, TfLiteTensor* row_sums, bool* compute_row_sums) { const int rank = params->rank; const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; @@ -244,6 +245,13 @@ inline void EvalHybridSVDF( float* output_ptr = GetTensorData(output); + int32_t* zero_points_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs && row_sums != nullptr) { + zero_points_ptr = GetTensorData(zero_points); + row_sums_ptr = GetTensorData(row_sums); + } + // Initialize the weights scale. const float weights_feature_scale = weights_feature->params.scale; @@ -258,21 +266,30 @@ inline void EvalHybridSVDF( if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) { // Quantize input from float to int8. - float unused_min, unused_max; for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; - tensor_utils::SymmetricQuantizeFloats( - input_ptr + offset, input_size, quantized_input_ptr + offset, - &unused_min, &unused_max, &scaling_factors_ptr[b]); + if (params->asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr + offset, input_size, quantized_input_ptr + offset, + &scaling_factors_ptr[b], &zero_points_ptr[b]); + } else { + // Quantize input from float to int8. + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + input_ptr + offset, input_size, quantized_input_ptr + offset, + &unused_min, &unused_max, &scaling_factors_ptr[b]); + } scaling_factors_ptr[b] *= weights_feature_scale; } // Compute conv1d(inputs, weights_feature). tensor_utils::MatrixBatchVectorMultiplyAccumulate( weights_feature_ptr, num_filters, input_size, quantized_input_ptr, - scaling_factors_ptr, batch_size, scratch_ptr); + scaling_factors_ptr, batch_size, scratch_ptr, + /*per_channel_scale=*/nullptr, zero_points_ptr, + reinterpret_cast(scratch_ptr), row_sums_ptr, compute_row_sums, + /*context=*/nullptr); } - // Copy the latest activation from scratch into activation_state: // The last, i.e. (memory_size-1)th entry for each batch, and filter. for (int i = 0; i < batch_size * num_filters; ++i) { diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index bbda9257651..4eafc215b6f 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -55,6 +55,7 @@ struct OpData { // These fields are only used by full kernel. int scratch_tensor_index; lstm_eval::IntegerLstmParameter integer_lstm_param; + bool compute_row_sums; }; namespace full { @@ -727,7 +728,7 @@ TfLiteStatus PopulateQuantizedLstmParams8x8_8( void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); op_data->kernel_type = kTfLiteLSTMFullKernel; - context->AddTensors(context, /*tensors_to_add=*/8, + context->AddTensors(context, /*tensors_to_add=*/10, &op_data->scratch_tensor_index); return op_data; } @@ -1236,7 +1237,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayFree(node->temporaries); if (is_hybrid_op) { - node->temporaries = TfLiteIntArrayCreate(8); + node->temporaries = TfLiteIntArrayCreate(10); } else if (is_integer) { if (is_8x8_16) { node->temporaries = TfLiteIntArrayCreate(6); @@ -1273,6 +1274,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } if (is_hybrid_op) { + op_data->compute_row_sums = true; // Allocate temporary tensors to store quantized values of input, // activation_state and cell_state tensors. node->temporaries->data[1] = op_data->scratch_tensor_index + 1; @@ -1370,6 +1372,41 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } + + node->temporaries->data[8] = op_data->scratch_tensor_index + 8; + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8); + zero_points->type = kTfLiteFloat32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {n_batch}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + + node->temporaries->data[9] = op_data->scratch_tensor_index + 9; + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + int row_sums_rows = use_cifg ? 6 : 8; + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights != nullptr) { + row_sums_rows += ceil(n_output / n_cell); + } + + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + const int row_sums_dims[2] = {row_sums_rows, n_cell}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); + row_sums_size->data[0] = row_sums_dims[0]; + row_sums_size->data[1] = row_sums_dims[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } if (is_integer) { @@ -1556,6 +1593,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/6); TfLiteTensor* output_scratch_buffer = GetTemporary(context, node, /*index=*/7); + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/8); + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/9); + const int row_sums_size = row_sums->dims->data[0]; return lstm_eval::EvalHybrid( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, @@ -1577,7 +1617,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, cell_state_quantized, activation_state, cell_state, - output_scratch_buffer, output, + output_scratch_buffer, output, zero_points, row_sums, row_sums_size, + &op_data->compute_row_sums, CpuBackendContext::GetFromContext(context)); } else { const int num_intermediate_tensors = node->intermediates->size; diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 9cc146ae8bd..1db812b251f 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -33,24 +33,93 @@ namespace builtin { namespace lstm_eval { namespace { -inline float GetTensorScale(const TfLiteTensor* tensor) { - return tensor == nullptr ? 1.0f : tensor->params.scale; +void ComputeRowSums( + int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums, + int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums, + int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums, + int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums, + int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums, + int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums, + int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell, + int n_input, int n_aux_input, int n_output, + const int8_t* input_to_input_weights_ptr, + const int8_t* input_to_forget_weights_ptr, + const int8_t* input_to_cell_weights_ptr, + const int8_t* input_to_output_weights_ptr, + const int8_t* aux_input_to_input_weights_ptr, + const int8_t* aux_input_to_forget_weights_ptr, + const int8_t* aux_input_to_cell_weights_ptr, + const int8_t* aux_input_to_output_weights_ptr, + const int8_t* recurrent_to_input_weights_ptr, + const int8_t* recurrent_to_forget_weights_ptr, + const int8_t* recurrent_to_cell_weights_ptr, + const int8_t* recurrent_to_output_weights_ptr, + const int8_t* projection_weights_ptr, bool use_cifg, + const float* aux_input_ptr) { + // Compute the row sums for dequantization + if (!use_cifg) { + memset(input_to_input_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(input_to_input_weights_ptr, + input_to_input_row_sums, n_cell, n_input); + } + memset(input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(input_to_forget_weights_ptr, + input_to_forget_row_sums, n_cell, n_input); + memset(input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(input_to_cell_weights_ptr, + input_to_cell_row_sums, n_cell, n_input); + memset(input_to_output_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(input_to_output_weights_ptr, + input_to_output_row_sums, n_cell, n_input); + + if (aux_input_ptr) { + if (!use_cifg) { + memset(aux_input_to_input_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr, + aux_input_to_input_row_sums, n_cell, + n_aux_input); + } + memset(aux_input_to_forget_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr, + aux_input_to_forget_row_sums, n_cell, + n_aux_input); + memset(aux_input_to_cell_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr, + aux_input_to_cell_row_sums, n_cell, + n_aux_input); + memset(aux_input_to_output_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr, + aux_input_to_output_row_sums, n_cell, + n_aux_input); + } + if (!use_cifg) { + memset(recurrent_to_input_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr, + recurrent_to_input_row_sums, n_cell, + n_output); + } + memset(recurrent_to_forget_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr, + recurrent_to_forget_row_sums, n_cell, + n_output); + memset(recurrent_to_cell_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr, + recurrent_to_cell_row_sums, n_cell, + n_output); + memset(recurrent_to_output_row_sums, 0, sizeof(int32_t) * n_cell); + tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr, + recurrent_to_output_row_sums, n_cell, + n_output); + + if (projection_weights_ptr != nullptr) { + memset(projection_weights_row_sums, 0, sizeof(int32_t) * n_output); + tensor_utils::ReductionSumVector( + projection_weights_ptr, projection_weights_row_sums, n_output, n_cell); + } } -inline void MatrixBatchVectorMultiplyAccumulate( - const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, int32_t* scratch, float* __restrict__ result, - CpuBackendContext* context) { -// TODO(b/148289189) Remove when Ruy GEMV is the default. -#ifdef TFLITE_WITH_RUY_GEMV - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, scratch, - result, context); -#else - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result); -#endif +inline float GetTensorScale(const TfLiteTensor* tensor) { + return tensor == nullptr ? 1.0f : tensor->params.scale; } // Performs an LSTM batch inference step for input specified by input_ptr. @@ -473,6 +542,8 @@ inline void LstmStepHybrid( int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr, + int32_t* zero_points, int32_t* row_sums, int row_sums_size, + bool* compute_row_sums, bool asymmetric_quantize_inputs, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("LstmStepHybrid"); // Since we have already checked that weights are all there or none, we @@ -503,53 +574,131 @@ inline void LstmStepHybrid( output_gate_scratch); } - // For each batch and cell: compute input_weight * input. - // Skip if input is all zeros. + int32_t* input_to_input_row_sums = nullptr; + int32_t* input_to_forget_row_sums = nullptr; + int32_t* input_to_cell_row_sums = nullptr; + int32_t* input_to_output_row_sums = nullptr; + int32_t* aux_input_to_input_row_sums = nullptr; + int32_t* aux_input_to_forget_row_sums = nullptr; + int32_t* aux_input_to_cell_row_sums = nullptr; + int32_t* aux_input_to_output_row_sums = nullptr; + int32_t* recurrent_to_input_row_sums = nullptr; + int32_t* recurrent_to_forget_row_sums = nullptr; + int32_t* recurrent_to_cell_row_sums = nullptr; + int32_t* recurrent_to_output_row_sums = nullptr; + int32_t* projection_weights_row_sums = nullptr; + + if (asymmetric_quantize_inputs) { + int num_row_sums = use_cifg ? 6 : 8; + if (aux_input_ptr != nullptr) { + num_row_sums += use_cifg ? 3 : 4; + } + if (projection_weights_ptr != nullptr) { + num_row_sums += ceil(n_output / n_cell); + } + TF_LITE_ASSERT(row_sums_size == num_row_sums); + input_to_input_row_sums = row_sums; + input_to_forget_row_sums = + use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell; + input_to_cell_row_sums = input_to_forget_row_sums + n_cell; + input_to_output_row_sums = input_to_cell_row_sums + n_cell; + if (aux_input_ptr != nullptr) { + aux_input_to_input_row_sums = input_to_output_row_sums + n_cell; + aux_input_to_forget_row_sums = use_cifg + ? aux_input_to_input_row_sums + : aux_input_to_input_row_sums + n_cell; + aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell; + aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell; + } + recurrent_to_input_row_sums = aux_input_ptr + ? aux_input_to_output_row_sums + n_cell + : input_to_output_row_sums + n_cell; + recurrent_to_forget_row_sums = use_cifg + ? recurrent_to_input_row_sums + : recurrent_to_input_row_sums + n_cell; + recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell; + recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell; + if (projection_weights_ptr != nullptr) { + projection_weights_row_sums = recurrent_to_output_row_sums + n_cell; + } + if (*compute_row_sums) { + ComputeRowSums( + input_to_input_row_sums, input_to_forget_row_sums, + input_to_cell_row_sums, input_to_output_row_sums, + aux_input_to_input_row_sums, aux_input_to_forget_row_sums, + aux_input_to_cell_row_sums, aux_input_to_output_row_sums, + recurrent_to_input_row_sums, recurrent_to_forget_row_sums, + recurrent_to_cell_row_sums, recurrent_to_output_row_sums, + projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input, + n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr, + input_to_cell_weights_ptr, input_to_output_weights_ptr, + aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, + aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, + recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, + projection_weights_ptr, use_cifg, aux_input_ptr); + *compute_row_sums = false; + } + } + if (!tensor_utils::IsZeroVector(input_ptr, n_batch * n_input)) { for (int b = 0; b < n_batch; ++b) { const int offset = b * n_input; - float unused_min, unused_max; - tensor_utils::SymmetricQuantizeFloats( - input_ptr + offset, n_input, quantized_input_ptr + offset, - &unused_min, &unused_max, &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + input_ptr + offset, n_input, quantized_input_ptr + offset, + &scaling_factors[b], &zero_points[b]); + } else { + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + input_ptr + offset, n_input, quantized_input_ptr + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } } if (!use_cifg) { for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_input_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, accum_scratch_ptr, - input_gate_scratch, context); + product_scaling_factors, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_input_row_sums, compute_row_sums, context); } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_forget_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, accum_scratch_ptr, - forget_gate_scratch, context); + product_scaling_factors, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_forget_row_sums, compute_row_sums, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_cell_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, accum_scratch_ptr, cell_scratch, - context); + product_scaling_factors, n_batch, cell_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_cell_row_sums, compute_row_sums, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * input_to_output_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, accum_scratch_ptr, - output_gate_scratch, context); + product_scaling_factors, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + input_to_output_row_sums, compute_row_sums, context); } // For each batch and cell: compute aux_input_weight * aux_input. @@ -558,59 +707,84 @@ inline void LstmStepHybrid( !tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)) { for (int b = 0; b < n_batch; ++b) { const int offset = b * n_aux_input; - float unused_min, unused_max; - tensor_utils::SymmetricQuantizeFloats( - aux_input_ptr + offset, n_aux_input, quantized_aux_input_ptr + offset, - &unused_min, &unused_max, &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + aux_input_ptr + offset, n_aux_input, + quantized_aux_input_ptr + offset, &scaling_factors[b], + &zero_points[b]); + } else { + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + aux_input_ptr + offset, n_aux_input, + quantized_aux_input_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } } + if (!use_cifg) { for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_input_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_input_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, input_gate_scratch, context); + input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums, + context); } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_forget_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_forget_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, forget_gate_scratch, context); + forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums, + context); + row_sums += n_cell; for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_cell_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_cell_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, cell_scratch, context); + quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_to_cell_row_sums, compute_row_sums, context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * aux_input_to_output_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_output_weights_ptr, n_cell, n_aux_input, quantized_aux_input_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, output_gate_scratch, context); + output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums, + context); } if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { // Save quantization and matmul computation for all zero input. for (int b = 0; b < n_batch; ++b) { const int offset = b * n_output; - float unused_min, unused_max; - tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, - quantized_output_state_ptr + offset, - &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, &scaling_factors[b], + &zero_points[b]); + } else { + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } } // For each batch and cell: compute recurrent_weight * output_state. if (!use_cifg) { @@ -618,38 +792,46 @@ inline void LstmStepHybrid( product_scaling_factors[b] = scaling_factors[b] * recurrent_to_input_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, input_gate_scratch, context); + input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums, + context); } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * recurrent_to_forget_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, forget_gate_scratch, context); + forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums, + context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * recurrent_to_cell_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_cell_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, cell_scratch, context); + cell_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums, + context); for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * recurrent_to_output_weights_scale; } - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_output_weights_ptr, n_cell, n_output, quantized_output_state_ptr, product_scaling_factors, n_batch, - accum_scratch_ptr, output_gate_scratch, context); + output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums, + context); } // For each batch and cell: update input gate. @@ -770,22 +952,32 @@ inline void LstmStepHybrid( // Save quantization and matmul computation for all zero input. for (int b = 0; b < n_batch; ++b) { const int offset = b * n_cell; - float unused_min, unused_max; - tensor_utils::SymmetricQuantizeFloats( - output_gate_scratch + offset, n_cell, - quantized_cell_state_ptr + offset, &unused_min, &unused_max, - &scaling_factors[b]); + if (asymmetric_quantize_inputs) { + tensor_utils::AsymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &scaling_factors[b], + &zero_points[b]); + } else { + float unused_min, unused_max; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } } for (int b = 0; b < n_batch; ++b) { product_scaling_factors[b] = scaling_factors[b] * projection_weights_scale; } for (int b = 0; b < n_batch; b++) { - MatrixBatchVectorMultiplyAccumulate( + tensor_utils::MatrixBatchVectorMultiplyAccumulate( projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b], - /*n_batch=*/1, accum_scratch_ptr, - output_ptr + b * output_batch_leading_dim, context); + /*n_batch=*/1, output_ptr + b * output_batch_leading_dim, + /*per_channel_scale=*/nullptr, + asymmetric_quantize_inputs ? &zero_points[b] : nullptr, + accum_scratch_ptr, projection_weights_row_sums, compute_row_sums, + context); } } if (params->proj_clip > 0.0) { @@ -1615,7 +1807,8 @@ TfLiteStatus EvalHybrid( TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, - TfLiteTensor* output, CpuBackendContext* context) { + TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, + int row_sums_size, bool* compute_row_sums, CpuBackendContext* context) { TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); const int n_input = input->dims->data[input->dims->size - 1]; int max_time, n_batch; @@ -1654,6 +1847,14 @@ TfLiteStatus EvalHybrid( const int output_batch_leading_dim = output->dims->data[output->dims->size - 1]; + + int32_t* zero_points_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + zero_points_ptr = GetTensorData(zero_points); + row_sums_ptr = GetTensorData(row_sums); + } + if (time_major) { // Feed the sequence into the LSTM step-by-step. const int input_step = n_batch * n_input; @@ -1721,7 +1922,9 @@ TfLiteStatus EvalHybrid( GetTensorData(output_state_quantized), GetTensorData(cell_state_quantized), GetTensorData(output_state), GetTensorData(cell_state), - GetTensorData(output_scratch_buffer), output_ptr, context); + GetTensorData(output_scratch_buffer), output_ptr, + zero_points_ptr, row_sums_ptr, row_sums_size, compute_row_sums, + params->asymmetric_quantize_inputs, context); } } else { for (int b = 0; b < n_batch; b++) { @@ -1806,7 +2009,8 @@ TfLiteStatus EvalHybrid( GetTensorData(output_state_quantized), GetTensorData(cell_state_quantized), output_state_ptr, cell_state_ptr, GetTensorData(output_scratch_buffer), - output_ptr, context); + output_ptr, zero_points_ptr, row_sums_ptr, row_sums_size, + compute_row_sums, params->asymmetric_quantize_inputs, context); } } } diff --git a/tensorflow/lite/kernels/lstm_eval.h b/tensorflow/lite/kernels/lstm_eval.h index ca3f96391aa..877cfd70a89 100644 --- a/tensorflow/lite/kernels/lstm_eval.h +++ b/tensorflow/lite/kernels/lstm_eval.h @@ -156,7 +156,8 @@ TfLiteStatus EvalHybrid( TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output_scratch_buffer, - TfLiteTensor* output, CpuBackendContext* context); + TfLiteTensor* output, TfLiteTensor* zero_points, TfLiteTensor* row_sums, + int row_sums_size, bool* compute_row_sums, CpuBackendContext* context); TfLiteStatus EvalInteger8x8_16( const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc index f426ffae0e0..2bd31eae8db 100644 --- a/tensorflow/lite/kernels/lstm_test.cc +++ b/tensorflow/lite/kernels/lstm_test.cc @@ -38,7 +38,8 @@ class LSTMOpModel : public SingleOpModel { bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, const std::vector>& input_shapes, - const TensorType weight_type, bool is_layer_norm) + const TensorType weight_type, bool is_layer_norm, + bool asymmetric_quantize_inputs = false) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -129,10 +130,12 @@ class LSTMOpModel : public SingleOpModel { output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, - CreateLSTMOptions(builder_, ActivationFunctionType_TANH, - cell_clip, proj_clip) - .Union()); + SetBuiltinOp( + BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip, + proj_clip, ::tflite::LSTMKernelType_FULL, + asymmetric_quantize_inputs) + .Union()); // Do not apply delegate yet since tensor values are not known (and more // specifically scales in quantized tensors are not known). @@ -315,7 +318,7 @@ class LSTMOpModel : public SingleOpModel { const TensorType weight_type_; }; -class BaseLstmTest : public ::testing::Test { +class BaseLstmTest : public ::testing::TestWithParam { protected: // Weights of the LSTM model. Some are optional. std::vector input_to_input_weights_; @@ -565,8 +568,11 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, +TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -604,14 +610,20 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0157651); } -TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, +class NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test + : public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {}; + +TEST_P(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test, HybridLstmBlackBoxTestInt8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -649,7 +661,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0157651); @@ -745,8 +757,11 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, +TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -784,13 +799,18 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } +class CifgNoPeepholeNoProjectionNoClippingLstmInt8Test + : public CifgNoPeepholeNoProjectionNoClippingLstmTest {}; -TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, +TEST_P(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test, HybridLstmBlackBoxTestInt8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -828,7 +848,7 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } @@ -1474,50 +1494,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) { VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestInt8) { - const int n_batch = 2; - const int n_input = 5; - const int n_cell = 20; - const int n_output = 16; - - LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, - /*use_cifg=*/false, /*use_peephole=*/true, - /*use_projection_weights=*/true, - /*use_projection_bias=*/false, - /*cell_clip=*/0.0, /*proj_clip=*/0.0, - { - {n_batch, n_input}, // input tensor - - {n_cell, n_input}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor - - {n_cell}, // cell_to_input_weight tensor - {n_cell}, // cell_to_forget_weight tensor - {n_cell}, // cell_to_output_weight tensor - - {n_cell}, // input_gate_bias tensor - {n_cell}, // forget_gate_bias tensor - {n_cell}, // cell_bias tensor - {n_cell}, // output_gate_bias tensor - - {n_output, n_cell}, // projection_weight tensor - {0}, // projection_bias tensor - }, - /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/false); - - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); -} - -TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, +TEST_P(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1554,11 +1535,60 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, {0}, // projection_bias tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/false); + /*is_layer_norm=*/false, GetParam()); VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } +class NoCifgPeepholeProjectionNoClippingLstmInt8Test + : public NoCifgPeepholeProjectionNoClippingLstmTest {}; + +TEST_P(NoCifgPeepholeProjectionNoClippingLstmInt8Test, + HybridLstmBlackBoxTestInt8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }, + /*weight_type=*/TensorType_INT8, + /*is_layer_norm=*/false, GetParam()); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0015); +} + class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest { void SetUp() override { @@ -1693,8 +1723,11 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); } -TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, +TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, HybridLayerNormLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 4; @@ -1741,7 +1774,7 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, {n_cell}, // output_layer_norm_coefficient tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/true); + /*is_layer_norm=*/true, GetParam()); lstm_golden_output_ = {{ // Batch0: 3 (input_sequence_size) * 3 (n_output) @@ -1760,8 +1793,14 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, /*tolerance=*/0.0010907); } -TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, +class NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test + : public NoCifgPeepholeProjectionNoClippingLayerNormLstmTest {}; + +TEST_P(NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test, HybridLayerNormLstmBlackBoxTestInt8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 4; @@ -1808,22 +1847,24 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, {n_cell}, // output_layer_norm_coefficient tensor }, /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/true); + /*is_layer_norm=*/true, GetParam()); + // Goldens are calculated from weight_type=TensorType_FLOAT32. lstm_golden_output_ = {{ // Batch0: 3 (input_sequence_size) * 3 (n_output) - 0.0244576, 0.127847, -0.00181765, // seq 0 - 0.0137518, 0.140892, 0.0402234, // seq 1 - -0.0048839, 0.155096, 0.0840309, // seq 2 + 0.0244077, 0.128027, -0.00170918, // seq 0 + 0.0137642, 0.140751, 0.0395835, // seq 1 + -0.00459233, 0.155278, 0.0837378, // seq 2 }, { // Batch1: 3 (input_sequence_size) * 3 (n_output) - -0.00728636, 0.0843957, 0.0634786, // seq 0 - -0.00448382, 0.139278, 0.0737372, // seq 1 - 0.00734616, 0.161793, 0.0560238, // seq 2 + -0.00692428, 0.0848741, 0.063445, // seq 0 + -0.00403911, 0.139963, 0.072681, // seq 1 + 0.00752708, 0.161903, 0.0561371, // seq 2 }}; - VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); + VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm, + /*tolerance=*/1.06e-3); } class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest { @@ -1940,8 +1981,11 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); } -TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, +TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest, HybridLayerNormLstmBlackBoxTestUint8) { + if (SingleOpModel::GetForceUseNnapi() && GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 4; @@ -1988,7 +2032,7 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, {n_cell}, // output_layer_norm_coefficient tensor }, /*weight_type=*/TensorType_UINT8, - /*is_layer_norm=*/true); + /*is_layer_norm=*/true, GetParam()); // Verify the final output. lstm_golden_output_ = { @@ -2009,7 +2053,10 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, /*tolerance=*/0.000902065); } -TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, +class CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test + : public CifgPeepholeProjectionNoClippingLayerNormLstmTest {}; + +TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test, HybridLayerNormLstmBlackBoxTestInt8) { const int n_batch = 2; const int n_input = 5; @@ -2057,24 +2104,24 @@ TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest, {n_cell}, // output_layer_norm_coefficient tensor }, /*weight_type=*/TensorType_INT8, - /*is_layer_norm=*/true); + /*is_layer_norm=*/true, GetParam()); - // Verify the final output. - lstm_golden_output_ = { - { - // Batch0: 3 (input_sequence_size) * 3 (n_output) - 0.0212250091, 0.140474007, 0.0115012666, // seq 0 - 0.0130806509, 0.152660668, 0.0347516984, // seq 1 - -0.0124010444, 0.166042402, 0.0898982584, // seq 2 - }, - { - // Batch1: 3 (input_sequence_size) * 3 (n_output) - -0.0228835996, 0.0917588323, 0.0778886303, // seq 0 - -0.0275101066, 0.148769245, 0.0938384682, // seq 1 - -0.0103605557, 0.172605693, 0.0728750974, // seq 2 - }}; + // Goldens are results using FLOAT32 inference. + lstm_golden_output_ = {{ + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0212971, 0.140816, 0.0112733, // seq 0 + 0.0132302, 0.152308, 0.0346313, // seq 1 + -0.0123688, 0.16579, 0.0893078, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.0226351, 0.0916948, 0.0769176, // seq 0 + -0.0269967, 0.149708, 0.0941492, // seq 1 + -0.0103429, 0.173016, 0.0720509, // seq 2 + }}; - VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm); + VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm, + /*tolerance=*/1e-3); } class LSTMIntegerOpModel : public SingleOpModel { @@ -3311,5 +3358,22 @@ TEST(LSTMOpModel, InvalidTypeTest) { ""); } #endif + +#define QUANTIZE_PARAMETER_TEST(test) \ + INSTANTIATE_TEST_SUITE_P(test, test, ::testing::Bool()) + +QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmTest); +QUANTIZE_PARAMETER_TEST(NoCifgNoPeepholeNoProjectionNoClippingLstmInt8Test); +QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmTest); +QUANTIZE_PARAMETER_TEST(CifgNoPeepholeNoProjectionNoClippingLstmInt8Test); +QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmTest); +QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLstmInt8Test); +QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest); +QUANTIZE_PARAMETER_TEST( + NoCifgPeepholeProjectionNoClippingLayerNormLstmInt8Test); +QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmTest); +QUANTIZE_PARAMETER_TEST(CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test); +#undef QUANTIZE_PARAMETER_TEST + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/svdf.cc b/tensorflow/lite/kernels/svdf.cc index bcbd06e8a67..82b7b7e4ee5 100644 --- a/tensorflow/lite/kernels/svdf.cc +++ b/tensorflow/lite/kernels/svdf.cc @@ -43,6 +43,7 @@ struct OpData { int effective_scale_1_b; int32 effective_scale_2_a; int effective_scale_2_b; + bool compute_row_sums = false; }; } // namespace @@ -61,8 +62,8 @@ constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); op_data->float_weights_time_initialized = false; - // Note: only needs 4 scratch tensors when is_hybrid_op, only 1 otherwise. - context->AddTensors(context, /*tensors_to_add=*/4, + // Note: only needs 6 scratch tensors when is_hybrid_op, only 1 otherwise. + context->AddTensors(context, /*tensors_to_add=*/6, &op_data->scratch_tensor_index); return op_data; } @@ -130,7 +131,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Resize scratch. TfLiteIntArrayFree(node->temporaries); if (is_hybrid_op) { - node->temporaries = TfLiteIntArrayCreate(4); + node->temporaries = TfLiteIntArrayCreate(6); } else if (is_full_integer) { node->temporaries = TfLiteIntArrayCreate(2); } else { @@ -156,6 +157,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scratch_size_array)); if (is_hybrid_op) { + op_data->compute_row_sums = true; // Tell interpreter to allocate temporary tensors to store quantized values // of input tensors. node->temporaries->data[1] = scratch_tensor_index + 1; @@ -195,6 +197,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, float_weights_time, float_weights_time_size)); } + + node->temporaries->data[4] = scratch_tensor_index + 4; + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); + zero_points->type = kTfLiteFloat32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {batch_size}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = zero_points_dims[0]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + + node->temporaries->data[5] = scratch_tensor_index + 5; + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); + row_sums->type = kTfLiteFloat32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_dims[1] = {num_filters}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1); + row_sums_size->data[0] = row_sums_dims[0]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } if (is_full_integer) { // Allocated one extra tensor. @@ -267,7 +293,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/2); TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3); - + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); // Dequantize weights time. // TODO(alanchiao): this dequantization initialization only needs to // happen once per model and should theoretically be placed in either @@ -285,10 +312,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } op_data->float_weights_time_initialized = true; } - reference_ops::EvalHybridSVDF(context, node, input, weights_feature, - float_weights_time, bias, params, scratch, - scaling_factors, input_quantized, - activation_state, output); + + reference_ops::EvalHybridSVDF( + context, node, input, weights_feature, float_weights_time, bias, + params, scratch, scaling_factors, input_quantized, activation_state, + output, zero_points, row_sums, &op_data->compute_row_sums); return kTfLiteOk; } else { auto* input_params = reinterpret_cast( diff --git a/tensorflow/lite/kernels/svdf_test.cc b/tensorflow/lite/kernels/svdf_test.cc index 1f5cfb040e7..68963b784f4 100644 --- a/tensorflow/lite/kernels/svdf_test.cc +++ b/tensorflow/lite/kernels/svdf_test.cc @@ -131,7 +131,8 @@ class BaseSVDFOpModel : public SingleOpModel { BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, int rank, TensorType weights_feature_type = TensorType_FLOAT32, - TensorType weights_time_type = TensorType_FLOAT32) + TensorType weights_time_type = TensorType_FLOAT32, + bool asymmetric_quantize_inputs = false) : batches_(batches), units_(units), input_size_(input_size), @@ -146,9 +147,10 @@ class BaseSVDFOpModel : public SingleOpModel { TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}}, /*is_variable=*/true); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, - CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); + SetBuiltinOp(BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, + CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE, + asymmetric_quantize_inputs) + .Union()); BuildInterpreter({ {batches_, input_size_}, // input tensor {units_ * rank, input_size_}, // weights_feature tensor @@ -203,9 +205,10 @@ class SVDFOpModel : public BaseSVDFOpModel { class HybridSVDFOpModel : public BaseSVDFOpModel { public: HybridSVDFOpModel(int batches, int units, int input_size, int memory_size, - int rank, TensorType tensor_type) + int rank, TensorType tensor_type, + bool asymmetric_quantize_inputs) : BaseSVDFOpModel(batches, units, input_size, memory_size, rank, - tensor_type, tensor_type) { + tensor_type, tensor_type, asymmetric_quantize_inputs) { tensor_type_ = tensor_type; } @@ -229,7 +232,7 @@ class HybridSVDFOpModel : public BaseSVDFOpModel { TensorType tensor_type_; }; -class SVDFOpTest : public ::testing::Test { +class SVDFOpTest : public ::testing::TestWithParam { protected: void VerifyGoldens(float golden_input[], float golden_output[], int golden_size, BaseSVDFOpModel* svdf, @@ -262,6 +265,9 @@ class SVDFOpTest : public ::testing::Test { } }; +INSTANTIATE_TEST_SUITE_P(SVDFOpTest, SVDFOpTest, + ::testing::ValuesIn({false, true})); + TEST_F(SVDFOpTest, BlackBoxTestRank1) { SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, /*memory_size=*/10, /*rank=*/1); @@ -325,9 +331,10 @@ TEST_F(SVDFOpTest, BlackBoxTestRank2) { &svdf); } -TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) { +TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Uint8) { HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/1, TensorType_UINT8); + /*memory_size=*/10, /*rank=*/1, TensorType_UINT8, + GetParam()); svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, 0.22197971, 0.12416199, 0.27901134, 0.27557442, 0.3905206, -0.36137494, -0.06634006, -0.10640851}); @@ -347,12 +354,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) { VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), &svdf, - /*tolerance=*/0.002945); + /*tolerance=*/0.004285); } -TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) { +TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Uint8) { HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/2, TensorType_UINT8); + /*memory_size=*/10, /*rank=*/2, TensorType_UINT8, + GetParam()); svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199, 0.15785322, 0.27901134, 0.3905206, 0.21931258, -0.36137494, -0.10640851, 0.31053296, @@ -387,12 +395,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) { VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), &svdf, - /*tolerance=*/0.00625109); + /*tolerance=*/0.007175); } -TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) { +TEST_P(SVDFOpTest, BlackBoxTestHybridRank1Int8) { HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/1, TensorType_INT8); + /*memory_size=*/10, /*rank=*/1, TensorType_INT8, + GetParam()); svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, 0.22197971, 0.12416199, 0.27901134, 0.27557442, 0.3905206, -0.36137494, -0.06634006, -0.10640851}); @@ -412,12 +421,13 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) { VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), &svdf, - /*tolerance=*/0.002945); + /*tolerance=*/0.004285); } -TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) { +TEST_P(SVDFOpTest, BlackBoxTestHybridRank2Int8) { HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, - /*memory_size=*/10, /*rank=*/2, TensorType_INT8); + /*memory_size=*/10, /*rank=*/2, TensorType_INT8, + GetParam()); svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199, 0.15785322, 0.27901134, 0.3905206, 0.21931258, -0.36137494, -0.10640851, 0.31053296, @@ -452,7 +462,7 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) { VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input), &svdf, - /*tolerance=*/0.00625109); + /*tolerance=*/0.007175); } // Test case for full integer quantization of SVDF. diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc index b49974da2e0..73b0535fc46 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc @@ -33,6 +33,7 @@ struct OpData { bool is_layer_norm_lstm; // The scratch tensor index. int scratch_tensor_index; + bool compute_row_sums = false; }; // Input Tensors of size {max_time, n_batch, n_input} @@ -92,7 +93,9 @@ enum TemporaryTensor { kProductScalingFactors = 5, kRecoveredCellWeights = 6, kAccumScratch = 7, - kNumTemporaryTensors + kZeroPoints = 8, + kRowSums = 9, + kNumTemporaryTensors = 10 }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -408,6 +411,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scratch_buffer_size)); if (IsHybridOp(input, input_to_output_weights)) { + op_data->compute_row_sums = true; // Allocate temporary tensors to store quantized values of input, // activation_state and cell_state tensors. node->temporaries->data[kInputQuantized] = @@ -515,6 +519,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, accum_scratch, accum_size)); } + node->temporaries->data[kZeroPoints] = scratch_tensor_index + kZeroPoints; + TfLiteTensor* zero_points = GetTemporary(context, node, kZeroPoints); + zero_points->type = kTfLiteFloat32; + zero_points->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, scaling_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = n_batch; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums; + TfLiteTensor* row_sums = GetTemporary(context, node, kRowSums); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_rows = use_cifg ? 6 : 8; + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights != nullptr) { + row_sums_rows += ceil(n_output / n_cell); + } + int row_sums_dims[2] = {row_sums_rows, n_cell}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); + row_sums_size->data[0] = row_sums_dims[0]; + row_sums_size->data[1] = row_sums_dims[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } return kTfLiteOk; } @@ -600,6 +632,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { lstm_params.activation = params->activation; lstm_params.cell_clip = params->cell_clip; lstm_params.proj_clip = params->proj_clip; + lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs; switch (input_to_output_weights->type) { case kTfLiteFloat32: { @@ -623,6 +656,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } case kTfLiteUInt8: case kTfLiteInt8: { + OpData* op_data = reinterpret_cast(node->user_data); TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); TfLiteTensor* activation_state_quantized = GetTemporary(context, node, /*index=*/2); @@ -635,6 +669,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/6); TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/kAccumScratch); + TfLiteTensor* zero_points = + GetTemporary(context, node, /*index=*/kZeroPoints); + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/kRowSums); + const int row_sums_size = row_sums->dims->data[0]; return lstm_eval::EvalHybrid( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, @@ -654,7 +692,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { prod_scaling_factors, recovered_cell_weights, input_quantized, /*aux_input_quantized=*/nullptr, activation_state_quantized, cell_state_quantized, activation_state, cell_state, accum_scratch, - output, CpuBackendContext::GetFromContext(context)); + output, zero_points, row_sums, row_sums_size, + &op_data->compute_row_sums, + CpuBackendContext::GetFromContext(context)); } default: context->ReportError(context, "Type %d is not currently supported.", diff --git a/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc index e89949e279e..4ea018c0cab 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -38,7 +38,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { float proj_clip, const std::vector>& input_shapes, const TensorType& weights_type = TensorType_FLOAT32, - bool is_layer_norm = false) + bool is_layer_norm = false, + bool asymmetric_quantize_inputs = false) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -131,7 +132,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { BuiltinOptions_UnidirectionalSequenceLSTMOptions, CreateUnidirectionalSequenceLSTMOptions( builder_, ActivationFunctionType_TANH, cell_clip, - proj_clip, time_major) + proj_clip, time_major, asymmetric_quantize_inputs) .Union()); BuildInterpreter(input_shapes); } @@ -292,11 +293,12 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { bool time_major, bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, const std::vector>& input_shapes, - TensorType tensor_type) + TensorType tensor_type, bool asymmetric_quantize_inputs) : UnidirectionalLSTMOpModel( n_batch, n_input, n_cell, n_output, sequence_length, time_major, use_cifg, use_peephole, use_projection_weights, use_projection_bias, - cell_clip, proj_clip, input_shapes, tensor_type) { + cell_clip, proj_clip, input_shapes, tensor_type, false, + asymmetric_quantize_inputs) { tensor_type_ = tensor_type; } @@ -360,7 +362,7 @@ class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { TensorType tensor_type_; }; -class BaseUnidirectionalLstmTest : public ::testing::Test { +class BaseUnidirectionalLstmTest : public ::testing::TestWithParam { protected: // Weights of the LSTM model. Some are optional. std::vector input_to_input_weights_; @@ -626,7 +628,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, /*time_major=*/false); } -TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, +TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestUint8) { const int n_batch = 1; const int n_input = 2; @@ -668,7 +670,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_UINT8); + TensorType_UINT8, GetParam()); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -689,7 +691,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, /*tolerance=*/0.0157651); } -TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, +TEST_P(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestInt8) { const int n_batch = 1; const int n_input = 2; @@ -731,7 +733,7 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_INT8); + TensorType_INT8, GetParam()); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -862,7 +864,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, +TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestUint8) { const int n_batch = 1; const int n_input = 2; @@ -880,11 +882,10 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, { {sequence_length, n_batch, n_input}, // input tensor - {0, 0}, // input_to_input_weight tensor - {n_cell, n_input}, // input_to_forget_weight tensor - {n_cell, n_input}, // input_to_cell_weight tensor - {n_cell, n_input}, // input_to_output_weight tensor - + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor {0, 0}, // recurrent_to_input_weight tensor {n_cell, n_output}, // recurrent_to_forget_weight tensor {n_cell, n_output}, // recurrent_to_cell_weight tensor @@ -905,7 +906,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_UINT8); + TensorType_UINT8, GetParam()); lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_); @@ -925,7 +926,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } -TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, +TEST_P(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestInt8) { const int n_batch = 1; const int n_input = 2; @@ -968,7 +969,7 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_INT8); + TensorType_INT8, GetParam()); lstm.SetInputToCellWeights(input_to_cell_weights_); lstm.SetInputToForgetWeights(input_to_forget_weights_); @@ -1655,14 +1656,16 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, +TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestUint8) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; const int n_output = 16; const int sequence_length = 4; - + if (GetParam()) { + return; + } HybridUnidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*time_major=*/true, /*use_cifg=*/false, /*use_peephole=*/true, @@ -1697,7 +1700,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_UINT8); + TensorType_UINT8, GetParam()); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -1723,8 +1726,11 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } -TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, +TEST_P(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, HybridLstmBlackBoxTestInt8) { + if (GetParam()) { + return; + } const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1765,7 +1771,7 @@ TEST_F(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor }, - TensorType_INT8); + TensorType_INT8, GetParam()); lstm.SetInputToInputWeights(input_to_input_weights_); lstm.SetInputToCellWeights(input_to_cell_weights_); @@ -2737,5 +2743,14 @@ TEST_F(CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest, VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } +#define QUANTIZE_PARAMETER_TEST(test) \ + INSTANTIATE_TEST_SUITE_P(test, test, ::testing::ValuesIn({false, true})); + +QUANTIZE_PARAMETER_TEST( + CifgPeepholeNoProjectionNoClippingUnidirectionalLstmTest); +QUANTIZE_PARAMETER_TEST( + NoCifgNoPeepholeNoProjectionNoClippingUnidirectionalLstmTest); +QUANTIZE_PARAMETER_TEST(NoCifgPeepholeProjectionClippingUnidirectionalLstmTest); +#undef QUANTIZE_PARAMETER_TEST } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc index 47c778185d4..7ed67c1614d 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc @@ -26,6 +26,15 @@ namespace ops { namespace builtin { namespace unidirectional_sequence_rnn { +namespace { + +struct OpData { + int scratch_tensor_index; + bool compute_row_sums = false; +}; + +} // namespace + // Input tensors. constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; @@ -37,13 +46,14 @@ constexpr int kHiddenStateTensor = 4; constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); - return scratch_tensor_index; + auto* op_data = new OpData(); + context->AddTensors(context, /*tensors_to_add=*/6, + &op_data->scratch_tensor_index); + return op_data; } void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + delete reinterpret_cast(buffer); } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { @@ -96,10 +106,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate temporary tensors to store quantized values of input and // hidden_state tensors. if (is_hybrid) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + auto* op_data = reinterpret_cast(node->user_data); + op_data->compute_row_sums = true; TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(3); - node->temporaries->data[0] = *scratch_tensor_index; + node->temporaries = TfLiteIntArrayCreate(6); + node->temporaries->data[0] = op_data->scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = input_weights->type; input_quantized->allocation_type = kTfLiteArenaRw; @@ -108,7 +119,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, input_quantized_size)); } - node->temporaries->data[1] = *scratch_tensor_index + 1; + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, /*index=*/1); hidden_state_quantized->type = input_weights->type; @@ -121,7 +132,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } - node->temporaries->data[2] = *scratch_tensor_index + 2; + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); scaling_factors->type = kTfLiteFloat32; scaling_factors->allocation_type = kTfLiteArenaRw; @@ -132,6 +143,42 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, scaling_factors_size)); } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* accum_scratch = GetTemporary(context, node, /*index=*/3); + accum_scratch->type = kTfLiteInt32; + accum_scratch->allocation_type = kTfLiteArenaRw; + int accum_scratch_dims[2] = {num_units, batch_size}; + if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, + accum_scratch_dims)) { + TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); + accum_scratch_size->data[0] = accum_scratch_dims[0]; + accum_scratch_size->data[1] = accum_scratch_dims[1]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, + accum_scratch_size)); + } + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* zero_points = GetTemporary(context, node, /*index=*/4); + zero_points->type = kTfLiteInt32; + zero_points->allocation_type = kTfLiteArenaRw; + int zero_points_dims[1] = {batch_size}; + if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { + TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); + zero_points_size->data[0] = batch_size; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, + zero_points_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* row_sums = GetTemporary(context, node, /*index=*/5); + row_sums->type = kTfLiteInt32; + row_sums->allocation_type = kTfLiteArenaRwPersistent; + int row_sums_dims[2] = {2, num_units}; + if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { + TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); + row_sums_size->data[0] = row_sums_dims[0]; + row_sums_size->data[1] = row_sums_dims[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, row_sums, row_sums_size)); + } } return kTfLiteOk; } @@ -202,7 +249,9 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, - TfLiteTensor* hidden_state, TfLiteTensor* output) { + TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points, + TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, + bool* compute_row_sums) { const bool time_major = params->time_major; const int batch_size = (time_major) ? input->dims->data[1] : input->dims->data[0]; @@ -227,6 +276,14 @@ TfLiteStatus EvalHybrid( float input_weights_scale = input_weights->params.scale; float recurrent_weights_scale = recurrent_weights->params.scale; float* scaling_factors_ptr = GetTensorData(scaling_factors); + int32_t* accum_scratch_ptr = GetTensorData(accum_scratch); + int32_t* zero_points_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + + if (params->asymmetric_quantize_inputs) { + zero_points_ptr = GetTensorData(zero_points); + row_sums_ptr = GetTensorData(row_sums); + } if (time_major) { // Initialize the pointer to hidden state. @@ -244,7 +301,9 @@ TfLiteStatus EvalHybrid( recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, num_units, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, - hidden_state_ptr_batch, output_ptr_batch); + hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, row_sums_ptr, compute_row_sums); } } else { // For each batch @@ -259,13 +318,14 @@ TfLiteStatus EvalHybrid( s * input_size; float* output_ptr_batch = GetTensorData(output) + b * num_units * max_time + s * num_units; - kernel_utils::RnnBatchStep( input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, /*batch_size=*/1, num_units, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, - scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch); + scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch, + params->asymmetric_quantize_inputs, zero_points_ptr, + accum_scratch_ptr, row_sums_ptr, compute_row_sums); } } } @@ -274,7 +334,6 @@ TfLiteStatus EvalHybrid( TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* recurrent_weights = @@ -292,12 +351,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: case kTfLiteInt8: { // TODO(mirkov): implement eval with quantized inputs as well. + auto* op_data = reinterpret_cast(node->user_data); TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + TfLiteTensor* accum_scratch = GetTemporary(context, node, 3); + TfLiteTensor* zero_points = GetTemporary(context, node, 4); + TfLiteTensor* row_sums = GetTemporary(context, node, 5); return EvalHybrid(input, input_weights, recurrent_weights, bias, params, input_quantized, hidden_state_quantized, - scaling_factors, hidden_state, output); + scaling_factors, hidden_state, output, zero_points, + accum_scratch, row_sums, &op_data->compute_row_sums); } default: context->ReportError(context, "Type %d not currently supported.", diff --git a/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc index 7e520ee9739..8b6f102acdb 100644 --- a/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -174,7 +174,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel { UnidirectionalRNNOpModel( int batches, int sequence_len, int units, int size, bool time_major, const TensorType& weights = TensorType_FLOAT32, - const TensorType& recurrent_weights = TensorType_FLOAT32) + const TensorType& recurrent_weights = TensorType_FLOAT32, + bool asymmetric_quantize_inputs = false) : batches_(batches), sequence_len_(sequence_len), units_(units), @@ -188,7 +189,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel { SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_SequenceRNNOptions, CreateSequenceRNNOptions(builder_, time_major, - ActivationFunctionType_RELU) + ActivationFunctionType_RELU, + asymmetric_quantize_inputs) .Union()); if (time_major) { BuildInterpreter({{sequence_len_, batches_, input_size_}, @@ -249,9 +251,11 @@ class HybridUnidirectionalRNNOpModel : public UnidirectionalRNNOpModel { public: HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size, bool time_major, - TensorType tensor_type) + TensorType tensor_type, + bool asymmetric_quantize_inputs) : UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major, - tensor_type, tensor_type) { + tensor_type, tensor_type, + asymmetric_quantize_inputs) { tensor_type_ = tensor_type; } @@ -297,10 +301,14 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } -TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) { +class HybridUnidirectionalRNNOpModelOpTest + : public ::testing::TestWithParam {}; + +TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) { HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, - /*time_major=*/false, TensorType_UINT8); + /*time_major=*/false, TensorType_UINT8, + GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -323,10 +331,11 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestUint8) { expected, /*max_abs_error=*/0.013))); } -TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) { +TEST_P(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTestInt8) { HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, - /*time_major=*/false, TensorType_INT8); + /*time_major=*/false, TensorType_INT8, + GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -378,10 +387,11 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } -TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) { +TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) { HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, - /*time_major=*/true, TensorType_UINT8); + /*time_major=*/true, TensorType_UINT8, + GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -408,10 +418,11 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestUint8) { expected, /*max_abs_error=*/0.013))); } -TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) { +TEST_P(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) { HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, - /*time_major=*/true, TensorType_INT8); + /*time_major=*/true, TensorType_INT8, + GetParam()); rnn.SetWeights(rnn_weights); rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); @@ -438,5 +449,9 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTestInt8) { expected, /*max_abs_error=*/0.013))); } +INSTANTIATE_TEST_SUITE_P(HybridUnidirectionalRNNOpModelOpTest, + HybridUnidirectionalRNNOpModelOpTest, + ::testing::ValuesIn({true, false})); + } // namespace } // namespace tflite diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index 24cd73eef7a..7f2a5cbf73a 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -519,17 +519,22 @@ table LSHProjectionOptions { table SVDFOptions { rank:int; fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; } // An implementation of TensorFlow RNNCell. table RNNOptions { fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; } // An implementation of TensorFlow dynamic_rnn with RNNCell. table SequenceRNNOptions { time_major:bool; fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; } // An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. @@ -537,6 +542,7 @@ table BidirectionalSequenceRNNOptions { time_major:bool; fused_activation_function:ActivationFunctionType; merge_outputs: bool; + asymmetric_quantize_inputs:bool; } enum FullyConnectedOptionsWeightsFormat: byte { @@ -556,6 +562,11 @@ table FullyConnectedOptions { // If set to true, then the number of dimension is preserved. Furthermore, // all but the last dimension of the input and output shapes will be equal. keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; } table SoftmaxOptions { @@ -604,6 +615,9 @@ table LSTMOptions { // Parameters for LSTM version 2 or above. // Basic kernel is only supported in version 2 or above. kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; } // An implementation of TensorFlow dynamic_rnn with LSTMCell. @@ -614,6 +628,9 @@ table UnidirectionalSequenceLSTMOptions { // If true then first dimension is sequence, otherwise batch. time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 4. + asymmetric_quantize_inputs:bool; } table BidirectionalSequenceLSTMOptions { @@ -630,6 +647,9 @@ table BidirectionalSequenceLSTMOptions { // Version 1 implementations assumed time_major to be true, so this default // value should never change. time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; } table ResizeBilinearOptions { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index 609eac198fb..ea3d4f61718 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -4216,9 +4216,11 @@ struct SVDFOptionsT : public flatbuffers::NativeTable { typedef SVDFOptions TableType; int32_t rank; tflite::ActivationFunctionType fused_activation_function; + bool asymmetric_quantize_inputs; SVDFOptionsT() : rank(0), - fused_activation_function(tflite::ActivationFunctionType_NONE) { + fused_activation_function(tflite::ActivationFunctionType_NONE), + asymmetric_quantize_inputs(false) { } }; @@ -4226,7 +4228,8 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SVDFOptionsT NativeTableType; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_RANK = 4, - VT_FUSED_ACTIVATION_FUNCTION = 6 + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 }; int32_t rank() const { return GetField(VT_RANK, 0); @@ -4234,10 +4237,14 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_RANK) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4254,6 +4261,9 @@ struct SVDFOptionsBuilder { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(SVDFOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4269,9 +4279,11 @@ struct SVDFOptionsBuilder { inline flatbuffers::Offset CreateSVDFOptions( flatbuffers::FlatBufferBuilder &_fbb, int32_t rank = 0, - tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) { SVDFOptionsBuilder builder_(_fbb); builder_.add_rank(rank); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -4281,22 +4293,29 @@ flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBufferBuilde struct RNNOptionsT : public flatbuffers::NativeTable { typedef RNNOptions TableType; tflite::ActivationFunctionType fused_activation_function; + bool asymmetric_quantize_inputs; RNNOptionsT() - : fused_activation_function(tflite::ActivationFunctionType_NONE) { + : fused_activation_function(tflite::ActivationFunctionType_NONE), + asymmetric_quantize_inputs(false) { } }; struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef RNNOptionsT NativeTableType; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_FUSED_ACTIVATION_FUNCTION = 4 + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 6 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4310,6 +4329,9 @@ struct RNNOptionsBuilder { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(RNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4324,8 +4346,10 @@ struct RNNOptionsBuilder { inline flatbuffers::Offset CreateRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, - tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) { RNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -4336,9 +4360,11 @@ struct SequenceRNNOptionsT : public flatbuffers::NativeTable { typedef SequenceRNNOptions TableType; bool time_major; tflite::ActivationFunctionType fused_activation_function; + bool asymmetric_quantize_inputs; SequenceRNNOptionsT() : time_major(false), - fused_activation_function(tflite::ActivationFunctionType_NONE) { + fused_activation_function(tflite::ActivationFunctionType_NONE), + asymmetric_quantize_inputs(false) { } }; @@ -4346,7 +4372,8 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SequenceRNNOptionsT NativeTableType; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_TIME_MAJOR = 4, - VT_FUSED_ACTIVATION_FUNCTION = 6 + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 }; bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; @@ -4354,10 +4381,14 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_TIME_MAJOR) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4374,6 +4405,9 @@ struct SequenceRNNOptionsBuilder { void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { fbb_.AddElement(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(SequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4389,8 +4423,10 @@ struct SequenceRNNOptionsBuilder { inline flatbuffers::Offset CreateSequenceRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, - tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) { SequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_fused_activation_function(fused_activation_function); builder_.add_time_major(time_major); return builder_.Finish(); @@ -4403,10 +4439,12 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable { bool time_major; tflite::ActivationFunctionType fused_activation_function; bool merge_outputs; + bool asymmetric_quantize_inputs; BidirectionalSequenceRNNOptionsT() : time_major(false), fused_activation_function(tflite::ActivationFunctionType_NONE), - merge_outputs(false) { + merge_outputs(false), + asymmetric_quantize_inputs(false) { } }; @@ -4415,7 +4453,8 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_TIME_MAJOR = 4, VT_FUSED_ACTIVATION_FUNCTION = 6, - VT_MERGE_OUTPUTS = 8 + VT_MERGE_OUTPUTS = 8, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 10 }; bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; @@ -4426,11 +4465,15 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf bool merge_outputs() const { return GetField(VT_MERGE_OUTPUTS, 0) != 0; } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_TIME_MAJOR) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_MERGE_OUTPUTS) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4450,6 +4493,9 @@ struct BidirectionalSequenceRNNOptionsBuilder { void add_merge_outputs(bool merge_outputs) { fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast(merge_outputs), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4466,8 +4512,10 @@ inline flatbuffers::Offset CreateBidirectionalS flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, - bool merge_outputs = false) { + bool merge_outputs = false, + bool asymmetric_quantize_inputs = false) { BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_merge_outputs(merge_outputs); builder_.add_fused_activation_function(fused_activation_function); builder_.add_time_major(time_major); @@ -4481,10 +4529,12 @@ struct FullyConnectedOptionsT : public flatbuffers::NativeTable { tflite::ActivationFunctionType fused_activation_function; tflite::FullyConnectedOptionsWeightsFormat weights_format; bool keep_num_dims; + bool asymmetric_quantize_inputs; FullyConnectedOptionsT() : fused_activation_function(tflite::ActivationFunctionType_NONE), weights_format(tflite::FullyConnectedOptionsWeightsFormat_DEFAULT), - keep_num_dims(false) { + keep_num_dims(false), + asymmetric_quantize_inputs(false) { } }; @@ -4493,7 +4543,8 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_WEIGHTS_FORMAT = 6, - VT_KEEP_NUM_DIMS = 8 + VT_KEEP_NUM_DIMS = 8, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 10 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -4504,11 +4555,15 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl bool keep_num_dims() const { return GetField(VT_KEEP_NUM_DIMS, 0) != 0; } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_WEIGHTS_FORMAT) && VerifyField(verifier, VT_KEEP_NUM_DIMS) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4528,6 +4583,9 @@ struct FullyConnectedOptionsBuilder { void add_keep_num_dims(bool keep_num_dims) { fbb_.AddElement(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast(keep_num_dims), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -4544,8 +4602,10 @@ inline flatbuffers::Offset CreateFullyConnectedOptions( flatbuffers::FlatBufferBuilder &_fbb, tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, - bool keep_num_dims = false) { + bool keep_num_dims = false, + bool asymmetric_quantize_inputs = false) { FullyConnectedOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_keep_num_dims(keep_num_dims); builder_.add_weights_format(weights_format); builder_.add_fused_activation_function(fused_activation_function); @@ -4932,11 +4992,13 @@ struct LSTMOptionsT : public flatbuffers::NativeTable { float cell_clip; float proj_clip; tflite::LSTMKernelType kernel_type; + bool asymmetric_quantize_inputs; LSTMOptionsT() : fused_activation_function(tflite::ActivationFunctionType_NONE), cell_clip(0.0f), proj_clip(0.0f), - kernel_type(tflite::LSTMKernelType_FULL) { + kernel_type(tflite::LSTMKernelType_FULL), + asymmetric_quantize_inputs(false) { } }; @@ -4946,7 +5008,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8, - VT_KERNEL_TYPE = 10 + VT_KERNEL_TYPE = 10, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 12 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -4960,12 +5023,16 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { tflite::LSTMKernelType kernel_type() const { return static_cast(GetField(VT_KERNEL_TYPE, 0)); } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && VerifyField(verifier, VT_PROJ_CLIP) && VerifyField(verifier, VT_KERNEL_TYPE) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -4988,6 +5055,9 @@ struct LSTMOptionsBuilder { void add_kernel_type(tflite::LSTMKernelType kernel_type) { fbb_.AddElement(LSTMOptions::VT_KERNEL_TYPE, static_cast(kernel_type), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(LSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5005,10 +5075,12 @@ inline flatbuffers::Offset CreateLSTMOptions( tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, float cell_clip = 0.0f, float proj_clip = 0.0f, - tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL) { + tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL, + bool asymmetric_quantize_inputs = false) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_kernel_type(kernel_type); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -5022,11 +5094,13 @@ struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { float cell_clip; float proj_clip; bool time_major; + bool asymmetric_quantize_inputs; UnidirectionalSequenceLSTMOptionsT() : fused_activation_function(tflite::ActivationFunctionType_NONE), cell_clip(0.0f), proj_clip(0.0f), - time_major(false) { + time_major(false), + asymmetric_quantize_inputs(false) { } }; @@ -5036,7 +5110,8 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8, - VT_TIME_MAJOR = 10 + VT_TIME_MAJOR = 10, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 12 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -5050,12 +5125,16 @@ struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatb bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && VerifyField(verifier, VT_PROJ_CLIP) && VerifyField(verifier, VT_TIME_MAJOR) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -5078,6 +5157,9 @@ struct UnidirectionalSequenceLSTMOptionsBuilder { void add_time_major(bool time_major) { fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast(time_major), 0); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5095,10 +5177,12 @@ inline flatbuffers::Offset CreateUnidirection tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, float cell_clip = 0.0f, float proj_clip = 0.0f, - bool time_major = false) { + bool time_major = false, + bool asymmetric_quantize_inputs = false) { UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_time_major(time_major); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); @@ -5113,12 +5197,14 @@ struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { float proj_clip; bool merge_outputs; bool time_major; + bool asymmetric_quantize_inputs; BidirectionalSequenceLSTMOptionsT() : fused_activation_function(tflite::ActivationFunctionType_NONE), cell_clip(0.0f), proj_clip(0.0f), merge_outputs(false), - time_major(true) { + time_major(true), + asymmetric_quantize_inputs(false) { } }; @@ -5129,7 +5215,8 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8, VT_MERGE_OUTPUTS = 10, - VT_TIME_MAJOR = 12 + VT_TIME_MAJOR = 12, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 14 }; tflite::ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -5146,6 +5233,9 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu bool time_major() const { return GetField(VT_TIME_MAJOR, 1) != 0; } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && @@ -5153,6 +5243,7 @@ struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbu VerifyField(verifier, VT_PROJ_CLIP) && VerifyField(verifier, VT_MERGE_OUTPUTS) && VerifyField(verifier, VT_TIME_MAJOR) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS) && verifier.EndTable(); } BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -5178,6 +5269,9 @@ struct BidirectionalSequenceLSTMOptionsBuilder { void add_time_major(bool time_major) { fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast(time_major), 1); } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5196,10 +5290,12 @@ inline flatbuffers::Offset CreateBidirectional float cell_clip = 0.0f, float proj_clip = 0.0f, bool merge_outputs = false, - bool time_major = true) { + bool time_major = true, + bool asymmetric_quantize_inputs = false) { BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); builder_.add_time_major(time_major); builder_.add_merge_outputs(merge_outputs); builder_.add_fused_activation_function(fused_activation_function); @@ -11034,6 +11130,7 @@ inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_ (void)_resolver; { auto _e = rank(); _o->rank = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11046,10 +11143,12 @@ inline flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBuffe struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _rank = _o->rank; auto _fused_activation_function = _o->fused_activation_function; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateSVDFOptions( _fbb, _rank, - _fused_activation_function); + _fused_activation_function, + _asymmetric_quantize_inputs); } inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11062,6 +11161,7 @@ inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_fu (void)_o; (void)_resolver; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11073,9 +11173,11 @@ inline flatbuffers::Offset CreateRNNOptions(flatbuffers::FlatBufferB (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateRNNOptions( _fbb, - _fused_activation_function); + _fused_activation_function, + _asymmetric_quantize_inputs); } inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11089,6 +11191,7 @@ inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const flatbuff (void)_resolver; { auto _e = time_major(); _o->time_major = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11101,10 +11204,12 @@ inline flatbuffers::Offset CreateSequenceRNNOptions(flatbuff struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _time_major = _o->time_major; auto _fused_activation_function = _o->fused_activation_function; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateSequenceRNNOptions( _fbb, _time_major, - _fused_activation_function); + _fused_activation_function, + _asymmetric_quantize_inputs); } inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11119,6 +11224,7 @@ inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOp { auto _e = time_major(); _o->time_major = _e; } { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = merge_outputs(); _o->merge_outputs = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11132,11 +11238,13 @@ inline flatbuffers::Offset CreateBidirectionalS auto _time_major = _o->time_major; auto _fused_activation_function = _o->fused_activation_function; auto _merge_outputs = _o->merge_outputs; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateBidirectionalSequenceRNNOptions( _fbb, _time_major, _fused_activation_function, - _merge_outputs); + _merge_outputs, + _asymmetric_quantize_inputs); } inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11151,6 +11259,7 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } { auto _e = weights_format(); _o->weights_format = _e; } { auto _e = keep_num_dims(); _o->keep_num_dims = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11164,11 +11273,13 @@ inline flatbuffers::Offset CreateFullyConnectedOptions(fl auto _fused_activation_function = _o->fused_activation_function; auto _weights_format = _o->weights_format; auto _keep_num_dims = _o->keep_num_dims; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateFullyConnectedOptions( _fbb, _fused_activation_function, _weights_format, - _keep_num_dims); + _keep_num_dims, + _asymmetric_quantize_inputs); } inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11352,6 +11463,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_ { auto _e = cell_clip(); _o->cell_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = kernel_type(); _o->kernel_type = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11366,12 +11478,14 @@ inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBuffe auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; auto _kernel_type = _o->kernel_type; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateLSTMOptions( _fbb, _fused_activation_function, _cell_clip, _proj_clip, - _kernel_type); + _kernel_type, + _asymmetric_quantize_inputs); } inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11387,6 +11501,7 @@ inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLS { auto _e = cell_clip(); _o->cell_clip = _e; } { auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = time_major(); _o->time_major = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11401,12 +11516,14 @@ inline flatbuffers::Offset CreateUnidirection auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; auto _time_major = _o->time_major; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateUnidirectionalSequenceLSTMOptions( _fbb, _fused_activation_function, _cell_clip, _proj_clip, - _time_major); + _time_major, + _asymmetric_quantize_inputs); } inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -11423,6 +11540,7 @@ inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTM { auto _e = proj_clip(); _o->proj_clip = _e; } { auto _e = merge_outputs(); _o->merge_outputs = _e; } { auto _e = time_major(); _o->time_major = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } } inline flatbuffers::Offset BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -11438,13 +11556,15 @@ inline flatbuffers::Offset CreateBidirectional auto _proj_clip = _o->proj_clip; auto _merge_outputs = _o->merge_outputs; auto _time_major = _o->time_major; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; return tflite::CreateBidirectionalSequenceLSTMOptions( _fbb, _fused_activation_function, _cell_clip, _proj_clip, _merge_outputs, - _time_major); + _time_major, + _asymmetric_quantize_inputs); } inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {