diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h index bb701ca87f5..e2af88d50e3 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/lite/kernels/internal/tensor_utils.h @@ -162,6 +162,27 @@ void MatrixBatchVectorMultiplyAccumulate( const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, bool* compute_row_sums, CpuBackendContext* context); +// Same as the function above, but provides separate scaling factor for the +// matrix and the vectors. The scaling factors are multiplied in the +// scaling_factor_scratch buffer. +inline void MatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, const float matrix_scaling_factor, + const float* vector_scaling_factors, int n_batch, + float* __restrict__ result, const float* per_channel_scale, + const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, + bool* compute_row_sums, float* scaling_factor_scratch, + CpuBackendContext* context) { + for (int b = 0; b < n_batch; ++b) { + scaling_factor_scratch[b] = + vector_scaling_factors[b] * matrix_scaling_factor; + } + MatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, + scaling_factor_scratch, n_batch, result, + per_channel_scale, input_offset, scratch, + row_sums, compute_row_sums, context); +} + // Same as the function above, but the matrix is stored in block compressed // sparse row format with block pattern 1x16 which consists of two arrays: // 1. A matrix array stores non-zero blocks of the matrix in row major. diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index bf4dc05ea8d..0a2c381ebf1 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -537,7 +537,7 @@ inline void LstmStepHybrid( int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, int output_batch_leading_dim, float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* scaling_factors, float* product_scaling_factors, + float* scaling_factors, float* scaling_factors_scratch, float* recovered_cell_weights, int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr, @@ -646,49 +646,34 @@ inline void LstmStepHybrid( quantized_input_ptr, scaling_factors, zero_points, asymmetric_quantize_inputs); if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_input_weights_scale; - } tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, input_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, - input_to_input_row_sums, compute_row_sums, context); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_forget_weights_scale; + input_to_input_weights_scale, scaling_factors, n_batch, + input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, input_to_input_row_sums, compute_row_sums, + scaling_factors_scratch, context); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, forget_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, - input_to_forget_row_sums, compute_row_sums, context); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_cell_weights_scale; - } + input_to_forget_weights_scale, scaling_factors, n_batch, + forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, input_to_forget_row_sums, compute_row_sums, + scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, cell_scratch, + input_to_cell_weights_scale, scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, - input_to_cell_row_sums, compute_row_sums, context); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_output_weights_scale; - } + input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch, + context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr, - product_scaling_factors, n_batch, output_gate_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, - input_to_output_row_sums, compute_row_sums, context); + input_to_output_weights_scale, scaling_factors, n_batch, + output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, + accum_scratch_ptr, input_to_output_row_sums, compute_row_sums, + scaling_factors_scratch, context); } // For each batch and cell: compute aux_input_weight * aux_input. @@ -700,49 +685,36 @@ inline void LstmStepHybrid( zero_points, asymmetric_quantize_inputs); if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_input_weights_scale; - } tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_input_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, product_scaling_factors, n_batch, - input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums, - context); + quantized_aux_input_ptr, aux_input_to_input_weights_scale, + scaling_factors, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_to_input_row_sums, compute_row_sums, + scaling_factors_scratch, context); } - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_forget_weights_scale; - } tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_forget_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, product_scaling_factors, n_batch, - forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums, + quantized_aux_input_ptr, aux_input_to_forget_weights_scale, + scaling_factors, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, context); - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_cell_weights_scale; - } tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_cell_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch, - /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, - aux_input_to_cell_row_sums, compute_row_sums, context); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_output_weights_scale; - } + quantized_aux_input_ptr, aux_input_to_cell_weights_scale, + scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr, + zero_points, accum_scratch_ptr, aux_input_to_cell_row_sums, + compute_row_sums, scaling_factors_scratch, context); tensor_utils::MatrixBatchVectorMultiplyAccumulate( aux_input_to_output_weights_ptr, n_cell, n_aux_input, - quantized_aux_input_ptr, product_scaling_factors, n_batch, - output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums, + quantized_aux_input_ptr, aux_input_to_output_weights_scale, + scaling_factors, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + aux_input_to_output_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -753,49 +725,36 @@ inline void LstmStepHybrid( scaling_factors, zero_points, asymmetric_quantize_inputs); // For each batch and cell: compute recurrent_weight * output_state. if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_input_weights_scale; - } tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums, - context); + quantized_output_state_ptr, recurrent_to_input_weights_scale, + scaling_factors, n_batch, input_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + recurrent_to_input_row_sums, compute_row_sums, + scaling_factors_scratch, context); } - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_forget_weights_scale; - } tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums, + quantized_output_state_ptr, recurrent_to_forget_weights_scale, + scaling_factors, n_batch, forget_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + recurrent_to_forget_row_sums, compute_row_sums, scaling_factors_scratch, context); - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_cell_weights_scale; - } tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_cell_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - cell_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums, - context); + quantized_output_state_ptr, recurrent_to_cell_weights_scale, + scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr, + zero_points, accum_scratch_ptr, recurrent_to_cell_row_sums, + compute_row_sums, scaling_factors_scratch, context); - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_output_weights_scale; - } tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_output_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points, - accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums, + quantized_output_state_ptr, recurrent_to_output_weights_scale, + scaling_factors, n_batch, output_gate_scratch, + /*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr, + recurrent_to_output_row_sums, compute_row_sums, scaling_factors_scratch, context); } @@ -919,13 +878,13 @@ inline void LstmStepHybrid( output_gate_scratch, n_batch, n_cell, quantized_cell_state_ptr, scaling_factors, zero_points, asymmetric_quantize_inputs); for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = + scaling_factors_scratch[b] = scaling_factors[b] * projection_weights_scale; } for (int b = 0; b < n_batch; b++) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( projection_weights_ptr, n_output, n_cell, - quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b], + quantized_cell_state_ptr + b * n_cell, &scaling_factors_scratch[b], /*n_batch=*/1, output_ptr + b * output_batch_leading_dim, /*per_channel_scale=*/nullptr, asymmetric_quantize_inputs ? &zero_points[b] : nullptr,