All LSTM implementations: Rename cell_scratch to cell_gate_scratch, and cell_bias_ptr to cell_gate_bias_ptr to better reflect what those arrays are.

Do note this is not the same thing as the LSTM cell "state", but a layer/gate that calculates the update. The cell state depends on the input, forget, and cell gates; these arrays are the output and the bias for the last gate.

PiperOrigin-RevId: 316720132
Change-Id: I71c370dabd27f776987e061b9393022c775589c9
This commit is contained in:
Robert David 2020-06-16 11:19:41 -07:00 committed by TensorFlower Gardener
parent 2d50164bdb
commit ed557008d6
2 changed files with 114 additions and 103 deletions

View File

@ -212,13 +212,13 @@ inline void LstmStepFloat(
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
const float* projection_weights_ptr, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr) {
float* forget_gate_scratch, float* cell_gate_scratch,
float* output_gate_scratch, float* output_ptr) {
ruy::profiler::ScopeLabel label("LstmStepFloat");
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
@ -233,7 +233,7 @@ inline void LstmStepFloat(
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
} else {
if (!use_cifg) {
@ -242,8 +242,8 @@ inline void LstmStepFloat(
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
@ -262,7 +262,7 @@ inline void LstmStepFloat(
forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, input_ptr, n_batch,
cell_scratch);
cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
output_gate_scratch);
@ -283,7 +283,7 @@ inline void LstmStepFloat(
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, cell_scratch);
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, output_gate_scratch);
@ -300,7 +300,7 @@ inline void LstmStepFloat(
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, cell_scratch);
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, output_gate_scratch);
@ -347,24 +347,26 @@ inline void LstmStepFloat(
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch);
tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch,
n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
}
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell,
params->activation, cell_gate_scratch);
if (use_cifg) {
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
cell_gate_scratch, forget_gate_scratch, n_batch * n_cell,
cell_state_ptr);
} else {
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
cell_gate_scratch, input_gate_scratch, n_batch * n_cell,
cell_state_ptr);
}
if (params->cell_clip > 0.0) {
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
@ -389,8 +391,8 @@ inline void LstmStepFloat(
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
params->activation, cell_gate_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch,
n_batch * n_cell, output_gate_scratch);
const bool use_projection_weight = (projection_weights_ptr != nullptr);
@ -525,19 +527,19 @@ inline void LstmStepHybrid(
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
const int8_t* projection_weights_ptr, float projection_weights_scale,
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
int output_batch_leading_dim, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* scaling_factors, float* scaling_factors_scratch,
float* recovered_cell_weights, int8_t* quantized_input_ptr,
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
int32_t* zero_points, int32_t* row_sums, int row_sums_size,
bool* compute_row_sums, bool asymmetric_quantize_inputs,
float* forget_gate_scratch, float* cell_gate_scratch,
float* output_gate_scratch, float* scaling_factors,
float* scaling_factors_scratch, float* recovered_cell_weights,
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
float* output_ptr, int32_t* zero_points, int32_t* row_sums,
int row_sums_size, bool* compute_row_sums, bool asymmetric_quantize_inputs,
CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("LstmStepHybrid");
// Since we have already checked that weights are all there or none, we
@ -553,7 +555,7 @@ inline void LstmStepHybrid(
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
} else {
if (!use_cifg) {
@ -562,8 +564,8 @@ inline void LstmStepHybrid(
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
@ -657,7 +659,8 @@ inline void LstmStepHybrid(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
input_to_cell_weights_scale, scaling_factors, n_batch, cell_scratch,
input_to_cell_weights_scale, scaling_factors, n_batch,
cell_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
context);
@ -699,9 +702,10 @@ inline void LstmStepHybrid(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, aux_input_to_cell_weights_scale,
scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr,
zero_points, accum_scratch_ptr, aux_input_to_cell_row_sums,
compute_row_sums, scaling_factors_scratch, context);
scaling_factors, n_batch, cell_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
@ -739,9 +743,10 @@ inline void LstmStepHybrid(
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, recurrent_to_cell_weights_scale,
scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr,
zero_points, accum_scratch_ptr, recurrent_to_cell_row_sums,
compute_row_sums, scaling_factors_scratch, context);
scaling_factors, n_batch, cell_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
recurrent_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output,
@ -800,24 +805,26 @@ inline void LstmStepHybrid(
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (use_layer_norm) {
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch);
tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch,
n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
}
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell,
params->activation, cell_gate_scratch);
if (use_cifg) {
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
cell_gate_scratch, forget_gate_scratch, n_batch * n_cell,
cell_state_ptr);
} else {
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
cell_gate_scratch, input_gate_scratch, n_batch * n_cell,
cell_state_ptr);
}
if (params->cell_clip > 0.0) {
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
@ -845,8 +852,8 @@ inline void LstmStepHybrid(
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
params->activation, cell_gate_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch,
n_batch * n_cell, output_gate_scratch);
const bool use_projection_weight = (projection_weights_ptr != nullptr);
@ -940,7 +947,7 @@ inline void LstmStepHybrid(
// Gate biases of size 'n_cell':
// input_bias_ptr - optional
// forget_bias_ptr
// cell_bias_ptr
// cell_gate_bias_ptr
// output_bias_ptr
//
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
@ -1028,7 +1035,7 @@ inline void LstmStepInteger(
const int16_t* layer_norm_output_weight_ptr,
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
const int32_t* cell_bias_ptr, const int32_t* output_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_bias_ptr,
int16_t quantized_cell_clip, int8_t quantized_proj_clip, int32_t cell_scale,
int32_t input_variance_guard, int32_t forget_variance_guard,
int32_t cell_variance_guard, int32_t output_variance_guard,
@ -1115,7 +1122,7 @@ inline void LstmStepInteger(
if (use_layer_norm) {
tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr,
cell_bias_ptr, layer_norm_cell_scale_a,
cell_gate_bias_ptr, layer_norm_cell_scale_a,
layer_norm_cell_scale_b, cell_variance_guard,
n_batch, n_cell, scratch_2_ptr);
}
@ -1266,7 +1273,7 @@ inline void LstmStepInteger(
// Gate biases of size 'n_cell':
// input_bias_ptr - optional
// forget_bias_ptr
// cell_bias_ptr
// cell_gate_bias_ptr
// output_bias_ptr
//
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
@ -1355,7 +1362,7 @@ void LstmStepInteger(
const int16_t* layer_norm_output_weight_ptr,
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
const int32_t* cell_bias_ptr, const int32_t* output_bias_ptr,
const int32_t* cell_gate_bias_ptr, const int32_t* output_bias_ptr,
const int32_t* proj_bias_ptr, const TfLiteLSTMParams* params,
const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
const int32_t* intermediate_zp, int32 quantized_cell_clip,
@ -1413,7 +1420,7 @@ void LstmStepInteger(
// Update gate with layer norm.
tensor_utils::ApplyLayerNormFloat(
scratch3, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
layer_norm_cell_scale_b, cell_bias_ptr, n_batch, n_cell, scratch3);
layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell, scratch3);
// Update gate tanh.
tensor_utils::ApplyTanhFloat(scratch3, n_batch, n_cell, -12, scratch3);
@ -1538,16 +1545,16 @@ TfLiteStatus EvalFloat(
// Index the scratch buffers pointers to the global scratch buffer.
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* cell_gate_scratch = nullptr;
float* forget_gate_scratch = nullptr;
float* output_gate_scratch = nullptr;
if (use_cifg) {
cell_scratch = scratch_buffer_ptr;
cell_gate_scratch = scratch_buffer_ptr;
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
} else {
input_gate_scratch = scratch_buffer_ptr;
cell_scratch = scratch_buffer_ptr + n_cell * n_batch;
cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
}
@ -1599,7 +1606,8 @@ TfLiteStatus EvalFloat(
n_input, aux_input_size, n_output, output_batch_leading_dim,
GetTensorData<float>(activation_state),
GetTensorData<float>(cell_state), input_gate_scratch,
forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr);
forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
output_ptr);
}
} else {
for (int b = 0; b < n_batch; b++) {
@ -1628,7 +1636,7 @@ TfLiteStatus EvalFloat(
float* input_gate_scratch_ptr =
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
float* cell_scratch_ptr = cell_scratch + b * n_cell;
float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
LstmStepFloat(
@ -1659,8 +1667,8 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
output_ptr);
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
output_gate_scratch_ptr, output_ptr);
}
}
}
@ -1723,16 +1731,16 @@ TfLiteStatus EvalHybrid(
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* cell_gate_scratch = nullptr;
float* forget_gate_scratch = nullptr;
float* output_gate_scratch = nullptr;
if (use_cifg) {
cell_scratch = scratch_buffer_ptr;
cell_gate_scratch = scratch_buffer_ptr;
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
} else {
input_gate_scratch = scratch_buffer_ptr;
cell_scratch = scratch_buffer_ptr + n_cell * n_batch;
cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
}
@ -1805,7 +1813,7 @@ TfLiteStatus EvalHybrid(
GetTensorScale(projection_weights),
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
n_input, aux_input_size, n_output, output_batch_leading_dim,
input_gate_scratch, forget_gate_scratch, cell_scratch,
input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
output_gate_scratch, GetTensorData<float>(scaling_factors),
GetTensorData<float>(prod_scaling_factors),
GetTensorData<float>(recovered_cell_weights),
@ -1845,7 +1853,7 @@ TfLiteStatus EvalHybrid(
float* input_gate_scratch_ptr =
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
float* cell_scratch_ptr = cell_scratch + b * n_cell;
float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
LstmStepHybrid(
@ -1892,8 +1900,8 @@ TfLiteStatus EvalHybrid(
GetTensorData<float>(projection_bias), params,
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
output_batch_leading_dim, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
GetTensorData<float>(scaling_factors),
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
output_gate_scratch_ptr, GetTensorData<float>(scaling_factors),
GetTensorData<float>(prod_scaling_factors),
GetTensorData<float>(recovered_cell_weights),
GetTensorData<int8_t>(input_quantized),
@ -2119,7 +2127,7 @@ TfLiteStatus EvalInteger8x8_8(
GetTensorData<int16_t>(output_layer_norm_coefficients);
const int32_t* input_bias_ptr = GetTensorData<int32_t>(input_gate_bias);
const int32_t* forget_bias_ptr = GetTensorData<int32_t>(forget_gate_bias);
const int32_t* cell_bias_ptr = GetTensorData<int32_t>(cell_bias);
const int32_t* cell_gate_bias_ptr = GetTensorData<int32_t>(cell_bias);
const int32_t* output_bias_ptr = GetTensorData<int32_t>(output_gate_bias);
const int32_t* proj_bias_ptr = GetTensorData<int32_t>(projection_bias);
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
@ -2206,7 +2214,7 @@ TfLiteStatus EvalInteger8x8_8(
integer_lstm_param->layer_norm_output_scale_a,
integer_lstm_param->layer_norm_output_scale_b,
input_bias_ptr, forget_bias_ptr, cell_bias_ptr, output_bias_ptr,
input_bias_ptr, forget_bias_ptr, cell_gate_bias_ptr, output_bias_ptr,
proj_bias_ptr,
params, integer_lstm_param->intermediate_scale_a,

View File

@ -58,13 +58,13 @@ inline void LstmStepWithAuxInput(
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
const float* projection_weights_ptr, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr, Logger* logger,
float* forget_gate_scratch, float* cell_gate_scratch,
float* output_gate_scratch, float* output_ptr, Logger* logger,
const std::vector<int>& intermediate_tensor_indexes,
ErrorReporter* error_reporter) {
// Since we have already checked that weights are all there or none, we can
@ -80,7 +80,7 @@ inline void LstmStepWithAuxInput(
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
} else {
if (!use_cifg) {
@ -89,8 +89,8 @@ inline void LstmStepWithAuxInput(
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
@ -107,7 +107,7 @@ inline void LstmStepWithAuxInput(
forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr,
n_cell, n_input, input_ptr,
n_batch, cell_scratch);
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
output_gate_scratch);
@ -125,7 +125,7 @@ inline void LstmStepWithAuxInput(
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, cell_scratch);
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, output_gate_scratch);
@ -142,7 +142,7 @@ inline void LstmStepWithAuxInput(
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, cell_scratch);
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, output_gate_scratch);
@ -193,26 +193,28 @@ inline void LstmStepWithAuxInput(
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_indexes[2], cell_scratch,
logger->LogTensorValue(intermediate_tensor_indexes[2], cell_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch);
tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch,
n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
}
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell,
params->activation, cell_gate_scratch);
if (use_cifg) {
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
cell_gate_scratch, forget_gate_scratch, n_batch * n_cell,
cell_state_ptr);
} else {
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
cell_gate_scratch, input_gate_scratch, n_batch * n_cell,
cell_state_ptr);
}
if (params->cell_clip > 0.0) {
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
@ -239,8 +241,8 @@ inline void LstmStepWithAuxInput(
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
params->activation, cell_gate_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch,
n_batch * n_cell, output_gate_scratch);
logger->LogTensorValue(intermediate_tensor_indexes[4], output_gate_scratch,
@ -329,16 +331,16 @@ TfLiteStatus EvalFloat(
// Index the scratch buffers pointers to the global scratch buffer.
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* cell_gate_scratch = nullptr;
float* forget_gate_scratch = nullptr;
float* output_gate_scratch = nullptr;
if (use_cifg) {
cell_scratch = scratch_buffer_ptr;
cell_gate_scratch = scratch_buffer_ptr;
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
} else {
input_gate_scratch = scratch_buffer_ptr;
cell_scratch = scratch_buffer_ptr + n_cell * n_batch;
cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
}
@ -390,7 +392,7 @@ TfLiteStatus EvalFloat(
n_input, aux_input_size, n_output, output_batch_leading_dim,
GetTensorData<float>(activation_state),
GetTensorData<float>(cell_state), input_gate_scratch,
forget_gate_scratch, cell_scratch, output_gate_scratch,
forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
output_ptr_time, logger, intermediate_tensor_indexes, error_reporter);
}
} else {
@ -420,7 +422,7 @@ TfLiteStatus EvalFloat(
float* input_gate_scratch_ptr =
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
float* cell_scratch_ptr = cell_scratch + b * n_cell;
float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
LstmStepWithAuxInput(
@ -451,8 +453,9 @@ TfLiteStatus EvalFloat(
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
output_ptr, logger, intermediate_tensor_indexes, error_reporter);
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
output_gate_scratch_ptr, output_ptr, logger,
intermediate_tensor_indexes, error_reporter);
}
}
}