All LSTM implementations: Rename cell_scratch to cell_gate_scratch, and cell_bias_ptr to cell_gate_bias_ptr to better reflect what those arrays are.
Do note this is not the same thing as the LSTM cell "state", but a layer/gate that calculates the update. The cell state depends on the input, forget, and cell gates; these arrays are the output and the bias for the last gate. PiperOrigin-RevId: 316720132 Change-Id: I71c370dabd27f776987e061b9393022c775589c9
This commit is contained in:
parent
2d50164bdb
commit
ed557008d6
@ -212,13 +212,13 @@ inline void LstmStepFloat(
|
|||||||
const float* cell_layer_norm_coefficients_ptr,
|
const float* cell_layer_norm_coefficients_ptr,
|
||||||
const float* output_layer_norm_coefficients_ptr,
|
const float* output_layer_norm_coefficients_ptr,
|
||||||
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
||||||
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
|
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
|
||||||
const float* projection_weights_ptr, const float* projection_bias_ptr,
|
const float* projection_weights_ptr, const float* projection_bias_ptr,
|
||||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||||
int n_aux_input, int n_output, int output_batch_leading_dim,
|
int n_aux_input, int n_output, int output_batch_leading_dim,
|
||||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
float* forget_gate_scratch, float* cell_gate_scratch,
|
||||||
float* output_ptr) {
|
float* output_gate_scratch, float* output_ptr) {
|
||||||
ruy::profiler::ScopeLabel label("LstmStepFloat");
|
ruy::profiler::ScopeLabel label("LstmStepFloat");
|
||||||
// Since we have already checked that weights are all there or none, we can
|
// Since we have already checked that weights are all there or none, we can
|
||||||
// check the existence of only one to the get the condition.
|
// check the existence of only one to the get the condition.
|
||||||
@ -233,7 +233,7 @@ inline void LstmStepFloat(
|
|||||||
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
}
|
}
|
||||||
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
} else {
|
} else {
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
@ -242,8 +242,8 @@ inline void LstmStepFloat(
|
|||||||
}
|
}
|
||||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
}
|
}
|
||||||
@ -262,7 +262,7 @@ inline void LstmStepFloat(
|
|||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_cell_weights_ptr, n_cell, n_input, input_ptr, n_batch,
|
input_to_cell_weights_ptr, n_cell, n_input, input_ptr, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
|
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
|
||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
@ -283,7 +283,7 @@ inline void LstmStepFloat(
|
|||||||
n_batch, forget_gate_scratch);
|
n_batch, forget_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
||||||
n_batch, cell_scratch);
|
n_batch, cell_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
||||||
n_batch, output_gate_scratch);
|
n_batch, output_gate_scratch);
|
||||||
@ -300,7 +300,7 @@ inline void LstmStepFloat(
|
|||||||
n_batch, forget_gate_scratch);
|
n_batch, forget_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
|
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||||
n_batch, cell_scratch);
|
n_batch, cell_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
|
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||||
n_batch, output_gate_scratch);
|
n_batch, output_gate_scratch);
|
||||||
@ -347,24 +347,26 @@ inline void LstmStepFloat(
|
|||||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||||
n_batch * n_cell, cell_state_ptr);
|
n_batch * n_cell, cell_state_ptr);
|
||||||
if (use_layer_norm) {
|
if (use_layer_norm) {
|
||||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch,
|
||||||
n_batch);
|
n_cell, n_batch);
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
}
|
}
|
||||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell,
|
||||||
params->activation, cell_scratch);
|
params->activation, cell_gate_scratch);
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||||
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
cell_gate_scratch, forget_gate_scratch, n_batch * n_cell,
|
||||||
|
cell_state_ptr);
|
||||||
} else {
|
} else {
|
||||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||||
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
cell_gate_scratch, input_gate_scratch, n_batch * n_cell,
|
||||||
|
cell_state_ptr);
|
||||||
}
|
}
|
||||||
if (params->cell_clip > 0.0) {
|
if (params->cell_clip > 0.0) {
|
||||||
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
||||||
@ -389,8 +391,8 @@ inline void LstmStepFloat(
|
|||||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||||
params->activation, cell_scratch);
|
params->activation, cell_gate_scratch);
|
||||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch,
|
||||||
n_batch * n_cell, output_gate_scratch);
|
n_batch * n_cell, output_gate_scratch);
|
||||||
|
|
||||||
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
||||||
@ -525,19 +527,19 @@ inline void LstmStepHybrid(
|
|||||||
const float* cell_layer_norm_coefficients_ptr,
|
const float* cell_layer_norm_coefficients_ptr,
|
||||||
const float* output_layer_norm_coefficients_ptr,
|
const float* output_layer_norm_coefficients_ptr,
|
||||||
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
||||||
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
|
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
|
||||||
const int8_t* projection_weights_ptr, float projection_weights_scale,
|
const int8_t* projection_weights_ptr, float projection_weights_scale,
|
||||||
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||||
int output_batch_leading_dim, float* input_gate_scratch,
|
int output_batch_leading_dim, float* input_gate_scratch,
|
||||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
float* forget_gate_scratch, float* cell_gate_scratch,
|
||||||
float* scaling_factors, float* scaling_factors_scratch,
|
float* output_gate_scratch, float* scaling_factors,
|
||||||
float* recovered_cell_weights, int8_t* quantized_input_ptr,
|
float* scaling_factors_scratch, float* recovered_cell_weights,
|
||||||
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
|
||||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
|
||||||
float* cell_state_ptr, int32_t* accum_scratch_ptr, float* output_ptr,
|
float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
|
||||||
int32_t* zero_points, int32_t* row_sums, int row_sums_size,
|
float* output_ptr, int32_t* zero_points, int32_t* row_sums,
|
||||||
bool* compute_row_sums, bool asymmetric_quantize_inputs,
|
int row_sums_size, bool* compute_row_sums, bool asymmetric_quantize_inputs,
|
||||||
CpuBackendContext* context) {
|
CpuBackendContext* context) {
|
||||||
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
ruy::profiler::ScopeLabel label("LstmStepHybrid");
|
||||||
// Since we have already checked that weights are all there or none, we
|
// Since we have already checked that weights are all there or none, we
|
||||||
@ -553,7 +555,7 @@ inline void LstmStepHybrid(
|
|||||||
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
}
|
}
|
||||||
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
} else {
|
} else {
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
@ -562,8 +564,8 @@ inline void LstmStepHybrid(
|
|||||||
}
|
}
|
||||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
}
|
}
|
||||||
@ -657,7 +659,8 @@ inline void LstmStepHybrid(
|
|||||||
|
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||||
input_to_cell_weights_scale, scaling_factors, n_batch, cell_scratch,
|
input_to_cell_weights_scale, scaling_factors, n_batch,
|
||||||
|
cell_gate_scratch,
|
||||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||||
input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
|
input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||||
context);
|
context);
|
||||||
@ -699,9 +702,10 @@ inline void LstmStepHybrid(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
|
||||||
quantized_aux_input_ptr, aux_input_to_cell_weights_scale,
|
quantized_aux_input_ptr, aux_input_to_cell_weights_scale,
|
||||||
scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr,
|
scaling_factors, n_batch, cell_gate_scratch,
|
||||||
zero_points, accum_scratch_ptr, aux_input_to_cell_row_sums,
|
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||||
compute_row_sums, scaling_factors_scratch, context);
|
aux_input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||||
|
context);
|
||||||
|
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
||||||
@ -739,9 +743,10 @@ inline void LstmStepHybrid(
|
|||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
||||||
quantized_output_state_ptr, recurrent_to_cell_weights_scale,
|
quantized_output_state_ptr, recurrent_to_cell_weights_scale,
|
||||||
scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr,
|
scaling_factors, n_batch, cell_gate_scratch,
|
||||||
zero_points, accum_scratch_ptr, recurrent_to_cell_row_sums,
|
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||||
compute_row_sums, scaling_factors_scratch, context);
|
recurrent_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||||
|
context);
|
||||||
|
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_output_weights_ptr, n_cell, n_output,
|
recurrent_to_output_weights_ptr, n_cell, n_output,
|
||||||
@ -800,24 +805,26 @@ inline void LstmStepHybrid(
|
|||||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||||
n_batch * n_cell, cell_state_ptr);
|
n_batch * n_cell, cell_state_ptr);
|
||||||
if (use_layer_norm) {
|
if (use_layer_norm) {
|
||||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch,
|
||||||
n_batch);
|
n_cell, n_batch);
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
}
|
}
|
||||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell,
|
||||||
params->activation, cell_scratch);
|
params->activation, cell_gate_scratch);
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||||
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
cell_gate_scratch, forget_gate_scratch, n_batch * n_cell,
|
||||||
|
cell_state_ptr);
|
||||||
} else {
|
} else {
|
||||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||||
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
cell_gate_scratch, input_gate_scratch, n_batch * n_cell,
|
||||||
|
cell_state_ptr);
|
||||||
}
|
}
|
||||||
if (params->cell_clip > 0.0) {
|
if (params->cell_clip > 0.0) {
|
||||||
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
||||||
@ -845,8 +852,8 @@ inline void LstmStepHybrid(
|
|||||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||||
params->activation, cell_scratch);
|
params->activation, cell_gate_scratch);
|
||||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch,
|
||||||
n_batch * n_cell, output_gate_scratch);
|
n_batch * n_cell, output_gate_scratch);
|
||||||
|
|
||||||
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
||||||
@ -940,7 +947,7 @@ inline void LstmStepHybrid(
|
|||||||
// Gate biases of size 'n_cell':
|
// Gate biases of size 'n_cell':
|
||||||
// input_bias_ptr - optional
|
// input_bias_ptr - optional
|
||||||
// forget_bias_ptr
|
// forget_bias_ptr
|
||||||
// cell_bias_ptr
|
// cell_gate_bias_ptr
|
||||||
// output_bias_ptr
|
// output_bias_ptr
|
||||||
//
|
//
|
||||||
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
|
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
|
||||||
@ -1028,7 +1035,7 @@ inline void LstmStepInteger(
|
|||||||
const int16_t* layer_norm_output_weight_ptr,
|
const int16_t* layer_norm_output_weight_ptr,
|
||||||
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
|
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
|
||||||
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
|
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
|
||||||
const int32_t* cell_bias_ptr, const int32_t* output_bias_ptr,
|
const int32_t* cell_gate_bias_ptr, const int32_t* output_bias_ptr,
|
||||||
int16_t quantized_cell_clip, int8_t quantized_proj_clip, int32_t cell_scale,
|
int16_t quantized_cell_clip, int8_t quantized_proj_clip, int32_t cell_scale,
|
||||||
int32_t input_variance_guard, int32_t forget_variance_guard,
|
int32_t input_variance_guard, int32_t forget_variance_guard,
|
||||||
int32_t cell_variance_guard, int32_t output_variance_guard,
|
int32_t cell_variance_guard, int32_t output_variance_guard,
|
||||||
@ -1115,7 +1122,7 @@ inline void LstmStepInteger(
|
|||||||
|
|
||||||
if (use_layer_norm) {
|
if (use_layer_norm) {
|
||||||
tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr,
|
tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr,
|
||||||
cell_bias_ptr, layer_norm_cell_scale_a,
|
cell_gate_bias_ptr, layer_norm_cell_scale_a,
|
||||||
layer_norm_cell_scale_b, cell_variance_guard,
|
layer_norm_cell_scale_b, cell_variance_guard,
|
||||||
n_batch, n_cell, scratch_2_ptr);
|
n_batch, n_cell, scratch_2_ptr);
|
||||||
}
|
}
|
||||||
@ -1266,7 +1273,7 @@ inline void LstmStepInteger(
|
|||||||
// Gate biases of size 'n_cell':
|
// Gate biases of size 'n_cell':
|
||||||
// input_bias_ptr - optional
|
// input_bias_ptr - optional
|
||||||
// forget_bias_ptr
|
// forget_bias_ptr
|
||||||
// cell_bias_ptr
|
// cell_gate_bias_ptr
|
||||||
// output_bias_ptr
|
// output_bias_ptr
|
||||||
//
|
//
|
||||||
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
|
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
|
||||||
@ -1355,7 +1362,7 @@ void LstmStepInteger(
|
|||||||
const int16_t* layer_norm_output_weight_ptr,
|
const int16_t* layer_norm_output_weight_ptr,
|
||||||
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
|
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
|
||||||
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
|
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
|
||||||
const int32_t* cell_bias_ptr, const int32_t* output_bias_ptr,
|
const int32_t* cell_gate_bias_ptr, const int32_t* output_bias_ptr,
|
||||||
const int32_t* proj_bias_ptr, const TfLiteLSTMParams* params,
|
const int32_t* proj_bias_ptr, const TfLiteLSTMParams* params,
|
||||||
const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
|
const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
|
||||||
const int32_t* intermediate_zp, int32 quantized_cell_clip,
|
const int32_t* intermediate_zp, int32 quantized_cell_clip,
|
||||||
@ -1413,7 +1420,7 @@ void LstmStepInteger(
|
|||||||
// Update gate with layer norm.
|
// Update gate with layer norm.
|
||||||
tensor_utils::ApplyLayerNormFloat(
|
tensor_utils::ApplyLayerNormFloat(
|
||||||
scratch3, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
|
scratch3, layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
|
||||||
layer_norm_cell_scale_b, cell_bias_ptr, n_batch, n_cell, scratch3);
|
layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_cell, scratch3);
|
||||||
|
|
||||||
// Update gate tanh.
|
// Update gate tanh.
|
||||||
tensor_utils::ApplyTanhFloat(scratch3, n_batch, n_cell, -12, scratch3);
|
tensor_utils::ApplyTanhFloat(scratch3, n_batch, n_cell, -12, scratch3);
|
||||||
@ -1538,16 +1545,16 @@ TfLiteStatus EvalFloat(
|
|||||||
// Index the scratch buffers pointers to the global scratch buffer.
|
// Index the scratch buffers pointers to the global scratch buffer.
|
||||||
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
|
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
|
||||||
float* input_gate_scratch = nullptr;
|
float* input_gate_scratch = nullptr;
|
||||||
float* cell_scratch = nullptr;
|
float* cell_gate_scratch = nullptr;
|
||||||
float* forget_gate_scratch = nullptr;
|
float* forget_gate_scratch = nullptr;
|
||||||
float* output_gate_scratch = nullptr;
|
float* output_gate_scratch = nullptr;
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
cell_scratch = scratch_buffer_ptr;
|
cell_gate_scratch = scratch_buffer_ptr;
|
||||||
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
||||||
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
||||||
} else {
|
} else {
|
||||||
input_gate_scratch = scratch_buffer_ptr;
|
input_gate_scratch = scratch_buffer_ptr;
|
||||||
cell_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
||||||
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
||||||
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
|
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
|
||||||
}
|
}
|
||||||
@ -1599,7 +1606,8 @@ TfLiteStatus EvalFloat(
|
|||||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||||
GetTensorData<float>(activation_state),
|
GetTensorData<float>(activation_state),
|
||||||
GetTensorData<float>(cell_state), input_gate_scratch,
|
GetTensorData<float>(cell_state), input_gate_scratch,
|
||||||
forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr);
|
forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
|
||||||
|
output_ptr);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int b = 0; b < n_batch; b++) {
|
for (int b = 0; b < n_batch; b++) {
|
||||||
@ -1628,7 +1636,7 @@ TfLiteStatus EvalFloat(
|
|||||||
float* input_gate_scratch_ptr =
|
float* input_gate_scratch_ptr =
|
||||||
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
|
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
|
||||||
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
|
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
|
||||||
float* cell_scratch_ptr = cell_scratch + b * n_cell;
|
float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
|
||||||
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
|
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
|
||||||
|
|
||||||
LstmStepFloat(
|
LstmStepFloat(
|
||||||
@ -1659,8 +1667,8 @@ TfLiteStatus EvalFloat(
|
|||||||
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
|
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
|
||||||
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
|
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||||
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
|
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
|
||||||
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
|
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
|
||||||
output_ptr);
|
output_gate_scratch_ptr, output_ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1723,16 +1731,16 @@ TfLiteStatus EvalHybrid(
|
|||||||
|
|
||||||
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
|
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
|
||||||
float* input_gate_scratch = nullptr;
|
float* input_gate_scratch = nullptr;
|
||||||
float* cell_scratch = nullptr;
|
float* cell_gate_scratch = nullptr;
|
||||||
float* forget_gate_scratch = nullptr;
|
float* forget_gate_scratch = nullptr;
|
||||||
float* output_gate_scratch = nullptr;
|
float* output_gate_scratch = nullptr;
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
cell_scratch = scratch_buffer_ptr;
|
cell_gate_scratch = scratch_buffer_ptr;
|
||||||
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
||||||
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
||||||
} else {
|
} else {
|
||||||
input_gate_scratch = scratch_buffer_ptr;
|
input_gate_scratch = scratch_buffer_ptr;
|
||||||
cell_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
||||||
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
||||||
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
|
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
|
||||||
}
|
}
|
||||||
@ -1805,7 +1813,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorScale(projection_weights),
|
GetTensorScale(projection_weights),
|
||||||
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
|
GetTensorData<float>(projection_bias), params, n_batch, n_cell,
|
||||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||||
input_gate_scratch, forget_gate_scratch, cell_scratch,
|
input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
|
||||||
output_gate_scratch, GetTensorData<float>(scaling_factors),
|
output_gate_scratch, GetTensorData<float>(scaling_factors),
|
||||||
GetTensorData<float>(prod_scaling_factors),
|
GetTensorData<float>(prod_scaling_factors),
|
||||||
GetTensorData<float>(recovered_cell_weights),
|
GetTensorData<float>(recovered_cell_weights),
|
||||||
@ -1845,7 +1853,7 @@ TfLiteStatus EvalHybrid(
|
|||||||
float* input_gate_scratch_ptr =
|
float* input_gate_scratch_ptr =
|
||||||
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
|
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
|
||||||
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
|
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
|
||||||
float* cell_scratch_ptr = cell_scratch + b * n_cell;
|
float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
|
||||||
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
|
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
|
||||||
|
|
||||||
LstmStepHybrid(
|
LstmStepHybrid(
|
||||||
@ -1892,8 +1900,8 @@ TfLiteStatus EvalHybrid(
|
|||||||
GetTensorData<float>(projection_bias), params,
|
GetTensorData<float>(projection_bias), params,
|
||||||
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
|
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
|
||||||
output_batch_leading_dim, input_gate_scratch_ptr,
|
output_batch_leading_dim, input_gate_scratch_ptr,
|
||||||
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
|
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
|
||||||
GetTensorData<float>(scaling_factors),
|
output_gate_scratch_ptr, GetTensorData<float>(scaling_factors),
|
||||||
GetTensorData<float>(prod_scaling_factors),
|
GetTensorData<float>(prod_scaling_factors),
|
||||||
GetTensorData<float>(recovered_cell_weights),
|
GetTensorData<float>(recovered_cell_weights),
|
||||||
GetTensorData<int8_t>(input_quantized),
|
GetTensorData<int8_t>(input_quantized),
|
||||||
@ -2119,7 +2127,7 @@ TfLiteStatus EvalInteger8x8_8(
|
|||||||
GetTensorData<int16_t>(output_layer_norm_coefficients);
|
GetTensorData<int16_t>(output_layer_norm_coefficients);
|
||||||
const int32_t* input_bias_ptr = GetTensorData<int32_t>(input_gate_bias);
|
const int32_t* input_bias_ptr = GetTensorData<int32_t>(input_gate_bias);
|
||||||
const int32_t* forget_bias_ptr = GetTensorData<int32_t>(forget_gate_bias);
|
const int32_t* forget_bias_ptr = GetTensorData<int32_t>(forget_gate_bias);
|
||||||
const int32_t* cell_bias_ptr = GetTensorData<int32_t>(cell_bias);
|
const int32_t* cell_gate_bias_ptr = GetTensorData<int32_t>(cell_bias);
|
||||||
const int32_t* output_bias_ptr = GetTensorData<int32_t>(output_gate_bias);
|
const int32_t* output_bias_ptr = GetTensorData<int32_t>(output_gate_bias);
|
||||||
const int32_t* proj_bias_ptr = GetTensorData<int32_t>(projection_bias);
|
const int32_t* proj_bias_ptr = GetTensorData<int32_t>(projection_bias);
|
||||||
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
|
int16_t* cell_ptr = GetTensorData<int16_t>(cell_state);
|
||||||
@ -2206,7 +2214,7 @@ TfLiteStatus EvalInteger8x8_8(
|
|||||||
integer_lstm_param->layer_norm_output_scale_a,
|
integer_lstm_param->layer_norm_output_scale_a,
|
||||||
integer_lstm_param->layer_norm_output_scale_b,
|
integer_lstm_param->layer_norm_output_scale_b,
|
||||||
|
|
||||||
input_bias_ptr, forget_bias_ptr, cell_bias_ptr, output_bias_ptr,
|
input_bias_ptr, forget_bias_ptr, cell_gate_bias_ptr, output_bias_ptr,
|
||||||
proj_bias_ptr,
|
proj_bias_ptr,
|
||||||
|
|
||||||
params, integer_lstm_param->intermediate_scale_a,
|
params, integer_lstm_param->intermediate_scale_a,
|
||||||
|
@ -58,13 +58,13 @@ inline void LstmStepWithAuxInput(
|
|||||||
const float* cell_layer_norm_coefficients_ptr,
|
const float* cell_layer_norm_coefficients_ptr,
|
||||||
const float* output_layer_norm_coefficients_ptr,
|
const float* output_layer_norm_coefficients_ptr,
|
||||||
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
|
||||||
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
|
const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
|
||||||
const float* projection_weights_ptr, const float* projection_bias_ptr,
|
const float* projection_weights_ptr, const float* projection_bias_ptr,
|
||||||
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||||
int n_aux_input, int n_output, int output_batch_leading_dim,
|
int n_aux_input, int n_output, int output_batch_leading_dim,
|
||||||
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
|
||||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
float* forget_gate_scratch, float* cell_gate_scratch,
|
||||||
float* output_ptr, Logger* logger,
|
float* output_gate_scratch, float* output_ptr, Logger* logger,
|
||||||
const std::vector<int>& intermediate_tensor_indexes,
|
const std::vector<int>& intermediate_tensor_indexes,
|
||||||
ErrorReporter* error_reporter) {
|
ErrorReporter* error_reporter) {
|
||||||
// Since we have already checked that weights are all there or none, we can
|
// Since we have already checked that weights are all there or none, we can
|
||||||
@ -80,7 +80,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
}
|
}
|
||||||
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
|
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
|
||||||
} else {
|
} else {
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
@ -89,8 +89,8 @@ inline void LstmStepWithAuxInput(
|
|||||||
}
|
}
|
||||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
}
|
}
|
||||||
@ -107,7 +107,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr,
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr,
|
||||||
n_cell, n_input, input_ptr,
|
n_cell, n_input, input_ptr,
|
||||||
n_batch, cell_scratch);
|
n_batch, cell_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
|
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
|
||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
@ -125,7 +125,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
n_batch, forget_gate_scratch);
|
n_batch, forget_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
||||||
n_batch, cell_scratch);
|
n_batch, cell_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
||||||
n_batch, output_gate_scratch);
|
n_batch, output_gate_scratch);
|
||||||
@ -142,7 +142,7 @@ inline void LstmStepWithAuxInput(
|
|||||||
n_batch, forget_gate_scratch);
|
n_batch, forget_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
|
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||||
n_batch, cell_scratch);
|
n_batch, cell_gate_scratch);
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
|
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||||
n_batch, output_gate_scratch);
|
n_batch, output_gate_scratch);
|
||||||
@ -193,26 +193,28 @@ inline void LstmStepWithAuxInput(
|
|||||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
||||||
n_batch * n_cell, cell_state_ptr);
|
n_batch * n_cell, cell_state_ptr);
|
||||||
if (use_layer_norm) {
|
if (use_layer_norm) {
|
||||||
logger->LogTensorValue(intermediate_tensor_indexes[2], cell_scratch,
|
logger->LogTensorValue(intermediate_tensor_indexes[2], cell_gate_scratch,
|
||||||
n_cell * n_batch, error_reporter);
|
n_cell * n_batch, error_reporter);
|
||||||
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
|
tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch,
|
||||||
n_batch);
|
n_cell, n_batch);
|
||||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||||
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
|
cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
|
tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch,
|
||||||
cell_scratch);
|
cell_gate_scratch);
|
||||||
}
|
}
|
||||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell,
|
||||||
params->activation, cell_scratch);
|
params->activation, cell_gate_scratch);
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||||
forget_gate_scratch);
|
forget_gate_scratch);
|
||||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||||
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
cell_gate_scratch, forget_gate_scratch, n_batch * n_cell,
|
||||||
|
cell_state_ptr);
|
||||||
} else {
|
} else {
|
||||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||||
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
cell_gate_scratch, input_gate_scratch, n_batch * n_cell,
|
||||||
|
cell_state_ptr);
|
||||||
}
|
}
|
||||||
if (params->cell_clip > 0.0) {
|
if (params->cell_clip > 0.0) {
|
||||||
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
||||||
@ -239,8 +241,8 @@ inline void LstmStepWithAuxInput(
|
|||||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||||
output_gate_scratch);
|
output_gate_scratch);
|
||||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||||
params->activation, cell_scratch);
|
params->activation, cell_gate_scratch);
|
||||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch,
|
||||||
n_batch * n_cell, output_gate_scratch);
|
n_batch * n_cell, output_gate_scratch);
|
||||||
|
|
||||||
logger->LogTensorValue(intermediate_tensor_indexes[4], output_gate_scratch,
|
logger->LogTensorValue(intermediate_tensor_indexes[4], output_gate_scratch,
|
||||||
@ -329,16 +331,16 @@ TfLiteStatus EvalFloat(
|
|||||||
// Index the scratch buffers pointers to the global scratch buffer.
|
// Index the scratch buffers pointers to the global scratch buffer.
|
||||||
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
|
float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
|
||||||
float* input_gate_scratch = nullptr;
|
float* input_gate_scratch = nullptr;
|
||||||
float* cell_scratch = nullptr;
|
float* cell_gate_scratch = nullptr;
|
||||||
float* forget_gate_scratch = nullptr;
|
float* forget_gate_scratch = nullptr;
|
||||||
float* output_gate_scratch = nullptr;
|
float* output_gate_scratch = nullptr;
|
||||||
if (use_cifg) {
|
if (use_cifg) {
|
||||||
cell_scratch = scratch_buffer_ptr;
|
cell_gate_scratch = scratch_buffer_ptr;
|
||||||
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
||||||
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
||||||
} else {
|
} else {
|
||||||
input_gate_scratch = scratch_buffer_ptr;
|
input_gate_scratch = scratch_buffer_ptr;
|
||||||
cell_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
|
||||||
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
|
||||||
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
|
output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
|
||||||
}
|
}
|
||||||
@ -390,7 +392,7 @@ TfLiteStatus EvalFloat(
|
|||||||
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||||
GetTensorData<float>(activation_state),
|
GetTensorData<float>(activation_state),
|
||||||
GetTensorData<float>(cell_state), input_gate_scratch,
|
GetTensorData<float>(cell_state), input_gate_scratch,
|
||||||
forget_gate_scratch, cell_scratch, output_gate_scratch,
|
forget_gate_scratch, cell_gate_scratch, output_gate_scratch,
|
||||||
output_ptr_time, logger, intermediate_tensor_indexes, error_reporter);
|
output_ptr_time, logger, intermediate_tensor_indexes, error_reporter);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -420,7 +422,7 @@ TfLiteStatus EvalFloat(
|
|||||||
float* input_gate_scratch_ptr =
|
float* input_gate_scratch_ptr =
|
||||||
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
|
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
|
||||||
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
|
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
|
||||||
float* cell_scratch_ptr = cell_scratch + b * n_cell;
|
float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
|
||||||
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
|
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
|
||||||
|
|
||||||
LstmStepWithAuxInput(
|
LstmStepWithAuxInput(
|
||||||
@ -451,8 +453,9 @@ TfLiteStatus EvalFloat(
|
|||||||
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
|
GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
|
||||||
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
|
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
|
||||||
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
|
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
|
||||||
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
|
forget_gate_scratch_ptr, cell_gate_scratch_ptr,
|
||||||
output_ptr, logger, intermediate_tensor_indexes, error_reporter);
|
output_gate_scratch_ptr, output_ptr, logger,
|
||||||
|
intermediate_tensor_indexes, error_reporter);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user