diff --git a/tensorflow/lite/kernels/lstm.cc b/tensorflow/lite/kernels/lstm.cc index c97745f77e8..8fd0e9bdc98 100644 --- a/tensorflow/lite/kernels/lstm.cc +++ b/tensorflow/lite/kernels/lstm.cc @@ -50,7 +50,7 @@ struct OpData { TfLiteLSTMKernelType kernel_type; // If the lstm is layer norm. - bool use_layer_norm_lstm; + bool use_layer_norm; // These fields are only used by full kernel. int scratch_tensor_index; @@ -92,7 +92,7 @@ TfLiteStatus PopulateQuantizedLstmParams( // Calculate effective scales. OpData* op_data = static_cast(node->user_data); - const bool use_layer_norm_lstm = op_data->use_layer_norm_lstm; + const bool use_layer_norm = op_data->use_layer_norm; const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -147,7 +147,7 @@ TfLiteStatus PopulateQuantizedLstmParams( std::vector intermediate_scale; std::vector intermediate_zp; for (int i = 0; i < 4; ++i) { - if (use_layer_norm_lstm) { + if (use_layer_norm) { const TfLiteTensor* intermediate = GetIntermediates(context, node, i); auto* params = static_cast( intermediate->quantization.params); @@ -218,7 +218,7 @@ TfLiteStatus PopulateQuantizedLstmParams( cell_to_output_weight_scale = cell_to_output_weights->params.scale; } - if (use_layer_norm_lstm) { + if (use_layer_norm) { if (!use_cifg) { layer_norm_input_scale = input_layer_norm_coefficients->params.scale; } @@ -381,8 +381,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell, - bool use_layer_norm_lstm, - bool is_integer) { + bool use_layer_norm, bool is_integer) { const auto* params = static_cast(node->builtin_data); // Making sure clipping parameters have valid values. @@ -574,7 +573,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, ((projection_weights != nullptr) || (projection_bias == nullptr)); TF_LITE_ENSURE(context, projection_tensors_consistent == true); - if (use_layer_norm_lstm) { + if (use_layer_norm) { const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( context, node, kInputLayerNormCoefficientsTensor); if (use_cifg) { @@ -714,7 +713,7 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context, // When there is layer normalization, the gate bias does not apply to matmul // directly: // y = ln(w * x + w * r + w * c) + b. - const bool is_layer_norm = op_data->use_layer_norm_lstm; + const bool is_layer_norm = op_data->use_layer_norm; // Forget gate. const TfLiteTensor* forget_gate_bias = @@ -801,13 +800,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor( context, node, kForgetLayerNormCoefficientsTensor); if (forget_layer_norm_coefficients == nullptr) { - op_data->use_layer_norm_lstm = false; + op_data->use_layer_norm = false; } else { - op_data->use_layer_norm_lstm = true; + op_data->use_layer_norm = true; } } else if (node->inputs->size == 20) { // This is deprecated and is only kept here for backward compatibility. - op_data->use_layer_norm_lstm = false; + op_data->use_layer_norm = false; } else { context->ReportError( context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs", @@ -815,7 +814,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } - const bool use_layer_norm_lstm = op_data->use_layer_norm_lstm; + const bool use_layer_norm = op_data->use_layer_norm; // Inferring batch size, number of outputs and number of cells from the // input tensors. @@ -839,9 +838,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions( - context, node, n_input, n_output, n_cell, - use_layer_norm_lstm, is_integer)); + TF_LITE_ENSURE_OK( + context, CheckInputTensorDimensions(context, node, n_input, n_output, + n_cell, use_layer_norm, is_integer)); // Get the pointer to output, activation_state and cell_state tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); diff --git a/tensorflow/lite/kernels/lstm_eval.cc b/tensorflow/lite/kernels/lstm_eval.cc index 869fd9abf49..b718912c23d 100644 --- a/tensorflow/lite/kernels/lstm_eval.cc +++ b/tensorflow/lite/kernels/lstm_eval.cc @@ -151,12 +151,11 @@ inline void LstmStepFloat( // check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - const bool use_layer_norm_lstm = - (forget_layer_norm_coefficients_ptr != nullptr); + const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr); // Initialize scratch buffers with bias for regular lstm or initialize with // zero for layer norm lstm. - if (use_layer_norm_lstm) { + if (use_layer_norm) { if (!use_cifg) { std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); } @@ -243,7 +242,7 @@ inline void LstmStepFloat( cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, input_gate_scratch); } - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::MeanStddevNormalization( input_gate_scratch, input_gate_scratch, n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( @@ -262,7 +261,7 @@ inline void LstmStepFloat( cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, forget_gate_scratch); } - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch, n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( @@ -277,7 +276,7 @@ inline void LstmStepFloat( // For each batch and cell: update the cell. tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, n_batch * n_cell, cell_state_ptr); - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( @@ -308,7 +307,7 @@ inline void LstmStepFloat( cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, output_gate_scratch); } - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch, n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( @@ -485,11 +484,11 @@ inline void LstmStepHybrid( // can check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - const bool use_layer_norm_lstm = - (forget_layer_norm_coefficients_ptr != nullptr); + const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr); - // Initialize scratch buffers with bias. - if (use_layer_norm_lstm) { + // Initialize scratch buffers with bias for regular lstm or initialize with + // zero for layer norm lstm. + if (use_layer_norm) { if (!use_cifg) { std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); } @@ -671,7 +670,7 @@ inline void LstmStepHybrid( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, input_gate_scratch); } - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::MeanStddevNormalization( input_gate_scratch, input_gate_scratch, n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( @@ -693,7 +692,7 @@ inline void LstmStepHybrid( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, forget_gate_scratch); } - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch, n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( @@ -708,7 +707,7 @@ inline void LstmStepHybrid( // For each batch and cell: update the cell. tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, n_batch * n_cell, cell_state_ptr); - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( @@ -742,7 +741,7 @@ inline void LstmStepHybrid( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, output_gate_scratch); } - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch, n_cell, n_batch); tensor_utils::VectorBatchVectorCwiseProduct( @@ -976,7 +975,7 @@ inline void LstmStepInteger( // Get hyper parameters. const bool use_cifg = (input_to_input_weight_ptr == nullptr); const bool use_peephole = (cell_to_output_weight_ptr != nullptr); - const bool use_layer_norm_lstm = (layer_norm_forget_weight_ptr != nullptr); + const bool use_layer_norm = (layer_norm_forget_weight_ptr != nullptr); const bool use_projection = (proj_weight_ptr != nullptr); // Check for nullptrs. @@ -1018,7 +1017,7 @@ inline void LstmStepInteger( scratch_1_ptr); } - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::ApplyLayerNorm( scratch_1_ptr, layer_norm_forget_weight_ptr, forget_bias_ptr, layer_norm_forget_scale_a, layer_norm_forget_scale_b, @@ -1039,7 +1038,7 @@ inline void LstmStepInteger( effective_recurrent_to_cell_scale_b, n_batch, n_output, n_cell, 0, scratch_5_ptr, scratch_2_ptr, context); - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr, cell_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b, cell_variance_guard, @@ -1069,7 +1068,7 @@ inline void LstmStepInteger( scratch_0_ptr); } - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::ApplyLayerNorm( scratch_0_ptr, layer_norm_input_weight_ptr, input_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b, @@ -1110,7 +1109,7 @@ inline void LstmStepInteger( scratch_3_ptr); } - if (use_layer_norm_lstm) { + if (use_layer_norm) { tensor_utils::ApplyLayerNorm( scratch_3_ptr, layer_norm_output_weight_ptr, output_bias_ptr, layer_norm_output_scale_a, layer_norm_output_scale_b, diff --git a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc index 379a58e1d90..083a85c14f2 100644 --- a/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc +++ b/tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc @@ -71,12 +71,11 @@ inline void LstmStepWithAuxInput( // check the existence of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - const bool is_layer_norm_lstm = - (forget_layer_norm_coefficients_ptr != nullptr); + const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr); // Initialize scratch buffers with bias for regular lstm or initialize with // zero for layer norm lstm. - if (is_layer_norm_lstm) { + if (use_layer_norm) { if (!use_cifg) { std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); } @@ -158,7 +157,7 @@ inline void LstmStepWithAuxInput( cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, input_gate_scratch); } - if (is_layer_norm_lstm) { + if (use_layer_norm) { logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch, n_cell * n_batch, error_reporter); tensor_utils::MeanStddevNormalization( @@ -179,7 +178,7 @@ inline void LstmStepWithAuxInput( cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, forget_gate_scratch); } - if (is_layer_norm_lstm) { + if (use_layer_norm) { logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch, n_cell * n_batch, error_reporter); tensor_utils::MeanStddevNormalization(forget_gate_scratch, @@ -196,7 +195,7 @@ inline void LstmStepWithAuxInput( // For each batch and cell: update the cell. tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, n_batch * n_cell, cell_state_ptr); - if (is_layer_norm_lstm) { + if (use_layer_norm) { logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch, n_cell * n_batch, error_reporter); tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, @@ -229,7 +228,7 @@ inline void LstmStepWithAuxInput( cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, output_gate_scratch); } - if (is_layer_norm_lstm) { + if (use_layer_norm) { logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch, n_cell * n_batch, error_reporter); tensor_utils::MeanStddevNormalization(output_gate_scratch, @@ -483,7 +482,7 @@ struct OpData { TfLiteLSTMKernelType kernel_type; // If the lstm is layer norm. - bool is_layer_norm_lstm; + bool use_layer_norm; // These fields are only used by full kernel. int scratch_tensor_index;