Integer LSTMs: Name scratch arrays Based on what gate they are representing. Make naming consistent with float/hybrid versions.

PiperOrigin-RevId: 317420201
Change-Id: Ia9447e51fce1530e75103c4db3759908592af983
This commit is contained in:
Robert David 2020-06-19 19:22:08 -07:00 committed by TensorFlower Gardener
parent b1fdd334e7
commit 429c0b423e
2 changed files with 130 additions and 93 deletions

View File

@ -216,9 +216,8 @@ inline void LstmStepFloat(
const float* projection_weights_ptr, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_gate_scratch,
float* output_gate_scratch, float* output_ptr) {
float* output_state_ptr, float* cell_state_ptr, float* scratch0,
float* scratch1, float* scratch2, float* scratch3, float* output_ptr) {
ruy::profiler::ScopeLabel label("LstmStepFloat");
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
@ -226,6 +225,12 @@ inline void LstmStepFloat(
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
// Make named scratch buffers for the different gates.
float* input_gate_scratch = scratch0;
float* forget_gate_scratch = scratch1;
float* cell_gate_scratch = scratch2;
float* output_gate_scratch = scratch3;
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (use_layer_norm) {
@ -531,9 +536,8 @@ inline void LstmStepHybrid(
const int8_t* projection_weights_ptr, float projection_weights_scale,
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
int output_batch_leading_dim, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_gate_scratch,
float* output_gate_scratch, float* scaling_factors,
int output_batch_leading_dim, float* scratch0, float* scratch1,
float* scratch2, float* scratch3, float* scaling_factors,
float* scaling_factors_scratch, float* recovered_cell_weights,
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
@ -548,6 +552,12 @@ inline void LstmStepHybrid(
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
// Make named scratch buffers for the different gates.
float* input_gate_scratch = scratch0;
float* forget_gate_scratch = scratch1;
float* cell_gate_scratch = scratch2;
float* output_gate_scratch = scratch3;
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (use_layer_norm) {
@ -974,12 +984,12 @@ inline void LstmStepHybrid(
//
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
// n_batch.
// scratch_0
// scratch_1
// scratch_2
// scratch_3
// scratch_4
// scratch_5: this scratch buffer is created purely for optimizing the
// scratch0
// scratch1
// scratch2
// scratch3
// scratch4
// scratch5: this scratch buffer is created purely for optimizing the
// MatrixBatchVectorMultiplyAccumulate.
//
// Outputs:
@ -1047,10 +1057,15 @@ inline void LstmStepInteger(
const int32_t* projection_effective_bias, int n_batch, int n_cell,
int n_input, int n_output, int8_t* output_state_ptr,
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
int16_t* scratch_0_ptr, int16_t* scratch_1_ptr, int16_t* scratch_2_ptr,
int16_t* scratch_3_ptr, int8_t* scratch_4_ptr, int32_t* scratch_5_ptr,
CpuBackendContext* context) {
int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("LstmStepInteger");
// Make named scratch buffers for the different gates.
int16_t* input_gate_scratch = scratch0;
int16_t* forget_gate_scratch = scratch1;
int16_t* cell_gate_scratch = scratch2;
int16_t* output_gate_scratch = scratch3;
// Get hyper parameters.
const bool use_cifg = (input_to_input_weight_ptr == nullptr);
const bool use_peephole = (cell_to_output_weight_ptr != nullptr);
@ -1072,99 +1087,103 @@ inline void LstmStepInteger(
// Set scratch to 0.
if (!use_cifg) {
std::fill_n(scratch_0_ptr, n_batch * n_cell, 0);
std::fill_n(input_gate_scratch, n_batch * n_cell, 0);
}
std::fill_n(scratch_1_ptr, n_batch * n_cell, 0);
std::fill_n(scratch_2_ptr, n_batch * n_cell, 0);
std::fill_n(scratch_3_ptr, n_batch * n_cell, 0);
std::fill_n(forget_gate_scratch, n_batch * n_cell, 0);
std::fill_n(cell_gate_scratch, n_batch * n_cell, 0);
std::fill_n(output_gate_scratch, n_batch * n_cell, 0);
// Forget gate.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_ptr, input_to_forget_effective_bias, input_to_forget_weight_ptr,
effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_1_ptr, context);
n_batch, n_input, n_cell, 0, scratch5, forget_gate_scratch, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
output_state_ptr, recurrent_to_forget_effective_bias,
recurrent_to_forget_weight_ptr, effective_recurrent_to_forget_scale_a,
effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell, 0,
scratch_5_ptr, scratch_1_ptr, context);
scratch5, forget_gate_scratch, context);
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_forget_weight_ptr, n_output, cell_ptr, n_batch,
effective_cell_to_forget_scale_a, effective_cell_to_forget_scale_b,
scratch_1_ptr);
forget_gate_scratch);
}
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(
scratch_1_ptr, layer_norm_forget_weight_ptr, forget_gate_bias_ptr,
forget_gate_scratch, layer_norm_forget_weight_ptr, forget_gate_bias_ptr,
layer_norm_forget_scale_a, layer_norm_forget_scale_b,
forget_variance_guard, n_batch, n_cell, scratch_1_ptr);
forget_variance_guard, n_batch, n_cell, forget_gate_scratch);
}
tensor_utils::ApplySigmoid(scratch_1_ptr, n_batch, n_cell, scratch_1_ptr);
tensor_utils::ApplySigmoid(forget_gate_scratch, n_batch, n_cell,
forget_gate_scratch);
// Modulation gate.
// Cell gate.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_ptr, input_to_cell_effective_bias, input_to_cell_weight_ptr,
effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, n_batch,
n_input, n_cell, 0, scratch_5_ptr, scratch_2_ptr, context);
n_input, n_cell, 0, scratch5, cell_gate_scratch, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
output_state_ptr, recurrent_to_cell_effective_bias,
recurrent_to_cell_weight_ptr, effective_recurrent_to_cell_scale_a,
effective_recurrent_to_cell_scale_b, n_batch, n_output, n_cell, 0,
scratch_5_ptr, scratch_2_ptr, context);
scratch5, cell_gate_scratch, context);
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr,
tensor_utils::ApplyLayerNorm(cell_gate_scratch, layer_norm_cell_weight_ptr,
cell_gate_bias_ptr, layer_norm_cell_scale_a,
layer_norm_cell_scale_b, cell_variance_guard,
n_batch, n_cell, scratch_2_ptr);
n_batch, n_cell, cell_gate_scratch);
}
tensor_utils::ApplyTanh(3, scratch_2_ptr, n_batch, n_cell, scratch_2_ptr);
tensor_utils::ApplyTanh(3, cell_gate_scratch, n_batch, n_cell,
cell_gate_scratch);
// Input gate.
if (use_cifg) {
tensor_utils::Sub1Vector(scratch_1_ptr, n_batch * n_cell, scratch_0_ptr);
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
input_gate_scratch);
} else {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_ptr, input_to_input_effective_bias, input_to_input_weight_ptr,
effective_input_to_input_scale_a, effective_input_to_input_scale_b,
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_0_ptr, context);
n_batch, n_input, n_cell, 0, scratch5, input_gate_scratch, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
output_state_ptr, recurrent_to_input_effective_bias,
recurrent_to_input_weight_ptr, effective_recurrent_to_input_scale_a,
effective_recurrent_to_input_scale_b, n_batch, n_output, n_cell, 0,
scratch_5_ptr, scratch_0_ptr, context);
scratch5, input_gate_scratch, context);
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_input_weight_ptr, n_output, cell_ptr, n_batch,
effective_cell_to_input_scale_a, effective_cell_to_input_scale_b,
scratch_0_ptr);
input_gate_scratch);
}
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(
scratch_0_ptr, layer_norm_input_weight_ptr, input_gate_bias_ptr,
input_gate_scratch, layer_norm_input_weight_ptr, input_gate_bias_ptr,
layer_norm_input_scale_a, layer_norm_input_scale_b,
input_variance_guard, n_batch, n_cell, scratch_0_ptr);
input_variance_guard, n_batch, n_cell, input_gate_scratch);
}
tensor_utils::ApplySigmoid(scratch_0_ptr, n_batch, n_cell, scratch_0_ptr);
tensor_utils::ApplySigmoid(input_gate_scratch, n_batch, n_cell,
input_gate_scratch);
}
// New cell.
tensor_utils::CwiseMul(scratch_1_ptr, cell_ptr, n_batch, n_cell, 15,
scratch_1_ptr);
tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell, 15,
forget_gate_scratch);
tensor_utils::CwiseMul(scratch_0_ptr, scratch_2_ptr, n_batch, n_cell,
30 + cell_scale, scratch_2_ptr);
tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell,
30 + cell_scale, cell_gate_scratch);
tensor_utils::CwiseAdd(scratch_1_ptr, scratch_2_ptr, n_batch, n_cell,
cell_ptr);
tensor_utils::CwiseAdd(forget_gate_scratch, cell_gate_scratch, n_batch,
n_cell, cell_ptr);
if (quantized_cell_clip > 0) {
tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell);
@ -1174,49 +1193,50 @@ inline void LstmStepInteger(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_ptr, input_to_output_effective_bias, input_to_output_weight_ptr,
effective_input_to_output_scale_a, effective_input_to_output_scale_b,
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_3_ptr, context);
n_batch, n_input, n_cell, 0, scratch5, output_gate_scratch, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
output_state_ptr, recurrent_to_output_effective_bias,
recurrent_to_output_weight_ptr, effective_recurrent_to_output_scale_a,
effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell, 0,
scratch_5_ptr, scratch_3_ptr, context);
scratch5, output_gate_scratch, context);
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_output_weight_ptr, n_output, cell_ptr, n_batch,
effective_cell_to_output_scale_a, effective_cell_to_output_scale_b,
scratch_3_ptr);
output_gate_scratch);
}
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(
scratch_3_ptr, layer_norm_output_weight_ptr, output_gate_bias_ptr,
output_gate_scratch, layer_norm_output_weight_ptr, output_gate_bias_ptr,
layer_norm_output_scale_a, layer_norm_output_scale_b,
output_variance_guard, n_batch, n_cell, scratch_3_ptr);
output_variance_guard, n_batch, n_cell, output_gate_scratch);
}
tensor_utils::ApplySigmoid(scratch_3_ptr, n_batch, n_cell, scratch_3_ptr);
tensor_utils::ApplySigmoid(output_gate_scratch, n_batch, n_cell,
output_gate_scratch);
// Hidden.
tensor_utils::ApplyTanh(15 + cell_scale, cell_ptr, n_batch, n_cell,
scratch_0_ptr);
input_gate_scratch);
tensor_utils::CwiseMul(scratch_3_ptr, scratch_0_ptr, effective_hidden_scale_a,
effective_hidden_scale_b, n_batch, n_cell, hidden_zp,
scratch_4_ptr);
tensor_utils::CwiseMul(output_gate_scratch, input_gate_scratch,
effective_hidden_scale_a, effective_hidden_scale_b,
n_batch, n_cell, hidden_zp, scratch4);
// Projection.
if (use_projection) {
std::fill_n(output_ptr, n_batch * n_output, 0);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
scratch_4_ptr, projection_effective_bias, projection_weight_ptr,
scratch4, projection_effective_bias, projection_weight_ptr,
effective_proj_scale_a, effective_proj_scale_b, n_batch, n_cell,
n_output, output_state_zp, scratch_5_ptr, output_ptr, context);
n_output, output_state_zp, scratch5, output_ptr, context);
if (quantized_proj_clip > 0) {
tensor_utils::CwiseClipping(output_ptr, quantized_proj_clip, n_batch,
n_output);
}
} else {
std::copy_n(scratch_4_ptr, n_batch * n_output, output_ptr);
std::copy_n(scratch4, n_batch * n_output, output_ptr);
}
std::copy_n(output_ptr, n_batch * n_output, output_state_ptr);
}
@ -1300,14 +1320,14 @@ inline void LstmStepInteger(
//
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
// n_batch.
// scratch_0
// scratch_1
// scratch_2
// scratch_3
// scratch_4
// scratch_5
// scratch_6
// scratch_7
// scratch0
// scratch1
// scratch2
// scratch3
// scratch4
// scratch5
// scratch6
// scratch7
//
// Outputs:
// output_state_ptr - size 'n_batch * n_output'
@ -1369,6 +1389,12 @@ void LstmStepInteger(
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
int16_t* scratch7) {
// Make named scratch buffers for the different gates.
int16_t* input_gate_scratch = scratch5;
int16_t* forget_gate_scratch = scratch2;
int16_t* cell_gate_scratch = scratch3;
int16_t* output_gate_scratch = scratch4;
// Forget gate.
std::fill_n(scratch0, n_batch * n_cell, 0);
std::fill_n(scratch1, n_batch * n_cell, 0);
@ -1386,16 +1412,17 @@ void LstmStepInteger(
tensor_utils::TwoGateSaturationgAdd(
scratch0, intermediate_zp[4], scratch1, intermediate_zp[5],
intermediate_scale_a[2], intermediate_scale_b[2], intermediate_scale_a[3],
intermediate_scale_b[3], n_batch, n_cell, scratch2);
intermediate_scale_b[3], n_batch, n_cell, forget_gate_scratch);
// Forget gate layer norm.
tensor_utils::ApplyLayerNormFloat(
scratch2, layer_norm_forget_weight_ptr, layer_norm_forget_scale_a,
layer_norm_forget_scale_b, forget_gate_bias_ptr, n_batch, n_cell,
scratch2);
forget_gate_scratch, layer_norm_forget_weight_ptr,
layer_norm_forget_scale_a, layer_norm_forget_scale_b,
forget_gate_bias_ptr, n_batch, n_cell, forget_gate_scratch);
// Forget gate sigmoid.
tensor_utils::ApplySigmoidFloat(scratch2, n_batch, n_cell, scratch2);
tensor_utils::ApplySigmoidFloat(forget_gate_scratch, n_batch, n_cell,
forget_gate_scratch);
// Update gate.
std::fill_n(scratch0, n_batch * n_cell, 0);
@ -1413,15 +1440,17 @@ void LstmStepInteger(
tensor_utils::TwoGateSaturationgAdd(
scratch0, intermediate_zp[7], scratch1, intermediate_zp[8],
intermediate_scale_a[4], intermediate_scale_b[4], intermediate_scale_a[5],
intermediate_scale_b[5], n_batch, n_cell, scratch3);
intermediate_scale_b[5], n_batch, n_cell, cell_gate_scratch);
// Update gate with layer norm.
// Update gate layer norm.
tensor_utils::ApplyLayerNormFloat(
scratch3, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell, scratch3);
cell_gate_scratch, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell,
cell_gate_scratch);
// Update gate tanh.
tensor_utils::ApplyTanhFloat(scratch3, n_batch, n_cell, -12, scratch3);
tensor_utils::ApplyTanhFloat(cell_gate_scratch, n_batch, n_cell, -12,
cell_gate_scratch);
// Output gate.
std::fill_n(scratch0, n_batch * n_cell, 0);
@ -1440,26 +1469,28 @@ void LstmStepInteger(
tensor_utils::TwoGateSaturationgAdd(
scratch0, intermediate_zp[10], scratch1, intermediate_zp[11],
intermediate_scale_a[6], intermediate_scale_b[6], intermediate_scale_a[7],
intermediate_scale_b[7], n_batch, n_cell, scratch4);
intermediate_scale_b[7], n_batch, n_cell, output_gate_scratch);
// Output gate with layer norm.
tensor_utils::ApplyLayerNormFloat(
scratch4, layer_norm_output_weight_ptr, layer_norm_output_scale_a,
layer_norm_output_scale_b, output_gate_bias_ptr, n_batch, n_cell,
scratch4);
output_gate_scratch, layer_norm_output_weight_ptr,
layer_norm_output_scale_a, layer_norm_output_scale_b,
output_gate_bias_ptr, n_batch, n_cell, output_gate_scratch);
// Output gate sigmoid.
tensor_utils::ApplySigmoidFloat(scratch4, n_batch, n_cell, scratch4);
tensor_utils::ApplySigmoidFloat(output_gate_scratch, n_batch, n_cell,
output_gate_scratch);
// Input gate with cifg
tensor_utils::Sub1Vector(scratch2, n_batch * n_cell, scratch5);
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
input_gate_scratch);
// New cell.
tensor_utils::CwiseMul(scratch2, cell_ptr, n_batch, n_cell, 15 + 15 - 15,
scratch6);
tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell,
15 + 15 - 15, scratch6);
tensor_utils::CwiseMul(scratch5, scratch3, n_batch, n_cell, 15 + 15 - 15,
scratch7);
tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell,
15 + 15 - 15, scratch7);
tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_ptr);
@ -1468,15 +1499,16 @@ void LstmStepInteger(
}
// Cell to hidden.
tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15, scratch2);
tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15,
forget_gate_scratch);
std::vector<int16_t> hidden(n_batch * n_cell);
tensor_utils::CwiseMul(scratch4, scratch2, n_batch, n_cell, 15 + 15 - 15,
scratch3);
tensor_utils::CwiseMul(output_gate_scratch, forget_gate_scratch, n_batch,
n_cell, 15 + 15 - 15, cell_gate_scratch);
// Projection.
tensor_utils::MatrixBatchVectorMultiply(
scratch3, projection_weight_ptr, effective_proj_scale_a,
cell_gate_scratch, projection_weight_ptr, effective_proj_scale_a,
effective_proj_scale_b, projection_bias_ptr, n_batch, n_cell, n_output,
output_state_zp, output_ptr);

View File

@ -62,11 +62,16 @@ inline void LstmStepWithAuxInput(
const float* projection_weights_ptr, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_gate_scratch,
float* output_gate_scratch, float* output_ptr, Logger* logger,
const std::vector<int>& intermediate_tensor_indexes,
float* output_state_ptr, float* cell_state_ptr, float* scratch0,
float* scratch1, float* scratch2, float* scratch3, float* output_ptr,
Logger* logger, const std::vector<int>& intermediate_tensor_indexes,
ErrorReporter* error_reporter) {
// Make named scratch buffers for the different gates.
float* input_gate_scratch = scratch0;
float* forget_gate_scratch = scratch1;
float* cell_gate_scratch = scratch2;
float* output_gate_scratch = scratch3;
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);