LSTM: Split gate calculations to separate functions.

PiperOrigin-RevId: 319113154
Change-Id: I3f6d5358ac2bc619e3705d6f8aaf66c38c7c1b66
This commit is contained in:
Robert David 2020-06-30 15:39:21 -07:00 committed by TensorFlower Gardener
parent c0bb8efaf1
commit ec07b637ce
2 changed files with 633 additions and 737 deletions

File diff suppressed because it is too large Load Diff

View File

@ -37,6 +37,64 @@ namespace builtin {
namespace { namespace {
inline void CalculateLstmGateFloat(
const float* input, const float* input_to_gate_weights,
const float* aux_input, const float* aux_input_to_gate_weights,
const float* output_state, const float* recurrent_to_gate_weights,
const float* cell_state, const float* cell_to_gate_weights,
const float* layer_norm_coefficients, const float* gate_bias,
const int n_batch, const int n_input, const int n_aux_input,
const int n_output, const int n_cell,
const TfLiteFusedActivation activation, float* gate,
const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
Logger* logger, int intermediate_tensor_index,
ErrorReporter* error_reporter) {
const bool use_peephole = (cell_to_gate_weights != nullptr);
const bool use_layer_norm = (layer_norm_coefficients != nullptr);
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (use_layer_norm) {
std::fill_n(gate, n_cell * n_batch, 0.0f);
} else {
tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
}
// For each batch and cell: compute input_weight * input.
// Skip if input is all zeros.
if (!is_input_all_zeros) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
}
// For each batch and cell: compute aux_input_weight * aux_input.
// Skip if auxiliary input is not available or all zeros.
if (!is_aux_input_all_zeros) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights,
n_cell, n_aux_input,
aux_input, n_batch, gate);
}
// For each batch and cell: compute recurrent_weight * output_state.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
// For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
}
// Do layer normalization (if layer norm LSTM)
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_index, gate, n_cell * n_batch,
error_reporter);
tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
gate, n_batch, gate);
tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
}
// Apply activation
tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation,
gate);
}
// TODO(b/159066113): This is the exact same function as UpdateLstmCellFloat in // TODO(b/159066113): This is the exact same function as UpdateLstmCellFloat in
// kernels/lstm_eval.cc, make that public and remove this. // kernels/lstm_eval.cc, make that public and remove this.
void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state, void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
@ -130,178 +188,72 @@ inline void LstmStepCalibration(
// Since we have already checked that weights are all there or none, we can // Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition. // check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
// Make named scratch buffers for the different gates. // Make named scratch buffers.
float* input_gate_scratch = scratch0; float* input_gate_scratch = scratch0;
float* forget_gate_scratch = scratch1; float* forget_gate_scratch = scratch1;
float* cell_gate_scratch = scratch2; float* cell_gate_scratch = scratch2;
float* output_gate_scratch = scratch3; float* output_gate_scratch = scratch3;
// Initialize scratch buffers with bias for regular lstm or initialize with // Check if inputs are all zeros so we can skip some computations.
// zero for layer norm lstm. const bool is_input_all_zeros =
if (use_layer_norm) { tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
if (!use_cifg) { const bool is_aux_input_all_zeros =
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); (aux_input_ptr == nullptr ||
} tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
} else {
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
n_batch, input_gate_scratch);
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
// For each batch and cell: compute input_weight * input.
if (!use_cifg) { if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate( // Calculate the input gate. (If not CIFG.)
input_to_input_weights_ptr, n_cell, n_input, input_ptr, n_batch, CalculateLstmGateFloat(
input_gate_scratch); input_ptr, input_to_input_weights_ptr, aux_input_ptr,
aux_input_to_input_weights_ptr, output_state_ptr,
recurrent_to_input_weights_ptr, cell_state_ptr,
cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
/*activation=*/kTfLiteActSigmoid, input_gate_scratch,
is_input_all_zeros, is_aux_input_all_zeros, logger,
intermediate_tensor_indexes[0], error_reporter);
} }
// Calculate the forget gate.
tensor_utils::MatrixBatchVectorMultiplyAccumulate( CalculateLstmGateFloat(
input_to_forget_weights_ptr, n_cell, n_input, input_ptr, n_batch, input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
forget_gate_scratch); aux_input_to_forget_weights_ptr, output_state_ptr,
tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr, recurrent_to_forget_weights_ptr, cell_state_ptr,
n_cell, n_input, input_ptr, cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
n_batch, cell_gate_scratch); forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
tensor_utils::MatrixBatchVectorMultiplyAccumulate( /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch, is_aux_input_all_zeros, logger, intermediate_tensor_indexes[1],
output_gate_scratch); error_reporter);
// Calculate the cell update gate.
// For each batch and cell: compute aux_input_weight * aux_input. CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
// Skip if auxiliary input is not available. aux_input_to_cell_weights_ptr, output_state_ptr,
if (aux_input_ptr != nullptr) { recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
if (!use_cifg) { /*cell_to_gate_weights=*/nullptr,
tensor_utils::MatrixBatchVectorMultiplyAccumulate( cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
n_batch, input_gate_scratch); params->activation, cell_gate_scratch,
} is_input_all_zeros, is_aux_input_all_zeros, logger,
intermediate_tensor_indexes[2], error_reporter);
tensor_utils::MatrixBatchVectorMultiplyAccumulate( // Update the cell state.
aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, output_gate_scratch);
}
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, input_gate_scratch);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, output_gate_scratch);
// For each batch and cell: update input gate.
if (!use_cifg) {
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_indexes[0], input_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
n_batch, input_gate_scratch);
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
}
// For each batch and cell: update forget gate.
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_indexes[1], forget_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
n_batch, forget_gate_scratch);
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell.
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_indexes[2], cell_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch,
n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
}
tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell,
params->activation, cell_gate_scratch);
UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
forget_gate_scratch, cell_gate_scratch, use_cifg, forget_gate_scratch, cell_gate_scratch, use_cifg,
params->cell_clip); params->cell_clip);
// Calculate output gate.
// For each batch and cell: update the output gate. CalculateLstmGateFloat(
if (use_peephole) { input_ptr, input_to_output_weights_ptr, aux_input_ptr,
tensor_utils::VectorBatchVectorCwiseProductAccumulate( aux_input_to_output_weights_ptr, output_state_ptr,
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, recurrent_to_output_weights_ptr, cell_state_ptr,
output_gate_scratch); cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
} output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
if (use_layer_norm) { /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
logger->LogTensorValue(intermediate_tensor_indexes[3], output_gate_scratch, is_aux_input_all_zeros, logger, intermediate_tensor_indexes[3],
n_cell * n_batch, error_reporter); error_reporter);
tensor_utils::MeanStddevNormalization(output_gate_scratch, // Update the output state.
output_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
n_batch, output_gate_scratch);
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
CalculateLstmOutputCalibration( CalculateLstmOutputCalibration(
n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
params->activation, projection_weights_ptr, projection_bias_ptr, params->activation, projection_weights_ptr, projection_bias_ptr,
params->proj_clip, output_state_ptr, scratch2, logger, params->proj_clip, output_state_ptr, scratch2, logger,
intermediate_tensor_indexes, error_reporter); intermediate_tensor_indexes, error_reporter);
// Copy output state to the output. Note that the output's rows may not be
// Copy output_state to the output. Note that the output batch rows may not be
// contiguous (output_batch_leading_dim != n_output). // contiguous (output_batch_leading_dim != n_output).
for (int b = 0; b < n_batch; b++) { for (int b = 0; b < n_batch; b++) {
std::copy_n(output_state_ptr + b * n_output, n_output, std::copy_n(output_state_ptr + b * n_output, n_output,