LSTM: Split gate calculations to separate functions.
PiperOrigin-RevId: 319113154 Change-Id: I3f6d5358ac2bc619e3705d6f8aaf66c38c7c1b66
This commit is contained in:
parent
c0bb8efaf1
commit
ec07b637ce
File diff suppressed because it is too large
Load Diff
@ -37,6 +37,64 @@ namespace builtin {
|
||||
|
||||
namespace {
|
||||
|
||||
inline void CalculateLstmGateFloat(
|
||||
const float* input, const float* input_to_gate_weights,
|
||||
const float* aux_input, const float* aux_input_to_gate_weights,
|
||||
const float* output_state, const float* recurrent_to_gate_weights,
|
||||
const float* cell_state, const float* cell_to_gate_weights,
|
||||
const float* layer_norm_coefficients, const float* gate_bias,
|
||||
const int n_batch, const int n_input, const int n_aux_input,
|
||||
const int n_output, const int n_cell,
|
||||
const TfLiteFusedActivation activation, float* gate,
|
||||
const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
|
||||
Logger* logger, int intermediate_tensor_index,
|
||||
ErrorReporter* error_reporter) {
|
||||
const bool use_peephole = (cell_to_gate_weights != nullptr);
|
||||
const bool use_layer_norm = (layer_norm_coefficients != nullptr);
|
||||
|
||||
// Initialize scratch buffers with bias for regular lstm or initialize with
|
||||
// zero for layer norm lstm.
|
||||
if (use_layer_norm) {
|
||||
std::fill_n(gate, n_cell * n_batch, 0.0f);
|
||||
} else {
|
||||
tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
|
||||
}
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
// Skip if input is all zeros.
|
||||
if (!is_input_all_zeros) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
|
||||
}
|
||||
// For each batch and cell: compute aux_input_weight * aux_input.
|
||||
// Skip if auxiliary input is not available or all zeros.
|
||||
if (!is_aux_input_all_zeros) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights,
|
||||
n_cell, n_aux_input,
|
||||
aux_input, n_batch, gate);
|
||||
}
|
||||
// For each batch and cell: compute recurrent_weight * output_state.
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
|
||||
// For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
|
||||
}
|
||||
// Do layer normalization (if layer norm LSTM)
|
||||
if (use_layer_norm) {
|
||||
logger->LogTensorValue(intermediate_tensor_index, gate, n_cell * n_batch,
|
||||
error_reporter);
|
||||
|
||||
tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
|
||||
gate, n_batch, gate);
|
||||
tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
|
||||
}
|
||||
// Apply activation
|
||||
tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation,
|
||||
gate);
|
||||
}
|
||||
|
||||
// TODO(b/159066113): This is the exact same function as UpdateLstmCellFloat in
|
||||
// kernels/lstm_eval.cc, make that public and remove this.
|
||||
void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
|
||||
@ -130,178 +188,72 @@ inline void LstmStepCalibration(
|
||||
// Since we have already checked that weights are all there or none, we can
|
||||
// check the existence of only one to the get the condition.
|
||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
|
||||
|
||||
// Make named scratch buffers for the different gates.
|
||||
// Make named scratch buffers.
|
||||
float* input_gate_scratch = scratch0;
|
||||
float* forget_gate_scratch = scratch1;
|
||||
float* cell_gate_scratch = scratch2;
|
||||
float* output_gate_scratch = scratch3;
|
||||
|
||||
// Initialize scratch buffers with bias for regular lstm or initialize with
|
||||
// zero for layer norm lstm.
|
||||
if (use_layer_norm) {
|
||||
if (!use_cifg) {
|
||||
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
|
||||
}
|
||||
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
|
||||
std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f);
|
||||
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
|
||||
} else {
|
||||
if (!use_cifg) {
|
||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
|
||||
n_batch, input_gate_scratch);
|
||||
}
|
||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch,
|
||||
cell_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
|
||||
// For each batch and cell: compute input_weight * input.
|
||||
// Check if inputs are all zeros so we can skip some computations.
|
||||
const bool is_input_all_zeros =
|
||||
tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
|
||||
const bool is_aux_input_all_zeros =
|
||||
(aux_input_ptr == nullptr ||
|
||||
tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_input_weights_ptr, n_cell, n_input, input_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
// Calculate the input gate. (If not CIFG.)
|
||||
CalculateLstmGateFloat(
|
||||
input_ptr, input_to_input_weights_ptr, aux_input_ptr,
|
||||
aux_input_to_input_weights_ptr, output_state_ptr,
|
||||
recurrent_to_input_weights_ptr, cell_state_ptr,
|
||||
cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
|
||||
input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
|
||||
/*activation=*/kTfLiteActSigmoid, input_gate_scratch,
|
||||
is_input_all_zeros, is_aux_input_all_zeros, logger,
|
||||
intermediate_tensor_indexes[0], error_reporter);
|
||||
}
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_forget_weights_ptr, n_cell, n_input, input_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr,
|
||||
n_cell, n_input, input_ptr,
|
||||
n_batch, cell_gate_scratch);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
|
||||
// For each batch and cell: compute aux_input_weight * aux_input.
|
||||
// Skip if auxiliary input is not available.
|
||||
if (aux_input_ptr != nullptr) {
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
||||
n_batch, input_gate_scratch);
|
||||
}
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
||||
n_batch, forget_gate_scratch);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
||||
n_batch, cell_gate_scratch);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
|
||||
n_batch, output_gate_scratch);
|
||||
}
|
||||
|
||||
// For each batch and cell: compute recurrent_weight * output_state.
|
||||
if (!use_cifg) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, input_gate_scratch);
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, forget_gate_scratch);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, cell_gate_scratch);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
|
||||
n_batch, output_gate_scratch);
|
||||
|
||||
// For each batch and cell: update input gate.
|
||||
if (!use_cifg) {
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
if (use_layer_norm) {
|
||||
logger->LogTensorValue(intermediate_tensor_indexes[0], input_gate_scratch,
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(
|
||||
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
|
||||
n_batch, input_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
||||
input_gate_scratch);
|
||||
}
|
||||
|
||||
// For each batch and cell: update forget gate.
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
if (use_layer_norm) {
|
||||
logger->LogTensorValue(intermediate_tensor_indexes[1], forget_gate_scratch,
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
|
||||
forget_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
|
||||
n_batch, forget_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
|
||||
forget_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
||||
forget_gate_scratch);
|
||||
|
||||
// For each batch and cell: update the cell.
|
||||
if (use_layer_norm) {
|
||||
logger->LogTensorValue(intermediate_tensor_indexes[2], cell_gate_scratch,
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch,
|
||||
n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch,
|
||||
cell_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch,
|
||||
cell_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell,
|
||||
params->activation, cell_gate_scratch);
|
||||
|
||||
// Calculate the forget gate.
|
||||
CalculateLstmGateFloat(
|
||||
input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
|
||||
aux_input_to_forget_weights_ptr, output_state_ptr,
|
||||
recurrent_to_forget_weights_ptr, cell_state_ptr,
|
||||
cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
|
||||
forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
|
||||
/*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
|
||||
is_aux_input_all_zeros, logger, intermediate_tensor_indexes[1],
|
||||
error_reporter);
|
||||
// Calculate the cell update gate.
|
||||
CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
|
||||
aux_input_to_cell_weights_ptr, output_state_ptr,
|
||||
recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
|
||||
/*cell_to_gate_weights=*/nullptr,
|
||||
cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
|
||||
n_batch, n_input, n_aux_input, n_output, n_cell,
|
||||
params->activation, cell_gate_scratch,
|
||||
is_input_all_zeros, is_aux_input_all_zeros, logger,
|
||||
intermediate_tensor_indexes[2], error_reporter);
|
||||
// Update the cell state.
|
||||
UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
|
||||
forget_gate_scratch, cell_gate_scratch, use_cifg,
|
||||
params->cell_clip);
|
||||
|
||||
// For each batch and cell: update the output gate.
|
||||
if (use_peephole) {
|
||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
if (use_layer_norm) {
|
||||
logger->LogTensorValue(intermediate_tensor_indexes[3], output_gate_scratch,
|
||||
n_cell * n_batch, error_reporter);
|
||||
tensor_utils::MeanStddevNormalization(output_gate_scratch,
|
||||
output_gate_scratch, n_cell, n_batch);
|
||||
tensor_utils::VectorBatchVectorCwiseProduct(
|
||||
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
|
||||
n_batch, output_gate_scratch);
|
||||
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
|
||||
output_gate_scratch);
|
||||
}
|
||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||
output_gate_scratch);
|
||||
|
||||
// Calculate output gate.
|
||||
CalculateLstmGateFloat(
|
||||
input_ptr, input_to_output_weights_ptr, aux_input_ptr,
|
||||
aux_input_to_output_weights_ptr, output_state_ptr,
|
||||
recurrent_to_output_weights_ptr, cell_state_ptr,
|
||||
cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
|
||||
output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
|
||||
/*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
|
||||
is_aux_input_all_zeros, logger, intermediate_tensor_indexes[3],
|
||||
error_reporter);
|
||||
// Update the output state.
|
||||
CalculateLstmOutputCalibration(
|
||||
n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
|
||||
params->activation, projection_weights_ptr, projection_bias_ptr,
|
||||
params->proj_clip, output_state_ptr, scratch2, logger,
|
||||
intermediate_tensor_indexes, error_reporter);
|
||||
|
||||
// Copy output_state to the output. Note that the output batch rows may not be
|
||||
// Copy output state to the output. Note that the output's rows may not be
|
||||
// contiguous (output_batch_leading_dim != n_output).
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::copy_n(output_state_ptr + b * n_output, n_output,
|
||||
|
Loading…
Reference in New Issue
Block a user