Add builtin layer norm lstm.

PiperOrigin-RevId: 226205306
This commit is contained in:
Jian Li 2018-12-19 11:30:28 -08:00 committed by TensorFlower Gardener
parent c93b568c00
commit d2e68e7287
6 changed files with 1038 additions and 74 deletions

View File

@ -886,11 +886,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
fw_cell_to_input_weights, fw_cell_to_forget_weights,
fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
fw_projection_weights, fw_projection_bias, &lstm_params,
fw_cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
/*forget_layer_norm_coefficients=*/nullptr,
/*cell_layer_norm_coefficients=*/nullptr,
/*output_layer_norm_coefficients=*/nullptr, aux_input,
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
&lstm_params,
/*forward_sequence=*/true, time_major, /*output_offset=*/0,
fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
@ -901,11 +906,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
bw_cell_to_input_weights, bw_cell_to_forget_weights,
bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
bw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
bw_projection_weights, bw_projection_bias, &lstm_params,
bw_cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
/*forget_layer_norm_coefficients=*/nullptr,
/*cell_layer_norm_coefficients=*/nullptr,
/*output_layer_norm_coefficients=*/nullptr, aux_input,
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
&lstm_params,
/*forward_sequence=*/false, time_major, bw_output_offset,
bw_scratch_buffer, bw_activation_state, bw_cell_state,
actual_bw_output);
@ -940,11 +950,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
fw_cell_to_input_weights, fw_cell_to_forget_weights,
fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
fw_projection_weights, fw_projection_bias, &lstm_params,
fw_cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
/*forget_layer_norm_coefficients=*/nullptr,
/*cell_layer_norm_coefficients=*/nullptr,
/*output_layer_norm_coefficients=*/nullptr, aux_input,
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
&lstm_params,
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
fw_scratch_buffer, scaling_factors, prod_scaling_factors,
recovered_cell_weights, input_quantized, aux_input_quantized,
@ -958,11 +973,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
bw_cell_to_input_weights, bw_cell_to_forget_weights,
bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
bw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
bw_projection_weights, bw_projection_bias, &lstm_params,
bw_cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
/*forget_layer_norm_coefficients=*/nullptr,
/*cell_layer_norm_coefficients=*/nullptr,
/*output_layer_norm_coefficients=*/nullptr, aux_input,
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
&lstm_params,
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
bw_scratch_buffer, scaling_factors, prod_scaling_factors,
recovered_cell_weights, input_quantized, aux_input_quantized,

View File

@ -38,17 +38,24 @@ namespace builtin {
namespace lstm {
struct OpData {
// Which kernel type to use. Full kernel (20 inputs) or basic kernel
// (5 inputs).
// Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
// inputs).
// Please note the 20-input full kernel is deprecated and only kept
// here for backward compatibility.
TfLiteLSTMKernelType kernel_type;
// If the lstm is layer norm.
bool is_layer_norm_lstm;
// These fields are only used by full kernel.
int activation_state_tensor_index;
int cell_state_tensor_index;
int scratch_tensor_index;
};
// For full inputs kernel (20-inputs).
// For full inputs kernel (24-inputs).
// Please note the 20-input full kernel is deprecated and only kept
// here for backward compatibility.
namespace full {
// Input Tensors of size {n_batch, n_input}
@ -87,6 +94,13 @@ constexpr int kProjectionBiasTensor = 17; // Optional
constexpr int kInputActivationStateTensor = 18;
constexpr int kInputCellStateTensor = 19;
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
// matrix.
constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
// Output tensors.
constexpr int kOutputTensor = 0;
@ -101,7 +115,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
int n_output, int n_cell,
bool is_layer_norm_lstm) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
@ -112,7 +127,8 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
if (input_to_input_weights != nullptr) {
const bool use_cifg = (input_to_input_weights == nullptr);
if (!use_cifg) {
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
@ -186,7 +202,6 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
}
// Making sure the peephole weights are there all or none.
const bool use_cifg = (input_to_input_weights == nullptr);
const bool peephole_weights_all_or_none =
((cell_to_input_weights != nullptr || use_cifg) &&
(cell_to_forget_weights != nullptr) &&
@ -244,6 +259,40 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
((projection_weights != nullptr) || (projection_bias == nullptr));
TF_LITE_ENSURE(context, projection_tensors_consistent == true);
if (is_layer_norm_lstm) {
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
context, node, kInputLayerNormCoefficientsTensor);
if (use_cifg) {
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
} else {
TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
n_cell);
}
const TfLiteTensor* forget_layer_norm_coefficients =
GetInput(context, node, kForgetLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
n_cell);
const TfLiteTensor* cell_layer_norm_coefficients =
GetInput(context, node, kCellLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
n_cell);
const TfLiteTensor* output_layer_norm_coefficients =
GetInput(context, node, kOutputLayerNormCoefficientsTensor);
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
n_cell);
}
return kTfLiteOk;
}
@ -254,8 +303,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
// Logic for determining regular lstm and layer norm lstm:
// input_size, forget_gate_layer_norm_tensor (20) null? is_layer_norm?
// 20, N/A, No.
// 24, null, No.
// 24, not null, Yes.
// 20-inputs lstm are deprecated and is only kept here for backward
// compatibility.
if (node->inputs->size == 24) {
const TfLiteTensor* forget_layer_norm_coefficients =
GetInput(context, node, kForgetLayerNormCoefficientsTensor);
if (forget_layer_norm_coefficients == nullptr) {
op_data->is_layer_norm_lstm = false;
} else {
op_data->is_layer_norm_lstm = true;
}
} else if (node->inputs->size == 20) {
// This is deprecated and is only kept here for backward compatibility.
op_data->is_layer_norm_lstm = false;
} else {
context->ReportError(
context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
node->inputs->size);
return kTfLiteError;
}
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
op_data->activation_state_tensor_index =
node->inputs->data[kInputActivationStateTensor];
op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor];
@ -282,8 +355,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
n_output, n_cell));
TF_LITE_ENSURE_OK(context,
CheckInputTensorDimensions(context, node, n_input, n_output,
n_cell, is_layer_norm_lstm));
// Get the pointer to output, activation_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
@ -430,6 +504,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@ -458,6 +533,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* cell_to_output_weights =
GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
const TfLiteTensor* input_layer_norm_coefficients =
is_layer_norm_lstm ? GetOptionalInputTensor(
context, node, kInputLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* forget_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node, kForgetLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* cell_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node, kCellLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* output_layer_norm_coefficients =
is_layer_norm_lstm
? GetInput(context, node, kOutputLayerNormCoefficientsTensor)
: nullptr;
const TfLiteTensor* input_gate_bias =
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
const TfLiteTensor* forget_gate_bias =
@ -490,6 +582,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
input_layer_norm_coefficients, forget_layer_norm_coefficients,
cell_layer_norm_coefficients, output_layer_norm_coefficients,
/*aux_input=*/nullptr,
/*aux_input_to_input_weights=*/nullptr,
/*aux_input_to_forget_weights=*/nullptr,
@ -518,6 +612,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
input_layer_norm_coefficients, forget_layer_norm_coefficients,
cell_layer_norm_coefficients, output_layer_norm_coefficients,
/*aux_input=*/nullptr,
/*aux_input_to_input_weights=*/nullptr,
/*aux_input_to_forget_weights=*/nullptr,

View File

@ -27,6 +27,10 @@ namespace lstm_eval {
namespace {
// Small float to avoid divergence during calculation of deviation for layer
// norm lstm.
const float kLayerNormEpsilon = 1e-8;
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
@ -65,30 +69,47 @@ inline void LstmStepWithAuxInput(
const float* recurrent_to_output_weights_ptr,
const float* cell_to_input_weights_ptr,
const float* cell_to_forget_weights_ptr,
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
int output_batch_leading_dim, float* output_state_ptr,
float* cell_state_ptr, float* input_gate_scratch,
const float* cell_to_output_weights_ptr,
const float* input_layer_norm_coefficients_ptr,
const float* forget_layer_norm_coefficients_ptr,
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
const float* projection_weights_ptr, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch) {
// Since we have already checked that weights are all there or none, we can
// check the existense of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
// Initialize scratch buffers with bias.
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
const bool is_layer_norm_lstm =
(forget_layer_norm_coefficients_ptr != nullptr);
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (is_layer_norm_lstm) {
if (!use_cifg) {
tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
}
tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
} else {
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
n_batch, input_gate_scratch);
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
// For each batch and cell: compute input_weight * input.
if (!use_cifg) {
@ -152,6 +173,16 @@ inline void LstmStepWithAuxInput(
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(input_gate_scratch,
input_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
n_batch, input_gate_scratch);
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
}
@ -162,12 +193,31 @@ inline void LstmStepWithAuxInput(
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
n_batch, forget_gate_scratch);
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch, kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
}
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params->activation, cell_scratch);
if (use_cifg) {
@ -190,6 +240,16 @@ inline void LstmStepWithAuxInput(
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
n_batch, output_gate_scratch);
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
@ -344,33 +404,50 @@ inline void LstmStepWithAuxInput(
const int8_t* cell_to_forget_weights_ptr,
float cell_to_forget_weights_scale,
const int8_t* cell_to_output_weights_ptr,
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
float projection_weights_scale, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
float* output_gate_scratch, float* scaling_factors,
float* product_scaling_factors, float* recovered_cell_weights,
int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch,
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch) {
float cell_to_output_weights_scale,
const float* input_layer_norm_coefficients_ptr,
const float* forget_layer_norm_coefficients_ptr,
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
const int8_t* projection_weights_ptr, float projection_weights_scale,
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
int output_batch_leading_dim, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* scaling_factors, float* product_scaling_factors,
float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr_batch) {
// Since we have already checked that weights are all there or none, we
// can check the existense of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool is_layer_norm_lstm =
(forget_layer_norm_coefficients_ptr != nullptr);
// Initialize scratch buffers with bias.
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
if (is_layer_norm_lstm) {
if (!use_cifg) {
tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
}
tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
} else {
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
n_batch, input_gate_scratch);
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
// Save quantization and matmul computation for all zero input.
@ -535,6 +612,16 @@ inline void LstmStepWithAuxInput(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(input_gate_scratch,
input_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
n_batch, input_gate_scratch);
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
}
@ -548,12 +635,31 @@ inline void LstmStepWithAuxInput(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
n_batch, forget_gate_scratch);
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch, kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
}
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params->activation, cell_scratch);
if (use_cifg) {
@ -581,6 +687,16 @@ inline void LstmStepWithAuxInput(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
n_batch, output_gate_scratch);
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
@ -702,7 +818,12 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
@ -735,6 +856,7 @@ TfLiteStatus EvalFloat(
// check the existense of only one to the get the condition.
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
float* input_gate_scratch = nullptr;
@ -765,6 +887,15 @@ TfLiteStatus EvalFloat(
(use_peephole) ? cell_to_forget_weights->data.f : nullptr;
const float* cell_to_output_weights_ptr =
(use_peephole) ? cell_to_output_weights->data.f : nullptr;
const float* input_layer_norm_coefficients_ptr =
(is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
: nullptr;
const float* forget_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
const float* cell_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
const float* output_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
const float* projection_weights_ptr =
(projection_weights == nullptr) ? nullptr : projection_weights->data.f;
const float* projection_bias_ptr =
@ -811,6 +942,8 @@ TfLiteStatus EvalFloat(
recurrent_to_cell_weights->data.f,
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
params, n_batch, n_cell, n_input, aux_input_size, n_output,
@ -855,7 +988,11 @@ TfLiteStatus EvalFloat(
recurrent_to_cell_weights->data.f,
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
input_layer_norm_coefficients_ptr,
forget_layer_norm_coefficients_ptr,
cell_layer_norm_coefficients_ptr,
output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
forget_gate_bias->data.f, cell_bias->data.f,
output_gate_bias->data.f, projection_weights_ptr,
projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input,
aux_input_size, n_output, output_batch_leading_dim,
@ -879,7 +1016,12 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
@ -914,6 +1056,7 @@ TfLiteStatus EvalHybrid(
// check the existence of only one to get the condition.
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
@ -966,6 +1109,16 @@ TfLiteStatus EvalHybrid(
cell_to_output_weights_scale = cell_to_output_weights->params.scale;
}
const float* input_layer_norm_coefficients_ptr =
(is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
: nullptr;
const float* forget_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
const float* cell_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
const float* output_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
const int8_t* projection_weights_ptr =
(projection_weights == nullptr)
? nullptr
@ -1084,6 +1237,8 @@ TfLiteStatus EvalHybrid(
cell_to_input_weights_ptr, cell_to_input_weights_scale,
cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
cell_to_output_weights_ptr, cell_to_output_weights_scale,
input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
output_gate_bias_ptr, projection_weights_ptr,
projection_weights_scale, projection_bias_ptr, params, n_batch,
@ -1138,7 +1293,10 @@ TfLiteStatus EvalHybrid(
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
cell_to_output_weights_scale, input_gate_bias_ptr,
cell_to_output_weights_scale, input_layer_norm_coefficients_ptr,
forget_layer_norm_coefficients_ptr,
cell_layer_norm_coefficients_ptr,
output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
projection_weights_ptr, projection_weights_scale,
projection_bias_ptr, params,

View File

@ -34,7 +34,12 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
@ -58,7 +63,12 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,

View File

@ -38,7 +38,8 @@ class LSTMOpModel : public SingleOpModel {
bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes,
const TensorType& weight_type = TensorType_FLOAT32)
const TensorType& weight_type = TensorType_FLOAT32,
bool is_layer_norm = false)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
@ -106,6 +107,18 @@ class LSTMOpModel : public SingleOpModel {
input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
// Layer norm weights.
if (is_layer_norm) {
if (use_cifg) {
input_layer_norm_coefficients_ = AddNullInput();
} else {
input_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32);
}
forget_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32);
cell_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32);
output_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32);
}
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
@ -160,6 +173,22 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(cell_to_output_weights_, f);
}
void SetInputLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(input_layer_norm_coefficients_, f);
}
void SetForgetLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(forget_layer_norm_coefficients_, f);
}
void SetCellLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(cell_layer_norm_coefficients_, f);
}
void SetOutputLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(output_layer_norm_coefficients_, f);
}
void SetInputGateBias(std::vector<float> f) {
PopulateTensor(input_gate_bias_, f);
}
@ -210,6 +239,11 @@ class LSTMOpModel : public SingleOpModel {
int cell_to_forget_weights_;
int cell_to_output_weights_;
int input_layer_norm_coefficients_;
int forget_layer_norm_coefficients_;
int cell_layer_norm_coefficients_;
int output_layer_norm_coefficients_;
int input_gate_bias_;
int forget_gate_bias_;
int cell_bias_;
@ -1392,6 +1426,644 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
class LayerNormLSTMOpModel : public LSTMOpModel {
public:
LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
bool use_cifg, bool use_peephole,
bool use_projection_weights, bool use_projection_bias,
float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes,
const TensorType& weight_type = TensorType_FLOAT32)
: LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole,
use_projection_weights, use_projection_bias, cell_clip,
proj_clip, input_shapes, weight_type,
/*is_layer_norm*/ true) {}
};
class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel {
public:
HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
bool use_cifg, bool use_peephole,
bool use_projection_weights,
bool use_projection_bias, float cell_clip,
float proj_clip,
const std::vector<std::vector<int>>& input_shapes)
: LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg,
use_peephole, use_projection_weights,
use_projection_bias, cell_clip, proj_clip,
input_shapes, TensorType_UINT8) {}
void SetInputToInputWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
}
void SetInputToForgetWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
}
void SetInputToCellWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
}
void SetInputToOutputWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
}
void SetRecurrentToInputWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
}
void SetRecurrentToForgetWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
}
void SetRecurrentToCellWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
}
void SetRecurrentToOutputWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
}
void SetCellToInputWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
}
void SetCellToForgetWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
}
void SetCellToOutputWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
}
void SetInputLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(input_layer_norm_coefficients_, f);
}
void SetForgetLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(forget_layer_norm_coefficients_, f);
}
void SetCellLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(cell_layer_norm_coefficients_, f);
}
void SetOutputLayerNormCoefficients(std::vector<float> f) {
PopulateTensor(output_layer_norm_coefficients_, f);
}
void SetProjectionWeights(std::vector<float> f) {
SymmetricQuantizeAndPopulate(projection_weights_, f);
}
};
class BaseLayerNormLstmTest : public ::testing::Test {
protected:
// Weights of the Layer Norm LSTM model. Some are optional.
std::vector<float> input_to_input_weights_;
std::vector<float> input_to_cell_weights_;
std::vector<float> input_to_forget_weights_;
std::vector<float> input_to_output_weights_;
std::vector<float> input_gate_bias_;
std::vector<float> cell_gate_bias_;
std::vector<float> forget_gate_bias_;
std::vector<float> output_gate_bias_;
std::vector<float> recurrent_to_input_weights_;
std::vector<float> recurrent_to_cell_weights_;
std::vector<float> recurrent_to_forget_weights_;
std::vector<float> recurrent_to_output_weights_;
std::vector<float> cell_to_input_weights_;
std::vector<float> cell_to_forget_weights_;
std::vector<float> cell_to_output_weights_;
std::vector<float> projection_weights_;
std::vector<float> input_layer_norm_coefficients_;
std::vector<float> forget_layer_norm_coefficients_;
std::vector<float> cell_layer_norm_coefficients_;
std::vector<float> output_layer_norm_coefficients_;
// Layer Norm LSTM input is stored as num_batch x num_inputs vector.
std::vector<std::vector<float>> layer_norm_lstm_input_;
// Compares output up to tolerance to the result of the layer_norm_lstm given
// the input.
void VerifyGoldens(const std::vector<std::vector<float>>& input,
const std::vector<std::vector<float>>& output,
LayerNormLSTMOpModel* layer_norm_lstm,
float tolerance = 1e-5) {
const int num_batches = input.size();
EXPECT_GT(num_batches, 0);
const int num_inputs = layer_norm_lstm->num_inputs();
EXPECT_GT(num_inputs, 0);
const int input_sequence_size = input[0].size() / num_inputs;
EXPECT_GT(input_sequence_size, 0);
for (int i = 0; i < input_sequence_size; ++i) {
for (int b = 0; b < num_batches; ++b) {
const float* batch_start = input[b].data() + i * num_inputs;
const float* batch_end = batch_start + num_inputs;
layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
batch_start, batch_end);
}
layer_norm_lstm->Invoke();
const int num_outputs = layer_norm_lstm->num_outputs();
std::vector<float> expected;
for (int b = 0; b < num_batches; ++b) {
const float* golden_start_batch = output[b].data() + i * num_outputs;
const float* golden_end_batch = golden_start_batch + num_outputs;
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
}
EXPECT_THAT(layer_norm_lstm->GetOutput(),
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
}
}
};
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
: public BaseLayerNormLstmTest {
void SetUp() override {
input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
-0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
-0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
-0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
-0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
-0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
-0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
-0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9,
-0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
-0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
-0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
layer_norm_lstm_input_ = {
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
{// Batch1: 3 (input_sequence_size) * 5 (n_input)
0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
};
}
};
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
LayerNormLstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 4;
const int n_output = 3;
const float ceil_clip = 0.0;
const float proj_clip = 0.0;
LayerNormLSTMOpModel layer_norm_lstm(
n_batch, n_input, n_cell, n_output,
/*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false, ceil_clip, proj_clip,
{
{n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{n_cell}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
{n_cell}, // input_layer_norm_coefficient tensor
{n_cell}, // forget_layer_norm_coefficient tensor
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
});
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
layer_norm_lstm.SetCellBias(cell_gate_bias_);
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
layer_norm_lstm.SetForgetLayerNormCoefficients(
forget_layer_norm_coefficients_);
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
layer_norm_lstm.SetOutputLayerNormCoefficients(
output_layer_norm_coefficients_);
layer_norm_lstm.SetProjectionWeights(projection_weights_);
// Verify the final output.
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
{
// Batch0: 3 (input_sequence_size) * 3 (n_output)
0.0244077, 0.128027, -0.00170918, // seq 0
0.0137642, 0.140751, 0.0395835, // seq 1
-0.00459231, 0.155278, 0.0837377, // seq 2
},
{
// Batch1: 3 (input_sequence_size) * 3 (n_output)
-0.00692428, 0.0848741, 0.063445, // seq 0
-0.00403912, 0.139963, 0.072681, // seq 1
0.00752706, 0.161903, 0.0561371, // seq 2
}};
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
&layer_norm_lstm);
}
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
HybridLayerNormLstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 4;
const int n_output = 3;
const float ceil_clip = 0.0;
const float proj_clip = 0.0;
HybridLayerNormLSTMOpModel layer_norm_lstm(
n_batch, n_input, n_cell, n_output,
/*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false, ceil_clip, proj_clip,
{
{n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{n_cell}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
{n_cell}, // input_layer_norm_coefficient tensor
{n_cell}, // forget_layer_norm_coefficient tensor
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
});
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
layer_norm_lstm.SetCellBias(cell_gate_bias_);
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
layer_norm_lstm.SetForgetLayerNormCoefficients(
forget_layer_norm_coefficients_);
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
layer_norm_lstm.SetOutputLayerNormCoefficients(
output_layer_norm_coefficients_);
layer_norm_lstm.SetProjectionWeights(projection_weights_);
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
{
// Batch0: 3 (input_sequence_size) * 3 (n_output)
0.0244576, 0.127847, -0.00181765, // seq 0
0.0137518, 0.140892, 0.0402234, // seq 1
-0.0048839, 0.155096, 0.0840309, // seq 2
},
{
// Batch1: 3 (input_sequence_size) * 3 (n_output)
-0.00728636, 0.0843957, 0.0634786, // seq 0
-0.00448382, 0.139278, 0.0737372, // seq 1
0.00734616, 0.161793, 0.0560238, // seq 2
}};
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
&layer_norm_lstm);
}
class CifgPeepholeProjectionNoClippingLayerNormLstmTest
: public BaseLayerNormLstmTest {
void SetUp() override {
input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
-0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
-0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
-0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
-0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
-0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
-0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
-0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
-0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
layer_norm_lstm_input_ = {
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
{// Batch1: 3 (input_sequence_size) * 5 (n_input)
0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
};
}
};
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
LayerNormLstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 4;
const int n_output = 3;
const float ceil_clip = 0.0;
const float proj_clip = 0.0;
LayerNormLSTMOpModel layer_norm_lstm(
n_batch, n_input, n_cell, n_output,
/*use_cifg=*/true, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false, ceil_clip, proj_clip,
{
{n_batch, n_input}, // input tensor
{0, 0}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{0, 0}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{0}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
{n_cell}, // forget_layer_norm_coefficient tensor
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
});
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
layer_norm_lstm.SetCellBias(cell_gate_bias_);
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
layer_norm_lstm.SetForgetLayerNormCoefficients(
forget_layer_norm_coefficients_);
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
layer_norm_lstm.SetOutputLayerNormCoefficients(
output_layer_norm_coefficients_);
layer_norm_lstm.SetProjectionWeights(projection_weights_);
// Verify the final output.
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
{
// Batch0: 3 (input_sequence_size) * 3 (n_output)
0.02129706, 0.140816242, 0.0112733059, // seq 0
0.0132302344, 0.152308047, 0.0346313119, // seq 1
-0.0123688057, 0.165790111, 0.0893077999, // seq 2
},
{
// Batch1: 3 (input_sequence_size) * 3 (n_output)
-0.0226350538, 0.0916948169, 0.0769175813, // seq 0
-0.0269966982, 0.149707705, 0.094149217, // seq 1
-0.0103429332, 0.173016444, 0.0720508844, // seq 2
}};
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
&layer_norm_lstm);
}
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
HybridLayerNormLstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 4;
const int n_output = 3;
const float ceil_clip = 0.0;
const float proj_clip = 0.0;
HybridLayerNormLSTMOpModel layer_norm_lstm(
n_batch, n_input, n_cell, n_output,
/*use_cifg=*/true, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false, ceil_clip, proj_clip,
{
{n_batch, n_input}, // input tensor
{0, 0}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{0, 0}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{0}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{0}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // activation_state tensor
{n_batch, n_cell}, // cell_state tensor
{0}, // input_layer_norm_coefficient tensor
{n_cell}, // forget_layer_norm_coefficient tensor
{n_cell}, // cell_layer_norm_coefficient tensor
{n_cell}, // output_layer_norm_coefficient tensor
});
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
layer_norm_lstm.SetCellBias(cell_gate_bias_);
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
layer_norm_lstm.SetForgetLayerNormCoefficients(
forget_layer_norm_coefficients_);
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
layer_norm_lstm.SetOutputLayerNormCoefficients(
output_layer_norm_coefficients_);
layer_norm_lstm.SetProjectionWeights(projection_weights_);
// Verify the final output.
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
{
// Batch0: 3 (input_sequence_size) * 3 (n_output)
0.0212250091, 0.140474007, 0.0115012666, // seq 0
0.0130806509, 0.152660668, 0.0347516984, // seq 1
-0.0124010444, 0.166042402, 0.0898982584, // seq 2
},
{
// Batch1: 3 (input_sequence_size) * 3 (n_output)
-0.0228835996, 0.0917588323, 0.0778886303, // seq 0
-0.0275101066, 0.148769245, 0.0938384682, // seq 1
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
}};
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
&layer_norm_lstm);
}
} // namespace
} // namespace tflite

View File

@ -502,6 +502,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
/*forget_layer_norm_coefficients=*/nullptr,
/*cell_layer_norm_coefficients=*/nullptr,
/*output_layer_norm_coefficients=*/nullptr,
/*aux_input=*/nullptr,
/*aux_input_to_input_weights=*/nullptr,
/*aux_input_to_forget_weights=*/nullptr,
@ -529,6 +533,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
/*input_layer_norm_coefficients=*/nullptr,
/*forget_layer_norm_coefficients=*/nullptr,
/*cell_layer_norm_coefficients=*/nullptr,
/*output_layer_norm_coefficients=*/nullptr,
/*aux_input=*/nullptr,
/*aux_input_to_input_weights=*/nullptr,
/*aux_input_to_forget_weights=*/nullptr,