Unidirectional sequence LSTM: Use constants for tensor indices in lstm_shared.h instead of redefining them.

PiperOrigin-RevId: 316573775
Change-Id: I2de5264e4ab3375344dfe3132b6cde7b6cb0226c
This commit is contained in:
Robert David 2020-06-15 16:58:43 -07:00 committed by TensorFlower Gardener
parent 5bc13320c9
commit 6f0425cd06
1 changed files with 76 additions and 116 deletions

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/lstm_eval.h"
#include "tensorflow/lite/kernels/lstm_shared.h"
namespace tflite {
namespace ops {
@ -36,53 +37,6 @@ struct OpData {
bool compute_row_sums = false;
};
// Input Tensors of size {max_time, n_batch, n_input}
constexpr int kInputTensor = 0;
// Input weight tensors of size: {n_cell, n_input}
constexpr int kInputToInputWeightsTensor = 1; // Optional
constexpr int kInputToForgetWeightsTensor = 2;
constexpr int kInputToCellWeightsTensor = 3;
constexpr int kInputToOutputWeightsTensor = 4;
// Recurrent weight tensors of size {n_cell, n_output}
constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
constexpr int kRecurrentToForgetWeightsTensor = 6;
constexpr int kRecurrentToCellWeightsTensor = 7;
constexpr int kRecurrentToOutputWeightsTensor = 8;
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
constexpr int kCellToInputWeightsTensor = 9; // Optional
constexpr int kCellToForgetWeightsTensor = 10; // Optional
constexpr int kCellToOutputWeightsTensor = 11; // Optional
// Gates bias tensors of size {n_cell}
constexpr int kInputGateBiasTensor = 12; // Optional
constexpr int kForgetGateBiasTensor = 13;
constexpr int kCellGateBiasTensor = 14;
constexpr int kOutputGateBiasTensor = 15;
// Projection weight tensor of size {n_output, n_cell}
constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
// Stateful input tensors that are variables and will be modified by the Op.
// Activation state tensor of size {n_batch, n_output}
constexpr int kInputActivationStateTensor = 18;
// Cell state tensor of size {n_batch, n_cell}
constexpr int kInputCellStateTensor = 19;
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
// matrix.
constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
// Output tensors.
constexpr int kOutputTensor = 0;
// Temporary tensors
enum TemporaryTensor {
kScratchBuffer = 0,
@ -122,8 +76,8 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TF_LITE_ENSURE(context, params->cell_clip >= 0);
TF_LITE_ENSURE(context, params->proj_clip >= 0);
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kInputToInputWeightsTensor);
if (input_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
@ -131,19 +85,19 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
}
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
GetInput(context, node, lstm::full::kInputToCellWeightsTensor);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kRecurrentToInputWeightsTensor);
if (recurrent_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
@ -153,7 +107,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
}
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
n_cell);
@ -161,7 +115,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
n_output);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
@ -176,22 +130,22 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
(recurrent_to_input_weights == nullptr));
TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
const TfLiteTensor* cell_to_input_weights =
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToInputWeightsTensor);
if (cell_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
}
const TfLiteTensor* cell_to_forget_weights =
GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToForgetWeightsTensor);
if (cell_to_forget_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
}
const TfLiteTensor* cell_to_output_weights =
GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToOutputWeightsTensor);
if (cell_to_output_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
@ -210,7 +164,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Make sure the input gate bias is present only when not a CIFG-LSTM.
const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
if (use_cifg) {
TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
} else {
@ -219,21 +173,22 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
}
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
GetInput(context, node, lstm::full::kForgetGateBiasTensor);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
const TfLiteTensor* cell_bias =
GetInput(context, node, lstm::full::kCellGateBiasTensor);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, lstm::full::kProjectionWeightsTensor);
if (projection_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
@ -241,7 +196,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
}
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
if (projection_bias != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
@ -258,7 +213,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
if (is_layer_norm_lstm) {
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
context, node, kInputLayerNormCoefficientsTensor);
context, node, lstm::full::kInputLayerNormCoefficientsTensor);
if (use_cifg) {
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
} else {
@ -271,7 +226,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
}
const TfLiteTensor* forget_layer_norm_coefficients =
GetInput(context, node, kForgetLayerNormCoefficientsTensor);
GetInput(context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
@ -280,7 +235,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
kTfLiteFloat32);
const TfLiteTensor* cell_layer_norm_coefficients =
GetInput(context, node, kCellLayerNormCoefficientsTensor);
GetInput(context, node, lstm::full::kCellLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
@ -289,7 +244,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
kTfLiteFloat32);
const TfLiteTensor* output_layer_norm_coefficients =
GetInput(context, node, kOutputLayerNormCoefficientsTensor);
GetInput(context, node, lstm::full::kOutputLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
@ -312,7 +267,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bool is_layer_norm_lstm = false;
if (node->inputs->size == 24) {
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
context, node, kForgetLayerNormCoefficientsTensor);
context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
if (forget_layer_norm_coefficients == nullptr) {
is_layer_norm_lstm = false;
} else {
@ -332,7 +287,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1);
const auto* params =
@ -343,13 +298,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_input = input->dims->data[2];
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor);
const int n_cell = input_to_output_weights->dims->data[0];
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
n_cell);
@ -361,13 +316,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
n_cell, is_layer_norm_lstm));
// Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
TfLiteTensor* activation_state =
GetVariableInput(context, node, kInputActivationStateTensor);
GetVariableInput(context, node, lstm::full::kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* cell_state =
GetVariableInput(context, node, kInputCellStateTensor);
GetVariableInput(context, node, lstm::full::kInputCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
// Check the shape of input state tensors.
@ -395,8 +350,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kInputToInputWeightsTensor);
const bool use_cifg = (input_to_input_weights == nullptr);
TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
scratch_buffer_size->data[0] = n_batch;
@ -534,8 +489,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
row_sums->type = kTfLiteInt32;
row_sums->allocation_type = kTfLiteArenaRwPersistent;
int row_sums_rows = use_cifg ? 6 : 8;
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, lstm::full::kProjectionWeightsTensor);
if (projection_weights != nullptr) {
row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
}
@ -558,74 +513,79 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
const bool time_major = params->time_major;
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input = GetInput(context, node, lstm::full::kInputTensor);
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kInputToInputWeightsTensor);
const TfLiteTensor* input_to_forget_weights =
GetInput(context, node, kInputToForgetWeightsTensor);
GetInput(context, node, lstm::full::kInputToForgetWeightsTensor);
const TfLiteTensor* input_to_cell_weights =
GetInput(context, node, kInputToCellWeightsTensor);
GetInput(context, node, lstm::full::kInputToCellWeightsTensor);
const TfLiteTensor* input_to_output_weights =
GetInput(context, node, kInputToOutputWeightsTensor);
GetInput(context, node, lstm::full::kInputToOutputWeightsTensor);
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kRecurrentToInputWeightsTensor);
const TfLiteTensor* recurrent_to_forget_weights =
GetInput(context, node, kRecurrentToForgetWeightsTensor);
GetInput(context, node, lstm::full::kRecurrentToForgetWeightsTensor);
const TfLiteTensor* recurrent_to_cell_weights =
GetInput(context, node, kRecurrentToCellWeightsTensor);
GetInput(context, node, lstm::full::kRecurrentToCellWeightsTensor);
const TfLiteTensor* recurrent_to_output_weights =
GetInput(context, node, kRecurrentToOutputWeightsTensor);
GetInput(context, node, lstm::full::kRecurrentToOutputWeightsTensor);
const TfLiteTensor* cell_to_input_weights =
GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
const TfLiteTensor* cell_to_forget_weights =
GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
const TfLiteTensor* cell_to_output_weights =
GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToInputWeightsTensor);
const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToForgetWeightsTensor);
const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
context, node, lstm::full::kCellToOutputWeightsTensor);
const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
GetInput(context, node, kForgetGateBiasTensor);
const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
GetInput(context, node, lstm::full::kForgetGateBiasTensor);
const TfLiteTensor* cell_bias =
GetInput(context, node, lstm::full::kCellGateBiasTensor);
const TfLiteTensor* output_gate_bias =
GetInput(context, node, kOutputGateBiasTensor);
GetInput(context, node, lstm::full::kOutputGateBiasTensor);
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
const TfLiteTensor* projection_weights = GetOptionalInputTensor(
context, node, lstm::full::kProjectionWeightsTensor);
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* activation_state =
GetVariableInput(context, node, kInputActivationStateTensor);
GetVariableInput(context, node, lstm::full::kInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);
TfLiteTensor* cell_state =
GetVariableInput(context, node, kInputCellStateTensor);
GetVariableInput(context, node, lstm::full::kInputCellStateTensor);
TF_LITE_ENSURE(context, cell_state != nullptr);
const TfLiteTensor* input_layer_norm_coefficients =
is_layer_norm_lstm ? GetOptionalInputTensor(
context, node, kInputLayerNormCoefficientsTensor)
: nullptr;
is_layer_norm_lstm
? GetOptionalInputTensor(
context, node, lstm::full::kInputLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* forget_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node, kForgetLayerNormCoefficientsTensor)
? GetInput(context, node,
lstm::full::kForgetLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* cell_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node, kCellLayerNormCoefficientsTensor)
? GetInput(context, node,
lstm::full::kCellLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* output_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node, kOutputLayerNormCoefficientsTensor)
? GetInput(context, node,
lstm::full::kOutputLayerNormCoefficientsTensor)
: nullptr;
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* output = GetOutput(context, node, lstm::full::kOutputTensor);
// Copy out the LSTM specific params so they can be passed in the function.
TfLiteLSTMParams lstm_params;