Nit: rename use_layer_norm_lstm to use_layer_norm, to be consistent with other similar variables.
PiperOrigin-RevId: 292043066 Change-Id: Ic77c7f712432cb717eb477e98d8f4d16a6dac0be
This commit is contained in:
parent
d049992117
commit
c3f78d1c1c
@ -50,7 +50,7 @@ struct OpData {
|
||||
TfLiteLSTMKernelType kernel_type;
|
||||
|
||||
// If the lstm is layer norm.
|
||||
bool use_layer_norm_lstm;
|
||||
bool use_layer_norm;
|
||||
|
||||
// These fields are only used by full kernel.
|
||||
int scratch_tensor_index;
|
||||
@ -92,7 +92,7 @@ TfLiteStatus PopulateQuantizedLstmParams(
|
||||
|
||||
// Calculate effective scales.
|
||||
OpData* op_data = static_cast<OpData*>(node->user_data);
|
||||
const bool use_layer_norm_lstm = op_data->use_layer_norm_lstm;
|
||||
const bool use_layer_norm = op_data->use_layer_norm;
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
|
||||
@ -147,7 +147,7 @@ TfLiteStatus PopulateQuantizedLstmParams(
|
||||
std::vector<float> intermediate_scale;
|
||||
std::vector<int32> intermediate_zp;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
const TfLiteTensor* intermediate = GetIntermediates(context, node, i);
|
||||
auto* params = static_cast<TfLiteAffineQuantization*>(
|
||||
intermediate->quantization.params);
|
||||
@ -218,7 +218,7 @@ TfLiteStatus PopulateQuantizedLstmParams(
|
||||
cell_to_output_weight_scale = cell_to_output_weights->params.scale;
|
||||
}
|
||||
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
if (!use_cifg) {
|
||||
layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
|
||||
}
|
||||
@ -381,8 +381,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
TfLiteNode* node, int n_input,
|
||||
int n_output, int n_cell,
|
||||
bool use_layer_norm_lstm,
|
||||
bool is_integer) {
|
||||
bool use_layer_norm, bool is_integer) {
|
||||
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
|
||||
|
||||
// Making sure clipping parameters have valid values.
|
||||
@ -574,7 +573,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
|
||||
((projection_weights != nullptr) || (projection_bias == nullptr));
|
||||
TF_LITE_ENSURE(context, projection_tensors_consistent == true);
|
||||
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
|
||||
context, node, kInputLayerNormCoefficientsTensor);
|
||||
if (use_cifg) {
|
||||
@ -714,7 +713,7 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
|
||||
// When there is layer normalization, the gate bias does not apply to matmul
|
||||
// directly:
|
||||
// y = ln(w * x + w * r + w * c) + b.
|
||||
const bool is_layer_norm = op_data->use_layer_norm_lstm;
|
||||
const bool is_layer_norm = op_data->use_layer_norm;
|
||||
|
||||
// Forget gate.
|
||||
const TfLiteTensor* forget_gate_bias =
|
||||
@ -801,13 +800,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
|
||||
context, node, kForgetLayerNormCoefficientsTensor);
|
||||
if (forget_layer_norm_coefficients == nullptr) {
|
||||
op_data->use_layer_norm_lstm = false;
|
||||
op_data->use_layer_norm = false;
|
||||
} else {
|
||||
op_data->use_layer_norm_lstm = true;
|
||||
op_data->use_layer_norm = true;
|
||||
}
|
||||
} else if (node->inputs->size == 20) {
|
||||
// This is deprecated and is only kept here for backward compatibility.
|
||||
op_data->use_layer_norm_lstm = false;
|
||||
op_data->use_layer_norm = false;
|
||||
} else {
|
||||
context->ReportError(
|
||||
context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
|
||||
@ -815,7 +814,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
const bool use_layer_norm_lstm = op_data->use_layer_norm_lstm;
|
||||
const bool use_layer_norm = op_data->use_layer_norm;
|
||||
|
||||
// Inferring batch size, number of outputs and number of cells from the
|
||||
// input tensors.
|
||||
@ -839,9 +838,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int n_output = recurrent_to_output_weights->dims->data[1];
|
||||
|
||||
// Check that input tensor dimensions matches with each other.
|
||||
TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(
|
||||
context, node, n_input, n_output, n_cell,
|
||||
use_layer_norm_lstm, is_integer));
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, CheckInputTensorDimensions(context, node, n_input, n_output,
|
||||
n_cell, use_layer_norm, is_integer));
|
||||
|
||||
// Get the pointer to output, activation_state and cell_state tensors.
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
@ -151,12 +151,11 @@ inline void LstmStepFloat(
|
||||
// check the existence of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
const bool use_layer_norm_lstm =
|
||||
(forget_layer_norm_coefficients_ptr != nullptr);
|
||||
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
|
||||
|
||||
// Initialize scratch buffers with bias for regular lstm or initialize with
|
||||
// zero for layer norm lstm.
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
if (!use_cifg) {
|
||||
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
||||
}
|
||||
@ -243,7 +242,7 @@ inline void LstmStepFloat(
|
||||
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::MeanStddevNormalization(
|
||||
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -262,7 +261,7 @@ inline void LstmStepFloat(
|
||||
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||
forget_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -277,7 +276,7 @@ inline void LstmStepFloat(
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||
n_batch * n_cell, cell_state_ptr);
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
||||
n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -308,7 +307,7 @@ inline void LstmStepFloat(
|
||||
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||
output_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -485,11 +484,11 @@ inline void LstmStepHybrid(
|
||||
// can check the existence of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
const bool use_layer_norm_lstm =
|
||||
(forget_layer_norm_coefficients_ptr != nullptr);
|
||||
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
|
||||
|
||||
// Initialize scratch buffers with bias.
|
||||
if (use_layer_norm_lstm) {
|
||||
// Initialize scratch buffers with bias for regular lstm or initialize with
|
||||
// zero for layer norm lstm.
|
||||
if (use_layer_norm) {
|
||||
if (!use_cifg) {
|
||||
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
||||
}
|
||||
@ -671,7 +670,7 @@ inline void LstmStepHybrid(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::MeanStddevNormalization(
|
||||
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -693,7 +692,7 @@ inline void LstmStepHybrid(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||
forget_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -708,7 +707,7 @@ inline void LstmStepHybrid(
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||
n_batch * n_cell, cell_state_ptr);
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
||||
n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -742,7 +741,7 @@ inline void LstmStepHybrid(
|
||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||
output_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
@ -976,7 +975,7 @@ inline void LstmStepInteger(
|
||||
// Get hyper parameters.
|
||||
const bool use_cifg = (input_to_input_weight_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weight_ptr != nullptr);
|
||||
const bool use_layer_norm_lstm = (layer_norm_forget_weight_ptr != nullptr);
|
||||
const bool use_layer_norm = (layer_norm_forget_weight_ptr != nullptr);
|
||||
const bool use_projection = (proj_weight_ptr != nullptr);
|
||||
|
||||
// Check for nullptrs.
|
||||
@ -1018,7 +1017,7 @@ inline void LstmStepInteger(
|
||||
scratch_1_ptr);
|
||||
}
|
||||
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::ApplyLayerNorm(
|
||||
scratch_1_ptr, layer_norm_forget_weight_ptr, forget_bias_ptr,
|
||||
layer_norm_forget_scale_a, layer_norm_forget_scale_b,
|
||||
@ -1039,7 +1038,7 @@ inline void LstmStepInteger(
|
||||
effective_recurrent_to_cell_scale_b, n_batch, n_output, n_cell, 0,
|
||||
scratch_5_ptr, scratch_2_ptr, context);
|
||||
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr,
|
||||
cell_bias_ptr, layer_norm_cell_scale_a,
|
||||
layer_norm_cell_scale_b, cell_variance_guard,
|
||||
@ -1069,7 +1068,7 @@ inline void LstmStepInteger(
|
||||
scratch_0_ptr);
|
||||
}
|
||||
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::ApplyLayerNorm(
|
||||
scratch_0_ptr, layer_norm_input_weight_ptr, input_bias_ptr,
|
||||
layer_norm_input_scale_a, layer_norm_input_scale_b,
|
||||
@ -1110,7 +1109,7 @@ inline void LstmStepInteger(
|
||||
scratch_3_ptr);
|
||||
}
|
||||
|
||||
if (use_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::ApplyLayerNorm(
|
||||
scratch_3_ptr, layer_norm_output_weight_ptr, output_bias_ptr,
|
||||
layer_norm_output_scale_a, layer_norm_output_scale_b,
|
||||
|
@ -71,12 +71,11 @@ inline void LstmStepWithAuxInput(
|
||||
// check the existence of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
const bool is_layer_norm_lstm =
|
||||
(forget_layer_norm_coefficients_ptr != nullptr);
|
||||
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
|
||||
|
||||
// Initialize scratch buffers with bias for regular lstm or initialize with
|
||||
// zero for layer norm lstm.
|
||||
if (is_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
if (!use_cifg) {
|
||||
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
||||
}
|
||||
@ -158,7 +157,7 @@ inline void LstmStepWithAuxInput(
|
||||
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(
|
||||
@ -179,7 +178,7 @@ inline void LstmStepWithAuxInput(
|
||||
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch,
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||
@ -196,7 +195,7 @@ inline void LstmStepWithAuxInput(
|
||||
// For each batch and cell: update the cell.
|
||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||
n_batch * n_cell, cell_state_ptr);
|
||||
if (is_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch,
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
||||
@ -229,7 +228,7 @@ inline void LstmStepWithAuxInput(
|
||||
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
if (is_layer_norm_lstm) {
|
||||
if (use_layer_norm) {
|
||||
logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch,
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||
@ -483,7 +482,7 @@ struct OpData {
|
||||
TfLiteLSTMKernelType kernel_type;
|
||||
|
||||
// If the lstm is layer norm.
|
||||
bool is_layer_norm_lstm;
|
||||
bool use_layer_norm;
|
||||
|
||||
// These fields are only used by full kernel.
|
||||
int scratch_tensor_index;
|
||||
|
Loading…
x
Reference in New Issue
Block a user