Integer LSTMs: Name scratch arrays Based on what gate they are representing. Make naming consistent with float/hybrid versions.
PiperOrigin-RevId: 317420201 Change-Id: Ia9447e51fce1530e75103c4db3759908592af983
This commit is contained in:
parent
b1fdd334e7
commit
429c0b423e
@ -216,9 +216,8 @@ inline void LstmStepFloat(
|
||||
const float* projection_weights_ptr, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_aux_input, int n_output, int output_batch_leading_dim,
|
||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_gate_scratch,
|
||||
float* output_gate_scratch, float* output_ptr) {
|
||||
float* output_state_ptr, float* cell_state_ptr, float* scratch0,
|
||||
float* scratch1, float* scratch2, float* scratch3, float* output_ptr) {
|
||||
ruy::profiler::ScopeLabel label("LstmStepFloat");
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to the get the condition.
|
||||
@ -226,6 +225,12 @@ inline void LstmStepFloat(
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
|
||||
|
||||
// Make named scratch buffers for the different gates.
|
||||
float* input_gate_scratch = scratch0;
|
||||
float* forget_gate_scratch = scratch1;
|
||||
float* cell_gate_scratch = scratch2;
|
||||
float* output_gate_scratch = scratch3;
|
||||
|
||||
// Initialize scratch buffers with bias for regular lstm or initialize with
|
||||
// zero for layer norm lstm.
|
||||
if (use_layer_norm) {
|
||||
@ -531,9 +536,8 @@ inline void LstmStepHybrid(
|
||||
const int8_t* projection_weights_ptr, float projection_weights_scale,
|
||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||
int output_batch_leading_dim, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_gate_scratch,
|
||||
float* output_gate_scratch, float* scaling_factors,
|
||||
int output_batch_leading_dim, float* scratch0, float* scratch1,
|
||||
float* scratch2, float* scratch3, float* scaling_factors,
|
||||
float* scaling_factors_scratch, float* recovered_cell_weights,
|
||||
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
|
||||
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
|
||||
@ -548,6 +552,12 @@ inline void LstmStepHybrid(
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
|
||||
|
||||
// Make named scratch buffers for the different gates.
|
||||
float* input_gate_scratch = scratch0;
|
||||
float* forget_gate_scratch = scratch1;
|
||||
float* cell_gate_scratch = scratch2;
|
||||
float* output_gate_scratch = scratch3;
|
||||
|
||||
// Initialize scratch buffers with bias for regular lstm or initialize with
|
||||
// zero for layer norm lstm.
|
||||
if (use_layer_norm) {
|
||||
@ -974,12 +984,12 @@ inline void LstmStepHybrid(
|
||||
//
|
||||
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
|
||||
// n_batch.
|
||||
// scratch_0
|
||||
// scratch_1
|
||||
// scratch_2
|
||||
// scratch_3
|
||||
// scratch_4
|
||||
// scratch_5: this scratch buffer is created purely for optimizing the
|
||||
// scratch0
|
||||
// scratch1
|
||||
// scratch2
|
||||
// scratch3
|
||||
// scratch4
|
||||
// scratch5: this scratch buffer is created purely for optimizing the
|
||||
// MatrixBatchVectorMultiplyAccumulate.
|
||||
//
|
||||
// Outputs:
|
||||
@ -1047,10 +1057,15 @@ inline void LstmStepInteger(
|
||||
const int32_t* projection_effective_bias, int n_batch, int n_cell,
|
||||
int n_input, int n_output, int8_t* output_state_ptr,
|
||||
int32_t output_state_zp, int16_t* cell_ptr, int8_t* output_ptr,
|
||||
int16_t* scratch_0_ptr, int16_t* scratch_1_ptr, int16_t* scratch_2_ptr,
|
||||
int16_t* scratch_3_ptr, int8_t* scratch_4_ptr, int32_t* scratch_5_ptr,
|
||||
CpuBackendContext* context) {
|
||||
int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
|
||||
int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) {
|
||||
ruy::profiler::ScopeLabel label("LstmStepInteger");
|
||||
// Make named scratch buffers for the different gates.
|
||||
int16_t* input_gate_scratch = scratch0;
|
||||
int16_t* forget_gate_scratch = scratch1;
|
||||
int16_t* cell_gate_scratch = scratch2;
|
||||
int16_t* output_gate_scratch = scratch3;
|
||||
|
||||
// Get hyper parameters.
|
||||
const bool use_cifg = (input_to_input_weight_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weight_ptr != nullptr);
|
||||
@ -1072,99 +1087,103 @@ inline void LstmStepInteger(
|
||||
|
||||
// Set scratch to 0.
|
||||
if (!use_cifg) {
|
||||
std::fill_n(scratch_0_ptr, n_batch * n_cell, 0);
|
||||
std::fill_n(input_gate_scratch, n_batch * n_cell, 0);
|
||||
}
|
||||
std::fill_n(scratch_1_ptr, n_batch * n_cell, 0);
|
||||
std::fill_n(scratch_2_ptr, n_batch * n_cell, 0);
|
||||
std::fill_n(scratch_3_ptr, n_batch * n_cell, 0);
|
||||
std::fill_n(forget_gate_scratch, n_batch * n_cell, 0);
|
||||
std::fill_n(cell_gate_scratch, n_batch * n_cell, 0);
|
||||
std::fill_n(output_gate_scratch, n_batch * n_cell, 0);
|
||||
|
||||
// Forget gate.
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_ptr, input_to_forget_effective_bias, input_to_forget_weight_ptr,
|
||||
effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
|
||||
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_1_ptr, context);
|
||||
n_batch, n_input, n_cell, 0, scratch5, forget_gate_scratch, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
output_state_ptr, recurrent_to_forget_effective_bias,
|
||||
recurrent_to_forget_weight_ptr, effective_recurrent_to_forget_scale_a,
|
||||
effective_recurrent_to_forget_scale_b, n_batch, n_output, n_cell, 0,
|
||||
scratch_5_ptr, scratch_1_ptr, context);
|
||||
scratch5, forget_gate_scratch, context);
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_forget_weight_ptr, n_output, cell_ptr, n_batch,
|
||||
effective_cell_to_forget_scale_a, effective_cell_to_forget_scale_b,
|
||||
scratch_1_ptr);
|
||||
forget_gate_scratch);
|
||||
}
|
||||
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::ApplyLayerNorm(
|
||||
scratch_1_ptr, layer_norm_forget_weight_ptr, forget_gate_bias_ptr,
|
||||
forget_gate_scratch, layer_norm_forget_weight_ptr, forget_gate_bias_ptr,
|
||||
layer_norm_forget_scale_a, layer_norm_forget_scale_b,
|
||||
forget_variance_guard, n_batch, n_cell, scratch_1_ptr);
|
||||
forget_variance_guard, n_batch, n_cell, forget_gate_scratch);
|
||||
}
|
||||
|
||||
tensor_utils::ApplySigmoid(scratch_1_ptr, n_batch, n_cell, scratch_1_ptr);
|
||||
tensor_utils::ApplySigmoid(forget_gate_scratch, n_batch, n_cell,
|
||||
forget_gate_scratch);
|
||||
|
||||
// Modulation gate.
|
||||
// Cell gate.
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_ptr, input_to_cell_effective_bias, input_to_cell_weight_ptr,
|
||||
effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, n_batch,
|
||||
n_input, n_cell, 0, scratch_5_ptr, scratch_2_ptr, context);
|
||||
n_input, n_cell, 0, scratch5, cell_gate_scratch, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
output_state_ptr, recurrent_to_cell_effective_bias,
|
||||
recurrent_to_cell_weight_ptr, effective_recurrent_to_cell_scale_a,
|
||||
effective_recurrent_to_cell_scale_b, n_batch, n_output, n_cell, 0,
|
||||
scratch_5_ptr, scratch_2_ptr, context);
|
||||
scratch5, cell_gate_scratch, context);
|
||||
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr,
|
||||
tensor_utils::ApplyLayerNorm(cell_gate_scratch, layer_norm_cell_weight_ptr,
|
||||
cell_gate_bias_ptr, layer_norm_cell_scale_a,
|
||||
layer_norm_cell_scale_b, cell_variance_guard,
|
||||
n_batch, n_cell, scratch_2_ptr);
|
||||
n_batch, n_cell, cell_gate_scratch);
|
||||
}
|
||||
|
||||
tensor_utils::ApplyTanh(3, scratch_2_ptr, n_batch, n_cell, scratch_2_ptr);
|
||||
tensor_utils::ApplyTanh(3, cell_gate_scratch, n_batch, n_cell,
|
||||
cell_gate_scratch);
|
||||
|
||||
// Input gate.
|
||||
if (use_cifg) {
|
||||
tensor_utils::Sub1Vector(scratch_1_ptr, n_batch * n_cell, scratch_0_ptr);
|
||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||
input_gate_scratch);
|
||||
} else {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_ptr, input_to_input_effective_bias, input_to_input_weight_ptr,
|
||||
effective_input_to_input_scale_a, effective_input_to_input_scale_b,
|
||||
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_0_ptr, context);
|
||||
n_batch, n_input, n_cell, 0, scratch5, input_gate_scratch, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
output_state_ptr, recurrent_to_input_effective_bias,
|
||||
recurrent_to_input_weight_ptr, effective_recurrent_to_input_scale_a,
|
||||
effective_recurrent_to_input_scale_b, n_batch, n_output, n_cell, 0,
|
||||
scratch_5_ptr, scratch_0_ptr, context);
|
||||
scratch5, input_gate_scratch, context);
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_input_weight_ptr, n_output, cell_ptr, n_batch,
|
||||
effective_cell_to_input_scale_a, effective_cell_to_input_scale_b,
|
||||
scratch_0_ptr);
|
||||
input_gate_scratch);
|
||||
}
|
||||
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::ApplyLayerNorm(
|
||||
scratch_0_ptr, layer_norm_input_weight_ptr, input_gate_bias_ptr,
|
||||
input_gate_scratch, layer_norm_input_weight_ptr, input_gate_bias_ptr,
|
||||
layer_norm_input_scale_a, layer_norm_input_scale_b,
|
||||
input_variance_guard, n_batch, n_cell, scratch_0_ptr);
|
||||
input_variance_guard, n_batch, n_cell, input_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoid(scratch_0_ptr, n_batch, n_cell, scratch_0_ptr);
|
||||
tensor_utils::ApplySigmoid(input_gate_scratch, n_batch, n_cell,
|
||||
input_gate_scratch);
|
||||
}
|
||||
|
||||
// New cell.
|
||||
tensor_utils::CwiseMul(scratch_1_ptr, cell_ptr, n_batch, n_cell, 15,
|
||||
scratch_1_ptr);
|
||||
tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell, 15,
|
||||
forget_gate_scratch);
|
||||
|
||||
tensor_utils::CwiseMul(scratch_0_ptr, scratch_2_ptr, n_batch, n_cell,
|
||||
30 + cell_scale, scratch_2_ptr);
|
||||
tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell,
|
||||
30 + cell_scale, cell_gate_scratch);
|
||||
|
||||
tensor_utils::CwiseAdd(scratch_1_ptr, scratch_2_ptr, n_batch, n_cell,
|
||||
cell_ptr);
|
||||
tensor_utils::CwiseAdd(forget_gate_scratch, cell_gate_scratch, n_batch,
|
||||
n_cell, cell_ptr);
|
||||
|
||||
if (quantized_cell_clip > 0) {
|
||||
tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell);
|
||||
@ -1174,49 +1193,50 @@ inline void LstmStepInteger(
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_ptr, input_to_output_effective_bias, input_to_output_weight_ptr,
|
||||
effective_input_to_output_scale_a, effective_input_to_output_scale_b,
|
||||
n_batch, n_input, n_cell, 0, scratch_5_ptr, scratch_3_ptr, context);
|
||||
n_batch, n_input, n_cell, 0, scratch5, output_gate_scratch, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
output_state_ptr, recurrent_to_output_effective_bias,
|
||||
recurrent_to_output_weight_ptr, effective_recurrent_to_output_scale_a,
|
||||
effective_recurrent_to_output_scale_b, n_batch, n_output, n_cell, 0,
|
||||
scratch_5_ptr, scratch_3_ptr, context);
|
||||
scratch5, output_gate_scratch, context);
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_output_weight_ptr, n_output, cell_ptr, n_batch,
|
||||
effective_cell_to_output_scale_a, effective_cell_to_output_scale_b,
|
||||
scratch_3_ptr);
|
||||
output_gate_scratch);
|
||||
}
|
||||
|
||||
if (use_layer_norm) {
|
||||
tensor_utils::ApplyLayerNorm(
|
||||
scratch_3_ptr, layer_norm_output_weight_ptr, output_gate_bias_ptr,
|
||||
output_gate_scratch, layer_norm_output_weight_ptr, output_gate_bias_ptr,
|
||||
layer_norm_output_scale_a, layer_norm_output_scale_b,
|
||||
output_variance_guard, n_batch, n_cell, scratch_3_ptr);
|
||||
output_variance_guard, n_batch, n_cell, output_gate_scratch);
|
||||
}
|
||||
|
||||
tensor_utils::ApplySigmoid(scratch_3_ptr, n_batch, n_cell, scratch_3_ptr);
|
||||
tensor_utils::ApplySigmoid(output_gate_scratch, n_batch, n_cell,
|
||||
output_gate_scratch);
|
||||
|
||||
// Hidden.
|
||||
tensor_utils::ApplyTanh(15 + cell_scale, cell_ptr, n_batch, n_cell,
|
||||
scratch_0_ptr);
|
||||
input_gate_scratch);
|
||||
|
||||
tensor_utils::CwiseMul(scratch_3_ptr, scratch_0_ptr, effective_hidden_scale_a,
|
||||
effective_hidden_scale_b, n_batch, n_cell, hidden_zp,
|
||||
scratch_4_ptr);
|
||||
tensor_utils::CwiseMul(output_gate_scratch, input_gate_scratch,
|
||||
effective_hidden_scale_a, effective_hidden_scale_b,
|
||||
n_batch, n_cell, hidden_zp, scratch4);
|
||||
// Projection.
|
||||
if (use_projection) {
|
||||
std::fill_n(output_ptr, n_batch * n_output, 0);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
scratch_4_ptr, projection_effective_bias, projection_weight_ptr,
|
||||
scratch4, projection_effective_bias, projection_weight_ptr,
|
||||
effective_proj_scale_a, effective_proj_scale_b, n_batch, n_cell,
|
||||
n_output, output_state_zp, scratch_5_ptr, output_ptr, context);
|
||||
n_output, output_state_zp, scratch5, output_ptr, context);
|
||||
if (quantized_proj_clip > 0) {
|
||||
tensor_utils::CwiseClipping(output_ptr, quantized_proj_clip, n_batch,
|
||||
n_output);
|
||||
}
|
||||
} else {
|
||||
std::copy_n(scratch_4_ptr, n_batch * n_output, output_ptr);
|
||||
std::copy_n(scratch4, n_batch * n_output, output_ptr);
|
||||
}
|
||||
std::copy_n(output_ptr, n_batch * n_output, output_state_ptr);
|
||||
}
|
||||
@ -1300,14 +1320,14 @@ inline void LstmStepInteger(
|
||||
//
|
||||
// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
|
||||
// n_batch.
|
||||
// scratch_0
|
||||
// scratch_1
|
||||
// scratch_2
|
||||
// scratch_3
|
||||
// scratch_4
|
||||
// scratch_5
|
||||
// scratch_6
|
||||
// scratch_7
|
||||
// scratch0
|
||||
// scratch1
|
||||
// scratch2
|
||||
// scratch3
|
||||
// scratch4
|
||||
// scratch5
|
||||
// scratch6
|
||||
// scratch7
|
||||
//
|
||||
// Outputs:
|
||||
// output_state_ptr - size 'n_batch * n_output'
|
||||
@ -1369,6 +1389,12 @@ void LstmStepInteger(
|
||||
int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
|
||||
int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
|
||||
int16_t* scratch7) {
|
||||
// Make named scratch buffers for the different gates.
|
||||
int16_t* input_gate_scratch = scratch5;
|
||||
int16_t* forget_gate_scratch = scratch2;
|
||||
int16_t* cell_gate_scratch = scratch3;
|
||||
int16_t* output_gate_scratch = scratch4;
|
||||
|
||||
// Forget gate.
|
||||
std::fill_n(scratch0, n_batch * n_cell, 0);
|
||||
std::fill_n(scratch1, n_batch * n_cell, 0);
|
||||
@ -1386,16 +1412,17 @@ void LstmStepInteger(
|
||||
tensor_utils::TwoGateSaturationgAdd(
|
||||
scratch0, intermediate_zp[4], scratch1, intermediate_zp[5],
|
||||
intermediate_scale_a[2], intermediate_scale_b[2], intermediate_scale_a[3],
|
||||
intermediate_scale_b[3], n_batch, n_cell, scratch2);
|
||||
intermediate_scale_b[3], n_batch, n_cell, forget_gate_scratch);
|
||||
|
||||
// Forget gate layer norm.
|
||||
tensor_utils::ApplyLayerNormFloat(
|
||||
scratch2, layer_norm_forget_weight_ptr, layer_norm_forget_scale_a,
|
||||
layer_norm_forget_scale_b, forget_gate_bias_ptr, n_batch, n_cell,
|
||||
scratch2);
|
||||
forget_gate_scratch, layer_norm_forget_weight_ptr,
|
||||
layer_norm_forget_scale_a, layer_norm_forget_scale_b,
|
||||
forget_gate_bias_ptr, n_batch, n_cell, forget_gate_scratch);
|
||||
|
||||
// Forget gate sigmoid.
|
||||
tensor_utils::ApplySigmoidFloat(scratch2, n_batch, n_cell, scratch2);
|
||||
tensor_utils::ApplySigmoidFloat(forget_gate_scratch, n_batch, n_cell,
|
||||
forget_gate_scratch);
|
||||
|
||||
// Update gate.
|
||||
std::fill_n(scratch0, n_batch * n_cell, 0);
|
||||
@ -1413,15 +1440,17 @@ void LstmStepInteger(
|
||||
tensor_utils::TwoGateSaturationgAdd(
|
||||
scratch0, intermediate_zp[7], scratch1, intermediate_zp[8],
|
||||
intermediate_scale_a[4], intermediate_scale_b[4], intermediate_scale_a[5],
|
||||
intermediate_scale_b[5], n_batch, n_cell, scratch3);
|
||||
intermediate_scale_b[5], n_batch, n_cell, cell_gate_scratch);
|
||||
|
||||
// Update gate with layer norm.
|
||||
// Update gate layer norm.
|
||||
tensor_utils::ApplyLayerNormFloat(
|
||||
scratch3, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
|
||||
layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell, scratch3);
|
||||
cell_gate_scratch, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
|
||||
layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell,
|
||||
cell_gate_scratch);
|
||||
|
||||
// Update gate tanh.
|
||||
tensor_utils::ApplyTanhFloat(scratch3, n_batch, n_cell, -12, scratch3);
|
||||
tensor_utils::ApplyTanhFloat(cell_gate_scratch, n_batch, n_cell, -12,
|
||||
cell_gate_scratch);
|
||||
|
||||
// Output gate.
|
||||
std::fill_n(scratch0, n_batch * n_cell, 0);
|
||||
@ -1440,26 +1469,28 @@ void LstmStepInteger(
|
||||
tensor_utils::TwoGateSaturationgAdd(
|
||||
scratch0, intermediate_zp[10], scratch1, intermediate_zp[11],
|
||||
intermediate_scale_a[6], intermediate_scale_b[6], intermediate_scale_a[7],
|
||||
intermediate_scale_b[7], n_batch, n_cell, scratch4);
|
||||
intermediate_scale_b[7], n_batch, n_cell, output_gate_scratch);
|
||||
|
||||
// Output gate with layer norm.
|
||||
tensor_utils::ApplyLayerNormFloat(
|
||||
scratch4, layer_norm_output_weight_ptr, layer_norm_output_scale_a,
|
||||
layer_norm_output_scale_b, output_gate_bias_ptr, n_batch, n_cell,
|
||||
scratch4);
|
||||
output_gate_scratch, layer_norm_output_weight_ptr,
|
||||
layer_norm_output_scale_a, layer_norm_output_scale_b,
|
||||
output_gate_bias_ptr, n_batch, n_cell, output_gate_scratch);
|
||||
|
||||
// Output gate sigmoid.
|
||||
tensor_utils::ApplySigmoidFloat(scratch4, n_batch, n_cell, scratch4);
|
||||
tensor_utils::ApplySigmoidFloat(output_gate_scratch, n_batch, n_cell,
|
||||
output_gate_scratch);
|
||||
|
||||
// Input gate with cifg
|
||||
tensor_utils::Sub1Vector(scratch2, n_batch * n_cell, scratch5);
|
||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||
input_gate_scratch);
|
||||
|
||||
// New cell.
|
||||
tensor_utils::CwiseMul(scratch2, cell_ptr, n_batch, n_cell, 15 + 15 - 15,
|
||||
scratch6);
|
||||
tensor_utils::CwiseMul(forget_gate_scratch, cell_ptr, n_batch, n_cell,
|
||||
15 + 15 - 15, scratch6);
|
||||
|
||||
tensor_utils::CwiseMul(scratch5, scratch3, n_batch, n_cell, 15 + 15 - 15,
|
||||
scratch7);
|
||||
tensor_utils::CwiseMul(input_gate_scratch, cell_gate_scratch, n_batch, n_cell,
|
||||
15 + 15 - 15, scratch7);
|
||||
|
||||
tensor_utils::CwiseAdd(scratch6, scratch7, n_batch, n_cell, cell_ptr);
|
||||
|
||||
@ -1468,15 +1499,16 @@ void LstmStepInteger(
|
||||
}
|
||||
|
||||
// Cell to hidden.
|
||||
tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15, scratch2);
|
||||
tensor_utils::ApplyTanhFloat(cell_ptr, n_batch, n_cell, -15,
|
||||
forget_gate_scratch);
|
||||
|
||||
std::vector<int16_t> hidden(n_batch * n_cell);
|
||||
tensor_utils::CwiseMul(scratch4, scratch2, n_batch, n_cell, 15 + 15 - 15,
|
||||
scratch3);
|
||||
tensor_utils::CwiseMul(output_gate_scratch, forget_gate_scratch, n_batch,
|
||||
n_cell, 15 + 15 - 15, cell_gate_scratch);
|
||||
|
||||
// Projection.
|
||||
tensor_utils::MatrixBatchVectorMultiply(
|
||||
scratch3, projection_weight_ptr, effective_proj_scale_a,
|
||||
cell_gate_scratch, projection_weight_ptr, effective_proj_scale_a,
|
||||
effective_proj_scale_b, projection_bias_ptr, n_batch, n_cell, n_output,
|
||||
output_state_zp, output_ptr);
|
||||
|
||||
|
@ -62,11 +62,16 @@ inline void LstmStepWithAuxInput(
|
||||
const float* projection_weights_ptr, const float* projection_bias_ptr,
|
||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||
int n_aux_input, int n_output, int output_batch_leading_dim,
|
||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_gate_scratch,
|
||||
float* output_gate_scratch, float* output_ptr, Logger* logger,
|
||||
const std::vector<int>& intermediate_tensor_indexes,
|
||||
float* output_state_ptr, float* cell_state_ptr, float* scratch0,
|
||||
float* scratch1, float* scratch2, float* scratch3, float* output_ptr,
|
||||
Logger* logger, const std::vector<int>& intermediate_tensor_indexes,
|
||||
ErrorReporter* error_reporter) {
|
||||
// Make named scratch buffers for the different gates.
|
||||
float* input_gate_scratch = scratch0;
|
||||
float* forget_gate_scratch = scratch1;
|
||||
float* cell_gate_scratch = scratch2;
|
||||
float* output_gate_scratch = scratch3;
|
||||
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
|
Loading…
Reference in New Issue
Block a user