Add builtin layer norm lstm.
PiperOrigin-RevId: 226205306
This commit is contained in:
parent
c93b568c00
commit
d2e68e7287
@ -886,11 +886,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
|
||||
fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
|
||||
fw_cell_to_input_weights, fw_cell_to_forget_weights,
|
||||
fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
|
||||
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
|
||||
fw_aux_input_to_output_weights, fw_input_gate_bias,
|
||||
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
|
||||
fw_projection_weights, fw_projection_bias, &lstm_params,
|
||||
fw_cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr, aux_input,
|
||||
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
|
||||
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
|
||||
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
|
||||
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
|
||||
&lstm_params,
|
||||
/*forward_sequence=*/true, time_major, /*output_offset=*/0,
|
||||
fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output);
|
||||
TF_LITE_ENSURE_OK(context, fw_pass_status);
|
||||
@ -901,11 +906,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
||||
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
|
||||
bw_cell_to_input_weights, bw_cell_to_forget_weights,
|
||||
bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
|
||||
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
|
||||
bw_aux_input_to_output_weights, bw_input_gate_bias,
|
||||
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
|
||||
bw_projection_weights, bw_projection_bias, &lstm_params,
|
||||
bw_cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr, aux_input,
|
||||
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
|
||||
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
|
||||
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
|
||||
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
|
||||
&lstm_params,
|
||||
/*forward_sequence=*/false, time_major, bw_output_offset,
|
||||
bw_scratch_buffer, bw_activation_state, bw_cell_state,
|
||||
actual_bw_output);
|
||||
@ -940,11 +950,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
|
||||
fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
|
||||
fw_cell_to_input_weights, fw_cell_to_forget_weights,
|
||||
fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
|
||||
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
|
||||
fw_aux_input_to_output_weights, fw_input_gate_bias,
|
||||
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
|
||||
fw_projection_weights, fw_projection_bias, &lstm_params,
|
||||
fw_cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr, aux_input,
|
||||
fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
|
||||
fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
|
||||
fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias,
|
||||
fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
|
||||
&lstm_params,
|
||||
/*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
|
||||
fw_scratch_buffer, scaling_factors, prod_scaling_factors,
|
||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||
@ -958,11 +973,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
|
||||
bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
|
||||
bw_cell_to_input_weights, bw_cell_to_forget_weights,
|
||||
bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
|
||||
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
|
||||
bw_aux_input_to_output_weights, bw_input_gate_bias,
|
||||
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
|
||||
bw_projection_weights, bw_projection_bias, &lstm_params,
|
||||
bw_cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr, aux_input,
|
||||
bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
|
||||
bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
|
||||
bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias,
|
||||
bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
|
||||
&lstm_params,
|
||||
/*forward_sequence=*/false, /*time_major=*/true, bw_output_offset,
|
||||
bw_scratch_buffer, scaling_factors, prod_scaling_factors,
|
||||
recovered_cell_weights, input_quantized, aux_input_quantized,
|
||||
|
@ -38,17 +38,24 @@ namespace builtin {
|
||||
namespace lstm {
|
||||
|
||||
struct OpData {
|
||||
// Which kernel type to use. Full kernel (20 inputs) or basic kernel
|
||||
// (5 inputs).
|
||||
// Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
|
||||
// inputs).
|
||||
// Please note the 20-input full kernel is deprecated and only kept
|
||||
// here for backward compatibility.
|
||||
TfLiteLSTMKernelType kernel_type;
|
||||
|
||||
// If the lstm is layer norm.
|
||||
bool is_layer_norm_lstm;
|
||||
|
||||
// These fields are only used by full kernel.
|
||||
int activation_state_tensor_index;
|
||||
int cell_state_tensor_index;
|
||||
int scratch_tensor_index;
|
||||
};
|
||||
|
||||
// For full inputs kernel (20-inputs).
|
||||
// For full inputs kernel (24-inputs).
|
||||
// Please note the 20-input full kernel is deprecated and only kept
|
||||
// here for backward compatibility.
|
||||
namespace full {
|
||||
|
||||
// Input Tensors of size {n_batch, n_input}
|
||||
@ -87,6 +94,13 @@ constexpr int kProjectionBiasTensor = 17; // Optional
|
||||
constexpr int kInputActivationStateTensor = 18;
|
||||
constexpr int kInputCellStateTensor = 19;
|
||||
|
||||
// Layer norm coefficient tensors of size {n_cell}, representing a diagonal
|
||||
// matrix.
|
||||
constexpr int kInputLayerNormCoefficientsTensor = 20; // Optional
|
||||
constexpr int kForgetLayerNormCoefficientsTensor = 21; // Optional
|
||||
constexpr int kCellLayerNormCoefficientsTensor = 22; // Optional
|
||||
constexpr int kOutputLayerNormCoefficientsTensor = 23; // Optional
|
||||
|
||||
// Output tensors.
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
@ -101,7 +115,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// Check that input tensor dimensions matches with each other.
|
||||
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
TfLiteNode* node, int n_input,
|
||||
int n_output, int n_cell) {
|
||||
int n_output, int n_cell,
|
||||
bool is_layer_norm_lstm) {
|
||||
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
|
||||
// Making sure clipping parameters have valid values.
|
||||
@ -112,7 +127,8 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
|
||||
const TfLiteTensor* input_to_input_weights =
|
||||
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
|
||||
if (input_to_input_weights != nullptr) {
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
if (!use_cifg) {
|
||||
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
|
||||
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
|
||||
@ -186,7 +202,6 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
}
|
||||
|
||||
// Making sure the peephole weights are there all or none.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool peephole_weights_all_or_none =
|
||||
((cell_to_input_weights != nullptr || use_cifg) &&
|
||||
(cell_to_forget_weights != nullptr) &&
|
||||
@ -244,6 +259,40 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
((projection_weights != nullptr) || (projection_bias == nullptr));
|
||||
TF_LITE_ENSURE(context, projection_tensors_consistent == true);
|
||||
|
||||
if (is_layer_norm_lstm) {
|
||||
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
|
||||
context, node, kInputLayerNormCoefficientsTensor);
|
||||
if (use_cifg) {
|
||||
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
|
||||
} else {
|
||||
TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
}
|
||||
|
||||
const TfLiteTensor* forget_layer_norm_coefficients =
|
||||
GetInput(context, node, kForgetLayerNormCoefficientsTensor);
|
||||
TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
|
||||
const TfLiteTensor* cell_layer_norm_coefficients =
|
||||
GetInput(context, node, kCellLayerNormCoefficientsTensor);
|
||||
TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
|
||||
const TfLiteTensor* output_layer_norm_coefficients =
|
||||
GetInput(context, node, kOutputLayerNormCoefficientsTensor);
|
||||
TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
|
||||
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
|
||||
n_cell);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -254,8 +303,32 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
|
||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
|
||||
// Logic for determining regular lstm and layer norm lstm:
|
||||
// input_size, forget_gate_layer_norm_tensor (20) null? is_layer_norm?
|
||||
// 20, N/A, No.
|
||||
// 24, null, No.
|
||||
// 24, not null, Yes.
|
||||
// 20-inputs lstm are deprecated and is only kept here for backward
|
||||
// compatibility.
|
||||
if (node->inputs->size == 24) {
|
||||
const TfLiteTensor* forget_layer_norm_coefficients =
|
||||
GetInput(context, node, kForgetLayerNormCoefficientsTensor);
|
||||
if (forget_layer_norm_coefficients == nullptr) {
|
||||
op_data->is_layer_norm_lstm = false;
|
||||
} else {
|
||||
op_data->is_layer_norm_lstm = true;
|
||||
}
|
||||
} else if (node->inputs->size == 20) {
|
||||
// This is deprecated and is only kept here for backward compatibility.
|
||||
op_data->is_layer_norm_lstm = false;
|
||||
} else {
|
||||
context->ReportError(
|
||||
context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
|
||||
node->inputs->size);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
|
||||
op_data->activation_state_tensor_index =
|
||||
node->inputs->data[kInputActivationStateTensor];
|
||||
op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor];
|
||||
@ -282,8 +355,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Check that input tensor dimensions matches with each other.
|
||||
TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
|
||||
n_output, n_cell));
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
CheckInputTensorDimensions(context, node, n_input, n_output,
|
||||
n_cell, is_layer_norm_lstm));
|
||||
|
||||
// Get the pointer to output, activation_state and cell_state tensors.
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
@ -430,6 +504,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||
const bool is_layer_norm_lstm = op_data->is_layer_norm_lstm;
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
|
||||
@ -458,6 +533,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* cell_to_output_weights =
|
||||
GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
|
||||
|
||||
const TfLiteTensor* input_layer_norm_coefficients =
|
||||
is_layer_norm_lstm ? GetOptionalInputTensor(
|
||||
context, node, kInputLayerNormCoefficientsTensor)
|
||||
: nullptr;
|
||||
const TfLiteTensor* forget_layer_norm_coefficients =
|
||||
is_layer_norm_lstm
|
||||
? GetInput(context, node, kForgetLayerNormCoefficientsTensor)
|
||||
: nullptr;
|
||||
const TfLiteTensor* cell_layer_norm_coefficients =
|
||||
is_layer_norm_lstm
|
||||
? GetInput(context, node, kCellLayerNormCoefficientsTensor)
|
||||
: nullptr;
|
||||
const TfLiteTensor* output_layer_norm_coefficients =
|
||||
is_layer_norm_lstm
|
||||
? GetInput(context, node, kOutputLayerNormCoefficientsTensor)
|
||||
: nullptr;
|
||||
|
||||
const TfLiteTensor* input_gate_bias =
|
||||
GetOptionalInputTensor(context, node, kInputGateBiasTensor);
|
||||
const TfLiteTensor* forget_gate_bias =
|
||||
@ -490,6 +582,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
input_layer_norm_coefficients, forget_layer_norm_coefficients,
|
||||
cell_layer_norm_coefficients, output_layer_norm_coefficients,
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
@ -518,6 +612,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
input_layer_norm_coefficients, forget_layer_norm_coefficients,
|
||||
cell_layer_norm_coefficients, output_layer_norm_coefficients,
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
|
@ -27,6 +27,10 @@ namespace lstm_eval {
|
||||
|
||||
namespace {
|
||||
|
||||
// Small float to avoid divergence during calculation of deviation for layer
|
||||
// norm lstm.
|
||||
const float kLayerNormEpsilon = 1e-8;
|
||||
|
||||
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
|
||||
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
|
||||
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
|
||||
@ -65,30 +69,47 @@ inline void LstmStepWithAuxInput(
|
||||
const float* recurrent_to_output_weights_ptr,
|
||||
const float* cell_to_input_weights_ptr,
|
||||
const float* cell_to_forget_weights_ptr,
|
||||
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||
int output_batch_leading_dim, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* input_gate_scratch,
|
||||
const float* cell_to_output_weights_ptr,
|
||||
const float* input_layer_norm_coefficients_ptr,
|
||||
const float* forget_layer_norm_coefficients_ptr,
|
||||
const float* cell_layer_norm_coefficients_ptr,
|
||||
const float* output_layer_norm_coefficients_ptr,
|
||||
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
||||
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
|
||||
const float* projection_weights_ptr, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_aux_input, int n_output, int output_batch_leading_dim,
|
||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* output_ptr_batch) {
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existense of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
// Initialize scratch buffers with bias.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
|
||||
input_gate_scratch);
|
||||
const bool is_layer_norm_lstm =
|
||||
(forget_layer_norm_coefficients_ptr != nullptr);
|
||||
|
||||
// Initialize scratch buffers with bias for regular lstm or initialize with
|
||||
// zero for layer norm lstm.
|
||||
if (is_layer_norm_lstm) {
|
||||
if (!use_cifg) {
|
||||
tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
|
||||
}
|
||||
tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
|
||||
tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
|
||||
tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
|
||||
} else {
|
||||
if (!use_cifg) {
|
||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
|
||||
n_batch, input_gate_scratch);
|
||||
}
|
||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
if (!use_cifg) {
|
||||
@ -152,6 +173,16 @@ inline void LstmStepWithAuxInput(
|
||||
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
tensor_utils::MeanStddevNormalization(input_gate_scratch,
|
||||
input_gate_scratch, n_cell, n_batch,
|
||||
kLayerNormEpsilon);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
|
||||
n_batch, input_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
@ -162,12 +193,31 @@ inline void LstmStepWithAuxInput(
|
||||
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||
forget_gate_scratch, n_cell, n_batch,
|
||||
kLayerNormEpsilon);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
|
||||
n_batch, forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
||||
forget_gate_scratch);
|
||||
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||
n_batch * n_cell, cell_state_ptr);
|
||||
if (is_layer_norm_lstm) {
|
||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
||||
n_batch, kLayerNormEpsilon);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
}
|
||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
if (use_cifg) {
|
||||
@ -190,6 +240,16 @@ inline void LstmStepWithAuxInput(
|
||||
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||
output_gate_scratch, n_cell, n_batch,
|
||||
kLayerNormEpsilon);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
|
||||
n_batch, output_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||
output_gate_scratch);
|
||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||
@ -344,33 +404,50 @@ inline void LstmStepWithAuxInput(
|
||||
const int8_t* cell_to_forget_weights_ptr,
|
||||
float cell_to_forget_weights_scale,
|
||||
const int8_t* cell_to_output_weights_ptr,
|
||||
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
|
||||
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
|
||||
float projection_weights_scale, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_aux_input, int n_output, int output_batch_leading_dim,
|
||||
float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
|
||||
float* output_gate_scratch, float* scaling_factors,
|
||||
float* product_scaling_factors, float* recovered_cell_weights,
|
||||
int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch,
|
||||
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
|
||||
float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch) {
|
||||
float cell_to_output_weights_scale,
|
||||
const float* input_layer_norm_coefficients_ptr,
|
||||
const float* forget_layer_norm_coefficients_ptr,
|
||||
const float* cell_layer_norm_coefficients_ptr,
|
||||
const float* output_layer_norm_coefficients_ptr,
|
||||
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
||||
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
|
||||
const int8_t* projection_weights_ptr, float projection_weights_scale,
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||
int output_batch_leading_dim, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* scaling_factors, float* product_scaling_factors,
|
||||
float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
|
||||
int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||
float* cell_state_ptr, float* output_ptr_batch) {
|
||||
// Since we have already checked that weights are all there or none, we
|
||||
// can check the existense of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
const bool is_layer_norm_lstm =
|
||||
(forget_layer_norm_coefficients_ptr != nullptr);
|
||||
|
||||
// Initialize scratch buffers with bias.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
|
||||
input_gate_scratch);
|
||||
if (is_layer_norm_lstm) {
|
||||
if (!use_cifg) {
|
||||
tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
|
||||
}
|
||||
tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
|
||||
tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
|
||||
tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
|
||||
} else {
|
||||
if (!use_cifg) {
|
||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
|
||||
n_batch, input_gate_scratch);
|
||||
}
|
||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
|
||||
if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
|
||||
// Save quantization and matmul computation for all zero input.
|
||||
@ -535,6 +612,16 @@ inline void LstmStepWithAuxInput(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
tensor_utils::MeanStddevNormalization(input_gate_scratch,
|
||||
input_gate_scratch, n_cell, n_batch,
|
||||
kLayerNormEpsilon);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
|
||||
n_batch, input_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
@ -548,12 +635,31 @@ inline void LstmStepWithAuxInput(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||
forget_gate_scratch, n_cell, n_batch,
|
||||
kLayerNormEpsilon);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
|
||||
n_batch, forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
||||
forget_gate_scratch);
|
||||
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||
n_batch * n_cell, cell_state_ptr);
|
||||
if (is_layer_norm_lstm) {
|
||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
||||
n_batch, kLayerNormEpsilon);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
||||
cell_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
|
||||
cell_scratch);
|
||||
}
|
||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
||||
params->activation, cell_scratch);
|
||||
if (use_cifg) {
|
||||
@ -581,6 +687,16 @@ inline void LstmStepWithAuxInput(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||
output_gate_scratch, n_cell, n_batch,
|
||||
kLayerNormEpsilon);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
|
||||
n_batch, output_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||
output_gate_scratch);
|
||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||
@ -702,7 +818,12 @@ TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
const TfLiteTensor* input_layer_norm_coefficients,
|
||||
const TfLiteTensor* forget_layer_norm_coefficients,
|
||||
const TfLiteTensor* cell_layer_norm_coefficients,
|
||||
const TfLiteTensor* output_layer_norm_coefficients,
|
||||
const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
@ -735,6 +856,7 @@ TfLiteStatus EvalFloat(
|
||||
// check the existense of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
|
||||
|
||||
// Index the scratch buffers pointers to the global scratch buffer.
|
||||
float* input_gate_scratch = nullptr;
|
||||
@ -765,6 +887,15 @@ TfLiteStatus EvalFloat(
|
||||
(use_peephole) ? cell_to_forget_weights->data.f : nullptr;
|
||||
const float* cell_to_output_weights_ptr =
|
||||
(use_peephole) ? cell_to_output_weights->data.f : nullptr;
|
||||
const float* input_layer_norm_coefficients_ptr =
|
||||
(is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
|
||||
: nullptr;
|
||||
const float* forget_layer_norm_coefficients_ptr =
|
||||
is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
|
||||
const float* cell_layer_norm_coefficients_ptr =
|
||||
is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
|
||||
const float* output_layer_norm_coefficients_ptr =
|
||||
is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
|
||||
const float* projection_weights_ptr =
|
||||
(projection_weights == nullptr) ? nullptr : projection_weights->data.f;
|
||||
const float* projection_bias_ptr =
|
||||
@ -811,6 +942,8 @@ TfLiteStatus EvalFloat(
|
||||
recurrent_to_cell_weights->data.f,
|
||||
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
|
||||
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
|
||||
input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
|
||||
cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
|
||||
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
|
||||
output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
|
||||
params, n_batch, n_cell, n_input, aux_input_size, n_output,
|
||||
@ -855,7 +988,11 @@ TfLiteStatus EvalFloat(
|
||||
recurrent_to_cell_weights->data.f,
|
||||
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
|
||||
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
|
||||
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
|
||||
input_layer_norm_coefficients_ptr,
|
||||
forget_layer_norm_coefficients_ptr,
|
||||
cell_layer_norm_coefficients_ptr,
|
||||
output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
|
||||
forget_gate_bias->data.f, cell_bias->data.f,
|
||||
output_gate_bias->data.f, projection_weights_ptr,
|
||||
projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input,
|
||||
aux_input_size, n_output, output_batch_leading_dim,
|
||||
@ -879,7 +1016,12 @@ TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
const TfLiteTensor* input_layer_norm_coefficients,
|
||||
const TfLiteTensor* forget_layer_norm_coefficients,
|
||||
const TfLiteTensor* cell_layer_norm_coefficients,
|
||||
const TfLiteTensor* output_layer_norm_coefficients,
|
||||
const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
@ -914,6 +1056,7 @@ TfLiteStatus EvalHybrid(
|
||||
// check the existence of only one to get the condition.
|
||||
const bool use_cifg = (input_to_input_weights == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights != nullptr);
|
||||
const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
|
||||
|
||||
float* input_gate_scratch = nullptr;
|
||||
float* cell_scratch = nullptr;
|
||||
@ -966,6 +1109,16 @@ TfLiteStatus EvalHybrid(
|
||||
cell_to_output_weights_scale = cell_to_output_weights->params.scale;
|
||||
}
|
||||
|
||||
const float* input_layer_norm_coefficients_ptr =
|
||||
(is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
|
||||
: nullptr;
|
||||
const float* forget_layer_norm_coefficients_ptr =
|
||||
is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
|
||||
const float* cell_layer_norm_coefficients_ptr =
|
||||
is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
|
||||
const float* output_layer_norm_coefficients_ptr =
|
||||
is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
|
||||
|
||||
const int8_t* projection_weights_ptr =
|
||||
(projection_weights == nullptr)
|
||||
? nullptr
|
||||
@ -1084,6 +1237,8 @@ TfLiteStatus EvalHybrid(
|
||||
cell_to_input_weights_ptr, cell_to_input_weights_scale,
|
||||
cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
|
||||
cell_to_output_weights_ptr, cell_to_output_weights_scale,
|
||||
input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
|
||||
cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
|
||||
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
|
||||
output_gate_bias_ptr, projection_weights_ptr,
|
||||
projection_weights_scale, projection_bias_ptr, params, n_batch,
|
||||
@ -1138,7 +1293,10 @@ TfLiteStatus EvalHybrid(
|
||||
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
|
||||
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
|
||||
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
|
||||
cell_to_output_weights_scale, input_gate_bias_ptr,
|
||||
cell_to_output_weights_scale, input_layer_norm_coefficients_ptr,
|
||||
forget_layer_norm_coefficients_ptr,
|
||||
cell_layer_norm_coefficients_ptr,
|
||||
output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
|
||||
forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
|
||||
projection_weights_ptr, projection_weights_scale,
|
||||
projection_bias_ptr, params,
|
||||
|
@ -34,7 +34,12 @@ TfLiteStatus EvalFloat(
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
const TfLiteTensor* input_layer_norm_coefficients,
|
||||
const TfLiteTensor* forget_layer_norm_coefficients,
|
||||
const TfLiteTensor* cell_layer_norm_coefficients,
|
||||
const TfLiteTensor* output_layer_norm_coefficients,
|
||||
const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
@ -58,7 +63,12 @@ TfLiteStatus EvalHybrid(
|
||||
const TfLiteTensor* recurrent_to_output_weights,
|
||||
const TfLiteTensor* cell_to_input_weights,
|
||||
const TfLiteTensor* cell_to_forget_weights,
|
||||
const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* cell_to_output_weights,
|
||||
const TfLiteTensor* input_layer_norm_coefficients,
|
||||
const TfLiteTensor* forget_layer_norm_coefficients,
|
||||
const TfLiteTensor* cell_layer_norm_coefficients,
|
||||
const TfLiteTensor* output_layer_norm_coefficients,
|
||||
const TfLiteTensor* aux_input,
|
||||
const TfLiteTensor* aux_input_to_input_weights,
|
||||
const TfLiteTensor* aux_input_to_forget_weights,
|
||||
const TfLiteTensor* aux_input_to_cell_weights,
|
||||
|
@ -38,7 +38,8 @@ class LSTMOpModel : public SingleOpModel {
|
||||
bool use_peephole, bool use_projection_weights,
|
||||
bool use_projection_bias, float cell_clip, float proj_clip,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorType& weight_type = TensorType_FLOAT32)
|
||||
const TensorType& weight_type = TensorType_FLOAT32,
|
||||
bool is_layer_norm = false)
|
||||
: n_batch_(n_batch),
|
||||
n_input_(n_input),
|
||||
n_cell_(n_cell),
|
||||
@ -106,6 +107,18 @@ class LSTMOpModel : public SingleOpModel {
|
||||
input_cell_state_ =
|
||||
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
|
||||
|
||||
// Layer norm weights.
|
||||
if (is_layer_norm) {
|
||||
if (use_cifg) {
|
||||
input_layer_norm_coefficients_ = AddNullInput();
|
||||
} else {
|
||||
input_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32);
|
||||
}
|
||||
forget_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32);
|
||||
cell_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32);
|
||||
output_layer_norm_coefficients_ = AddInput(TensorType_FLOAT32);
|
||||
}
|
||||
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
|
||||
@ -160,6 +173,22 @@ class LSTMOpModel : public SingleOpModel {
|
||||
PopulateTensor(cell_to_output_weights_, f);
|
||||
}
|
||||
|
||||
void SetInputLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(input_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetForgetLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(forget_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetCellLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(cell_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetOutputLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(output_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetInputGateBias(std::vector<float> f) {
|
||||
PopulateTensor(input_gate_bias_, f);
|
||||
}
|
||||
@ -210,6 +239,11 @@ class LSTMOpModel : public SingleOpModel {
|
||||
int cell_to_forget_weights_;
|
||||
int cell_to_output_weights_;
|
||||
|
||||
int input_layer_norm_coefficients_;
|
||||
int forget_layer_norm_coefficients_;
|
||||
int cell_layer_norm_coefficients_;
|
||||
int output_layer_norm_coefficients_;
|
||||
|
||||
int input_gate_bias_;
|
||||
int forget_gate_bias_;
|
||||
int cell_bias_;
|
||||
@ -1392,6 +1426,644 @@ TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
|
||||
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
|
||||
}
|
||||
|
||||
class LayerNormLSTMOpModel : public LSTMOpModel {
|
||||
public:
|
||||
LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
|
||||
bool use_cifg, bool use_peephole,
|
||||
bool use_projection_weights, bool use_projection_bias,
|
||||
float cell_clip, float proj_clip,
|
||||
const std::vector<std::vector<int>>& input_shapes,
|
||||
const TensorType& weight_type = TensorType_FLOAT32)
|
||||
: LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole,
|
||||
use_projection_weights, use_projection_bias, cell_clip,
|
||||
proj_clip, input_shapes, weight_type,
|
||||
/*is_layer_norm*/ true) {}
|
||||
};
|
||||
|
||||
class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel {
|
||||
public:
|
||||
HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
|
||||
bool use_cifg, bool use_peephole,
|
||||
bool use_projection_weights,
|
||||
bool use_projection_bias, float cell_clip,
|
||||
float proj_clip,
|
||||
const std::vector<std::vector<int>>& input_shapes)
|
||||
: LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg,
|
||||
use_peephole, use_projection_weights,
|
||||
use_projection_bias, cell_clip, proj_clip,
|
||||
input_shapes, TensorType_UINT8) {}
|
||||
|
||||
void SetInputToInputWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
|
||||
}
|
||||
|
||||
void SetInputToForgetWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
|
||||
}
|
||||
|
||||
void SetInputToCellWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
|
||||
}
|
||||
|
||||
void SetInputToOutputWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
|
||||
}
|
||||
|
||||
void SetRecurrentToInputWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
|
||||
}
|
||||
|
||||
void SetRecurrentToForgetWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
|
||||
}
|
||||
|
||||
void SetRecurrentToCellWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
|
||||
}
|
||||
|
||||
void SetRecurrentToOutputWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
|
||||
}
|
||||
|
||||
void SetCellToInputWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
|
||||
}
|
||||
|
||||
void SetCellToForgetWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
|
||||
}
|
||||
|
||||
void SetCellToOutputWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
|
||||
}
|
||||
|
||||
void SetInputLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(input_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetForgetLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(forget_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetCellLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(cell_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetOutputLayerNormCoefficients(std::vector<float> f) {
|
||||
PopulateTensor(output_layer_norm_coefficients_, f);
|
||||
}
|
||||
|
||||
void SetProjectionWeights(std::vector<float> f) {
|
||||
SymmetricQuantizeAndPopulate(projection_weights_, f);
|
||||
}
|
||||
};
|
||||
|
||||
class BaseLayerNormLstmTest : public ::testing::Test {
|
||||
protected:
|
||||
// Weights of the Layer Norm LSTM model. Some are optional.
|
||||
std::vector<float> input_to_input_weights_;
|
||||
std::vector<float> input_to_cell_weights_;
|
||||
std::vector<float> input_to_forget_weights_;
|
||||
std::vector<float> input_to_output_weights_;
|
||||
std::vector<float> input_gate_bias_;
|
||||
std::vector<float> cell_gate_bias_;
|
||||
std::vector<float> forget_gate_bias_;
|
||||
std::vector<float> output_gate_bias_;
|
||||
std::vector<float> recurrent_to_input_weights_;
|
||||
std::vector<float> recurrent_to_cell_weights_;
|
||||
std::vector<float> recurrent_to_forget_weights_;
|
||||
std::vector<float> recurrent_to_output_weights_;
|
||||
std::vector<float> cell_to_input_weights_;
|
||||
std::vector<float> cell_to_forget_weights_;
|
||||
std::vector<float> cell_to_output_weights_;
|
||||
std::vector<float> projection_weights_;
|
||||
std::vector<float> input_layer_norm_coefficients_;
|
||||
std::vector<float> forget_layer_norm_coefficients_;
|
||||
std::vector<float> cell_layer_norm_coefficients_;
|
||||
std::vector<float> output_layer_norm_coefficients_;
|
||||
|
||||
// Layer Norm LSTM input is stored as num_batch x num_inputs vector.
|
||||
std::vector<std::vector<float>> layer_norm_lstm_input_;
|
||||
|
||||
// Compares output up to tolerance to the result of the layer_norm_lstm given
|
||||
// the input.
|
||||
void VerifyGoldens(const std::vector<std::vector<float>>& input,
|
||||
const std::vector<std::vector<float>>& output,
|
||||
LayerNormLSTMOpModel* layer_norm_lstm,
|
||||
float tolerance = 1e-5) {
|
||||
const int num_batches = input.size();
|
||||
EXPECT_GT(num_batches, 0);
|
||||
const int num_inputs = layer_norm_lstm->num_inputs();
|
||||
EXPECT_GT(num_inputs, 0);
|
||||
const int input_sequence_size = input[0].size() / num_inputs;
|
||||
EXPECT_GT(input_sequence_size, 0);
|
||||
for (int i = 0; i < input_sequence_size; ++i) {
|
||||
for (int b = 0; b < num_batches; ++b) {
|
||||
const float* batch_start = input[b].data() + i * num_inputs;
|
||||
const float* batch_end = batch_start + num_inputs;
|
||||
|
||||
layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
|
||||
batch_start, batch_end);
|
||||
}
|
||||
|
||||
layer_norm_lstm->Invoke();
|
||||
|
||||
const int num_outputs = layer_norm_lstm->num_outputs();
|
||||
std::vector<float> expected;
|
||||
for (int b = 0; b < num_batches; ++b) {
|
||||
const float* golden_start_batch = output[b].data() + i * num_outputs;
|
||||
const float* golden_end_batch = golden_start_batch + num_outputs;
|
||||
expected.insert(expected.end(), golden_start_batch, golden_end_batch);
|
||||
}
|
||||
EXPECT_THAT(layer_norm_lstm->GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(expected, tolerance)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
|
||||
: public BaseLayerNormLstmTest {
|
||||
void SetUp() override {
|
||||
input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
|
||||
0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
|
||||
-0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
|
||||
|
||||
input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
|
||||
-0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
|
||||
-0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
|
||||
|
||||
input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
|
||||
-0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
|
||||
-0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
|
||||
|
||||
input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
|
||||
-0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
|
||||
-0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
|
||||
|
||||
input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
|
||||
|
||||
forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
|
||||
|
||||
cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
|
||||
|
||||
output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
|
||||
|
||||
recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9,
|
||||
-0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
|
||||
|
||||
recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
|
||||
-0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
|
||||
|
||||
recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
|
||||
0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
|
||||
|
||||
recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
|
||||
-0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
|
||||
|
||||
cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
|
||||
|
||||
cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
|
||||
|
||||
cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
|
||||
|
||||
input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
|
||||
forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
|
||||
cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
|
||||
output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
|
||||
|
||||
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
|
||||
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
||||
|
||||
layer_norm_lstm_input_ = {
|
||||
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
|
||||
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
|
||||
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
|
||||
0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
|
||||
|
||||
{// Batch1: 3 (input_sequence_size) * 5 (n_input)
|
||||
0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
|
||||
0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
|
||||
0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
LayerNormLstmBlackBoxTest) {
|
||||
const int n_batch = 2;
|
||||
const int n_input = 5;
|
||||
const int n_cell = 4;
|
||||
const int n_output = 3;
|
||||
const float ceil_clip = 0.0;
|
||||
const float proj_clip = 0.0;
|
||||
|
||||
LayerNormLSTMOpModel layer_norm_lstm(
|
||||
n_batch, n_input, n_cell, n_output,
|
||||
/*use_cifg=*/false, /*use_peephole=*/true,
|
||||
/*use_projection_weights=*/true,
|
||||
/*use_projection_bias=*/false, ceil_clip, proj_clip,
|
||||
{
|
||||
{n_batch, n_input}, // input tensor
|
||||
|
||||
{n_cell, n_input}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||
|
||||
{n_cell}, // cell_to_input_weight tensor
|
||||
{n_cell}, // cell_to_forget_weight tensor
|
||||
{n_cell}, // cell_to_output_weight tensor
|
||||
|
||||
{n_cell}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{n_cell}, // input_layer_norm_coefficient tensor
|
||||
{n_cell}, // forget_layer_norm_coefficient tensor
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
});
|
||||
|
||||
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
// Verify the final output.
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0244077, 0.128027, -0.00170918, // seq 0
|
||||
0.0137642, 0.140751, 0.0395835, // seq 1
|
||||
-0.00459231, 0.155278, 0.0837377, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.00692428, 0.0848741, 0.063445, // seq 0
|
||||
-0.00403912, 0.139963, 0.072681, // seq 1
|
||||
0.00752706, 0.161903, 0.0561371, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
}
|
||||
|
||||
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
HybridLayerNormLstmBlackBoxTest) {
|
||||
const int n_batch = 2;
|
||||
const int n_input = 5;
|
||||
const int n_cell = 4;
|
||||
const int n_output = 3;
|
||||
const float ceil_clip = 0.0;
|
||||
const float proj_clip = 0.0;
|
||||
|
||||
HybridLayerNormLSTMOpModel layer_norm_lstm(
|
||||
n_batch, n_input, n_cell, n_output,
|
||||
/*use_cifg=*/false, /*use_peephole=*/true,
|
||||
/*use_projection_weights=*/true,
|
||||
/*use_projection_bias=*/false, ceil_clip, proj_clip,
|
||||
{
|
||||
{n_batch, n_input}, // input tensor
|
||||
|
||||
{n_cell, n_input}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{n_cell, n_output}, // recurrent_to_input_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||
|
||||
{n_cell}, // cell_to_input_weight tensor
|
||||
{n_cell}, // cell_to_forget_weight tensor
|
||||
{n_cell}, // cell_to_output_weight tensor
|
||||
|
||||
{n_cell}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{n_cell}, // input_layer_norm_coefficient tensor
|
||||
{n_cell}, // forget_layer_norm_coefficient tensor
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
});
|
||||
|
||||
layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetInputGateBias(input_gate_bias_);
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0244576, 0.127847, -0.00181765, // seq 0
|
||||
0.0137518, 0.140892, 0.0402234, // seq 1
|
||||
-0.0048839, 0.155096, 0.0840309, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.00728636, 0.0843957, 0.0634786, // seq 0
|
||||
-0.00448382, 0.139278, 0.0737372, // seq 1
|
||||
0.00734616, 0.161793, 0.0560238, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
}
|
||||
|
||||
class CifgPeepholeProjectionNoClippingLayerNormLstmTest
|
||||
: public BaseLayerNormLstmTest {
|
||||
void SetUp() override {
|
||||
input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
|
||||
-0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
|
||||
-0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
|
||||
input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
|
||||
-0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
|
||||
-0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
|
||||
input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
|
||||
-0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
|
||||
-0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
|
||||
|
||||
forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
|
||||
cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
|
||||
output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
|
||||
|
||||
recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
|
||||
-0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
|
||||
recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
|
||||
0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
|
||||
recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
|
||||
-0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
|
||||
|
||||
cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
|
||||
cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
|
||||
|
||||
forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
|
||||
cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
|
||||
output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
|
||||
projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
|
||||
0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
|
||||
|
||||
layer_norm_lstm_input_ = {
|
||||
{// Batch0: 3 (input_sequence_size) * 5 (n_input)
|
||||
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
|
||||
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
|
||||
0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
|
||||
|
||||
{// Batch1: 3 (input_sequence_size) * 5 (n_input)
|
||||
0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
|
||||
0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
|
||||
0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
LayerNormLstmBlackBoxTest) {
|
||||
const int n_batch = 2;
|
||||
const int n_input = 5;
|
||||
const int n_cell = 4;
|
||||
const int n_output = 3;
|
||||
const float ceil_clip = 0.0;
|
||||
const float proj_clip = 0.0;
|
||||
|
||||
LayerNormLSTMOpModel layer_norm_lstm(
|
||||
n_batch, n_input, n_cell, n_output,
|
||||
/*use_cifg=*/true, /*use_peephole=*/true,
|
||||
/*use_projection_weights=*/true,
|
||||
/*use_projection_bias=*/false, ceil_clip, proj_clip,
|
||||
{
|
||||
{n_batch, n_input}, // input tensor
|
||||
|
||||
{0, 0}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{0, 0}, // recurrent_to_input_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||
|
||||
{0}, // cell_to_input_weight tensor
|
||||
{n_cell}, // cell_to_forget_weight tensor
|
||||
{n_cell}, // cell_to_output_weight tensor
|
||||
|
||||
{0}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
{n_cell}, // forget_layer_norm_coefficient tensor
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
});
|
||||
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
// Verify the final output.
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.02129706, 0.140816242, 0.0112733059, // seq 0
|
||||
0.0132302344, 0.152308047, 0.0346313119, // seq 1
|
||||
-0.0123688057, 0.165790111, 0.0893077999, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.0226350538, 0.0916948169, 0.0769175813, // seq 0
|
||||
-0.0269966982, 0.149707705, 0.094149217, // seq 1
|
||||
-0.0103429332, 0.173016444, 0.0720508844, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
}
|
||||
|
||||
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
|
||||
HybridLayerNormLstmBlackBoxTest) {
|
||||
const int n_batch = 2;
|
||||
const int n_input = 5;
|
||||
const int n_cell = 4;
|
||||
const int n_output = 3;
|
||||
const float ceil_clip = 0.0;
|
||||
const float proj_clip = 0.0;
|
||||
|
||||
HybridLayerNormLSTMOpModel layer_norm_lstm(
|
||||
n_batch, n_input, n_cell, n_output,
|
||||
/*use_cifg=*/true, /*use_peephole=*/true,
|
||||
/*use_projection_weights=*/true,
|
||||
/*use_projection_bias=*/false, ceil_clip, proj_clip,
|
||||
{
|
||||
{n_batch, n_input}, // input tensor
|
||||
|
||||
{0, 0}, // input_to_input_weight tensor
|
||||
{n_cell, n_input}, // input_to_forget_weight tensor
|
||||
{n_cell, n_input}, // input_to_cell_weight tensor
|
||||
{n_cell, n_input}, // input_to_output_weight tensor
|
||||
|
||||
{0, 0}, // recurrent_to_input_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_forget_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_cell_weight tensor
|
||||
{n_cell, n_output}, // recurrent_to_output_weight tensor
|
||||
|
||||
{0}, // cell_to_input_weight tensor
|
||||
{n_cell}, // cell_to_forget_weight tensor
|
||||
{n_cell}, // cell_to_output_weight tensor
|
||||
|
||||
{0}, // input_gate_bias tensor
|
||||
{n_cell}, // forget_gate_bias tensor
|
||||
{n_cell}, // cell_bias tensor
|
||||
{n_cell}, // output_gate_bias tensor
|
||||
|
||||
{n_output, n_cell}, // projection_weight tensor
|
||||
{0}, // projection_bias tensor
|
||||
|
||||
{n_batch, n_output}, // activation_state tensor
|
||||
{n_batch, n_cell}, // cell_state tensor
|
||||
|
||||
{0}, // input_layer_norm_coefficient tensor
|
||||
{n_cell}, // forget_layer_norm_coefficient tensor
|
||||
{n_cell}, // cell_layer_norm_coefficient tensor
|
||||
{n_cell}, // output_layer_norm_coefficient tensor
|
||||
});
|
||||
|
||||
layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
|
||||
layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
|
||||
layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellBias(cell_gate_bias_);
|
||||
layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
|
||||
layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
|
||||
|
||||
layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
|
||||
layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
|
||||
layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
|
||||
layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
|
||||
|
||||
layer_norm_lstm.SetForgetLayerNormCoefficients(
|
||||
forget_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
|
||||
layer_norm_lstm.SetOutputLayerNormCoefficients(
|
||||
output_layer_norm_coefficients_);
|
||||
|
||||
layer_norm_lstm.SetProjectionWeights(projection_weights_);
|
||||
|
||||
// Verify the final output.
|
||||
const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
|
||||
{
|
||||
// Batch0: 3 (input_sequence_size) * 3 (n_output)
|
||||
0.0212250091, 0.140474007, 0.0115012666, // seq 0
|
||||
0.0130806509, 0.152660668, 0.0347516984, // seq 1
|
||||
-0.0124010444, 0.166042402, 0.0898982584, // seq 2
|
||||
},
|
||||
{
|
||||
// Batch1: 3 (input_sequence_size) * 3 (n_output)
|
||||
-0.0228835996, 0.0917588323, 0.0778886303, // seq 0
|
||||
-0.0275101066, 0.148769245, 0.0938384682, // seq 1
|
||||
-0.0103605557, 0.172605693, 0.0728750974, // seq 2
|
||||
}};
|
||||
|
||||
VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
|
||||
&layer_norm_lstm);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -502,6 +502,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr,
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
@ -529,6 +533,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
recurrent_to_input_weights, recurrent_to_forget_weights,
|
||||
recurrent_to_cell_weights, recurrent_to_output_weights,
|
||||
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
|
||||
/*input_layer_norm_coefficients=*/nullptr,
|
||||
/*forget_layer_norm_coefficients=*/nullptr,
|
||||
/*cell_layer_norm_coefficients=*/nullptr,
|
||||
/*output_layer_norm_coefficients=*/nullptr,
|
||||
/*aux_input=*/nullptr,
|
||||
/*aux_input_to_input_weights=*/nullptr,
|
||||
/*aux_input_to_forget_weights=*/nullptr,
|
||||
|
Loading…
Reference in New Issue
Block a user