Unidirectional sequence LSTM: Use constants for tensor indices in lstm_shared.h instead of redefining them.
PiperOrigin-RevId: 316573775 Change-Id: I2de5264e4ab3375344dfe3132b6cde7b6cb0226c
This commit is contained in:
parent
5bc13320c9
commit
6f0425cd06
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
#include "tensorflow/lite/kernels/lstm_eval.h"
|
#include "tensorflow/lite/kernels/lstm_eval.h"
|
||||||
|
#include "tensorflow/lite/kernels/lstm_shared.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -36,53 +37,6 @@ struct OpData {
|
||||||
bool compute_row_sums = false;
|
bool compute_row_sums = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Input Tensors of size {max_time, n_batch, n_input}
|
|
||||||
constexpr int kInputTensor = 0;
|
|
||||||
|
|
||||||
// Input weight tensors of size: {n_cell, n_input}
|
|
||||||
constexpr int kInputToInputWeightsTensor = 1; // Optional
|
|
||||||
constexpr int kInputToForgetWeightsTensor = 2;
|
|
||||||
constexpr int kInputToCellWeightsTensor = 3;
|
|
||||||
constexpr int kInputToOutputWeightsTensor = 4;
|
|
||||||
|
|
||||||
// Recurrent weight tensors of size {n_cell, n_output}
|
|
||||||
constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
|
|
||||||
constexpr int kRecurrentToForgetWeightsTensor = 6;
|
|
||||||
constexpr int kRecurrentToCellWeightsTensor = 7;
|
|
||||||
constexpr int kRecurrentToOutputWeightsTensor = 8;
|
|
||||||
|
|
||||||
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
|
|
||||||
constexpr int kCellToInputWeightsTensor = 9; // Optional
|
|
||||||
constexpr int kCellToForgetWeightsTensor = 10; // Optional
|
|
||||||
constexpr int kCellToOutputWeightsTensor = 11; // Optional
|
|
||||||
|
|
||||||
// Gates bias tensors of size {n_cell}
|
|
||||||
constexpr int kInputGateBiasTensor = 12; // Optional
|
|
||||||
constexpr int kForgetGateBiasTensor = 13;
|
|
||||||
constexpr int kCellGateBiasTensor = 14;
|
|
||||||
constexpr int kOutputGateBiasTensor = 15;
|
|
||||||
|
|
||||||
// Projection weight tensor of size {n_output, n_cell}
|
|
||||||
constexpr int kProjectionWeightsTensor = 16; // Optional
|
|
||||||
// Projection bias tensor of size {n_output}
|
|
||||||
constexpr int kProjectionBiasTensor = 17; // Optional
|
|
||||||
|
|
||||||
// Stateful input tensors that are variables and will be modified by the Op.
|
|
||||||
// Activation state tensor of size {n_batch, n_output}
|
|
||||||
constexpr int kInputActivationStateTensor = 18;
|
|
||||||
// Cell state tensor of size {n_batch, n_cell}
|
|
||||||
constexpr int kInputCellStateTensor = 19;
|
|
||||||
|
|
||||||
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
|
|
||||||
// matrix.
|
|
||||||
constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
|
|
||||||
constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
|
|
||||||
constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
|
|
||||||
constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
|
|
||||||
|
|
||||||
// Output tensors.
|
|
||||||
constexpr int kOutputTensor = 0;
|
|
||||||
|
|
||||||
// Temporary tensors
|
// Temporary tensors
|
||||||
enum TemporaryTensor {
|
enum TemporaryTensor {
|
||||||
kScratchBuffer = 0,
|
kScratchBuffer = 0,
|
||||||
|
@ -122,8 +76,8 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
TF_LITE_ENSURE(context, params->cell_clip >= 0);
|
TF_LITE_ENSURE(context, params->cell_clip >= 0);
|
||||||
TF_LITE_ENSURE(context, params->proj_clip >= 0);
|
TF_LITE_ENSURE(context, params->proj_clip >= 0);
|
||||||
|
|
||||||
const TfLiteTensor* input_to_input_weights =
|
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
context, node, lstm::full::kInputToInputWeightsTensor);
|
||||||
if (input_to_input_weights != nullptr) {
|
if (input_to_input_weights != nullptr) {
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
|
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
|
||||||
|
@ -131,19 +85,19 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
}
|
}
|
||||||
|
|
||||||
const TfLiteTensor* input_to_forget_weights =
|
const TfLiteTensor* input_to_forget_weights =
|
||||||
GetInput(context, node, kInputToForgetWeightsTensor);
|
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
|
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
|
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
|
||||||
|
|
||||||
const TfLiteTensor* input_to_cell_weights =
|
const TfLiteTensor* input_to_cell_weights =
|
||||||
GetInput(context, node, kInputToCellWeightsTensor);
|
GetInput(context, node, lstm::full::kInputToCellWeightsTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
|
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
|
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
|
||||||
|
|
||||||
const TfLiteTensor* recurrent_to_input_weights =
|
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
|
context, node, lstm::full::kRecurrentToInputWeightsTensor);
|
||||||
if (recurrent_to_input_weights != nullptr) {
|
if (recurrent_to_input_weights != nullptr) {
|
||||||
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
|
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
|
||||||
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
|
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
|
||||||
|
@ -153,7 +107,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
}
|
}
|
||||||
|
|
||||||
const TfLiteTensor* recurrent_to_forget_weights =
|
const TfLiteTensor* recurrent_to_forget_weights =
|
||||||
GetInput(context, node, kRecurrentToForgetWeightsTensor);
|
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
|
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
|
||||||
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
|
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
|
||||||
n_cell);
|
n_cell);
|
||||||
|
@ -161,7 +115,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
n_output);
|
n_output);
|
||||||
|
|
||||||
const TfLiteTensor* recurrent_to_cell_weights =
|
const TfLiteTensor* recurrent_to_cell_weights =
|
||||||
GetInput(context, node, kRecurrentToCellWeightsTensor);
|
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
|
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
|
||||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
|
||||||
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
|
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
|
||||||
|
@ -176,22 +130,22 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
(recurrent_to_input_weights == nullptr));
|
(recurrent_to_input_weights == nullptr));
|
||||||
TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
|
TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
|
||||||
|
|
||||||
const TfLiteTensor* cell_to_input_weights =
|
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
|
context, node, lstm::full::kCellToInputWeightsTensor);
|
||||||
if (cell_to_input_weights != nullptr) {
|
if (cell_to_input_weights != nullptr) {
|
||||||
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
|
||||||
}
|
}
|
||||||
|
|
||||||
const TfLiteTensor* cell_to_forget_weights =
|
const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
|
context, node, lstm::full::kCellToForgetWeightsTensor);
|
||||||
if (cell_to_forget_weights != nullptr) {
|
if (cell_to_forget_weights != nullptr) {
|
||||||
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
|
||||||
}
|
}
|
||||||
|
|
||||||
const TfLiteTensor* cell_to_output_weights =
|
const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
|
context, node, lstm::full::kCellToOutputWeightsTensor);
|
||||||
if (cell_to_output_weights != nullptr) {
|
if (cell_to_output_weights != nullptr) {
|
||||||
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
|
||||||
|
@ -210,7 +164,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
|
|
||||||
// Make sure the input gate bias is present only when not a CIFG-LSTM.
|
// Make sure the input gate bias is present only when not a CIFG-LSTM.
|
||||||
const TfLiteTensor* input_gate_bias =
|
const TfLiteTensor* input_gate_bias =
|
||||||
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
|
GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
|
TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
|
||||||
} else {
|
} else {
|
||||||
|
@ -219,21 +173,22 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
}
|
}
|
||||||
|
|
||||||
const TfLiteTensor* forget_gate_bias =
|
const TfLiteTensor* forget_gate_bias =
|
||||||
GetInput(context, node, kForgetGateBiasTensor);
|
GetInput(context, node, lstm::full::kForgetGateBiasTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
|
||||||
|
|
||||||
const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
|
const TfLiteTensor* cell_bias =
|
||||||
|
GetInput(context, node, lstm::full::kCellGateBiasTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
|
||||||
|
|
||||||
const TfLiteTensor* output_gate_bias =
|
const TfLiteTensor* output_gate_bias =
|
||||||
GetInput(context, node, kOutputGateBiasTensor);
|
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
|
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
|
||||||
|
|
||||||
const TfLiteTensor* projection_weights =
|
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
context, node, lstm::full::kProjectionWeightsTensor);
|
||||||
if (projection_weights != nullptr) {
|
if (projection_weights != nullptr) {
|
||||||
TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
|
TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
|
||||||
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
|
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
|
||||||
|
@ -241,7 +196,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
}
|
}
|
||||||
|
|
||||||
const TfLiteTensor* projection_bias =
|
const TfLiteTensor* projection_bias =
|
||||||
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
|
GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
|
||||||
if (projection_bias != nullptr) {
|
if (projection_bias != nullptr) {
|
||||||
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
|
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
|
||||||
|
@ -258,7 +213,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
|
|
||||||
if (is_layer_norm_lstm) {
|
if (is_layer_norm_lstm) {
|
||||||
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
|
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
|
||||||
context, node, kInputLayerNormCoefficientsTensor);
|
context, node, lstm::full::kInputLayerNormCoefficientsTensor);
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
|
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
|
||||||
} else {
|
} else {
|
||||||
|
@ -271,7 +226,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
}
|
}
|
||||||
|
|
||||||
const TfLiteTensor* forget_layer_norm_coefficients =
|
const TfLiteTensor* forget_layer_norm_coefficients =
|
||||||
GetInput(context, node, kForgetLayerNormCoefficientsTensor);
|
GetInput(context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
|
||||||
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
|
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
|
||||||
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
|
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
|
||||||
|
@ -280,7 +235,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
kTfLiteFloat32);
|
kTfLiteFloat32);
|
||||||
|
|
||||||
const TfLiteTensor* cell_layer_norm_coefficients =
|
const TfLiteTensor* cell_layer_norm_coefficients =
|
||||||
GetInput(context, node, kCellLayerNormCoefficientsTensor);
|
GetInput(context, node, lstm::full::kCellLayerNormCoefficientsTensor);
|
||||||
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
|
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
|
||||||
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
|
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
|
||||||
|
@ -289,7 +244,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||||
kTfLiteFloat32);
|
kTfLiteFloat32);
|
||||||
|
|
||||||
const TfLiteTensor* output_layer_norm_coefficients =
|
const TfLiteTensor* output_layer_norm_coefficients =
|
||||||
GetInput(context, node, kOutputLayerNormCoefficientsTensor);
|
GetInput(context, node, lstm::full::kOutputLayerNormCoefficientsTensor);
|
||||||
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
|
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
|
||||||
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
|
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
|
||||||
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
|
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
|
||||||
|
@ -312,7 +267,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
bool is_layer_norm_lstm = false;
|
bool is_layer_norm_lstm = false;
|
||||||
if (node->inputs->size == 24) {
|
if (node->inputs->size == 24) {
|
||||||
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
|
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
|
||||||
context, node, kForgetLayerNormCoefficientsTensor);
|
context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
|
||||||
if (forget_layer_norm_coefficients == nullptr) {
|
if (forget_layer_norm_coefficients == nullptr) {
|
||||||
is_layer_norm_lstm = false;
|
is_layer_norm_lstm = false;
|
||||||
} else {
|
} else {
|
||||||
|
@ -332,7 +287,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
|
||||||
// Inferring batch size, number of outputs and sequence length and
|
// Inferring batch size, number of outputs and sequence length and
|
||||||
// number of cells from the input tensors.
|
// number of cells from the input tensors.
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||||
TF_LITE_ENSURE(context, input->dims->size > 1);
|
TF_LITE_ENSURE(context, input->dims->size > 1);
|
||||||
const auto* params =
|
const auto* params =
|
||||||
|
@ -343,13 +298,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const int n_input = input->dims->data[2];
|
const int n_input = input->dims->data[2];
|
||||||
|
|
||||||
const TfLiteTensor* input_to_output_weights =
|
const TfLiteTensor* input_to_output_weights =
|
||||||
GetInput(context, node, kInputToOutputWeightsTensor);
|
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor);
|
||||||
const int n_cell = input_to_output_weights->dims->data[0];
|
const int n_cell = input_to_output_weights->dims->data[0];
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
|
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
|
||||||
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
|
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
|
||||||
|
|
||||||
const TfLiteTensor* recurrent_to_output_weights =
|
const TfLiteTensor* recurrent_to_output_weights =
|
||||||
GetInput(context, node, kRecurrentToOutputWeightsTensor);
|
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor);
|
||||||
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
|
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
|
||||||
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
|
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
|
||||||
n_cell);
|
n_cell);
|
||||||
|
@ -361,13 +316,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
n_cell, is_layer_norm_lstm));
|
n_cell, is_layer_norm_lstm));
|
||||||
|
|
||||||
// Get the pointer to output, activation_state and cell_state buffer tensors.
|
// Get the pointer to output, activation_state and cell_state buffer tensors.
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
|
||||||
|
|
||||||
TfLiteTensor* activation_state =
|
TfLiteTensor* activation_state =
|
||||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
GetVariableInput(context, node, lstm::full::kInputActivationStateTensor);
|
||||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||||
TfLiteTensor* cell_state =
|
TfLiteTensor* cell_state =
|
||||||
GetVariableInput(context, node, kInputCellStateTensor);
|
GetVariableInput(context, node, lstm::full::kInputCellStateTensor);
|
||||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||||
|
|
||||||
// Check the shape of input state tensors.
|
// Check the shape of input state tensors.
|
||||||
|
@ -395,8 +350,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
scratch_buffer->type = input->type;
|
scratch_buffer->type = input->type;
|
||||||
scratch_buffer->allocation_type = kTfLiteArenaRw;
|
scratch_buffer->allocation_type = kTfLiteArenaRw;
|
||||||
|
|
||||||
const TfLiteTensor* input_to_input_weights =
|
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
context, node, lstm::full::kInputToInputWeightsTensor);
|
||||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||||
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
|
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
|
||||||
scratch_buffer_size->data[0] = n_batch;
|
scratch_buffer_size->data[0] = n_batch;
|
||||||
|
@ -534,8 +489,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
row_sums->type = kTfLiteInt32;
|
row_sums->type = kTfLiteInt32;
|
||||||
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
row_sums->allocation_type = kTfLiteArenaRwPersistent;
|
||||||
int row_sums_rows = use_cifg ? 6 : 8;
|
int row_sums_rows = use_cifg ? 6 : 8;
|
||||||
const TfLiteTensor* projection_weights =
|
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
context, node, lstm::full::kProjectionWeightsTensor);
|
||||||
if (projection_weights != nullptr) {
|
if (projection_weights != nullptr) {
|
||||||
row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
|
row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
|
||||||
}
|
}
|
||||||
|
@ -558,74 +513,79 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
|
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
|
||||||
const bool time_major = params->time_major;
|
const bool time_major = params->time_major;
|
||||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor);
|
||||||
|
|
||||||
const TfLiteTensor* input_to_input_weights =
|
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
context, node, lstm::full::kInputToInputWeightsTensor);
|
||||||
const TfLiteTensor* input_to_forget_weights =
|
const TfLiteTensor* input_to_forget_weights =
|
||||||
GetInput(context, node, kInputToForgetWeightsTensor);
|
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor);
|
||||||
const TfLiteTensor* input_to_cell_weights =
|
const TfLiteTensor* input_to_cell_weights =
|
||||||
GetInput(context, node, kInputToCellWeightsTensor);
|
GetInput(context, node, lstm::full::kInputToCellWeightsTensor);
|
||||||
const TfLiteTensor* input_to_output_weights =
|
const TfLiteTensor* input_to_output_weights =
|
||||||
GetInput(context, node, kInputToOutputWeightsTensor);
|
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor);
|
||||||
|
|
||||||
const TfLiteTensor* recurrent_to_input_weights =
|
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
|
context, node, lstm::full::kRecurrentToInputWeightsTensor);
|
||||||
const TfLiteTensor* recurrent_to_forget_weights =
|
const TfLiteTensor* recurrent_to_forget_weights =
|
||||||
GetInput(context, node, kRecurrentToForgetWeightsTensor);
|
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor);
|
||||||
const TfLiteTensor* recurrent_to_cell_weights =
|
const TfLiteTensor* recurrent_to_cell_weights =
|
||||||
GetInput(context, node, kRecurrentToCellWeightsTensor);
|
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor);
|
||||||
const TfLiteTensor* recurrent_to_output_weights =
|
const TfLiteTensor* recurrent_to_output_weights =
|
||||||
GetInput(context, node, kRecurrentToOutputWeightsTensor);
|
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor);
|
||||||
|
|
||||||
const TfLiteTensor* cell_to_input_weights =
|
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
|
context, node, lstm::full::kCellToInputWeightsTensor);
|
||||||
const TfLiteTensor* cell_to_forget_weights =
|
const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
|
context, node, lstm::full::kCellToForgetWeightsTensor);
|
||||||
const TfLiteTensor* cell_to_output_weights =
|
const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
|
context, node, lstm::full::kCellToOutputWeightsTensor);
|
||||||
|
|
||||||
const TfLiteTensor* input_gate_bias =
|
const TfLiteTensor* input_gate_bias =
|
||||||
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
|
GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
|
||||||
const TfLiteTensor* forget_gate_bias =
|
const TfLiteTensor* forget_gate_bias =
|
||||||
GetInput(context, node, kForgetGateBiasTensor);
|
GetInput(context, node, lstm::full::kForgetGateBiasTensor);
|
||||||
const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
|
const TfLiteTensor* cell_bias =
|
||||||
|
GetInput(context, node, lstm::full::kCellGateBiasTensor);
|
||||||
const TfLiteTensor* output_gate_bias =
|
const TfLiteTensor* output_gate_bias =
|
||||||
GetInput(context, node, kOutputGateBiasTensor);
|
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
|
||||||
|
|
||||||
const TfLiteTensor* projection_weights =
|
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
|
||||||
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
|
context, node, lstm::full::kProjectionWeightsTensor);
|
||||||
const TfLiteTensor* projection_bias =
|
const TfLiteTensor* projection_bias =
|
||||||
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
|
GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
|
||||||
|
|
||||||
// Index the scratch buffers pointers to the global scratch buffer.
|
// Index the scratch buffers pointers to the global scratch buffer.
|
||||||
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
|
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
|
||||||
|
|
||||||
TfLiteTensor* activation_state =
|
TfLiteTensor* activation_state =
|
||||||
GetVariableInput(context, node, kInputActivationStateTensor);
|
GetVariableInput(context, node, lstm::full::kInputActivationStateTensor);
|
||||||
TF_LITE_ENSURE(context, activation_state != nullptr);
|
TF_LITE_ENSURE(context, activation_state != nullptr);
|
||||||
TfLiteTensor* cell_state =
|
TfLiteTensor* cell_state =
|
||||||
GetVariableInput(context, node, kInputCellStateTensor);
|
GetVariableInput(context, node, lstm::full::kInputCellStateTensor);
|
||||||
TF_LITE_ENSURE(context, cell_state != nullptr);
|
TF_LITE_ENSURE(context, cell_state != nullptr);
|
||||||
|
|
||||||
const TfLiteTensor* input_layer_norm_coefficients =
|
const TfLiteTensor* input_layer_norm_coefficients =
|
||||||
is_layer_norm_lstm ? GetOptionalInputTensor(
|
is_layer_norm_lstm
|
||||||
context, node, kInputLayerNormCoefficientsTensor)
|
? GetOptionalInputTensor(
|
||||||
: nullptr;
|
context, node, lstm::full::kInputLayerNormCoefficientsTensor)
|
||||||
|
: nullptr;
|
||||||
const TfLiteTensor* forget_layer_norm_coefficients =
|
const TfLiteTensor* forget_layer_norm_coefficients =
|
||||||
is_layer_norm_lstm
|
is_layer_norm_lstm
|
||||||
? GetInput(context, node, kForgetLayerNormCoefficientsTensor)
|
? GetInput(context, node,
|
||||||
|
lstm::full::kForgetLayerNormCoefficientsTensor)
|
||||||
: nullptr;
|
: nullptr;
|
||||||
const TfLiteTensor* cell_layer_norm_coefficients =
|
const TfLiteTensor* cell_layer_norm_coefficients =
|
||||||
is_layer_norm_lstm
|
is_layer_norm_lstm
|
||||||
? GetInput(context, node, kCellLayerNormCoefficientsTensor)
|
? GetInput(context, node,
|
||||||
|
lstm::full::kCellLayerNormCoefficientsTensor)
|
||||||
: nullptr;
|
: nullptr;
|
||||||
const TfLiteTensor* output_layer_norm_coefficients =
|
const TfLiteTensor* output_layer_norm_coefficients =
|
||||||
is_layer_norm_lstm
|
is_layer_norm_lstm
|
||||||
? GetInput(context, node, kOutputLayerNormCoefficientsTensor)
|
? GetInput(context, node,
|
||||||
|
lstm::full::kOutputLayerNormCoefficientsTensor)
|
||||||
: nullptr;
|
: nullptr;
|
||||||
|
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
|
||||||
|
|
||||||
// Copy out the LSTM specific params so they can be passed in the function.
|
// Copy out the LSTM specific params so they can be passed in the function.
|
||||||
TfLiteLSTMParams lstm_params;
|
TfLiteLSTMParams lstm_params;
|
||||||
|
|
Loading…
Reference in New Issue