Merge the different LSTM EvalFloat/EvalHybrid calls into a single file.
PiperOrigin-RevId: 215870962
This commit is contained in:
parent
57d31aa599
commit
3b94d75a9e
@ -234,11 +234,11 @@ cc_library(
|
||||
":activation_functor",
|
||||
":eigen_support",
|
||||
":kernel_util",
|
||||
":lstm_eval",
|
||||
":op_macros",
|
||||
":padding",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite:util",
|
||||
"//tensorflow/contrib/lite/c:c_api_internal",
|
||||
"//tensorflow/contrib/lite/kernels:gemm_support",
|
||||
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
|
||||
@ -254,6 +254,17 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lstm_eval",
|
||||
srcs = ["lstm_eval.cc"],
|
||||
hdrs = ["lstm_eval.h"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/lite/c:c_api_internal",
|
||||
"//tensorflow/contrib/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "builtin_ops",
|
||||
srcs = ["register.cc"],
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
|
||||
#include "tensorflow/contrib/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -694,330 +695,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
const TfLiteTensor* aux_input_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
||||
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output) {
|
||||
const int max_time = input->dims->data[0];
|
||||
const int n_batch = input->dims->data[1];
|
||||
const int n_input = input->dims->data[2];
|
||||
const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
|
||||
|
||||
// n_cell and n_output will be the same size when there is no projection.
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existense of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
// Index the scratch buffers pointers to the global scratch buffer.
|
||||
float* input_gate_scratch = nullptr;
|
||||
float* cell_scratch = nullptr;
|
||||
float* forget_gate_scratch = nullptr;
|
||||
float* output_gate_scratch = nullptr;
|
||||
if (use_cifg) {
|
||||
cell_scratch = scratch_buffer->data.f;
|
||||
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
} else {
|
||||
input_gate_scratch = scratch_buffer->data.f;
|
||||
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
|
||||
}
|
||||
|
||||
// Check optional tensors, the respective pointers can be null.
|
||||
const float* input_to_input_weights_ptr =
|
||||
(use_cifg) ? nullptr : input_to_input_weights->data.f;
|
||||
const float* recurrent_to_input_weights_ptr =
|
||||
(use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
|
||||
const float* input_gate_bias_ptr =
|
||||
(use_cifg) ? nullptr : input_gate_bias->data.f;
|
||||
const float* cell_to_input_weights_ptr =
|
||||
(use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
|
||||
const float* cell_to_forget_weights_ptr =
|
||||
(use_peephole) ? cell_to_forget_weights->data.f : nullptr;
|
||||
const float* cell_to_output_weights_ptr =
|
||||
(use_peephole) ? cell_to_output_weights->data.f : nullptr;
|
||||
const float* projection_weights_ptr =
|
||||
(projection_weights == nullptr) ? nullptr : projection_weights->data.f;
|
||||
const float* projection_bias_ptr =
|
||||
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
|
||||
|
||||
float* aux_input_ptr = nullptr;
|
||||
float* aux_input_to_input_weights_ptr = nullptr;
|
||||
float* aux_input_to_forget_weights_ptr = nullptr;
|
||||
float* aux_input_to_cell_weights_ptr = nullptr;
|
||||
float* aux_input_to_output_weights_ptr = nullptr;
|
||||
if (aux_input_size > 0) {
|
||||
aux_input_ptr = aux_input->data.f;
|
||||
aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
|
||||
aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
|
||||
aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
|
||||
aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
|
||||
}
|
||||
|
||||
// Loop through the sequence.
|
||||
const int input_step = n_batch * n_input;
|
||||
const int output_step = n_batch * output->dims->data[2];
|
||||
for (int t = 0; t < max_time; t++) {
|
||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||
float* output_ptr_time =
|
||||
output->data.f + t_rel * output_step + output_offset;
|
||||
|
||||
kernel_utils::LstmStepWithAuxInput(
|
||||
input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
|
||||
input_to_cell_weights->data.f, input_to_output_weights->data.f,
|
||||
aux_input_ptr, aux_input_to_input_weights_ptr,
|
||||
aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
|
||||
aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
|
||||
recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
|
||||
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
|
||||
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
|
||||
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
|
||||
output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
|
||||
params, n_batch, n_cell, n_input, aux_input_size, n_output,
|
||||
activation_state->data.f, cell_state->data.f, input_gate_scratch,
|
||||
forget_gate_scratch, cell_scratch, output_gate_scratch,
|
||||
output_ptr_time);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
const TfLiteTensor* aux_input_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
||||
TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
|
||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
||||
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
|
||||
TfLiteTensor* output_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output) {
|
||||
const int max_time = input->dims->data[0];
|
||||
const int n_batch = input->dims->data[1];
|
||||
const int n_input = input->dims->data[2];
|
||||
const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
|
||||
// n_cell and n_output will be the same size when there is no projection.
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
float* input_gate_scratch = nullptr;
|
||||
float* cell_scratch = nullptr;
|
||||
float* forget_gate_scratch = nullptr;
|
||||
float* output_gate_scratch = nullptr;
|
||||
if (use_cifg) {
|
||||
cell_scratch = scratch_buffer->data.f;
|
||||
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
} else {
|
||||
input_gate_scratch = scratch_buffer->data.f;
|
||||
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
|
||||
}
|
||||
|
||||
// Check optional tensors, the respective pointers can be null.
|
||||
int8_t* input_to_input_weights_ptr = nullptr;
|
||||
float input_to_input_weights_scale = 1.0f;
|
||||
int8_t* recurrent_to_input_weights_ptr = nullptr;
|
||||
float recurrent_to_input_weights_scale = 1.0f;
|
||||
float* input_gate_bias_ptr = nullptr;
|
||||
if (!use_cifg) {
|
||||
input_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
|
||||
recurrent_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
|
||||
input_gate_bias_ptr = input_gate_bias->data.f;
|
||||
input_to_input_weights_scale = input_to_input_weights->params.scale;
|
||||
recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
|
||||
}
|
||||
|
||||
int8_t* cell_to_input_weights_ptr = nullptr;
|
||||
int8_t* cell_to_forget_weights_ptr = nullptr;
|
||||
int8_t* cell_to_output_weights_ptr = nullptr;
|
||||
float cell_to_input_weights_scale = 1.0f;
|
||||
float cell_to_forget_weights_scale = 1.0f;
|
||||
float cell_to_output_weights_scale = 1.0f;
|
||||
if (use_peephole) {
|
||||
if (!use_cifg) {
|
||||
cell_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
|
||||
cell_to_input_weights_scale = cell_to_input_weights->params.scale;
|
||||
}
|
||||
cell_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
|
||||
cell_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
|
||||
cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
|
||||
cell_to_output_weights_scale = cell_to_output_weights->params.scale;
|
||||
}
|
||||
|
||||
const int8_t* projection_weights_ptr =
|
||||
(projection_weights == nullptr)
|
||||
? nullptr
|
||||
: reinterpret_cast<int8_t*>(projection_weights->data.uint8);
|
||||
const float projection_weights_scale =
|
||||
(projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
|
||||
const float* projection_bias_ptr =
|
||||
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
|
||||
|
||||
// Required tensors, pointers are non-null.
|
||||
const int8_t* input_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
|
||||
const float input_to_forget_weights_scale =
|
||||
input_to_forget_weights->params.scale;
|
||||
const int8_t* input_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
|
||||
const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
|
||||
const int8_t* input_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
|
||||
const float input_to_output_weights_scale =
|
||||
input_to_output_weights->params.scale;
|
||||
const int8_t* recurrent_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
|
||||
const float recurrent_to_forget_weights_scale =
|
||||
recurrent_to_forget_weights->params.scale;
|
||||
const int8_t* recurrent_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
|
||||
const float recurrent_to_cell_weights_scale =
|
||||
recurrent_to_cell_weights->params.scale;
|
||||
const int8_t* recurrent_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
|
||||
const float recurrent_to_output_weights_scale =
|
||||
recurrent_to_output_weights->params.scale;
|
||||
const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
|
||||
const float* cell_bias_ptr = cell_bias->data.f;
|
||||
const float* output_gate_bias_ptr = output_gate_bias->data.f;
|
||||
|
||||
float* output_state_ptr = output_state->data.f;
|
||||
float* cell_state_ptr = cell_state->data.f;
|
||||
|
||||
// Temporary storage for quantized values and scaling factors.
|
||||
int8_t* quantized_input_ptr =
|
||||
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
|
||||
int8_t* quantized_aux_input_ptr =
|
||||
(aux_input_quantized == nullptr)
|
||||
? nullptr
|
||||
: reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
|
||||
int8_t* quantized_output_state_ptr =
|
||||
reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
|
||||
int8_t* quantized_cell_state_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
|
||||
float* scaling_factors_ptr = scaling_factors->data.f;
|
||||
float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
|
||||
float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
|
||||
|
||||
// Auxiliary input and weights.
|
||||
float* aux_input_ptr = nullptr;
|
||||
int8_t* aux_input_to_input_weights_ptr = nullptr;
|
||||
int8_t* aux_input_to_forget_weights_ptr = nullptr;
|
||||
int8_t* aux_input_to_cell_weights_ptr = nullptr;
|
||||
int8_t* aux_input_to_output_weights_ptr = nullptr;
|
||||
float aux_input_to_input_weights_scale = 0.0f;
|
||||
float aux_input_to_forget_weights_scale = 0.0f;
|
||||
float aux_input_to_cell_weights_scale = 0.0f;
|
||||
float aux_input_to_output_weights_scale = 0.0f;
|
||||
if (aux_input_size > 0) {
|
||||
aux_input_ptr = aux_input->data.f;
|
||||
aux_input_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
|
||||
aux_input_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
|
||||
aux_input_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
|
||||
aux_input_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
|
||||
aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
|
||||
aux_input_to_forget_weights_scale =
|
||||
aux_input_to_forget_weights->params.scale;
|
||||
aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
|
||||
aux_input_to_output_weights_scale =
|
||||
aux_input_to_output_weights->params.scale;
|
||||
}
|
||||
|
||||
// Feed the sequence into the LSTM step-by-step.
|
||||
const int input_step = n_batch * n_input;
|
||||
const int output_step = n_batch * output->dims->data[2];
|
||||
for (int t = 0; t < max_time; t++) {
|
||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||
float* output_ptr = output->data.f + t_rel * output_step + output_offset;
|
||||
|
||||
kernel_utils::LstmStepWithAuxInput(
|
||||
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
|
||||
input_to_forget_weights_ptr, input_to_forget_weights_scale,
|
||||
input_to_cell_weights_ptr, input_to_cell_weights_scale,
|
||||
input_to_output_weights_ptr, input_to_output_weights_scale,
|
||||
aux_input_ptr, aux_input_to_input_weights_ptr,
|
||||
aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
|
||||
aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
|
||||
aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
|
||||
aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
|
||||
recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
|
||||
recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
|
||||
recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
|
||||
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
|
||||
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
|
||||
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
|
||||
cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
|
||||
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
|
||||
projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
|
||||
n_input, aux_input_size, n_output, input_gate_scratch,
|
||||
forget_gate_scratch, cell_scratch, output_gate_scratch,
|
||||
scaling_factors_ptr, prod_scaling_factors_ptr,
|
||||
recovered_cell_weights_ptr, quantized_input_ptr,
|
||||
quantized_aux_input_ptr, quantized_output_state_ptr,
|
||||
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// The LSTM Op engine.
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
|
||||
@ -1157,7 +834,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
switch (fw_input_to_output_weights->type) {
|
||||
case kTfLiteFloat32: {
|
||||
TfLiteStatus fw_pass_status = EvalFloat(
|
||||
TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
|
||||
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
||||
fw_input_to_cell_weights, fw_input_to_output_weights,
|
||||
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
|
||||
@ -1172,7 +849,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
fw_activation_state, fw_cell_state, fw_output);
|
||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||
|
||||
TfLiteStatus bw_pass_status = EvalFloat(
|
||||
TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
|
||||
input, bw_input_to_input_weights, bw_input_to_forget_weights,
|
||||
bw_input_to_cell_weights, bw_input_to_output_weights,
|
||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
||||
@ -1208,7 +885,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, kRecoveredCellWeights);
|
||||
|
||||
TfLiteStatus fw_pass_status = EvalHybrid(
|
||||
TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
|
||||
input, fw_input_to_input_weights, fw_input_to_forget_weights,
|
||||
fw_input_to_cell_weights, fw_input_to_output_weights,
|
||||
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
|
||||
@ -1226,7 +903,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
fw_output);
|
||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||
|
||||
TfLiteStatus bw_pass_status = EvalHybrid(
|
||||
TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
|
||||
input, bw_input_to_input_weights, bw_input_to_forget_weights,
|
||||
bw_input_to_cell_weights, bw_input_to_output_weights,
|
||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
||||
|
@ -169,603 +169,5 @@ void RnnBatchStep(
|
||||
hidden_state_ptr_batch);
|
||||
}
|
||||
|
||||
void LstmStep(
|
||||
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
|
||||
const float* input_to_forget_weights_ptr,
|
||||
const float* input_to_cell_weights_ptr,
|
||||
const float* input_to_output_weights_ptr,
|
||||
const float* recurrent_to_input_weights_ptr,
|
||||
const float* recurrent_to_forget_weights_ptr,
|
||||
const float* recurrent_to_cell_weights_ptr,
|
||||
const float* recurrent_to_output_weights_ptr,
|
||||
const float* cell_to_input_weights_ptr,
|
||||
const float* cell_to_forget_weights_ptr,
|
||||
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* output_ptr_batch) {
|
||||
LstmStepWithAuxInput(
|
||||
input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
|
||||
input_to_cell_weights_ptr, input_to_output_weights_ptr,
|
||||
/*aux_input_ptr_batch=*/nullptr,
|
||||
/*aux_input_to_input_weights_ptr=*/nullptr,
|
||||
/*aux_input_to_forget_weights_ptr=*/nullptr,
|
||||
/*aux_input_to_cell_weights_ptr=*/nullptr,
|
||||
/*aux_input_to_output_weights_ptr=*/nullptr,
|
||||
recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
|
||||
recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
|
||||
cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
|
||||
cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
|
||||
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
|
||||
projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0,
|
||||
n_output, output_state_ptr, cell_state_ptr, input_gate_scratch,
|
||||
forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
|
||||
}
|
||||
|
||||
void LstmStepWithAuxInput(
|
||||
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
|
||||
const float* input_to_forget_weights_ptr,
|
||||
const float* input_to_cell_weights_ptr,
|
||||
const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
|
||||
const float* aux_input_to_input_weights_ptr,
|
||||
const float* aux_input_to_forget_weights_ptr,
|
||||
const float* aux_input_to_cell_weights_ptr,
|
||||
const float* aux_input_to_output_weights_ptr,
|
||||
const float* recurrent_to_input_weights_ptr,
|
||||
const float* recurrent_to_forget_weights_ptr,
|
||||
const float* recurrent_to_cell_weights_ptr,
|
||||
const float* recurrent_to_output_weights_ptr,
|
||||
const float* cell_to_input_weights_ptr,
|
||||
const float* cell_to_forget_weights_ptr,
|
||||
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* output_ptr_batch) {
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existense of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
// Initialize scratch buffers with bias.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
||||
forget_gate_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
||||
cell_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
||||
output_gate_scratch, /*result_stride=*/1);
|
||||
|
||||
// If auxiliary input is available then compute aux_input_weight * aux_input
|
||||
if (aux_input_ptr_batch != nullptr) {
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
|
||||
aux_input_ptr_batch, n_batch, input_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
}
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
|
||||
aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
|
||||
n_batch, cell_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
||||
aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
// For each batch and cell: compute recurrent_weight * output_state.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, forget_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, cell_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, output_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
|
||||
// For each batch and cell: update input gate.
|
||||
if (!use_cifg) {
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
|
||||
// For each batch and cell: update forget gate.
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
||||
forget_gate_scratch);
|
||||
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||
n_batch * n_cell, cell_state_ptr);
|
||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
if (use_cifg) {
|
||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
||||
} else {
|
||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
||||
}
|
||||
if (params->cell_clip > 0.0) {
|
||||
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
||||
params->cell_clip, cell_state_ptr);
|
||||
}
|
||||
|
||||
// For each batch and cell: update the output gate.
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||
output_gate_scratch);
|
||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
||||
n_batch * n_cell, output_gate_scratch);
|
||||
|
||||
// For each batch: update the projection and output_state.
|
||||
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
||||
const bool use_projection_bias = (projection_bias_ptr != nullptr);
|
||||
if (use_projection_weight) {
|
||||
if (use_projection_bias) {
|
||||
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
|
||||
n_batch, output_ptr_batch);
|
||||
} else {
|
||||
tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
|
||||
output_ptr_batch, /*result_stride=*/1);
|
||||
if (params->proj_clip > 0.0) {
|
||||
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
|
||||
params->proj_clip, output_ptr_batch);
|
||||
}
|
||||
} else {
|
||||
tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
|
||||
output_ptr_batch);
|
||||
}
|
||||
tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
|
||||
output_state_ptr);
|
||||
}
|
||||
|
||||
void LstmStep(
|
||||
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
|
||||
float input_to_input_weights_scale,
|
||||
const int8_t* input_to_forget_weights_ptr,
|
||||
float input_to_forget_weights_scale,
|
||||
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
|
||||
const int8_t* input_to_output_weights_ptr,
|
||||
float input_to_output_weights_scale,
|
||||
const int8_t* recurrent_to_input_weights_ptr,
|
||||
float recurrent_to_input_weights_scale,
|
||||
const int8_t* recurrent_to_forget_weights_ptr,
|
||||
float recurrent_to_forget_weights_scale,
|
||||
const int8_t* recurrent_to_cell_weights_ptr,
|
||||
float recurrent_to_cell_weights_scale,
|
||||
const int8_t* recurrent_to_output_weights_ptr,
|
||||
float recurrent_to_output_weights_scale,
|
||||
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
|
||||
const int8_t* cell_to_forget_weights_ptr,
|
||||
float cell_to_forget_weights_scale,
|
||||
const int8_t* cell_to_output_weights_ptr,
|
||||
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
|
||||
float projection_weights_scale, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_output, float* input_gate_scratch, float* forget_gate_scratch,
|
||||
float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
|
||||
float* product_scaling_factors, float* recovered_cell_weights,
|
||||
int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* output_ptr_batch) {
|
||||
LstmStepWithAuxInput(
|
||||
input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
|
||||
input_to_forget_weights_ptr, input_to_forget_weights_scale,
|
||||
input_to_cell_weights_ptr, input_to_cell_weights_scale,
|
||||
input_to_output_weights_ptr, input_to_output_weights_scale,
|
||||
/*aux_input_ptr_batch=*/nullptr,
|
||||
/*aux_input_to_input_weights_ptr=*/nullptr,
|
||||
/*aux_input_to_input_weights_scale=*/0.0f,
|
||||
/*aux_input_to_forget_weights_ptr=*/nullptr,
|
||||
/*aux_input_to_forget_weights_scale=*/0.0f,
|
||||
/*aux_input_to_cell_weights_ptr=*/nullptr,
|
||||
/*aux_input_to_cell_weights_scale=*/0.0f,
|
||||
/*aux_input_to_output_weights_ptr=*/nullptr,
|
||||
/*aux_input_to_output_weights_scale=*/0.0f,
|
||||
recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
|
||||
recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
|
||||
recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
|
||||
recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
|
||||
cell_to_input_weights_ptr, cell_to_input_weights_scale,
|
||||
cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
|
||||
cell_to_output_weights_ptr, cell_to_output_weights_scale,
|
||||
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
|
||||
output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
|
||||
projection_bias_ptr, params, n_batch, n_cell, n_input,
|
||||
/*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch,
|
||||
cell_scratch, output_gate_scratch, scaling_factors,
|
||||
product_scaling_factors, recovered_cell_weights,
|
||||
quantized_input_ptr_batch,
|
||||
/*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
|
||||
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
|
||||
output_ptr_batch);
|
||||
}
|
||||
|
||||
void LstmStepWithAuxInput(
|
||||
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
|
||||
float input_to_input_weights_scale,
|
||||
const int8_t* input_to_forget_weights_ptr,
|
||||
float input_to_forget_weights_scale,
|
||||
const int8_t* input_to_cell_weights_ptr,
|
||||
float input_to_cell_weights_scale,
|
||||
const int8_t* input_to_output_weights_ptr,
|
||||
float input_to_output_weights_scale, const float* aux_input_ptr_batch,
|
||||
const int8_t* aux_input_to_input_weights_ptr,
|
||||
float aux_input_to_input_weights_scale,
|
||||
const int8_t* aux_input_to_forget_weights_ptr,
|
||||
float aux_input_to_forget_weights_scale,
|
||||
const int8_t* aux_input_to_cell_weights_ptr,
|
||||
float aux_input_to_cell_weights_scale,
|
||||
const int8_t* aux_input_to_output_weights_ptr,
|
||||
float aux_input_to_output_weights_scale,
|
||||
const int8_t* recurrent_to_input_weights_ptr,
|
||||
float recurrent_to_input_weights_scale,
|
||||
const int8_t* recurrent_to_forget_weights_ptr,
|
||||
float recurrent_to_forget_weights_scale,
|
||||
const int8_t* recurrent_to_cell_weights_ptr,
|
||||
float recurrent_to_cell_weights_scale,
|
||||
const int8_t* recurrent_to_output_weights_ptr,
|
||||
float recurrent_to_output_weights_scale,
|
||||
const int8_t* cell_to_input_weights_ptr,
|
||||
float cell_to_input_weights_scale,
|
||||
const int8_t* cell_to_forget_weights_ptr,
|
||||
float cell_to_forget_weights_scale,
|
||||
const int8_t* cell_to_output_weights_ptr,
|
||||
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
|
||||
float projection_weights_scale, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_aux_input, int n_output, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch,
|
||||
float* output_gate_scratch, float* scaling_factors,
|
||||
float* product_scaling_factors, float* recovered_cell_weights,
|
||||
int8_t* quantized_input_ptr_batch,
|
||||
int8_t* quantized_aux_input_ptr_batch,
|
||||
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
|
||||
float* output_state_ptr, float* cell_state_ptr,
|
||||
float* output_ptr_batch) {
|
||||
// Since we have already checked that weights are all there or none, we
|
||||
// can check the existense of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
// Initialize scratch buffers with bias.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
|
||||
n_batch, input_gate_scratch);
|
||||
}
|
||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell,
|
||||
n_batch, forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell,
|
||||
n_batch, output_gate_scratch);
|
||||
|
||||
if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
float unused_min, unused_max;
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_input;
|
||||
tensor_utils::SymmetricQuantizeFloats(
|
||||
input_ptr_batch + offset, n_input,
|
||||
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||
&scaling_factors[b]);
|
||||
}
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
if (!use_cifg) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_input_weights_ptr, n_cell, n_input,
|
||||
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_forget_weights_ptr, n_cell, n_input,
|
||||
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
forget_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_cell_weights_ptr, n_cell, n_input,
|
||||
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
cell_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_output_weights_ptr, n_cell, n_input,
|
||||
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
output_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
}
|
||||
|
||||
if (aux_input_ptr_batch != nullptr &&
|
||||
!tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
float unused_min, unused_max;
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_input;
|
||||
tensor_utils::SymmetricQuantizeFloats(
|
||||
aux_input_ptr_batch + offset, n_input,
|
||||
quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||
&scaling_factors[b]);
|
||||
}
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
if (!use_cifg) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_input_weights_ptr, n_cell, n_input,
|
||||
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_forget_weights_ptr, n_cell, n_input,
|
||||
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_cell_weights_ptr, n_cell, n_input,
|
||||
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
cell_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_output_weights_ptr, n_cell, n_input,
|
||||
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
output_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
float unused_min, unused_max;
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_output;
|
||||
tensor_utils::SymmetricQuantizeFloats(
|
||||
output_state_ptr + offset, n_output,
|
||||
quantized_output_state_ptr + offset, &unused_min, &unused_max,
|
||||
&scaling_factors[b]);
|
||||
}
|
||||
// For each batch and cell: compute recurrent_weight * output_state.
|
||||
if (!use_cifg) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_input_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
cell_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_output_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
output_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
bool is_cell_state_all_zeros =
|
||||
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
|
||||
|
||||
// For each batch and cell: update input gate.
|
||||
if (!use_cifg) {
|
||||
if (use_peephole && !is_cell_state_all_zeros) {
|
||||
tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
|
||||
cell_to_input_weights_scale,
|
||||
recovered_cell_weights);
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
|
||||
// For each batch and cell: update forget gate.
|
||||
if (use_peephole && !is_cell_state_all_zeros) {
|
||||
tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
|
||||
cell_to_forget_weights_scale,
|
||||
recovered_cell_weights);
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
||||
forget_gate_scratch);
|
||||
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
|
||||
cell_state_ptr, n_batch * n_cell,
|
||||
cell_state_ptr);
|
||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
if (use_cifg) {
|
||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||
cell_scratch, forget_gate_scratch, n_batch * n_cell,
|
||||
cell_state_ptr);
|
||||
} else {
|
||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
||||
}
|
||||
if (params->cell_clip > 0.0) {
|
||||
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
||||
params->cell_clip, cell_state_ptr);
|
||||
}
|
||||
|
||||
is_cell_state_all_zeros =
|
||||
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
|
||||
// For each batch and cell: update the output gate.
|
||||
if (use_peephole && !is_cell_state_all_zeros) {
|
||||
tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
|
||||
cell_to_output_weights_scale,
|
||||
recovered_cell_weights);
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||
output_gate_scratch);
|
||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
||||
n_batch * n_cell,
|
||||
output_gate_scratch);
|
||||
|
||||
// For each batch: update the projection and output_state.
|
||||
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
||||
const bool use_projection_bias = (projection_bias_ptr != nullptr);
|
||||
if (use_projection_weight) {
|
||||
if (use_projection_bias) {
|
||||
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
|
||||
n_batch, output_ptr_batch);
|
||||
} else {
|
||||
tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
|
||||
}
|
||||
if (!tensor_utils::IsZeroVector(output_gate_scratch,
|
||||
n_batch * n_cell)) {
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
float unused_min, unused_max;
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_cell;
|
||||
tensor_utils::SymmetricQuantizeFloats(
|
||||
output_gate_scratch + offset, n_cell,
|
||||
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
|
||||
&scaling_factors[b]);
|
||||
}
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * projection_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights_ptr, n_output, n_cell,
|
||||
quantized_cell_state_ptr, product_scaling_factors, n_batch,
|
||||
output_ptr_batch,
|
||||
/*result_stride=*/1);
|
||||
}
|
||||
if (params->proj_clip > 0.0) {
|
||||
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
|
||||
params->proj_clip, output_ptr_batch);
|
||||
}
|
||||
} else {
|
||||
tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
|
||||
output_ptr_batch);
|
||||
}
|
||||
tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
|
||||
output_state_ptr);
|
||||
}
|
||||
|
||||
} // namespace kernel_utils
|
||||
} // namespace tflite
|
||||
|
@ -76,190 +76,6 @@ void RnnBatchStep(
|
||||
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
|
||||
float* hidden_state_ptr_batch, float* output_ptr_batch);
|
||||
|
||||
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
|
||||
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
|
||||
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
|
||||
// parameters:
|
||||
// - params: various LSTM params including activation, clipping, etc.,
|
||||
// - n_batch: size of batch,
|
||||
// - n_cell: number of cells (or units),
|
||||
// - n_input: the input size,
|
||||
// - n_output: the output size.
|
||||
//
|
||||
// The pointers to the cell and output state and the output are updated.
|
||||
//
|
||||
// The pointers with the suffix "_batch" point to data aligned in batch_major
|
||||
// order, and each step processes batch_size many inputs from input_ptr_batch,
|
||||
// and updates batch_size many cell and output states.
|
||||
void LstmStep(
|
||||
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
|
||||
const float* input_to_forget_weights_ptr,
|
||||
const float* input_to_cell_weights_ptr,
|
||||
const float* input_to_output_weights_ptr,
|
||||
const float* recurrent_to_input_weights_ptr,
|
||||
const float* recurrent_to_forget_weights_ptr,
|
||||
const float* recurrent_to_cell_weights_ptr,
|
||||
const float* recurrent_to_output_weights_ptr,
|
||||
const float* cell_to_input_weights_ptr,
|
||||
const float* cell_to_forget_weights_ptr,
|
||||
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* output_ptr_batch);
|
||||
|
||||
// Same as above but includes an auxiliary input with the corresponding weights.
|
||||
void LstmStepWithAuxInput(
|
||||
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
|
||||
const float* input_to_forget_weights_ptr,
|
||||
const float* input_to_cell_weights_ptr,
|
||||
const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
|
||||
const float* aux_input_to_input_weights_ptr,
|
||||
const float* aux_input_to_forget_weights_ptr,
|
||||
const float* aux_input_to_cell_weights_ptr,
|
||||
const float* aux_input_to_output_weights_ptr,
|
||||
const float* recurrent_to_input_weights_ptr,
|
||||
const float* recurrent_to_forget_weights_ptr,
|
||||
const float* recurrent_to_cell_weights_ptr,
|
||||
const float* recurrent_to_output_weights_ptr,
|
||||
const float* cell_to_input_weights_ptr,
|
||||
const float* cell_to_forget_weights_ptr,
|
||||
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* output_ptr_batch);
|
||||
|
||||
// Same as above but with quantized weight matrices. In detail:
|
||||
// Input of size 'n_batch * n_input':
|
||||
// input_ptr_batch
|
||||
//
|
||||
// LSTM weights:
|
||||
// Quantized input weights of size 'n_cell * n_input':
|
||||
// input_to_input_weights - optional (can be nullptr)
|
||||
// input_to_forget_weights
|
||||
// input_to_cell_weights
|
||||
// input_to_input_weights
|
||||
// Quantized recurrent weights of size 'n_cell * n_output':
|
||||
// recurrent_to_input_weights - optional
|
||||
// recurrent_to_forget_weights
|
||||
// recurrent_to_cell_weights
|
||||
// recurrent_to_input_weights
|
||||
// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
|
||||
// cell_to_input_weights - optional
|
||||
// cell_to_cell_weights - optional
|
||||
// cell_to_output_weights - optional
|
||||
// Quantized projection weights of size 'n_output * n_cell'
|
||||
// projection_weights_ptr - optional
|
||||
// Weight scales (scalars) for each of the weights above.
|
||||
// input_to_input_weights_scale - optional
|
||||
// input_to_forget_weights_scale
|
||||
// input_to_cell_weights_scale
|
||||
// input_to_output_weights_scale
|
||||
// recurrent_to_input_weights_scale - optional
|
||||
// recurrent_to_forget_weights_scale
|
||||
// recurrent_to_cell_weights_scale
|
||||
// recurrent_to_output_weights_scale
|
||||
// cell_to_input_weights_scale,
|
||||
// cell_to_forget_weights_scale,
|
||||
// cell_to_output_weights_scale,
|
||||
// projection_weights_scale - optional
|
||||
// Gate biases of size 'n_cell':
|
||||
// input_gate_bias_ptr - optional
|
||||
// forget_gate_bias_ptr
|
||||
// cell_gate_bias_ptr
|
||||
// output_gate_bias_ptr
|
||||
//
|
||||
// Temporary pre-allocated storage for quantized values:
|
||||
// quantized_input_ptr_batch (same size as input_ptr_batch)
|
||||
// quantized_output_state_ptr (same size as output_state_ptr)
|
||||
// quantized_cell_state_ptr (same size as cell_state_ptr)
|
||||
// Temporary pre-allocated storage for recovered values:
|
||||
// recovered_cell_weights (same size as cell_to_*_weights)
|
||||
//
|
||||
// Outputs:
|
||||
// output_state_ptr - size 'n_batch * n_output'
|
||||
// cell_state_ptr - size 'n_batch * n_cell'
|
||||
// output_ptr_batch - size 'n_batch * n_output'
|
||||
void LstmStep(
|
||||
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
|
||||
float input_to_input_weights_scale,
|
||||
const int8_t* input_to_forget_weights_ptr,
|
||||
float input_to_forget_weights_scale,
|
||||
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
|
||||
const int8_t* input_to_output_weights_ptr,
|
||||
float input_to_output_weights_scale,
|
||||
const int8_t* recurrent_to_input_weights_ptr,
|
||||
float recurrent_to_input_weights_scale,
|
||||
const int8_t* recurrent_to_forget_weights_ptr,
|
||||
float recurrent_to_forget_weights_scale,
|
||||
const int8_t* recurrent_to_cell_weights_ptr,
|
||||
float recurrent_to_cell_weights_scale,
|
||||
const int8_t* recurrent_to_output_weights_ptr,
|
||||
float recurrent_to_output_weights_scale,
|
||||
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
|
||||
const int8_t* cell_to_forget_weights_ptr,
|
||||
float cell_to_forget_weights_scale,
|
||||
const int8_t* cell_to_output_weights_ptr,
|
||||
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
|
||||
float projection_weights_scale, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_output, float* input_gate_scratch, float* forget_gate_scratch,
|
||||
float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
|
||||
float* product_scaling_factors, float* recovered_cell_weights,
|
||||
int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* output_ptr_batch);
|
||||
|
||||
void LstmStepWithAuxInput(
|
||||
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
|
||||
float input_to_input_weights_scale,
|
||||
const int8_t* input_to_forget_weights_ptr,
|
||||
float input_to_forget_weights_scale,
|
||||
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
|
||||
const int8_t* input_to_output_weights_ptr,
|
||||
float input_to_output_weights_scale, const float* aux_input_ptr_batch,
|
||||
const int8_t* aux_input_to_input_weights_ptr,
|
||||
float aux_input_to_input_weights_scale,
|
||||
const int8_t* aux_input_to_forget_weights_ptr,
|
||||
float aux_input_to_forget_weights_scale,
|
||||
const int8_t* aux_input_to_cell_weights_ptr,
|
||||
float aux_input_to_cell_weights_scale,
|
||||
const int8_t* aux_input_to_output_weights_ptr,
|
||||
float aux_input_to_output_weights_scale,
|
||||
const int8_t* recurrent_to_input_weights_ptr,
|
||||
float recurrent_to_input_weights_scale,
|
||||
const int8_t* recurrent_to_forget_weights_ptr,
|
||||
float recurrent_to_forget_weights_scale,
|
||||
const int8_t* recurrent_to_cell_weights_ptr,
|
||||
float recurrent_to_cell_weights_scale,
|
||||
const int8_t* recurrent_to_output_weights_ptr,
|
||||
float recurrent_to_output_weights_scale,
|
||||
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
|
||||
const int8_t* cell_to_forget_weights_ptr,
|
||||
float cell_to_forget_weights_scale,
|
||||
const int8_t* cell_to_output_weights_ptr,
|
||||
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
|
||||
float projection_weights_scale, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_aux_input, int n_output, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* scaling_factors, float* product_scaling_factors,
|
||||
float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
|
||||
int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* output_ptr_batch);
|
||||
|
||||
} // namespace kernel_utils
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
|
||||
#include "tensorflow/contrib/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -424,263 +425,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// The LSTM Op engine.
|
||||
TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output) {
|
||||
const int n_batch = input->dims->data[0];
|
||||
const int n_input = input->dims->data[1];
|
||||
// n_cell and n_output will be the same size when there is no projection.
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
float* input_gate_scratch = nullptr;
|
||||
float* cell_scratch = nullptr;
|
||||
float* forget_gate_scratch = nullptr;
|
||||
float* output_gate_scratch = nullptr;
|
||||
if (use_cifg) {
|
||||
cell_scratch = scratch_buffer->data.f;
|
||||
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
} else {
|
||||
input_gate_scratch = scratch_buffer->data.f;
|
||||
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
|
||||
}
|
||||
|
||||
// Check optional tensors, the respective pointers can be null.
|
||||
const float* input_to_input_weights_ptr =
|
||||
(use_cifg) ? nullptr : input_to_input_weights->data.f;
|
||||
const float* recurrent_to_input_weights_ptr =
|
||||
(use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
|
||||
const float* input_gate_bias_ptr =
|
||||
(use_cifg) ? nullptr : input_gate_bias->data.f;
|
||||
const float* cell_to_input_weights_ptr =
|
||||
(use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
|
||||
const float* cell_to_forget_weights_ptr =
|
||||
(use_peephole) ? cell_to_forget_weights->data.f : nullptr;
|
||||
const float* cell_to_output_weights_ptr =
|
||||
(use_peephole) ? cell_to_output_weights->data.f : nullptr;
|
||||
const float* projection_weights_ptr =
|
||||
(projection_weights == nullptr) ? nullptr : projection_weights->data.f;
|
||||
const float* projection_bias_ptr =
|
||||
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
|
||||
|
||||
// Required tensors, pointers are non-null.
|
||||
const float* input_ptr_batch = input->data.f;
|
||||
const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
|
||||
const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
|
||||
const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
|
||||
const float* recurrent_to_forget_weights_ptr =
|
||||
recurrent_to_forget_weights->data.f;
|
||||
const float* recurrent_to_cell_weights_ptr =
|
||||
recurrent_to_cell_weights->data.f;
|
||||
const float* recurrent_to_output_weights_ptr =
|
||||
recurrent_to_output_weights->data.f;
|
||||
const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
|
||||
const float* cell_bias_ptr = cell_bias->data.f;
|
||||
const float* output_gate_bias_ptr = output_gate_bias->data.f;
|
||||
|
||||
float* activation_state_ptr = activation_state->data.f;
|
||||
float* cell_state_ptr = cell_state->data.f;
|
||||
float* output_ptr_batch = output->data.f;
|
||||
|
||||
kernel_utils::LstmStep(
|
||||
input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
|
||||
input_to_cell_weights_ptr, input_to_output_weights_ptr,
|
||||
recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
|
||||
recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
|
||||
cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
|
||||
cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
|
||||
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
|
||||
projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
|
||||
activation_state_ptr, cell_state_ptr, input_gate_scratch,
|
||||
forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
|
||||
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
|
||||
TfLiteTensor* activation_state_quantized,
|
||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output) {
|
||||
const int n_batch = input->dims->data[0];
|
||||
const int n_input = input->dims->data[1];
|
||||
// n_cell and n_output will be the same size when there is no projection.
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
float* input_gate_scratch = nullptr;
|
||||
float* cell_scratch = nullptr;
|
||||
float* forget_gate_scratch = nullptr;
|
||||
float* output_gate_scratch = nullptr;
|
||||
if (use_cifg) {
|
||||
cell_scratch = scratch_buffer->data.f;
|
||||
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
} else {
|
||||
input_gate_scratch = scratch_buffer->data.f;
|
||||
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
|
||||
}
|
||||
|
||||
// Check optional tensors, the respective pointers can be null.
|
||||
int8_t* input_to_input_weights_ptr = nullptr;
|
||||
float input_to_input_weights_scale = 1.0f;
|
||||
int8_t* recurrent_to_input_weights_ptr = nullptr;
|
||||
float recurrent_to_input_weights_scale = 1.0f;
|
||||
float* input_gate_bias_ptr = nullptr;
|
||||
if (!use_cifg) {
|
||||
input_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
|
||||
recurrent_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
|
||||
input_gate_bias_ptr = input_gate_bias->data.f;
|
||||
input_to_input_weights_scale = input_to_input_weights->params.scale;
|
||||
recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
|
||||
}
|
||||
|
||||
int8_t* cell_to_input_weights_ptr = nullptr;
|
||||
int8_t* cell_to_forget_weights_ptr = nullptr;
|
||||
int8_t* cell_to_output_weights_ptr = nullptr;
|
||||
float cell_to_input_weights_scale = 1.0f;
|
||||
float cell_to_forget_weights_scale = 1.0f;
|
||||
float cell_to_output_weights_scale = 1.0f;
|
||||
if (use_peephole) {
|
||||
if (!use_cifg) {
|
||||
cell_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
|
||||
cell_to_input_weights_scale = cell_to_input_weights->params.scale;
|
||||
}
|
||||
cell_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
|
||||
cell_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
|
||||
cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
|
||||
cell_to_output_weights_scale = cell_to_output_weights->params.scale;
|
||||
}
|
||||
|
||||
const int8_t* projection_weights_ptr =
|
||||
(projection_weights == nullptr)
|
||||
? nullptr
|
||||
: reinterpret_cast<int8_t*>(projection_weights->data.uint8);
|
||||
const float projection_weights_scale =
|
||||
(projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
|
||||
const float* projection_bias_ptr =
|
||||
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
|
||||
|
||||
// Required tensors, pointers are non-null.
|
||||
const float* input_ptr_batch = input->data.f;
|
||||
const int8_t* input_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
|
||||
const float input_to_forget_weights_scale =
|
||||
input_to_forget_weights->params.scale;
|
||||
const int8_t* input_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
|
||||
const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
|
||||
const int8_t* input_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
|
||||
const float input_to_output_weights_scale =
|
||||
input_to_output_weights->params.scale;
|
||||
const int8_t* recurrent_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
|
||||
const float recurrent_to_forget_weights_scale =
|
||||
recurrent_to_forget_weights->params.scale;
|
||||
const int8_t* recurrent_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
|
||||
const float recurrent_to_cell_weights_scale =
|
||||
recurrent_to_cell_weights->params.scale;
|
||||
const int8_t* recurrent_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
|
||||
const float recurrent_to_output_weights_scale =
|
||||
recurrent_to_output_weights->params.scale;
|
||||
const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
|
||||
const float* cell_bias_ptr = cell_bias->data.f;
|
||||
const float* output_gate_bias_ptr = output_gate_bias->data.f;
|
||||
|
||||
float* activation_state_ptr = activation_state->data.f;
|
||||
float* cell_state_ptr = cell_state->data.f;
|
||||
float* output_ptr_batch = output->data.f;
|
||||
|
||||
// Temporary storage for quantized values and scaling factors.
|
||||
int8_t* quantized_input_ptr =
|
||||
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
|
||||
int8_t* quantized_activation_state_ptr =
|
||||
reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
|
||||
int8_t* quantized_cell_state_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
|
||||
float* scaling_factors_ptr = scaling_factors->data.f;
|
||||
float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
|
||||
float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
|
||||
|
||||
kernel_utils::LstmStep(
|
||||
input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
|
||||
input_to_forget_weights_ptr, input_to_forget_weights_scale,
|
||||
input_to_cell_weights_ptr, input_to_cell_weights_scale,
|
||||
input_to_output_weights_ptr, input_to_output_weights_scale,
|
||||
recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
|
||||
recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
|
||||
recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
|
||||
recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
|
||||
cell_to_input_weights_ptr, cell_to_input_weights_scale,
|
||||
cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
|
||||
cell_to_output_weights_ptr, cell_to_output_weights_scale,
|
||||
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
|
||||
output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
|
||||
projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
|
||||
input_gate_scratch, forget_gate_scratch, cell_scratch,
|
||||
output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
|
||||
recovered_cell_weights_ptr, quantized_input_ptr,
|
||||
quantized_activation_state_ptr, quantized_cell_state_ptr,
|
||||
activation_state_ptr, cell_state_ptr, output_ptr_batch);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
@ -738,15 +482,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// TODO(mirkov): add a check that weights are all uint8s or all floats.
|
||||
switch (input_to_output_weights->type) {
|
||||
case kTfLiteFloat32: {
|
||||
return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights,
|
||||
cell_to_output_weights, input_gate_bias,
|
||||
forget_gate_bias, cell_bias, output_gate_bias,
|
||||
projection_weights, projection_bias, params,
|
||||
scratch_buffer, activation_state, cell_state, output);
|
||||
return lstm_eval::EvalFloat(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
/*aux_input_to_cell_weights=*/nullptr,
|
||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||
projection_bias, params, /*forward_sequence=*/true,
|
||||
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
||||
output);
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
|
||||
@ -759,17 +509,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, /*index=*/6);
|
||||
return EvalHybrid(
|
||||
return lstm_eval::EvalHybrid(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
|
||||
projection_weights, projection_bias, params, scratch_buffer,
|
||||
scaling_factors, prod_scaling_factors, recovered_cell_weights,
|
||||
input_quantized, activation_state_quantized, cell_state_quantized,
|
||||
activation_state, cell_state, output);
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
/*aux_input_to_cell_weights=*/nullptr,
|
||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||
projection_bias, params, /*forward_sequence=*/true,
|
||||
/*output_offset=*/0, scratch_buffer, scaling_factors,
|
||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||
cell_state_quantized, activation_state, cell_state, output);
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Type %d is not currently supported.",
|
||||
|
909
tensorflow/contrib/lite/kernels/lstm_eval.cc
Normal file
909
tensorflow/contrib/lite/kernels/lstm_eval.cc
Normal file
@ -0,0 +1,909 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace lstm_eval {
|
||||
|
||||
namespace {
|
||||
|
||||
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
|
||||
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
|
||||
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
|
||||
// parameters:
|
||||
// - params: various LSTM params including activation, clipping, etc.,
|
||||
// - n_batch: size of batch,
|
||||
// - n_cell: number of cells (or units),
|
||||
// - n_input: the input size,
|
||||
// - n_output: the output size.
|
||||
//
|
||||
// The pointers to the cell and output state and the output are updated.
|
||||
//
|
||||
// The pointers with the suffix "_batch" point to data aligned in batch_major
|
||||
// order, and each step processes batch_size many inputs from input_ptr_batch,
|
||||
// and updates batch_size many cell and output states.
|
||||
inline void LstmStepWithAuxInput(
|
||||
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
|
||||
const float* input_to_forget_weights_ptr,
|
||||
const float* input_to_cell_weights_ptr,
|
||||
const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
|
||||
const float* aux_input_to_input_weights_ptr,
|
||||
const float* aux_input_to_forget_weights_ptr,
|
||||
const float* aux_input_to_cell_weights_ptr,
|
||||
const float* aux_input_to_output_weights_ptr,
|
||||
const float* recurrent_to_input_weights_ptr,
|
||||
const float* recurrent_to_forget_weights_ptr,
|
||||
const float* recurrent_to_cell_weights_ptr,
|
||||
const float* recurrent_to_output_weights_ptr,
|
||||
const float* cell_to_input_weights_ptr,
|
||||
const float* cell_to_forget_weights_ptr,
|
||||
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* output_ptr_batch) {
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existense of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
// Initialize scratch buffers with bias.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
||||
forget_gate_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
||||
cell_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
||||
output_gate_scratch, /*result_stride=*/1);
|
||||
|
||||
// If auxiliary input is available then compute aux_input_weight * aux_input
|
||||
if (aux_input_ptr_batch != nullptr) {
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
|
||||
aux_input_ptr_batch, n_batch, input_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
}
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
|
||||
aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
|
||||
n_batch, cell_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
||||
aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
// For each batch and cell: compute recurrent_weight * output_state.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, forget_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, cell_scratch, /*result_stride=*/1);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, output_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
|
||||
// For each batch and cell: update input gate.
|
||||
if (!use_cifg) {
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
|
||||
// For each batch and cell: update forget gate.
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
||||
forget_gate_scratch);
|
||||
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||
n_batch * n_cell, cell_state_ptr);
|
||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
if (use_cifg) {
|
||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
||||
} else {
|
||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
||||
}
|
||||
if (params->cell_clip > 0.0) {
|
||||
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
||||
params->cell_clip, cell_state_ptr);
|
||||
}
|
||||
|
||||
// For each batch and cell: update the output gate.
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||
output_gate_scratch);
|
||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
||||
n_batch * n_cell, output_gate_scratch);
|
||||
|
||||
// For each batch: update the projection and output_state.
|
||||
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
||||
const bool use_projection_bias = (projection_bias_ptr != nullptr);
|
||||
if (use_projection_weight) {
|
||||
if (use_projection_bias) {
|
||||
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
|
||||
n_batch, output_ptr_batch);
|
||||
} else {
|
||||
tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
|
||||
output_ptr_batch, /*result_stride=*/1);
|
||||
if (params->proj_clip > 0.0) {
|
||||
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
|
||||
params->proj_clip, output_ptr_batch);
|
||||
}
|
||||
} else {
|
||||
tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
|
||||
output_ptr_batch);
|
||||
}
|
||||
tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
|
||||
output_state_ptr);
|
||||
}
|
||||
|
||||
// Same as above but with quantized weight matrices. In detail:
|
||||
// Input of size 'n_batch * n_input':
|
||||
// input_ptr_batch
|
||||
//
|
||||
// LSTM weights:
|
||||
// Quantized input weights of size 'n_cell * n_input':
|
||||
// input_to_input_weights - optional (can be nullptr)
|
||||
// input_to_forget_weights
|
||||
// input_to_cell_weights
|
||||
// input_to_input_weights
|
||||
// Quantized recurrent weights of size 'n_cell * n_output':
|
||||
// recurrent_to_input_weights - optional
|
||||
// recurrent_to_forget_weights
|
||||
// recurrent_to_cell_weights
|
||||
// recurrent_to_input_weights
|
||||
// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
|
||||
// cell_to_input_weights - optional
|
||||
// cell_to_cell_weights - optional
|
||||
// cell_to_output_weights - optional
|
||||
// Quantized projection weights of size 'n_output * n_cell'
|
||||
// projection_weights_ptr - optional
|
||||
// Weight scales (scalars) for each of the weights above.
|
||||
// input_to_input_weights_scale - optional
|
||||
// input_to_forget_weights_scale
|
||||
// input_to_cell_weights_scale
|
||||
// input_to_output_weights_scale
|
||||
// recurrent_to_input_weights_scale - optional
|
||||
// recurrent_to_forget_weights_scale
|
||||
// recurrent_to_cell_weights_scale
|
||||
// recurrent_to_output_weights_scale
|
||||
// cell_to_input_weights_scale,
|
||||
// cell_to_forget_weights_scale,
|
||||
// cell_to_output_weights_scale,
|
||||
// projection_weights_scale - optional
|
||||
// Gate biases of size 'n_cell':
|
||||
// input_gate_bias_ptr - optional
|
||||
// forget_gate_bias_ptr
|
||||
// cell_gate_bias_ptr
|
||||
// output_gate_bias_ptr
|
||||
//
|
||||
// Temporary pre-allocated storage for quantized values:
|
||||
// quantized_input_ptr_batch (same size as input_ptr_batch)
|
||||
// quantized_output_state_ptr (same size as output_state_ptr)
|
||||
// quantized_cell_state_ptr (same size as cell_state_ptr)
|
||||
// Temporary pre-allocated storage for recovered values:
|
||||
// recovered_cell_weights (same size as cell_to_*_weights)
|
||||
//
|
||||
// Outputs:
|
||||
// output_state_ptr - size 'n_batch * n_output'
|
||||
// cell_state_ptr - size 'n_batch * n_cell'
|
||||
// output_ptr_batch - size 'n_batch * n_output'
|
||||
inline void LstmStepWithAuxInput(
|
||||
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
|
||||
float input_to_input_weights_scale,
|
||||
const int8_t* input_to_forget_weights_ptr,
|
||||
float input_to_forget_weights_scale,
|
||||
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
|
||||
const int8_t* input_to_output_weights_ptr,
|
||||
float input_to_output_weights_scale, const float* aux_input_ptr_batch,
|
||||
const int8_t* aux_input_to_input_weights_ptr,
|
||||
float aux_input_to_input_weights_scale,
|
||||
const int8_t* aux_input_to_forget_weights_ptr,
|
||||
float aux_input_to_forget_weights_scale,
|
||||
const int8_t* aux_input_to_cell_weights_ptr,
|
||||
float aux_input_to_cell_weights_scale,
|
||||
const int8_t* aux_input_to_output_weights_ptr,
|
||||
float aux_input_to_output_weights_scale,
|
||||
const int8_t* recurrent_to_input_weights_ptr,
|
||||
float recurrent_to_input_weights_scale,
|
||||
const int8_t* recurrent_to_forget_weights_ptr,
|
||||
float recurrent_to_forget_weights_scale,
|
||||
const int8_t* recurrent_to_cell_weights_ptr,
|
||||
float recurrent_to_cell_weights_scale,
|
||||
const int8_t* recurrent_to_output_weights_ptr,
|
||||
float recurrent_to_output_weights_scale,
|
||||
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
|
||||
const int8_t* cell_to_forget_weights_ptr,
|
||||
float cell_to_forget_weights_scale,
|
||||
const int8_t* cell_to_output_weights_ptr,
|
||||
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
|
||||
float projection_weights_scale, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_aux_input, int n_output, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* scaling_factors, float* product_scaling_factors,
|
||||
float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
|
||||
int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* output_ptr_batch) {
|
||||
// Since we have already checked that weights are all there or none, we
|
||||
// can check the existense of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
// Initialize scratch buffers with bias.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
|
||||
if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
float unused_min, unused_max;
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_input;
|
||||
tensor_utils::SymmetricQuantizeFloats(
|
||||
input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
|
||||
&unused_min, &unused_max, &scaling_factors[b]);
|
||||
}
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
if (!use_cifg) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_input_weights_ptr, n_cell, n_input,
|
||||
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
|
||||
product_scaling_factors, n_batch, forget_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
|
||||
product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
|
||||
product_scaling_factors, n_batch, output_gate_scratch,
|
||||
/*result_stride=*/1);
|
||||
}
|
||||
|
||||
if (aux_input_ptr_batch != nullptr &&
|
||||
!tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
float unused_min, unused_max;
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_input;
|
||||
tensor_utils::SymmetricQuantizeFloats(
|
||||
aux_input_ptr_batch + offset, n_input,
|
||||
quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||
&scaling_factors[b]);
|
||||
}
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
if (!use_cifg) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_input_weights_ptr, n_cell, n_input,
|
||||
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_forget_weights_ptr, n_cell, n_input,
|
||||
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_cell_weights_ptr, n_cell, n_input,
|
||||
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
cell_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_output_weights_ptr, n_cell, n_input,
|
||||
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||
output_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
float unused_min, unused_max;
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_output;
|
||||
tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
|
||||
quantized_output_state_ptr + offset,
|
||||
&unused_min, &unused_max,
|
||||
&scaling_factors[b]);
|
||||
}
|
||||
// For each batch and cell: compute recurrent_weight * output_state.
|
||||
if (!use_cifg) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_input_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
cell_scratch, /*result_stride=*/1);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_output_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
output_gate_scratch, /*result_stride=*/1);
|
||||
}
|
||||
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
bool is_cell_state_all_zeros =
|
||||
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
|
||||
|
||||
// For each batch and cell: update input gate.
|
||||
if (!use_cifg) {
|
||||
if (use_peephole && !is_cell_state_all_zeros) {
|
||||
tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
|
||||
cell_to_input_weights_scale,
|
||||
recovered_cell_weights);
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
|
||||
// For each batch and cell: update forget gate.
|
||||
if (use_peephole && !is_cell_state_all_zeros) {
|
||||
tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
|
||||
cell_to_forget_weights_scale,
|
||||
recovered_cell_weights);
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
||||
forget_gate_scratch);
|
||||
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||
n_batch * n_cell, cell_state_ptr);
|
||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
if (use_cifg) {
|
||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
||||
} else {
|
||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
||||
}
|
||||
if (params->cell_clip > 0.0) {
|
||||
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
||||
params->cell_clip, cell_state_ptr);
|
||||
}
|
||||
|
||||
is_cell_state_all_zeros =
|
||||
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
|
||||
// For each batch and cell: update the output gate.
|
||||
if (use_peephole && !is_cell_state_all_zeros) {
|
||||
tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
|
||||
cell_to_output_weights_scale,
|
||||
recovered_cell_weights);
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||
output_gate_scratch);
|
||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
||||
n_batch * n_cell, output_gate_scratch);
|
||||
|
||||
// For each batch: update the projection and output_state.
|
||||
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
||||
const bool use_projection_bias = (projection_bias_ptr != nullptr);
|
||||
if (use_projection_weight) {
|
||||
if (use_projection_bias) {
|
||||
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
|
||||
n_batch, output_ptr_batch);
|
||||
} else {
|
||||
tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
|
||||
}
|
||||
if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
float unused_min, unused_max;
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
const int offset = b * n_cell;
|
||||
tensor_utils::SymmetricQuantizeFloats(
|
||||
output_gate_scratch + offset, n_cell,
|
||||
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
|
||||
&scaling_factors[b]);
|
||||
}
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * projection_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
|
||||
product_scaling_factors, n_batch, output_ptr_batch,
|
||||
/*result_stride=*/1);
|
||||
}
|
||||
if (params->proj_clip > 0.0) {
|
||||
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
|
||||
params->proj_clip, output_ptr_batch);
|
||||
}
|
||||
} else {
|
||||
tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
|
||||
output_ptr_batch);
|
||||
}
|
||||
tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
|
||||
output_state_ptr);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
const TfLiteTensor* aux_input_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
||||
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output) {
|
||||
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
|
||||
const int n_batch = input->dims->data[input->dims->size - 2];
|
||||
const int n_input = input->dims->data[input->dims->size - 1];
|
||||
const int aux_input_size =
|
||||
(aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
|
||||
|
||||
// n_cell and n_output will be the same size when there is no projection.
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existense of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
// Index the scratch buffers pointers to the global scratch buffer.
|
||||
float* input_gate_scratch = nullptr;
|
||||
float* cell_scratch = nullptr;
|
||||
float* forget_gate_scratch = nullptr;
|
||||
float* output_gate_scratch = nullptr;
|
||||
if (use_cifg) {
|
||||
cell_scratch = scratch_buffer->data.f;
|
||||
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
} else {
|
||||
input_gate_scratch = scratch_buffer->data.f;
|
||||
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
|
||||
}
|
||||
|
||||
// Check optional tensors, the respective pointers can be null.
|
||||
const float* input_to_input_weights_ptr =
|
||||
(use_cifg) ? nullptr : input_to_input_weights->data.f;
|
||||
const float* recurrent_to_input_weights_ptr =
|
||||
(use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
|
||||
const float* input_gate_bias_ptr =
|
||||
(use_cifg) ? nullptr : input_gate_bias->data.f;
|
||||
const float* cell_to_input_weights_ptr =
|
||||
(use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
|
||||
const float* cell_to_forget_weights_ptr =
|
||||
(use_peephole) ? cell_to_forget_weights->data.f : nullptr;
|
||||
const float* cell_to_output_weights_ptr =
|
||||
(use_peephole) ? cell_to_output_weights->data.f : nullptr;
|
||||
const float* projection_weights_ptr =
|
||||
(projection_weights == nullptr) ? nullptr : projection_weights->data.f;
|
||||
const float* projection_bias_ptr =
|
||||
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
|
||||
|
||||
float* aux_input_ptr = nullptr;
|
||||
float* aux_input_to_input_weights_ptr = nullptr;
|
||||
float* aux_input_to_forget_weights_ptr = nullptr;
|
||||
float* aux_input_to_cell_weights_ptr = nullptr;
|
||||
float* aux_input_to_output_weights_ptr = nullptr;
|
||||
if (aux_input_size > 0) {
|
||||
aux_input_ptr = aux_input->data.f;
|
||||
aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
|
||||
aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
|
||||
aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
|
||||
aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
|
||||
}
|
||||
|
||||
// Loop through the sequence.
|
||||
const int input_step = n_batch * n_input;
|
||||
const int output_step = n_batch * output->dims->data[output->dims->size - 1];
|
||||
for (int t = 0; t < max_time; t++) {
|
||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||
float* output_ptr_time =
|
||||
output->data.f + t_rel * output_step + output_offset;
|
||||
|
||||
LstmStepWithAuxInput(
|
||||
input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
|
||||
input_to_cell_weights->data.f, input_to_output_weights->data.f,
|
||||
aux_input_ptr, aux_input_to_input_weights_ptr,
|
||||
aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
|
||||
aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
|
||||
recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
|
||||
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
|
||||
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
|
||||
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
|
||||
output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
|
||||
params, n_batch, n_cell, n_input, aux_input_size, n_output,
|
||||
activation_state->data.f, cell_state->data.f, input_gate_scratch,
|
||||
forget_gate_scratch, cell_scratch, output_gate_scratch,
|
||||
output_ptr_time);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
const TfLiteTensor* aux_input_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
||||
TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
|
||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
||||
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
|
||||
TfLiteTensor* output_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output) {
|
||||
const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
|
||||
const int n_batch = input->dims->data[input->dims->size - 2];
|
||||
const int n_input = input->dims->data[input->dims->size - 1];
|
||||
const int aux_input_size =
|
||||
(aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
|
||||
// n_cell and n_output will be the same size when there is no projection.
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
float* input_gate_scratch = nullptr;
|
||||
float* cell_scratch = nullptr;
|
||||
float* forget_gate_scratch = nullptr;
|
||||
float* output_gate_scratch = nullptr;
|
||||
if (use_cifg) {
|
||||
cell_scratch = scratch_buffer->data.f;
|
||||
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
} else {
|
||||
input_gate_scratch = scratch_buffer->data.f;
|
||||
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
|
||||
}
|
||||
|
||||
// Check optional tensors, the respective pointers can be null.
|
||||
int8_t* input_to_input_weights_ptr = nullptr;
|
||||
float input_to_input_weights_scale = 1.0f;
|
||||
int8_t* recurrent_to_input_weights_ptr = nullptr;
|
||||
float recurrent_to_input_weights_scale = 1.0f;
|
||||
float* input_gate_bias_ptr = nullptr;
|
||||
if (!use_cifg) {
|
||||
input_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
|
||||
recurrent_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
|
||||
input_gate_bias_ptr = input_gate_bias->data.f;
|
||||
input_to_input_weights_scale = input_to_input_weights->params.scale;
|
||||
recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
|
||||
}
|
||||
|
||||
int8_t* cell_to_input_weights_ptr = nullptr;
|
||||
int8_t* cell_to_forget_weights_ptr = nullptr;
|
||||
int8_t* cell_to_output_weights_ptr = nullptr;
|
||||
float cell_to_input_weights_scale = 1.0f;
|
||||
float cell_to_forget_weights_scale = 1.0f;
|
||||
float cell_to_output_weights_scale = 1.0f;
|
||||
if (use_peephole) {
|
||||
if (!use_cifg) {
|
||||
cell_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
|
||||
cell_to_input_weights_scale = cell_to_input_weights->params.scale;
|
||||
}
|
||||
cell_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
|
||||
cell_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
|
||||
cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
|
||||
cell_to_output_weights_scale = cell_to_output_weights->params.scale;
|
||||
}
|
||||
|
||||
const int8_t* projection_weights_ptr =
|
||||
(projection_weights == nullptr)
|
||||
? nullptr
|
||||
: reinterpret_cast<int8_t*>(projection_weights->data.uint8);
|
||||
const float projection_weights_scale =
|
||||
(projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
|
||||
const float* projection_bias_ptr =
|
||||
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
|
||||
|
||||
// Required tensors, pointers are non-null.
|
||||
const int8_t* input_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
|
||||
const float input_to_forget_weights_scale =
|
||||
input_to_forget_weights->params.scale;
|
||||
const int8_t* input_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
|
||||
const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
|
||||
const int8_t* input_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
|
||||
const float input_to_output_weights_scale =
|
||||
input_to_output_weights->params.scale;
|
||||
const int8_t* recurrent_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
|
||||
const float recurrent_to_forget_weights_scale =
|
||||
recurrent_to_forget_weights->params.scale;
|
||||
const int8_t* recurrent_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
|
||||
const float recurrent_to_cell_weights_scale =
|
||||
recurrent_to_cell_weights->params.scale;
|
||||
const int8_t* recurrent_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
|
||||
const float recurrent_to_output_weights_scale =
|
||||
recurrent_to_output_weights->params.scale;
|
||||
const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
|
||||
const float* cell_bias_ptr = cell_bias->data.f;
|
||||
const float* output_gate_bias_ptr = output_gate_bias->data.f;
|
||||
|
||||
float* output_state_ptr = output_state->data.f;
|
||||
float* cell_state_ptr = cell_state->data.f;
|
||||
|
||||
// Temporary storage for quantized values and scaling factors.
|
||||
int8_t* quantized_input_ptr =
|
||||
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
|
||||
int8_t* quantized_aux_input_ptr =
|
||||
(aux_input_quantized == nullptr)
|
||||
? nullptr
|
||||
: reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
|
||||
int8_t* quantized_output_state_ptr =
|
||||
reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
|
||||
int8_t* quantized_cell_state_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
|
||||
float* scaling_factors_ptr = scaling_factors->data.f;
|
||||
float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
|
||||
float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
|
||||
|
||||
// Auxiliary input and weights.
|
||||
float* aux_input_ptr = nullptr;
|
||||
int8_t* aux_input_to_input_weights_ptr = nullptr;
|
||||
int8_t* aux_input_to_forget_weights_ptr = nullptr;
|
||||
int8_t* aux_input_to_cell_weights_ptr = nullptr;
|
||||
int8_t* aux_input_to_output_weights_ptr = nullptr;
|
||||
float aux_input_to_input_weights_scale = 0.0f;
|
||||
float aux_input_to_forget_weights_scale = 0.0f;
|
||||
float aux_input_to_cell_weights_scale = 0.0f;
|
||||
float aux_input_to_output_weights_scale = 0.0f;
|
||||
if (aux_input_size > 0) {
|
||||
aux_input_ptr = aux_input->data.f;
|
||||
aux_input_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
|
||||
aux_input_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
|
||||
aux_input_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
|
||||
aux_input_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
|
||||
aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
|
||||
aux_input_to_forget_weights_scale =
|
||||
aux_input_to_forget_weights->params.scale;
|
||||
aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
|
||||
aux_input_to_output_weights_scale =
|
||||
aux_input_to_output_weights->params.scale;
|
||||
}
|
||||
|
||||
// Feed the sequence into the LSTM step-by-step.
|
||||
const int input_step = n_batch * n_input;
|
||||
const int output_step = n_batch * output->dims->data[output->dims->size - 1];
|
||||
for (int t = 0; t < max_time; t++) {
|
||||
// If this is the forward_sequence, step forward, otherwise step backwards.
|
||||
const int t_rel = forward_sequence ? t : max_time - t - 1;
|
||||
const float* input_ptr = input->data.f + t_rel * input_step;
|
||||
float* output_ptr = output->data.f + t_rel * output_step + output_offset;
|
||||
|
||||
LstmStepWithAuxInput(
|
||||
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
|
||||
input_to_forget_weights_ptr, input_to_forget_weights_scale,
|
||||
input_to_cell_weights_ptr, input_to_cell_weights_scale,
|
||||
input_to_output_weights_ptr, input_to_output_weights_scale,
|
||||
aux_input_ptr, aux_input_to_input_weights_ptr,
|
||||
aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
|
||||
aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
|
||||
aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
|
||||
aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
|
||||
recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
|
||||
recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
|
||||
recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
|
||||
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
|
||||
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
|
||||
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
|
||||
cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
|
||||
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
|
||||
projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
|
||||
n_input, aux_input_size, n_output, input_gate_scratch,
|
||||
forget_gate_scratch, cell_scratch, output_gate_scratch,
|
||||
scaling_factors_ptr, prod_scaling_factors_ptr,
|
||||
recovered_cell_weights_ptr, quantized_input_ptr,
|
||||
quantized_aux_input_ptr, quantized_output_state_ptr,
|
||||
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace lstm_eval
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
79
tensorflow/contrib/lite/kernels/lstm_eval.h
Normal file
79
tensorflow/contrib/lite/kernels/lstm_eval.h
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
|
||||
#define TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
|
||||
|
||||
#include "tensorflow/contrib/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/contrib/lite/c/c_api_internal.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace lstm_eval {
|
||||
|
||||
TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
const TfLiteTensor* aux_input_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
||||
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output);
|
||||
|
||||
TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
const TfLiteTensor* aux_input_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
|
||||
TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
|
||||
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
|
||||
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
|
||||
TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
|
||||
TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output);
|
||||
|
||||
} // namespace lstm_eval
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
|
||||
#include "tensorflow/contrib/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -429,273 +430,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// The LSTM Op engine.
|
||||
TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
|
||||
TfLiteTensor* output) {
|
||||
const int max_time = input->dims->data[0];
|
||||
const int n_batch = input->dims->data[1];
|
||||
const int n_input = input->dims->data[2];
|
||||
// n_cell and n_output will be the same size when there is no projection.
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
float* input_gate_scratch = nullptr;
|
||||
float* cell_scratch = nullptr;
|
||||
float* forget_gate_scratch = nullptr;
|
||||
float* output_gate_scratch = nullptr;
|
||||
if (use_cifg) {
|
||||
cell_scratch = scratch_buffer->data.f;
|
||||
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
} else {
|
||||
input_gate_scratch = scratch_buffer->data.f;
|
||||
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
|
||||
}
|
||||
|
||||
// Check optional tensors, the respective pointers can be null.
|
||||
const float* input_to_input_weights_ptr =
|
||||
(use_cifg) ? nullptr : input_to_input_weights->data.f;
|
||||
const float* recurrent_to_input_weights_ptr =
|
||||
(use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
|
||||
const float* input_gate_bias_ptr =
|
||||
(use_cifg) ? nullptr : input_gate_bias->data.f;
|
||||
const float* cell_to_input_weights_ptr =
|
||||
(use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
|
||||
const float* cell_to_forget_weights_ptr =
|
||||
(use_peephole) ? cell_to_forget_weights->data.f : nullptr;
|
||||
const float* cell_to_output_weights_ptr =
|
||||
(use_peephole) ? cell_to_output_weights->data.f : nullptr;
|
||||
const float* projection_weights_ptr =
|
||||
(projection_weights == nullptr) ? nullptr : projection_weights->data.f;
|
||||
const float* projection_bias_ptr =
|
||||
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
|
||||
|
||||
// Required tensors, pointers are non-null.
|
||||
const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
|
||||
const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
|
||||
const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
|
||||
const float* recurrent_to_forget_weights_ptr =
|
||||
recurrent_to_forget_weights->data.f;
|
||||
const float* recurrent_to_cell_weights_ptr =
|
||||
recurrent_to_cell_weights->data.f;
|
||||
const float* recurrent_to_output_weights_ptr =
|
||||
recurrent_to_output_weights->data.f;
|
||||
const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
|
||||
const float* cell_bias_ptr = cell_bias->data.f;
|
||||
const float* output_gate_bias_ptr = output_gate_bias->data.f;
|
||||
|
||||
float* activation_state_ptr = activation_state->data.f;
|
||||
float* cell_state_ptr = cell_state->data.f;
|
||||
|
||||
// Feed the sequence into the LSTM step-by-step.
|
||||
for (int t = 0; t < max_time; t++) {
|
||||
const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
|
||||
float* output_ptr_batch = output->data.f + t * n_batch * n_output;
|
||||
|
||||
kernel_utils::LstmStep(
|
||||
input_ptr_batch, input_to_input_weights_ptr,
|
||||
input_to_forget_weights_ptr, input_to_cell_weights_ptr,
|
||||
input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
|
||||
recurrent_to_forget_weights_ptr, recurrent_to_cell_weights_ptr,
|
||||
recurrent_to_output_weights_ptr, cell_to_input_weights_ptr,
|
||||
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
|
||||
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
|
||||
output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr,
|
||||
params, n_batch, n_cell, n_input, n_output, activation_state_ptr,
|
||||
cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch,
|
||||
output_gate_scratch, output_ptr_batch);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
|
||||
const TfLiteTensor* input_to_forget_weights,
|
||||
const TfLiteTensor* input_to_cell_weights,
|
||||
const TfLiteTensor* input_to_output_weights,
|
||||
const TfLiteTensor* recurrent_to_input_weights,
|
||||
const TfLiteTensor* recurrent_to_forget_weights,
|
||||
const TfLiteTensor* recurrent_to_cell_weights,
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
|
||||
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
|
||||
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
|
||||
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
|
||||
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
|
||||
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
|
||||
TfLiteTensor* activation_state_quantized,
|
||||
TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
|
||||
TfLiteTensor* cell_state, TfLiteTensor* output) {
|
||||
const int max_time = input->dims->data[0];
|
||||
const int n_batch = input->dims->data[1];
|
||||
const int n_input = input->dims->data[2];
|
||||
// n_cell and n_output will be the same size when there is no projection.
|
||||
const int n_cell = input_to_output_weights->dims->data[0];
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
|
||||
float* input_gate_scratch = nullptr;
|
||||
float* cell_scratch = nullptr;
|
||||
float* forget_gate_scratch = nullptr;
|
||||
float* output_gate_scratch = nullptr;
|
||||
if (use_cifg) {
|
||||
cell_scratch = scratch_buffer->data.f;
|
||||
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
} else {
|
||||
input_gate_scratch = scratch_buffer->data.f;
|
||||
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
|
||||
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
|
||||
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
|
||||
}
|
||||
|
||||
// Check optional tensors, the respective pointers can be null.
|
||||
int8_t* input_to_input_weights_ptr = nullptr;
|
||||
float input_to_input_weights_scale = 1.0f;
|
||||
int8_t* recurrent_to_input_weights_ptr = nullptr;
|
||||
float recurrent_to_input_weights_scale = 1.0f;
|
||||
float* input_gate_bias_ptr = nullptr;
|
||||
if (!use_cifg) {
|
||||
input_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
|
||||
recurrent_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
|
||||
input_gate_bias_ptr = input_gate_bias->data.f;
|
||||
input_to_input_weights_scale = input_to_input_weights->params.scale;
|
||||
recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
|
||||
}
|
||||
|
||||
int8_t* cell_to_input_weights_ptr = nullptr;
|
||||
int8_t* cell_to_forget_weights_ptr = nullptr;
|
||||
int8_t* cell_to_output_weights_ptr = nullptr;
|
||||
float cell_to_input_weights_scale = 1.0f;
|
||||
float cell_to_forget_weights_scale = 1.0f;
|
||||
float cell_to_output_weights_scale = 1.0f;
|
||||
if (use_peephole) {
|
||||
if (!use_cifg) {
|
||||
cell_to_input_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
|
||||
cell_to_input_weights_scale = cell_to_input_weights->params.scale;
|
||||
}
|
||||
cell_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
|
||||
cell_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
|
||||
cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
|
||||
cell_to_output_weights_scale = cell_to_output_weights->params.scale;
|
||||
}
|
||||
|
||||
const int8_t* projection_weights_ptr =
|
||||
(projection_weights == nullptr)
|
||||
? nullptr
|
||||
: reinterpret_cast<int8_t*>(projection_weights->data.uint8);
|
||||
float projection_weights_scale =
|
||||
(projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
|
||||
const float* projection_bias_ptr =
|
||||
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
|
||||
|
||||
// Required tensors, pointers are non-null.
|
||||
const int8_t* input_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
|
||||
const float input_to_forget_weights_scale =
|
||||
input_to_forget_weights->params.scale;
|
||||
const int8_t* input_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
|
||||
const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
|
||||
const int8_t* input_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
|
||||
const float input_to_output_weights_scale =
|
||||
input_to_output_weights->params.scale;
|
||||
const int8_t* recurrent_to_forget_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
|
||||
const float recurrent_to_forget_weights_scale =
|
||||
recurrent_to_forget_weights->params.scale;
|
||||
const int8_t* recurrent_to_cell_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
|
||||
const float recurrent_to_cell_weights_scale =
|
||||
recurrent_to_cell_weights->params.scale;
|
||||
const int8_t* recurrent_to_output_weights_ptr =
|
||||
reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
|
||||
const float recurrent_to_output_weights_scale =
|
||||
recurrent_to_output_weights->params.scale;
|
||||
const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
|
||||
const float* cell_bias_ptr = cell_bias->data.f;
|
||||
const float* output_gate_bias_ptr = output_gate_bias->data.f;
|
||||
|
||||
float* activation_state_ptr = activation_state->data.f;
|
||||
float* cell_state_ptr = cell_state->data.f;
|
||||
|
||||
// Temporary storage for quantized values and scaling factors.
|
||||
int8_t* quantized_input_ptr =
|
||||
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
|
||||
int8_t* quantized_activation_state_ptr =
|
||||
reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
|
||||
int8_t* quantized_cell_state_ptr =
|
||||
reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
|
||||
float* scaling_factors_ptr = scaling_factors->data.f;
|
||||
float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
|
||||
float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
|
||||
|
||||
// Feed the sequence into the LSTM step-by-step.
|
||||
for (int t = 0; t < max_time; t++) {
|
||||
const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
|
||||
float* output_ptr_batch = output->data.f + t * n_batch * n_output;
|
||||
|
||||
kernel_utils::LstmStep(
|
||||
input_ptr_batch, input_to_input_weights_ptr,
|
||||
input_to_input_weights_scale, input_to_forget_weights_ptr,
|
||||
input_to_forget_weights_scale, input_to_cell_weights_ptr,
|
||||
input_to_cell_weights_scale, input_to_output_weights_ptr,
|
||||
input_to_output_weights_scale, recurrent_to_input_weights_ptr,
|
||||
recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
|
||||
recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
|
||||
recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
|
||||
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
|
||||
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
|
||||
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
|
||||
cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
|
||||
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
|
||||
projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
|
||||
n_input, n_output, input_gate_scratch, forget_gate_scratch,
|
||||
cell_scratch, output_gate_scratch, scaling_factors_ptr,
|
||||
prod_scaling_factors_ptr, recovered_cell_weights_ptr,
|
||||
quantized_input_ptr, quantized_activation_state_ptr,
|
||||
quantized_cell_state_ptr, activation_state_ptr, cell_state_ptr,
|
||||
output_ptr_batch);
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
@ -750,15 +484,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
switch (input_to_output_weights->type) {
|
||||
case kTfLiteFloat32: {
|
||||
return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights,
|
||||
cell_to_output_weights, input_gate_bias,
|
||||
forget_gate_bias, cell_bias, output_gate_bias,
|
||||
projection_weights, projection_bias, params,
|
||||
scratch_buffer, activation_state, cell_state, output);
|
||||
return lstm_eval::EvalFloat(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
/*aux_input_to_cell_weights=*/nullptr,
|
||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||
projection_bias, params, /*forward_sequence=*/true,
|
||||
/*output_offset=*/0, scratch_buffer, activation_state, cell_state,
|
||||
output);
|
||||
}
|
||||
case kTfLiteUInt8: {
|
||||
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
|
||||
@ -771,17 +511,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetTemporary(context, node, /*index=*/5);
|
||||
TfLiteTensor* recovered_cell_weights =
|
||||
GetTemporary(context, node, /*index=*/6);
|
||||
return EvalHybrid(
|
||||
return lstm_eval::EvalHybrid(
|
||||
input, input_to_input_weights, input_to_forget_weights,
|
||||
input_to_cell_weights, input_to_output_weights,
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
|
||||
projection_weights, projection_bias, params, scratch_buffer,
|
||||
scaling_factors, prod_scaling_factors, recovered_cell_weights,
|
||||
input_quantized, activation_state_quantized, cell_state_quantized,
|
||||
activation_state, cell_state, output);
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
/*aux_input_to_cell_weights=*/nullptr,
|
||||
/*aux_input_to_output_weights=*/nullptr, input_gate_bias,
|
||||
forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
|
||||
projection_bias, params, /*forward_sequence=*/true,
|
||||
/*output_offset=*/0, scratch_buffer, scaling_factors,
|
||||
prod_scaling_factors, recovered_cell_weights, input_quantized,
|
||||
/*aux_input_quantized=*/nullptr, activation_state_quantized,
|
||||
cell_state_quantized, activation_state, cell_state, output);
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Type %d is not currently supported.",
|
||||
|
Loading…
Reference in New Issue
Block a user