Nit: rename use_layer_norm_lstm to use_layer_norm, to be consistent with other similar variables.

PiperOrigin-RevId: 292043066
Change-Id: Ic77c7f712432cb717eb477e98d8f4d16a6dac0be
This commit is contained in:
Robert David 2020-01-28 16:54:53 -08:00 committed by TensorFlower Gardener
parent d049992117
commit c3f78d1c1c
3 changed files with 40 additions and 43 deletions

View File

@ -50,7 +50,7 @@ struct OpData {
TfLiteLSTMKernelType kernel_type;
// If the lstm is layer norm.
bool use_layer_norm_lstm;
bool use_layer_norm;
// These fields are only used by full kernel.
int scratch_tensor_index;
@ -92,7 +92,7 @@ TfLiteStatus PopulateQuantizedLstmParams(
// Calculate effective scales.
OpData* op_data = static_cast<OpData*>(node->user_data);
const bool use_layer_norm_lstm = op_data->use_layer_norm_lstm;
const bool use_layer_norm = op_data->use_layer_norm;
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@ -147,7 +147,7 @@ TfLiteStatus PopulateQuantizedLstmParams(
std::vector<float> intermediate_scale;
std::vector<int32> intermediate_zp;
for (int i = 0; i < 4; ++i) {
if (use_layer_norm_lstm) {
if (use_layer_norm) {
const TfLiteTensor* intermediate = GetIntermediates(context, node, i);
auto* params = static_cast<TfLiteAffineQuantization*>(
intermediate->quantization.params);
@ -218,7 +218,7 @@ TfLiteStatus PopulateQuantizedLstmParams(
cell_to_output_weight_scale = cell_to_output_weights->params.scale;
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
if (!use_cifg) {
layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
}
@ -381,8 +381,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell,
bool use_layer_norm_lstm,
bool is_integer) {
bool use_layer_norm, bool is_integer) {
const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
@ -574,7 +573,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
((projection_weights != nullptr) || (projection_bias == nullptr));
TF_LITE_ENSURE(context, projection_tensors_consistent == true);
if (use_layer_norm_lstm) {
if (use_layer_norm) {
const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
context, node, kInputLayerNormCoefficientsTensor);
if (use_cifg) {
@ -714,7 +713,7 @@ TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
// When there is layer normalization, the gate bias does not apply to matmul
// directly:
// y = ln(w * x + w * r + w * c) + b.
const bool is_layer_norm = op_data->use_layer_norm_lstm;
const bool is_layer_norm = op_data->use_layer_norm;
// Forget gate.
const TfLiteTensor* forget_gate_bias =
@ -801,13 +800,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
context, node, kForgetLayerNormCoefficientsTensor);
if (forget_layer_norm_coefficients == nullptr) {
op_data->use_layer_norm_lstm = false;
op_data->use_layer_norm = false;
} else {
op_data->use_layer_norm_lstm = true;
op_data->use_layer_norm = true;
}
} else if (node->inputs->size == 20) {
// This is deprecated and is only kept here for backward compatibility.
op_data->use_layer_norm_lstm = false;
op_data->use_layer_norm = false;
} else {
context->ReportError(
context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
@ -815,7 +814,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
const bool use_layer_norm_lstm = op_data->use_layer_norm_lstm;
const bool use_layer_norm = op_data->use_layer_norm;
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
@ -839,9 +838,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int n_output = recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(
context, node, n_input, n_output, n_cell,
use_layer_norm_lstm, is_integer));
TF_LITE_ENSURE_OK(
context, CheckInputTensorDimensions(context, node, n_input, n_output,
n_cell, use_layer_norm, is_integer));
// Get the pointer to output, activation_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);

View File

@ -151,12 +151,11 @@ inline void LstmStepFloat(
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool use_layer_norm_lstm =
(forget_layer_norm_coefficients_ptr != nullptr);
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (use_layer_norm_lstm) {
if (use_layer_norm) {
if (!use_cifg) {
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
@ -243,7 +242,7 @@ inline void LstmStepFloat(
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -262,7 +261,7 @@ inline void LstmStepFloat(
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -277,7 +276,7 @@ inline void LstmStepFloat(
// For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -308,7 +307,7 @@ inline void LstmStepFloat(
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -485,11 +484,11 @@ inline void LstmStepHybrid(
// can check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool use_layer_norm_lstm =
(forget_layer_norm_coefficients_ptr != nullptr);
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
// Initialize scratch buffers with bias.
if (use_layer_norm_lstm) {
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (use_layer_norm) {
if (!use_cifg) {
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
@ -671,7 +670,7 @@ inline void LstmStepHybrid(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -693,7 +692,7 @@ inline void LstmStepHybrid(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -708,7 +707,7 @@ inline void LstmStepHybrid(
// For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -742,7 +741,7 @@ inline void LstmStepHybrid(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
@ -976,7 +975,7 @@ inline void LstmStepInteger(
// Get hyper parameters.
const bool use_cifg = (input_to_input_weight_ptr == nullptr);
const bool use_peephole = (cell_to_output_weight_ptr != nullptr);
const bool use_layer_norm_lstm = (layer_norm_forget_weight_ptr != nullptr);
const bool use_layer_norm = (layer_norm_forget_weight_ptr != nullptr);
const bool use_projection = (proj_weight_ptr != nullptr);
// Check for nullptrs.
@ -1018,7 +1017,7 @@ inline void LstmStepInteger(
scratch_1_ptr);
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(
scratch_1_ptr, layer_norm_forget_weight_ptr, forget_bias_ptr,
layer_norm_forget_scale_a, layer_norm_forget_scale_b,
@ -1039,7 +1038,7 @@ inline void LstmStepInteger(
effective_recurrent_to_cell_scale_b, n_batch, n_output, n_cell, 0,
scratch_5_ptr, scratch_2_ptr, context);
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr,
cell_bias_ptr, layer_norm_cell_scale_a,
layer_norm_cell_scale_b, cell_variance_guard,
@ -1069,7 +1068,7 @@ inline void LstmStepInteger(
scratch_0_ptr);
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(
scratch_0_ptr, layer_norm_input_weight_ptr, input_bias_ptr,
layer_norm_input_scale_a, layer_norm_input_scale_b,
@ -1110,7 +1109,7 @@ inline void LstmStepInteger(
scratch_3_ptr);
}
if (use_layer_norm_lstm) {
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(
scratch_3_ptr, layer_norm_output_weight_ptr, output_bias_ptr,
layer_norm_output_scale_a, layer_norm_output_scale_b,

View File

@ -71,12 +71,11 @@ inline void LstmStepWithAuxInput(
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool is_layer_norm_lstm =
(forget_layer_norm_coefficients_ptr != nullptr);
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (is_layer_norm_lstm) {
if (use_layer_norm) {
if (!use_cifg) {
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
@ -158,7 +157,7 @@ inline void LstmStepWithAuxInput(
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (is_layer_norm_lstm) {
if (use_layer_norm) {
logger->LogTensorValue(intemediate_tensor_indexes[0], input_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(
@ -179,7 +178,7 @@ inline void LstmStepWithAuxInput(
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (is_layer_norm_lstm) {
if (use_layer_norm) {
logger->LogTensorValue(intemediate_tensor_indexes[1], forget_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
@ -196,7 +195,7 @@ inline void LstmStepWithAuxInput(
// For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (is_layer_norm_lstm) {
if (use_layer_norm) {
logger->LogTensorValue(intemediate_tensor_indexes[2], cell_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
@ -229,7 +228,7 @@ inline void LstmStepWithAuxInput(
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
if (is_layer_norm_lstm) {
if (use_layer_norm) {
logger->LogTensorValue(intemediate_tensor_indexes[3], output_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(output_gate_scratch,
@ -483,7 +482,7 @@ struct OpData {
TfLiteLSTMKernelType kernel_type;
// If the lstm is layer norm.
bool is_layer_norm_lstm;
bool use_layer_norm;
// These fields are only used by full kernel.
int scratch_tensor_index;