From 3b94d75a9e10ef8ef33760d0ef6aad326e1353ba Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Oct 2018 01:22:02 -0700 Subject: [PATCH] Merge the different LSTM EvalFloat/EvalHybrid calls into a single file. PiperOrigin-RevId: 215870962 --- tensorflow/contrib/lite/kernels/BUILD | 13 +- .../kernels/bidirectional_sequence_lstm.cc | 333 +------ .../lite/kernels/internal/kernel_utils.cc | 598 ------------ .../lite/kernels/internal/kernel_utils.h | 184 ---- tensorflow/contrib/lite/kernels/lstm.cc | 300 +----- tensorflow/contrib/lite/kernels/lstm_eval.cc | 909 ++++++++++++++++++ tensorflow/contrib/lite/kernels/lstm_eval.h | 79 ++ .../kernels/unidirectional_sequence_lstm.cc | 310 +----- 8 files changed, 1061 insertions(+), 1665 deletions(-) create mode 100644 tensorflow/contrib/lite/kernels/lstm_eval.cc create mode 100644 tensorflow/contrib/lite/kernels/lstm_eval.h diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 95e387814d5..68636fb070c 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -234,11 +234,11 @@ cc_library( ":activation_functor", ":eigen_support", ":kernel_util", + ":lstm_eval", ":op_macros", ":padding", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite:util", "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/kernels/internal:audio_utils", @@ -254,6 +254,17 @@ cc_library( ], ) +cc_library( + name = "lstm_eval", + srcs = ["lstm_eval.cc"], + hdrs = ["lstm_eval.h"], + deps = [ + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/kernels/internal:kernel_utils", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + ], +) + cc_library( name = "builtin_ops", srcs = ["register.cc"], diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 0532528f521..a326827b1e4 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/lstm_eval.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -694,330 +695,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -TfLiteStatus EvalFloat( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, - const TfLiteTensor* aux_input_to_input_weights, - const TfLiteTensor* aux_input_to_forget_weights, - const TfLiteTensor* aux_input_to_cell_weights, - const TfLiteTensor* aux_input_to_output_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, - TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, - TfLiteTensor* cell_state, TfLiteTensor* output) { - const int max_time = input->dims->data[0]; - const int n_batch = input->dims->data[1]; - const int n_input = input->dims->data[2]; - const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; - - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - // Index the scratch buffers pointers to the global scratch buffer. - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - const float* input_to_input_weights_ptr = - (use_cifg) ? nullptr : input_to_input_weights->data.f; - const float* recurrent_to_input_weights_ptr = - (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; - const float* input_gate_bias_ptr = - (use_cifg) ? nullptr : input_gate_bias->data.f; - const float* cell_to_input_weights_ptr = - (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; - const float* cell_to_forget_weights_ptr = - (use_peephole) ? cell_to_forget_weights->data.f : nullptr; - const float* cell_to_output_weights_ptr = - (use_peephole) ? cell_to_output_weights->data.f : nullptr; - const float* projection_weights_ptr = - (projection_weights == nullptr) ? nullptr : projection_weights->data.f; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - - float* aux_input_ptr = nullptr; - float* aux_input_to_input_weights_ptr = nullptr; - float* aux_input_to_forget_weights_ptr = nullptr; - float* aux_input_to_cell_weights_ptr = nullptr; - float* aux_input_to_output_weights_ptr = nullptr; - if (aux_input_size > 0) { - aux_input_ptr = aux_input->data.f; - aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f; - aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f; - aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f; - aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f; - } - - // Loop through the sequence. - const int input_step = n_batch * n_input; - const int output_step = n_batch * output->dims->data[2]; - for (int t = 0; t < max_time; t++) { - // If this is the forward_sequence, step forward, otherwise step backwards. - const int t_rel = forward_sequence ? t : max_time - t - 1; - const float* input_ptr = input->data.f + t_rel * input_step; - float* output_ptr_time = - output->data.f + t_rel * output_step + output_offset; - - kernel_utils::LstmStepWithAuxInput( - input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f, - input_to_cell_weights->data.f, input_to_output_weights->data.f, - aux_input_ptr, aux_input_to_input_weights_ptr, - aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr, - aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr, - recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f, - recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, - cell_to_forget_weights_ptr, cell_to_output_weights_ptr, - input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, - output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, - params, n_batch, n_cell, n_input, aux_input_size, n_output, - activation_state->data.f, cell_state->data.f, input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, - output_ptr_time); - } - return kTfLiteOk; -} - -TfLiteStatus EvalHybrid( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, - const TfLiteTensor* aux_input_to_input_weights, - const TfLiteTensor* aux_input_to_forget_weights, - const TfLiteTensor* aux_input_to_cell_weights, - const TfLiteTensor* aux_input_to_output_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, - TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, - TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, - TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, - TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, - TfLiteTensor* output_state, TfLiteTensor* cell_state, - TfLiteTensor* output) { - const int max_time = input->dims->data[0]; - const int n_batch = input->dims->data[1]; - const int n_input = input->dims->data[2]; - const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existence of only one to get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - int8_t* input_to_input_weights_ptr = nullptr; - float input_to_input_weights_scale = 1.0f; - int8_t* recurrent_to_input_weights_ptr = nullptr; - float recurrent_to_input_weights_scale = 1.0f; - float* input_gate_bias_ptr = nullptr; - if (!use_cifg) { - input_to_input_weights_ptr = - reinterpret_cast(input_to_input_weights->data.uint8); - recurrent_to_input_weights_ptr = - reinterpret_cast(recurrent_to_input_weights->data.uint8); - input_gate_bias_ptr = input_gate_bias->data.f; - input_to_input_weights_scale = input_to_input_weights->params.scale; - recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; - } - - int8_t* cell_to_input_weights_ptr = nullptr; - int8_t* cell_to_forget_weights_ptr = nullptr; - int8_t* cell_to_output_weights_ptr = nullptr; - float cell_to_input_weights_scale = 1.0f; - float cell_to_forget_weights_scale = 1.0f; - float cell_to_output_weights_scale = 1.0f; - if (use_peephole) { - if (!use_cifg) { - cell_to_input_weights_ptr = - reinterpret_cast(cell_to_input_weights->data.uint8); - cell_to_input_weights_scale = cell_to_input_weights->params.scale; - } - cell_to_forget_weights_ptr = - reinterpret_cast(cell_to_forget_weights->data.uint8); - cell_to_output_weights_ptr = - reinterpret_cast(cell_to_output_weights->data.uint8); - cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; - cell_to_output_weights_scale = cell_to_output_weights->params.scale; - } - - const int8_t* projection_weights_ptr = - (projection_weights == nullptr) - ? nullptr - : reinterpret_cast(projection_weights->data.uint8); - const float projection_weights_scale = - (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - - // Required tensors, pointers are non-null. - const int8_t* input_to_forget_weights_ptr = - reinterpret_cast(input_to_forget_weights->data.uint8); - const float input_to_forget_weights_scale = - input_to_forget_weights->params.scale; - const int8_t* input_to_cell_weights_ptr = - reinterpret_cast(input_to_cell_weights->data.uint8); - const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; - const int8_t* input_to_output_weights_ptr = - reinterpret_cast(input_to_output_weights->data.uint8); - const float input_to_output_weights_scale = - input_to_output_weights->params.scale; - const int8_t* recurrent_to_forget_weights_ptr = - reinterpret_cast(recurrent_to_forget_weights->data.uint8); - const float recurrent_to_forget_weights_scale = - recurrent_to_forget_weights->params.scale; - const int8_t* recurrent_to_cell_weights_ptr = - reinterpret_cast(recurrent_to_cell_weights->data.uint8); - const float recurrent_to_cell_weights_scale = - recurrent_to_cell_weights->params.scale; - const int8_t* recurrent_to_output_weights_ptr = - reinterpret_cast(recurrent_to_output_weights->data.uint8); - const float recurrent_to_output_weights_scale = - recurrent_to_output_weights->params.scale; - const float* forget_gate_bias_ptr = forget_gate_bias->data.f; - const float* cell_bias_ptr = cell_bias->data.f; - const float* output_gate_bias_ptr = output_gate_bias->data.f; - - float* output_state_ptr = output_state->data.f; - float* cell_state_ptr = cell_state->data.f; - - // Temporary storage for quantized values and scaling factors. - int8_t* quantized_input_ptr = - reinterpret_cast(input_quantized->data.uint8); - int8_t* quantized_aux_input_ptr = - (aux_input_quantized == nullptr) - ? nullptr - : reinterpret_cast(aux_input_quantized->data.uint8); - int8_t* quantized_output_state_ptr = - reinterpret_cast(output_state_quantized->data.uint8); - int8_t* quantized_cell_state_ptr = - reinterpret_cast(cell_state_quantized->data.uint8); - float* scaling_factors_ptr = scaling_factors->data.f; - float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; - float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; - - // Auxiliary input and weights. - float* aux_input_ptr = nullptr; - int8_t* aux_input_to_input_weights_ptr = nullptr; - int8_t* aux_input_to_forget_weights_ptr = nullptr; - int8_t* aux_input_to_cell_weights_ptr = nullptr; - int8_t* aux_input_to_output_weights_ptr = nullptr; - float aux_input_to_input_weights_scale = 0.0f; - float aux_input_to_forget_weights_scale = 0.0f; - float aux_input_to_cell_weights_scale = 0.0f; - float aux_input_to_output_weights_scale = 0.0f; - if (aux_input_size > 0) { - aux_input_ptr = aux_input->data.f; - aux_input_to_input_weights_ptr = - reinterpret_cast(aux_input_to_input_weights->data.uint8); - aux_input_to_forget_weights_ptr = - reinterpret_cast(aux_input_to_forget_weights->data.uint8); - aux_input_to_cell_weights_ptr = - reinterpret_cast(aux_input_to_cell_weights->data.uint8); - aux_input_to_output_weights_ptr = - reinterpret_cast(aux_input_to_output_weights->data.uint8); - aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale; - aux_input_to_forget_weights_scale = - aux_input_to_forget_weights->params.scale; - aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale; - aux_input_to_output_weights_scale = - aux_input_to_output_weights->params.scale; - } - - // Feed the sequence into the LSTM step-by-step. - const int input_step = n_batch * n_input; - const int output_step = n_batch * output->dims->data[2]; - for (int t = 0; t < max_time; t++) { - // If this is the forward_sequence, step forward, otherwise step backwards. - const int t_rel = forward_sequence ? t : max_time - t - 1; - const float* input_ptr = input->data.f + t_rel * input_step; - float* output_ptr = output->data.f + t_rel * output_step + output_offset; - - kernel_utils::LstmStepWithAuxInput( - input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, - input_to_forget_weights_ptr, input_to_forget_weights_scale, - input_to_cell_weights_ptr, input_to_cell_weights_scale, - input_to_output_weights_ptr, input_to_output_weights_scale, - aux_input_ptr, aux_input_to_input_weights_ptr, - aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, - aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, - aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, - aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, - recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, - recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, - recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, - recurrent_to_output_weights_scale, cell_to_input_weights_ptr, - cell_to_input_weights_scale, cell_to_forget_weights_ptr, - cell_to_forget_weights_scale, cell_to_output_weights_ptr, - cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, - cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, - n_input, aux_input_size, n_output, input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, - scaling_factors_ptr, prod_scaling_factors_ptr, - recovered_cell_weights_ptr, quantized_input_ptr, - quantized_aux_input_ptr, quantized_output_state_ptr, - quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr); - } - - return kTfLiteOk; -} - // The LSTM Op engine. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast( @@ -1157,7 +834,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (fw_input_to_output_weights->type) { case kTfLiteFloat32: { - TfLiteStatus fw_pass_status = EvalFloat( + TfLiteStatus fw_pass_status = lstm_eval::EvalFloat( input, fw_input_to_input_weights, fw_input_to_forget_weights, fw_input_to_cell_weights, fw_input_to_output_weights, fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, @@ -1172,7 +849,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_activation_state, fw_cell_state, fw_output); TF_LITE_ENSURE_OK(context, fw_pass_status); - TfLiteStatus bw_pass_status = EvalFloat( + TfLiteStatus bw_pass_status = lstm_eval::EvalFloat( input, bw_input_to_input_weights, bw_input_to_forget_weights, bw_input_to_cell_weights, bw_input_to_output_weights, bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, @@ -1208,7 +885,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, kRecoveredCellWeights); - TfLiteStatus fw_pass_status = EvalHybrid( + TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( input, fw_input_to_input_weights, fw_input_to_forget_weights, fw_input_to_cell_weights, fw_input_to_output_weights, fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, @@ -1226,7 +903,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_output); TF_LITE_ENSURE_OK(context, fw_pass_status); - TfLiteStatus bw_pass_status = EvalHybrid( + TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid( input, bw_input_to_input_weights, bw_input_to_forget_weights, bw_input_to_cell_weights, bw_input_to_output_weights, bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 56e93678781..083e5839bd5 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -169,603 +169,5 @@ void RnnBatchStep( hidden_state_ptr_batch); } -void LstmStep( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, - float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr_batch) { - LstmStepWithAuxInput( - input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, - input_to_cell_weights_ptr, input_to_output_weights_ptr, - /*aux_input_ptr_batch=*/nullptr, - /*aux_input_to_input_weights_ptr=*/nullptr, - /*aux_input_to_forget_weights_ptr=*/nullptr, - /*aux_input_to_cell_weights_ptr=*/nullptr, - /*aux_input_to_output_weights_ptr=*/nullptr, - recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, - recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, - cell_to_input_weights_ptr, cell_to_forget_weights_ptr, - cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, - cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0, - n_output, output_state_ptr, cell_state_ptr, input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); -} - -void LstmStepWithAuxInput( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch, - const float* aux_input_to_input_weights_ptr, - const float* aux_input_to_forget_weights_ptr, - const float* aux_input_to_cell_weights_ptr, - const float* aux_input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, - float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr_batch) { - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, - output_gate_scratch, /*result_stride=*/1); - - // If auxiliary input is available then compute aux_input_weight * aux_input - if (aux_input_ptr_batch != nullptr) { - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_input_weights_ptr, n_cell, n_aux_input, - aux_input_ptr_batch, n_batch, input_gate_scratch, - /*result_stride=*/1); - } - - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_forget_weights_ptr, n_cell, n_aux_input, - aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch, - n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_output_weights_ptr, n_cell, n_aux_input, - aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1); - } - - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, /*result_stride=*/1); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, forget_gate_scratch, - /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, cell_scratch, /*result_stride=*/1); - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, output_gate_scratch, - /*result_stride=*/1); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, - n_batch * n_cell, cell_state_ptr); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, - params->cell_clip, cell_state_ptr); - } - - // For each batch and cell: update the output gate. - if (use_peephole) { - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights_ptr != nullptr); - const bool use_projection_bias = (projection_bias_ptr != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, - n_batch, output_ptr_batch); - } else { - tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, - output_ptr_batch, /*result_stride=*/1); - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, - params->proj_clip, output_ptr_batch); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_batch); - } - tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, - output_state_ptr); -} - -void LstmStep( - const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, - float input_to_input_weights_scale, - const int8_t* input_to_forget_weights_ptr, - float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, - const int8_t* input_to_output_weights_ptr, - float input_to_output_weights_scale, - const int8_t* recurrent_to_input_weights_ptr, - float recurrent_to_input_weights_scale, - const int8_t* recurrent_to_forget_weights_ptr, - float recurrent_to_forget_weights_scale, - const int8_t* recurrent_to_cell_weights_ptr, - float recurrent_to_cell_weights_scale, - const int8_t* recurrent_to_output_weights_ptr, - float recurrent_to_output_weights_scale, - const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, - const int8_t* cell_to_forget_weights_ptr, - float cell_to_forget_weights_scale, - const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_output, float* input_gate_scratch, float* forget_gate_scratch, - float* cell_scratch, float* output_gate_scratch, float* scaling_factors, - float* product_scaling_factors, float* recovered_cell_weights, - int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, - int8_t* quantized_cell_state_ptr, float* output_state_ptr, - float* cell_state_ptr, float* output_ptr_batch) { - LstmStepWithAuxInput( - input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, - input_to_forget_weights_ptr, input_to_forget_weights_scale, - input_to_cell_weights_ptr, input_to_cell_weights_scale, - input_to_output_weights_ptr, input_to_output_weights_scale, - /*aux_input_ptr_batch=*/nullptr, - /*aux_input_to_input_weights_ptr=*/nullptr, - /*aux_input_to_input_weights_scale=*/0.0f, - /*aux_input_to_forget_weights_ptr=*/nullptr, - /*aux_input_to_forget_weights_scale=*/0.0f, - /*aux_input_to_cell_weights_ptr=*/nullptr, - /*aux_input_to_cell_weights_scale=*/0.0f, - /*aux_input_to_output_weights_ptr=*/nullptr, - /*aux_input_to_output_weights_scale=*/0.0f, - recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, - recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, - recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, - recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, - cell_to_input_weights_ptr, cell_to_input_weights_scale, - cell_to_forget_weights_ptr, cell_to_forget_weights_scale, - cell_to_output_weights_ptr, cell_to_output_weights_scale, - input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, - output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, - projection_bias_ptr, params, n_batch, n_cell, n_input, - /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, scaling_factors, - product_scaling_factors, recovered_cell_weights, - quantized_input_ptr_batch, - /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr, - quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, - output_ptr_batch); - } - - void LstmStepWithAuxInput( - const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, - float input_to_input_weights_scale, - const int8_t* input_to_forget_weights_ptr, - float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, - float input_to_cell_weights_scale, - const int8_t* input_to_output_weights_ptr, - float input_to_output_weights_scale, const float* aux_input_ptr_batch, - const int8_t* aux_input_to_input_weights_ptr, - float aux_input_to_input_weights_scale, - const int8_t* aux_input_to_forget_weights_ptr, - float aux_input_to_forget_weights_scale, - const int8_t* aux_input_to_cell_weights_ptr, - float aux_input_to_cell_weights_scale, - const int8_t* aux_input_to_output_weights_ptr, - float aux_input_to_output_weights_scale, - const int8_t* recurrent_to_input_weights_ptr, - float recurrent_to_input_weights_scale, - const int8_t* recurrent_to_forget_weights_ptr, - float recurrent_to_forget_weights_scale, - const int8_t* recurrent_to_cell_weights_ptr, - float recurrent_to_cell_weights_scale, - const int8_t* recurrent_to_output_weights_ptr, - float recurrent_to_output_weights_scale, - const int8_t* cell_to_input_weights_ptr, - float cell_to_input_weights_scale, - const int8_t* cell_to_forget_weights_ptr, - float cell_to_forget_weights_scale, - const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_aux_input, int n_output, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, - float* output_gate_scratch, float* scaling_factors, - float* product_scaling_factors, float* recovered_cell_weights, - int8_t* quantized_input_ptr_batch, - int8_t* quantized_aux_input_ptr_batch, - int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, - float* output_state_ptr, float* cell_state_ptr, - float* output_ptr_batch) { - // Since we have already checked that weights are all there or none, we - // can check the existense of only one to the get the condition. - const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, - n_batch, input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, - n_batch, forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, - n_batch, output_gate_scratch); - - if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_input; - tensor_utils::SymmetricQuantizeFloats( - input_ptr_batch + offset, n_input, - quantized_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); - } - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_input_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - forget_gate_scratch, - /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - output_gate_scratch, - /*result_stride=*/1); - } - - if (aux_input_ptr_batch != nullptr && - !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_input; - tensor_utils::SymmetricQuantizeFloats( - aux_input_ptr_batch + offset, n_input, - quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max, - &scaling_factors[b]); - } - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_input_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_input_weights_ptr, n_cell, n_input, - quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_forget_weights_ptr, n_cell, n_input, - quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, - forget_gate_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_cell_weights_ptr, n_cell, n_input, - quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, - cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * aux_input_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_output_weights_ptr, n_cell, n_input, - quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, - output_gate_scratch, /*result_stride=*/1); - } - - if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_output; - tensor_utils::SymmetricQuantizeFloats( - output_state_ptr + offset, n_output, - quantized_output_state_ptr + offset, &unused_min, &unused_max, - &scaling_factors[b]); - } - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_input_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - forget_gate_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - output_gate_scratch, /*result_stride=*/1); - } - - // Save quantization and matmul computation for all zero input. - bool is_cell_state_all_zeros = - tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); - - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, - cell_to_input_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } - - // For each batch and cell: update forget gate. - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, - cell_to_forget_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); - - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, - cell_state_ptr, n_batch * n_cell, - cell_state_ptr); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, - cell_state_ptr); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, - params->cell_clip, cell_state_ptr); - } - - is_cell_state_all_zeros = - tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); - // For each batch and cell: update the output gate. - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, - cell_to_output_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, - output_gate_scratch); - - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights_ptr != nullptr); - const bool use_projection_bias = (projection_bias_ptr != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, - n_batch, output_ptr_batch); - } else { - tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); - } - if (!tensor_utils::IsZeroVector(output_gate_scratch, - n_batch * n_cell)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_cell; - tensor_utils::SymmetricQuantizeFloats( - output_gate_scratch + offset, n_cell, - quantized_cell_state_ptr + offset, &unused_min, &unused_max, - &scaling_factors[b]); - } - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * projection_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights_ptr, n_output, n_cell, - quantized_cell_state_ptr, product_scaling_factors, n_batch, - output_ptr_batch, - /*result_stride=*/1); - } - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, - params->proj_clip, output_ptr_batch); - } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_batch); - } - tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, - output_state_ptr); - } - } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index b5558cce55a..74e0a4a53de 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -76,190 +76,6 @@ void RnnBatchStep( int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors, float* hidden_state_ptr_batch, float* output_ptr_batch); -// Performs an LSTM batch inference step for input specified by input_ptr_batch. -// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and -// biases (*_bias_ptr), and buffers (*_scratch), along with additional -// parameters: -// - params: various LSTM params including activation, clipping, etc., -// - n_batch: size of batch, -// - n_cell: number of cells (or units), -// - n_input: the input size, -// - n_output: the output size. -// -// The pointers to the cell and output state and the output are updated. -// -// The pointers with the suffix "_batch" point to data aligned in batch_major -// order, and each step processes batch_size many inputs from input_ptr_batch, -// and updates batch_size many cell and output states. -void LstmStep( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, - float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr_batch); - -// Same as above but includes an auxiliary input with the corresponding weights. -void LstmStepWithAuxInput( - const float* input_ptr_batch, const float* input_to_input_weights_ptr, - const float* input_to_forget_weights_ptr, - const float* input_to_cell_weights_ptr, - const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch, - const float* aux_input_to_input_weights_ptr, - const float* aux_input_to_forget_weights_ptr, - const float* aux_input_to_cell_weights_ptr, - const float* aux_input_to_output_weights_ptr, - const float* recurrent_to_input_weights_ptr, - const float* recurrent_to_forget_weights_ptr, - const float* recurrent_to_cell_weights_ptr, - const float* recurrent_to_output_weights_ptr, - const float* cell_to_input_weights_ptr, - const float* cell_to_forget_weights_ptr, - const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const float* projection_weights_ptr, - const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, - float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* output_ptr_batch); - -// Same as above but with quantized weight matrices. In detail: -// Input of size 'n_batch * n_input': -// input_ptr_batch -// -// LSTM weights: -// Quantized input weights of size 'n_cell * n_input': -// input_to_input_weights - optional (can be nullptr) -// input_to_forget_weights -// input_to_cell_weights -// input_to_input_weights -// Quantized recurrent weights of size 'n_cell * n_output': -// recurrent_to_input_weights - optional -// recurrent_to_forget_weights -// recurrent_to_cell_weights -// recurrent_to_input_weights -// Quantized peephole weights of size 'n_cell', representing diagonal matrices. -// cell_to_input_weights - optional -// cell_to_cell_weights - optional -// cell_to_output_weights - optional -// Quantized projection weights of size 'n_output * n_cell' -// projection_weights_ptr - optional -// Weight scales (scalars) for each of the weights above. -// input_to_input_weights_scale - optional -// input_to_forget_weights_scale -// input_to_cell_weights_scale -// input_to_output_weights_scale -// recurrent_to_input_weights_scale - optional -// recurrent_to_forget_weights_scale -// recurrent_to_cell_weights_scale -// recurrent_to_output_weights_scale -// cell_to_input_weights_scale, -// cell_to_forget_weights_scale, -// cell_to_output_weights_scale, -// projection_weights_scale - optional -// Gate biases of size 'n_cell': -// input_gate_bias_ptr - optional -// forget_gate_bias_ptr -// cell_gate_bias_ptr -// output_gate_bias_ptr -// -// Temporary pre-allocated storage for quantized values: -// quantized_input_ptr_batch (same size as input_ptr_batch) -// quantized_output_state_ptr (same size as output_state_ptr) -// quantized_cell_state_ptr (same size as cell_state_ptr) -// Temporary pre-allocated storage for recovered values: -// recovered_cell_weights (same size as cell_to_*_weights) -// -// Outputs: -// output_state_ptr - size 'n_batch * n_output' -// cell_state_ptr - size 'n_batch * n_cell' -// output_ptr_batch - size 'n_batch * n_output' -void LstmStep( - const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, - float input_to_input_weights_scale, - const int8_t* input_to_forget_weights_ptr, - float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, - const int8_t* input_to_output_weights_ptr, - float input_to_output_weights_scale, - const int8_t* recurrent_to_input_weights_ptr, - float recurrent_to_input_weights_scale, - const int8_t* recurrent_to_forget_weights_ptr, - float recurrent_to_forget_weights_scale, - const int8_t* recurrent_to_cell_weights_ptr, - float recurrent_to_cell_weights_scale, - const int8_t* recurrent_to_output_weights_ptr, - float recurrent_to_output_weights_scale, - const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, - const int8_t* cell_to_forget_weights_ptr, - float cell_to_forget_weights_scale, - const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_output, float* input_gate_scratch, float* forget_gate_scratch, - float* cell_scratch, float* output_gate_scratch, float* scaling_factors, - float* product_scaling_factors, float* recovered_cell_weights, - int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, - int8_t* quantized_cell_state_ptr, float* output_state_ptr, - float* cell_state_ptr, float* output_ptr_batch); - -void LstmStepWithAuxInput( - const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, - float input_to_input_weights_scale, - const int8_t* input_to_forget_weights_ptr, - float input_to_forget_weights_scale, - const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, - const int8_t* input_to_output_weights_ptr, - float input_to_output_weights_scale, const float* aux_input_ptr_batch, - const int8_t* aux_input_to_input_weights_ptr, - float aux_input_to_input_weights_scale, - const int8_t* aux_input_to_forget_weights_ptr, - float aux_input_to_forget_weights_scale, - const int8_t* aux_input_to_cell_weights_ptr, - float aux_input_to_cell_weights_scale, - const int8_t* aux_input_to_output_weights_ptr, - float aux_input_to_output_weights_scale, - const int8_t* recurrent_to_input_weights_ptr, - float recurrent_to_input_weights_scale, - const int8_t* recurrent_to_forget_weights_ptr, - float recurrent_to_forget_weights_scale, - const int8_t* recurrent_to_cell_weights_ptr, - float recurrent_to_cell_weights_scale, - const int8_t* recurrent_to_output_weights_ptr, - float recurrent_to_output_weights_scale, - const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, - const int8_t* cell_to_forget_weights_ptr, - float cell_to_forget_weights_scale, - const int8_t* cell_to_output_weights_ptr, - float cell_to_output_weights_scale, const float* input_gate_bias_ptr, - const float* forget_gate_bias_ptr, const float* cell_bias_ptr, - const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, - float projection_weights_scale, const float* projection_bias_ptr, - const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_aux_input, int n_output, float* input_gate_scratch, - float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, - float* scaling_factors, float* product_scaling_factors, - float* recovered_cell_weights, int8_t* quantized_input_ptr_batch, - int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr, - int8_t* quantized_cell_state_ptr, float* output_state_ptr, - float* cell_state_ptr, float* output_ptr_batch); - } // namespace kernel_utils } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 5b996d00bcb..16d67a1a938 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/lstm_eval.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -424,263 +425,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -// The LSTM Op engine. -TfLiteStatus EvalFloat( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, - TfLiteTensor* activation_state, TfLiteTensor* cell_state, - TfLiteTensor* output) { - const int n_batch = input->dims->data[0]; - const int n_input = input->dims->data[1]; - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existence of only one to get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - const float* input_to_input_weights_ptr = - (use_cifg) ? nullptr : input_to_input_weights->data.f; - const float* recurrent_to_input_weights_ptr = - (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; - const float* input_gate_bias_ptr = - (use_cifg) ? nullptr : input_gate_bias->data.f; - const float* cell_to_input_weights_ptr = - (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; - const float* cell_to_forget_weights_ptr = - (use_peephole) ? cell_to_forget_weights->data.f : nullptr; - const float* cell_to_output_weights_ptr = - (use_peephole) ? cell_to_output_weights->data.f : nullptr; - const float* projection_weights_ptr = - (projection_weights == nullptr) ? nullptr : projection_weights->data.f; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - - // Required tensors, pointers are non-null. - const float* input_ptr_batch = input->data.f; - const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f; - const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f; - const float* input_to_output_weights_ptr = input_to_output_weights->data.f; - const float* recurrent_to_forget_weights_ptr = - recurrent_to_forget_weights->data.f; - const float* recurrent_to_cell_weights_ptr = - recurrent_to_cell_weights->data.f; - const float* recurrent_to_output_weights_ptr = - recurrent_to_output_weights->data.f; - const float* forget_gate_bias_ptr = forget_gate_bias->data.f; - const float* cell_bias_ptr = cell_bias->data.f; - const float* output_gate_bias_ptr = output_gate_bias->data.f; - - float* activation_state_ptr = activation_state->data.f; - float* cell_state_ptr = cell_state->data.f; - float* output_ptr_batch = output->data.f; - - kernel_utils::LstmStep( - input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, - input_to_cell_weights_ptr, input_to_output_weights_ptr, - recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, - recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, - cell_to_input_weights_ptr, cell_to_forget_weights_ptr, - cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, - cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - activation_state_ptr, cell_state_ptr, input_gate_scratch, - forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); - - return kTfLiteOk; -} - -TfLiteStatus EvalHybrid( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, - TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, - TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* activation_state_quantized, - TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, - TfLiteTensor* cell_state, TfLiteTensor* output) { - const int n_batch = input->dims->data[0]; - const int n_input = input->dims->data[1]; - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existence of only one to get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - int8_t* input_to_input_weights_ptr = nullptr; - float input_to_input_weights_scale = 1.0f; - int8_t* recurrent_to_input_weights_ptr = nullptr; - float recurrent_to_input_weights_scale = 1.0f; - float* input_gate_bias_ptr = nullptr; - if (!use_cifg) { - input_to_input_weights_ptr = - reinterpret_cast(input_to_input_weights->data.uint8); - recurrent_to_input_weights_ptr = - reinterpret_cast(recurrent_to_input_weights->data.uint8); - input_gate_bias_ptr = input_gate_bias->data.f; - input_to_input_weights_scale = input_to_input_weights->params.scale; - recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; - } - - int8_t* cell_to_input_weights_ptr = nullptr; - int8_t* cell_to_forget_weights_ptr = nullptr; - int8_t* cell_to_output_weights_ptr = nullptr; - float cell_to_input_weights_scale = 1.0f; - float cell_to_forget_weights_scale = 1.0f; - float cell_to_output_weights_scale = 1.0f; - if (use_peephole) { - if (!use_cifg) { - cell_to_input_weights_ptr = - reinterpret_cast(cell_to_input_weights->data.uint8); - cell_to_input_weights_scale = cell_to_input_weights->params.scale; - } - cell_to_forget_weights_ptr = - reinterpret_cast(cell_to_forget_weights->data.uint8); - cell_to_output_weights_ptr = - reinterpret_cast(cell_to_output_weights->data.uint8); - cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; - cell_to_output_weights_scale = cell_to_output_weights->params.scale; - } - - const int8_t* projection_weights_ptr = - (projection_weights == nullptr) - ? nullptr - : reinterpret_cast(projection_weights->data.uint8); - const float projection_weights_scale = - (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - - // Required tensors, pointers are non-null. - const float* input_ptr_batch = input->data.f; - const int8_t* input_to_forget_weights_ptr = - reinterpret_cast(input_to_forget_weights->data.uint8); - const float input_to_forget_weights_scale = - input_to_forget_weights->params.scale; - const int8_t* input_to_cell_weights_ptr = - reinterpret_cast(input_to_cell_weights->data.uint8); - const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; - const int8_t* input_to_output_weights_ptr = - reinterpret_cast(input_to_output_weights->data.uint8); - const float input_to_output_weights_scale = - input_to_output_weights->params.scale; - const int8_t* recurrent_to_forget_weights_ptr = - reinterpret_cast(recurrent_to_forget_weights->data.uint8); - const float recurrent_to_forget_weights_scale = - recurrent_to_forget_weights->params.scale; - const int8_t* recurrent_to_cell_weights_ptr = - reinterpret_cast(recurrent_to_cell_weights->data.uint8); - const float recurrent_to_cell_weights_scale = - recurrent_to_cell_weights->params.scale; - const int8_t* recurrent_to_output_weights_ptr = - reinterpret_cast(recurrent_to_output_weights->data.uint8); - const float recurrent_to_output_weights_scale = - recurrent_to_output_weights->params.scale; - const float* forget_gate_bias_ptr = forget_gate_bias->data.f; - const float* cell_bias_ptr = cell_bias->data.f; - const float* output_gate_bias_ptr = output_gate_bias->data.f; - - float* activation_state_ptr = activation_state->data.f; - float* cell_state_ptr = cell_state->data.f; - float* output_ptr_batch = output->data.f; - - // Temporary storage for quantized values and scaling factors. - int8_t* quantized_input_ptr = - reinterpret_cast(input_quantized->data.uint8); - int8_t* quantized_activation_state_ptr = - reinterpret_cast(activation_state_quantized->data.uint8); - int8_t* quantized_cell_state_ptr = - reinterpret_cast(cell_state_quantized->data.uint8); - float* scaling_factors_ptr = scaling_factors->data.f; - float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; - float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; - - kernel_utils::LstmStep( - input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, - input_to_forget_weights_ptr, input_to_forget_weights_scale, - input_to_cell_weights_ptr, input_to_cell_weights_scale, - input_to_output_weights_ptr, input_to_output_weights_scale, - recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, - recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, - recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, - recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, - cell_to_input_weights_ptr, cell_to_input_weights_scale, - cell_to_forget_weights_ptr, cell_to_forget_weights_scale, - cell_to_output_weights_ptr, cell_to_output_weights_scale, - input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, - output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, - projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, - recovered_cell_weights_ptr, quantized_input_ptr, - quantized_activation_state_ptr, quantized_cell_state_ptr, - activation_state_ptr, cell_state_ptr, output_ptr_batch); - - return kTfLiteOk; -} - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast(node->builtin_data); OpData* op_data = reinterpret_cast(node->user_data); @@ -738,15 +482,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(mirkov): add a check that weights are all uint8s or all floats. switch (input_to_output_weights->type) { case kTfLiteFloat32: { - return EvalFloat(input, input_to_input_weights, input_to_forget_weights, - input_to_cell_weights, input_to_output_weights, - recurrent_to_input_weights, recurrent_to_forget_weights, - recurrent_to_cell_weights, recurrent_to_output_weights, - cell_to_input_weights, cell_to_forget_weights, - cell_to_output_weights, input_gate_bias, - forget_gate_bias, cell_bias, output_gate_bias, - projection_weights, projection_bias, params, - scratch_buffer, activation_state, cell_state, output); + return lstm_eval::EvalFloat( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, projection_weights, + projection_bias, params, /*forward_sequence=*/true, + /*output_offset=*/0, scratch_buffer, activation_state, cell_state, + output); } case kTfLiteUInt8: { TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); @@ -759,17 +509,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/5); TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, /*index=*/6); - return EvalHybrid( + return lstm_eval::EvalHybrid( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, - input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, - projection_weights, projection_bias, params, scratch_buffer, - scaling_factors, prod_scaling_factors, recovered_cell_weights, - input_quantized, activation_state_quantized, cell_state_quantized, - activation_state, cell_state, output); + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, projection_weights, + projection_bias, params, /*forward_sequence=*/true, + /*output_offset=*/0, scratch_buffer, scaling_factors, + prod_scaling_factors, recovered_cell_weights, input_quantized, + /*aux_input_quantized=*/nullptr, activation_state_quantized, + cell_state_quantized, activation_state, cell_state, output); } default: context->ReportError(context, "Type %d is not currently supported.", diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.cc b/tensorflow/contrib/lite/kernels/lstm_eval.cc new file mode 100644 index 00000000000..c6c21eb0850 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm_eval.cc @@ -0,0 +1,909 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/lstm_eval.h" + +#include + +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace lstm_eval { + +namespace { + +// Performs an LSTM batch inference step for input specified by input_ptr_batch. +// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and +// biases (*_bias_ptr), and buffers (*_scratch), along with additional +// parameters: +// - params: various LSTM params including activation, clipping, etc., +// - n_batch: size of batch, +// - n_cell: number of cells (or units), +// - n_input: the input size, +// - n_output: the output size. +// +// The pointers to the cell and output state and the output are updated. +// +// The pointers with the suffix "_batch" point to data aligned in batch_major +// order, and each step processes batch_size many inputs from input_ptr_batch, +// and updates batch_size many cell and output states. +inline void LstmStepWithAuxInput( + const float* input_ptr_batch, const float* input_to_input_weights_ptr, + const float* input_to_forget_weights_ptr, + const float* input_to_cell_weights_ptr, + const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch, + const float* aux_input_to_input_weights_ptr, + const float* aux_input_to_forget_weights_ptr, + const float* aux_input_to_cell_weights_ptr, + const float* aux_input_to_output_weights_ptr, + const float* recurrent_to_input_weights_ptr, + const float* recurrent_to_forget_weights_ptr, + const float* recurrent_to_cell_weights_ptr, + const float* recurrent_to_output_weights_ptr, + const float* cell_to_input_weights_ptr, + const float* cell_to_forget_weights_ptr, + const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const float* projection_weights_ptr, + const float* projection_bias_ptr, const TfLiteLSTMParams* params, + int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, + float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, + float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + output_gate_scratch, /*result_stride=*/1); + + // If auxiliary input is available then compute aux_input_weight * aux_input + if (aux_input_ptr_batch != nullptr) { + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_input_weights_ptr, n_cell, n_aux_input, + aux_input_ptr_batch, n_batch, input_gate_scratch, + /*result_stride=*/1); + } + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_forget_weights_ptr, n_cell, n_aux_input, + aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_output_weights_ptr, n_cell, n_aux_input, + aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1); + } + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, forget_gate_scratch, + /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, output_gate_scratch, + /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, + output_ptr_batch, /*result_stride=*/1); + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, + params->proj_clip, output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + +// Same as above but with quantized weight matrices. In detail: +// Input of size 'n_batch * n_input': +// input_ptr_batch +// +// LSTM weights: +// Quantized input weights of size 'n_cell * n_input': +// input_to_input_weights - optional (can be nullptr) +// input_to_forget_weights +// input_to_cell_weights +// input_to_input_weights +// Quantized recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weights - optional +// recurrent_to_forget_weights +// recurrent_to_cell_weights +// recurrent_to_input_weights +// Quantized peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// Quantized projection weights of size 'n_output * n_cell' +// projection_weights_ptr - optional +// Weight scales (scalars) for each of the weights above. +// input_to_input_weights_scale - optional +// input_to_forget_weights_scale +// input_to_cell_weights_scale +// input_to_output_weights_scale +// recurrent_to_input_weights_scale - optional +// recurrent_to_forget_weights_scale +// recurrent_to_cell_weights_scale +// recurrent_to_output_weights_scale +// cell_to_input_weights_scale, +// cell_to_forget_weights_scale, +// cell_to_output_weights_scale, +// projection_weights_scale - optional +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Temporary pre-allocated storage for quantized values: +// quantized_input_ptr_batch (same size as input_ptr_batch) +// quantized_output_state_ptr (same size as output_state_ptr) +// quantized_cell_state_ptr (same size as cell_state_ptr) +// Temporary pre-allocated storage for recovered values: +// recovered_cell_weights (same size as cell_to_*_weights) +// +// Outputs: +// output_state_ptr - size 'n_batch * n_output' +// cell_state_ptr - size 'n_batch * n_cell' +// output_ptr_batch - size 'n_batch * n_output' +inline void LstmStepWithAuxInput( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, const float* aux_input_ptr_batch, + const int8_t* aux_input_to_input_weights_ptr, + float aux_input_to_input_weights_scale, + const int8_t* aux_input_to_forget_weights_ptr, + float aux_input_to_forget_weights_scale, + const int8_t* aux_input_to_cell_weights_ptr, + float aux_input_to_cell_weights_scale, + const int8_t* aux_input_to_output_weights_ptr, + float aux_input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_aux_input, int n_output, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, + float* scaling_factors, float* product_scaling_factors, + float* recovered_cell_weights, int8_t* quantized_input_ptr_batch, + int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we + // can check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, forget_gate_scratch, + /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, output_gate_scratch, + /*result_stride=*/1); + } + + if (aux_input_ptr_batch != nullptr && + !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + aux_input_ptr_batch + offset, n_input, + quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * aux_input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_input_weights_ptr, n_cell, n_input, + quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * aux_input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_forget_weights_ptr, n_cell, n_input, + quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * aux_input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_cell_weights_ptr, n_cell, n_input, + quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * aux_input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_output_weights_ptr, n_cell, n_input, + quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + + if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_output; + tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, + &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + + // Save quantization and matmul computation for all zero input. + bool is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, + cell_to_input_weights_scale, + recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, + cell_to_forget_weights_scale, + recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } + + is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + // For each batch and cell: update the output gate. + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, + cell_to_output_weights_scale, + recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_cell; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * projection_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, + product_scaling_factors, n_batch, output_ptr_batch, + /*result_stride=*/1); + } + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, + params->proj_clip, output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} +} // namespace + +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* aux_input_to_input_weights, + const TfLiteTensor* aux_input_to_forget_weights, + const TfLiteTensor* aux_input_to_cell_weights, + const TfLiteTensor* aux_input_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, + TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0]; + const int n_batch = input->dims->data[input->dims->size - 2]; + const int n_input = input->dims->data[input->dims->size - 1]; + const int aux_input_size = + (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0; + + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + const float* input_to_input_weights_ptr = + (use_cifg) ? nullptr : input_to_input_weights->data.f; + const float* recurrent_to_input_weights_ptr = + (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; + const float* input_gate_bias_ptr = + (use_cifg) ? nullptr : input_gate_bias->data.f; + const float* cell_to_input_weights_ptr = + (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; + const float* cell_to_forget_weights_ptr = + (use_peephole) ? cell_to_forget_weights->data.f : nullptr; + const float* cell_to_output_weights_ptr = + (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* projection_weights_ptr = + (projection_weights == nullptr) ? nullptr : projection_weights->data.f; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + float* aux_input_ptr = nullptr; + float* aux_input_to_input_weights_ptr = nullptr; + float* aux_input_to_forget_weights_ptr = nullptr; + float* aux_input_to_cell_weights_ptr = nullptr; + float* aux_input_to_output_weights_ptr = nullptr; + if (aux_input_size > 0) { + aux_input_ptr = aux_input->data.f; + aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f; + aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f; + aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f; + aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f; + } + + // Loop through the sequence. + const int input_step = n_batch * n_input; + const int output_step = n_batch * output->dims->data[output->dims->size - 1]; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = input->data.f + t_rel * input_step; + float* output_ptr_time = + output->data.f + t_rel * output_step + output_offset; + + LstmStepWithAuxInput( + input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f, + input_to_cell_weights->data.f, input_to_output_weights->data.f, + aux_input_ptr, aux_input_to_input_weights_ptr, + aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr, + aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr, + recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, aux_input_size, n_output, + activation_state->data.f, cell_state->data.f, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, + output_ptr_time); + } + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* aux_input_to_input_weights, + const TfLiteTensor* aux_input_to_forget_weights, + const TfLiteTensor* aux_input_to_cell_weights, + const TfLiteTensor* aux_input_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, + TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, + TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { + const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0]; + const int n_batch = input->dims->data[input->dims->size - 2]; + const int n_input = input->dims->data[input->dims->size - 1]; + const int aux_input_size = + (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_aux_input_ptr = + (aux_input_quantized == nullptr) + ? nullptr + : reinterpret_cast(aux_input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + // Auxiliary input and weights. + float* aux_input_ptr = nullptr; + int8_t* aux_input_to_input_weights_ptr = nullptr; + int8_t* aux_input_to_forget_weights_ptr = nullptr; + int8_t* aux_input_to_cell_weights_ptr = nullptr; + int8_t* aux_input_to_output_weights_ptr = nullptr; + float aux_input_to_input_weights_scale = 0.0f; + float aux_input_to_forget_weights_scale = 0.0f; + float aux_input_to_cell_weights_scale = 0.0f; + float aux_input_to_output_weights_scale = 0.0f; + if (aux_input_size > 0) { + aux_input_ptr = aux_input->data.f; + aux_input_to_input_weights_ptr = + reinterpret_cast(aux_input_to_input_weights->data.uint8); + aux_input_to_forget_weights_ptr = + reinterpret_cast(aux_input_to_forget_weights->data.uint8); + aux_input_to_cell_weights_ptr = + reinterpret_cast(aux_input_to_cell_weights->data.uint8); + aux_input_to_output_weights_ptr = + reinterpret_cast(aux_input_to_output_weights->data.uint8); + aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale; + aux_input_to_forget_weights_scale = + aux_input_to_forget_weights->params.scale; + aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale; + aux_input_to_output_weights_scale = + aux_input_to_output_weights->params.scale; + } + + // Feed the sequence into the LSTM step-by-step. + const int input_step = n_batch * n_input; + const int output_step = n_batch * output->dims->data[output->dims->size - 1]; + for (int t = 0; t < max_time; t++) { + // If this is the forward_sequence, step forward, otherwise step backwards. + const int t_rel = forward_sequence ? t : max_time - t - 1; + const float* input_ptr = input->data.f + t_rel * input_step; + float* output_ptr = output->data.f + t_rel * output_step + output_offset; + + LstmStepWithAuxInput( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + aux_input_ptr, aux_input_to_input_weights_ptr, + aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, + aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, + aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, + aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, + n_input, aux_input_size, n_output, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, + scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_aux_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr); + } + + return kTfLiteOk; +} + +} // namespace lstm_eval +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.h b/tensorflow/contrib/lite/kernels/lstm_eval.h new file mode 100644 index 00000000000..adf8cf0f645 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm_eval.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace lstm_eval { + +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* aux_input_to_input_weights, + const TfLiteTensor* aux_input_to_forget_weights, + const TfLiteTensor* aux_input_to_cell_weights, + const TfLiteTensor* aux_input_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, + TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output); + +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* aux_input_to_input_weights, + const TfLiteTensor* aux_input_to_forget_weights, + const TfLiteTensor* aux_input_to_cell_weights, + const TfLiteTensor* aux_input_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, + TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, + TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output); + +} // namespace lstm_eval +} // namespace builtin +} // namespace ops +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_ diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 63817bd8865..ec9cf38b831 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/lstm_eval.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -429,273 +430,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -// The LSTM Op engine. -TfLiteStatus EvalFloat( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, - TfLiteTensor* activation_state, TfLiteTensor* cell_state, - TfLiteTensor* output) { - const int max_time = input->dims->data[0]; - const int n_batch = input->dims->data[1]; - const int n_input = input->dims->data[2]; - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existence of only one to get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - const float* input_to_input_weights_ptr = - (use_cifg) ? nullptr : input_to_input_weights->data.f; - const float* recurrent_to_input_weights_ptr = - (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; - const float* input_gate_bias_ptr = - (use_cifg) ? nullptr : input_gate_bias->data.f; - const float* cell_to_input_weights_ptr = - (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; - const float* cell_to_forget_weights_ptr = - (use_peephole) ? cell_to_forget_weights->data.f : nullptr; - const float* cell_to_output_weights_ptr = - (use_peephole) ? cell_to_output_weights->data.f : nullptr; - const float* projection_weights_ptr = - (projection_weights == nullptr) ? nullptr : projection_weights->data.f; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - - // Required tensors, pointers are non-null. - const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f; - const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f; - const float* input_to_output_weights_ptr = input_to_output_weights->data.f; - const float* recurrent_to_forget_weights_ptr = - recurrent_to_forget_weights->data.f; - const float* recurrent_to_cell_weights_ptr = - recurrent_to_cell_weights->data.f; - const float* recurrent_to_output_weights_ptr = - recurrent_to_output_weights->data.f; - const float* forget_gate_bias_ptr = forget_gate_bias->data.f; - const float* cell_bias_ptr = cell_bias->data.f; - const float* output_gate_bias_ptr = output_gate_bias->data.f; - - float* activation_state_ptr = activation_state->data.f; - float* cell_state_ptr = cell_state->data.f; - - // Feed the sequence into the LSTM step-by-step. - for (int t = 0; t < max_time; t++) { - const float* input_ptr_batch = input->data.f + t * n_batch * n_input; - float* output_ptr_batch = output->data.f + t * n_batch * n_output; - - kernel_utils::LstmStep( - input_ptr_batch, input_to_input_weights_ptr, - input_to_forget_weights_ptr, input_to_cell_weights_ptr, - input_to_output_weights_ptr, recurrent_to_input_weights_ptr, - recurrent_to_forget_weights_ptr, recurrent_to_cell_weights_ptr, - recurrent_to_output_weights_ptr, cell_to_input_weights_ptr, - cell_to_forget_weights_ptr, cell_to_output_weights_ptr, - input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, - output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr, - params, n_batch, n_cell, n_input, n_output, activation_state_ptr, - cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, output_ptr_batch); - } - return kTfLiteOk; -} - -TfLiteStatus EvalHybrid( - const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, - const TfLiteTensor* input_to_forget_weights, - const TfLiteTensor* input_to_cell_weights, - const TfLiteTensor* input_to_output_weights, - const TfLiteTensor* recurrent_to_input_weights, - const TfLiteTensor* recurrent_to_forget_weights, - const TfLiteTensor* recurrent_to_cell_weights, - const TfLiteTensor* recurrent_to_output_weights, - const TfLiteTensor* cell_to_input_weights, - const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, - const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, - const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, - const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, - TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, - TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, - TfLiteTensor* activation_state_quantized, - TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, - TfLiteTensor* cell_state, TfLiteTensor* output) { - const int max_time = input->dims->data[0]; - const int n_batch = input->dims->data[1]; - const int n_input = input->dims->data[2]; - // n_cell and n_output will be the same size when there is no projection. - const int n_cell = input_to_output_weights->dims->data[0]; - const int n_output = recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existence of only one to get the condition. - const bool use_cifg = (input_to_input_weights == nullptr); - const bool use_peephole = (cell_to_output_weights != nullptr); - - float* input_gate_scratch = nullptr; - float* cell_scratch = nullptr; - float* forget_gate_scratch = nullptr; - float* output_gate_scratch = nullptr; - if (use_cifg) { - cell_scratch = scratch_buffer->data.f; - forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - } else { - input_gate_scratch = scratch_buffer->data.f; - cell_scratch = scratch_buffer->data.f + n_cell * n_batch; - forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; - output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - int8_t* input_to_input_weights_ptr = nullptr; - float input_to_input_weights_scale = 1.0f; - int8_t* recurrent_to_input_weights_ptr = nullptr; - float recurrent_to_input_weights_scale = 1.0f; - float* input_gate_bias_ptr = nullptr; - if (!use_cifg) { - input_to_input_weights_ptr = - reinterpret_cast(input_to_input_weights->data.uint8); - recurrent_to_input_weights_ptr = - reinterpret_cast(recurrent_to_input_weights->data.uint8); - input_gate_bias_ptr = input_gate_bias->data.f; - input_to_input_weights_scale = input_to_input_weights->params.scale; - recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; - } - - int8_t* cell_to_input_weights_ptr = nullptr; - int8_t* cell_to_forget_weights_ptr = nullptr; - int8_t* cell_to_output_weights_ptr = nullptr; - float cell_to_input_weights_scale = 1.0f; - float cell_to_forget_weights_scale = 1.0f; - float cell_to_output_weights_scale = 1.0f; - if (use_peephole) { - if (!use_cifg) { - cell_to_input_weights_ptr = - reinterpret_cast(cell_to_input_weights->data.uint8); - cell_to_input_weights_scale = cell_to_input_weights->params.scale; - } - cell_to_forget_weights_ptr = - reinterpret_cast(cell_to_forget_weights->data.uint8); - cell_to_output_weights_ptr = - reinterpret_cast(cell_to_output_weights->data.uint8); - cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; - cell_to_output_weights_scale = cell_to_output_weights->params.scale; - } - - const int8_t* projection_weights_ptr = - (projection_weights == nullptr) - ? nullptr - : reinterpret_cast(projection_weights->data.uint8); - float projection_weights_scale = - (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; - const float* projection_bias_ptr = - (projection_bias == nullptr) ? nullptr : projection_bias->data.f; - - // Required tensors, pointers are non-null. - const int8_t* input_to_forget_weights_ptr = - reinterpret_cast(input_to_forget_weights->data.uint8); - const float input_to_forget_weights_scale = - input_to_forget_weights->params.scale; - const int8_t* input_to_cell_weights_ptr = - reinterpret_cast(input_to_cell_weights->data.uint8); - const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; - const int8_t* input_to_output_weights_ptr = - reinterpret_cast(input_to_output_weights->data.uint8); - const float input_to_output_weights_scale = - input_to_output_weights->params.scale; - const int8_t* recurrent_to_forget_weights_ptr = - reinterpret_cast(recurrent_to_forget_weights->data.uint8); - const float recurrent_to_forget_weights_scale = - recurrent_to_forget_weights->params.scale; - const int8_t* recurrent_to_cell_weights_ptr = - reinterpret_cast(recurrent_to_cell_weights->data.uint8); - const float recurrent_to_cell_weights_scale = - recurrent_to_cell_weights->params.scale; - const int8_t* recurrent_to_output_weights_ptr = - reinterpret_cast(recurrent_to_output_weights->data.uint8); - const float recurrent_to_output_weights_scale = - recurrent_to_output_weights->params.scale; - const float* forget_gate_bias_ptr = forget_gate_bias->data.f; - const float* cell_bias_ptr = cell_bias->data.f; - const float* output_gate_bias_ptr = output_gate_bias->data.f; - - float* activation_state_ptr = activation_state->data.f; - float* cell_state_ptr = cell_state->data.f; - - // Temporary storage for quantized values and scaling factors. - int8_t* quantized_input_ptr = - reinterpret_cast(input_quantized->data.uint8); - int8_t* quantized_activation_state_ptr = - reinterpret_cast(activation_state_quantized->data.uint8); - int8_t* quantized_cell_state_ptr = - reinterpret_cast(cell_state_quantized->data.uint8); - float* scaling_factors_ptr = scaling_factors->data.f; - float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; - float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; - - // Feed the sequence into the LSTM step-by-step. - for (int t = 0; t < max_time; t++) { - const float* input_ptr_batch = input->data.f + t * n_batch * n_input; - float* output_ptr_batch = output->data.f + t * n_batch * n_output; - - kernel_utils::LstmStep( - input_ptr_batch, input_to_input_weights_ptr, - input_to_input_weights_scale, input_to_forget_weights_ptr, - input_to_forget_weights_scale, input_to_cell_weights_ptr, - input_to_cell_weights_scale, input_to_output_weights_ptr, - input_to_output_weights_scale, recurrent_to_input_weights_ptr, - recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, - recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, - recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, - recurrent_to_output_weights_scale, cell_to_input_weights_ptr, - cell_to_input_weights_scale, cell_to_forget_weights_ptr, - cell_to_forget_weights_scale, cell_to_output_weights_ptr, - cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, - cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, - n_input, n_output, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, scaling_factors_ptr, - prod_scaling_factors_ptr, recovered_cell_weights_ptr, - quantized_input_ptr, quantized_activation_state_ptr, - quantized_cell_state_ptr, activation_state_ptr, cell_state_ptr, - output_ptr_batch); - } - return kTfLiteOk; -} - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -750,15 +484,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (input_to_output_weights->type) { case kTfLiteFloat32: { - return EvalFloat(input, input_to_input_weights, input_to_forget_weights, - input_to_cell_weights, input_to_output_weights, - recurrent_to_input_weights, recurrent_to_forget_weights, - recurrent_to_cell_weights, recurrent_to_output_weights, - cell_to_input_weights, cell_to_forget_weights, - cell_to_output_weights, input_gate_bias, - forget_gate_bias, cell_bias, output_gate_bias, - projection_weights, projection_bias, params, - scratch_buffer, activation_state, cell_state, output); + return lstm_eval::EvalFloat( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, projection_weights, + projection_bias, params, /*forward_sequence=*/true, + /*output_offset=*/0, scratch_buffer, activation_state, cell_state, + output); } case kTfLiteUInt8: { TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); @@ -771,17 +511,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, /*index=*/5); TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, /*index=*/6); - return EvalHybrid( + return lstm_eval::EvalHybrid( input, input_to_input_weights, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_input_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, - input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, - projection_weights, projection_bias, params, scratch_buffer, - scaling_factors, prod_scaling_factors, recovered_cell_weights, - input_quantized, activation_state_quantized, cell_state_quantized, - activation_state, cell_state, output); + /*aux_input=*/nullptr, + /*aux_input_to_input_weights=*/nullptr, + /*aux_input_to_forget_weights=*/nullptr, + /*aux_input_to_cell_weights=*/nullptr, + /*aux_input_to_output_weights=*/nullptr, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, projection_weights, + projection_bias, params, /*forward_sequence=*/true, + /*output_offset=*/0, scratch_buffer, scaling_factors, + prod_scaling_factors, recovered_cell_weights, input_quantized, + /*aux_input_quantized=*/nullptr, activation_state_quantized, + cell_state_quantized, activation_state, cell_state, output); } default: context->ReportError(context, "Type %d is not currently supported.",