LSTM: Split gate calculations to separate functions.

PiperOrigin-RevId: 319113154
Change-Id: I3f6d5358ac2bc619e3705d6f8aaf66c38c7c1b66
This commit is contained in:
Robert David 2020-06-30 15:39:21 -07:00 committed by TensorFlower Gardener
parent c0bb8efaf1
commit ec07b637ce
2 changed files with 633 additions and 737 deletions

File diff suppressed because it is too large Load Diff

View File

@ -37,6 +37,64 @@ namespace builtin {
namespace {
inline void CalculateLstmGateFloat(
const float* input, const float* input_to_gate_weights,
const float* aux_input, const float* aux_input_to_gate_weights,
const float* output_state, const float* recurrent_to_gate_weights,
const float* cell_state, const float* cell_to_gate_weights,
const float* layer_norm_coefficients, const float* gate_bias,
const int n_batch, const int n_input, const int n_aux_input,
const int n_output, const int n_cell,
const TfLiteFusedActivation activation, float* gate,
const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
Logger* logger, int intermediate_tensor_index,
ErrorReporter* error_reporter) {
const bool use_peephole = (cell_to_gate_weights != nullptr);
const bool use_layer_norm = (layer_norm_coefficients != nullptr);
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (use_layer_norm) {
std::fill_n(gate, n_cell * n_batch, 0.0f);
} else {
tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
}
// For each batch and cell: compute input_weight * input.
// Skip if input is all zeros.
if (!is_input_all_zeros) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_gate_weights, n_cell, n_input, input, n_batch, gate);
}
// For each batch and cell: compute aux_input_weight * aux_input.
// Skip if auxiliary input is not available or all zeros.
if (!is_aux_input_all_zeros) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights,
n_cell, n_aux_input,
aux_input, n_batch, gate);
}
// For each batch and cell: compute recurrent_weight * output_state.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_gate_weights, n_cell, n_output, output_state, n_batch, gate);
// For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_gate_weights, n_cell, cell_state, n_batch, gate);
}
// Do layer normalization (if layer norm LSTM)
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_index, gate, n_cell * n_batch,
error_reporter);
tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
gate, n_batch, gate);
tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
}
// Apply activation
tensor_utils::ApplyActivationToVector(gate, n_batch * n_cell, activation,
gate);
}
// TODO(b/159066113): This is the exact same function as UpdateLstmCellFloat in
// kernels/lstm_eval.cc, make that public and remove this.
void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
@ -130,178 +188,72 @@ inline void LstmStepCalibration(
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool use_layer_norm = (forget_layer_norm_coefficients_ptr != nullptr);
// Make named scratch buffers for the different gates.
// Make named scratch buffers.
float* input_gate_scratch = scratch0;
float* forget_gate_scratch = scratch1;
float* cell_gate_scratch = scratch2;
float* output_gate_scratch = scratch3;
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (use_layer_norm) {
if (!use_cifg) {
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
} else {
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
n_batch, input_gate_scratch);
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
// For each batch and cell: compute input_weight * input.
// Check if inputs are all zeros so we can skip some computations.
const bool is_input_all_zeros =
tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
const bool is_aux_input_all_zeros =
(aux_input_ptr == nullptr ||
tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input, input_ptr, n_batch,
input_gate_scratch);
// Calculate the input gate. (If not CIFG.)
CalculateLstmGateFloat(
input_ptr, input_to_input_weights_ptr, aux_input_ptr,
aux_input_to_input_weights_ptr, output_state_ptr,
recurrent_to_input_weights_ptr, cell_state_ptr,
cell_to_input_weights_ptr, input_layer_norm_coefficients_ptr,
input_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
/*activation=*/kTfLiteActSigmoid, input_gate_scratch,
is_input_all_zeros, is_aux_input_all_zeros, logger,
intermediate_tensor_indexes[0], error_reporter);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, input_ptr, n_batch,
forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_ptr,
n_cell, n_input, input_ptr,
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, input_ptr, n_batch,
output_gate_scratch);
// For each batch and cell: compute aux_input_weight * aux_input.
// Skip if auxiliary input is not available.
if (aux_input_ptr != nullptr) {
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, input_gate_scratch);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input, aux_input_ptr,
n_batch, output_gate_scratch);
}
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, input_gate_scratch);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, forget_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, cell_gate_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, output_gate_scratch);
// For each batch and cell: update input gate.
if (!use_cifg) {
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_indexes[0], input_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(
input_gate_scratch, input_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
n_batch, input_gate_scratch);
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
}
// For each batch and cell: update forget gate.
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_indexes[1], forget_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
n_batch, forget_gate_scratch);
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell.
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_indexes[2], cell_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(cell_gate_scratch, cell_gate_scratch,
n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_gate_scratch, n_batch,
cell_gate_scratch);
tensor_utils::VectorBatchVectorAdd(cell_gate_bias_ptr, n_cell, n_batch,
cell_gate_scratch);
}
tensor_utils::ApplyActivationToVector(cell_gate_scratch, n_batch * n_cell,
params->activation, cell_gate_scratch);
// Calculate the forget gate.
CalculateLstmGateFloat(
input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
aux_input_to_forget_weights_ptr, output_state_ptr,
recurrent_to_forget_weights_ptr, cell_state_ptr,
cell_to_forget_weights_ptr, forget_layer_norm_coefficients_ptr,
forget_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
/*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
is_aux_input_all_zeros, logger, intermediate_tensor_indexes[1],
error_reporter);
// Calculate the cell update gate.
CalculateLstmGateFloat(input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
aux_input_to_cell_weights_ptr, output_state_ptr,
recurrent_to_cell_weights_ptr, /*cell_state=*/nullptr,
/*cell_to_gate_weights=*/nullptr,
cell_layer_norm_coefficients_ptr, cell_gate_bias_ptr,
n_batch, n_input, n_aux_input, n_output, n_cell,
params->activation, cell_gate_scratch,
is_input_all_zeros, is_aux_input_all_zeros, logger,
intermediate_tensor_indexes[2], error_reporter);
// Update the cell state.
UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
forget_gate_scratch, cell_gate_scratch, use_cifg,
params->cell_clip);
// For each batch and cell: update the output gate.
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
if (use_layer_norm) {
logger->LogTensorValue(intermediate_tensor_indexes[3], output_gate_scratch,
n_cell * n_batch, error_reporter);
tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch);
tensor_utils::VectorBatchVectorCwiseProduct(
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
n_batch, output_gate_scratch);
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
// Calculate output gate.
CalculateLstmGateFloat(
input_ptr, input_to_output_weights_ptr, aux_input_ptr,
aux_input_to_output_weights_ptr, output_state_ptr,
recurrent_to_output_weights_ptr, cell_state_ptr,
cell_to_output_weights_ptr, output_layer_norm_coefficients_ptr,
output_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
/*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
is_aux_input_all_zeros, logger, intermediate_tensor_indexes[3],
error_reporter);
// Update the output state.
CalculateLstmOutputCalibration(
n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
params->activation, projection_weights_ptr, projection_bias_ptr,
params->proj_clip, output_state_ptr, scratch2, logger,
intermediate_tensor_indexes, error_reporter);
// Copy output_state to the output. Note that the output batch rows may not be
// Copy output state to the output. Note that the output's rows may not be
// contiguous (output_batch_leading_dim != n_output).
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_state_ptr + b * n_output, n_output,