Internal change.

PiperOrigin-RevId: 209828735
This commit is contained in:
A. Unique TensorFlower 2018-08-22 14:19:32 -07:00 committed by TensorFlower Gardener
parent c21e14a133
commit 5022fc95aa
2 changed files with 461 additions and 201 deletions

View File

@ -127,6 +127,47 @@ void LstmStep(
float* cell_state_ptr, float* input_gate_scratch, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch) { float* output_ptr_batch) {
LstmStepWithAuxInput(
input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
input_to_cell_weights_ptr, input_to_output_weights_ptr,
/*aux_input_ptr_batch=*/nullptr,
/*aux_input_to_input_weights_ptr=*/nullptr,
/*aux_input_to_forget_weights_ptr=*/nullptr,
/*aux_input_to_cell_weights_ptr=*/nullptr,
/*aux_input_to_output_weights_ptr=*/nullptr,
recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
cell_scratch, output_gate_scratch, output_ptr_batch);
}
void LstmStepWithAuxInput(
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
const float* input_to_forget_weights_ptr,
const float* input_to_cell_weights_ptr,
const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
const float* aux_input_to_input_weights_ptr,
const float* aux_input_to_forget_weights_ptr,
const float* aux_input_to_cell_weights_ptr,
const float* aux_input_to_output_weights_ptr,
const float* recurrent_to_input_weights_ptr,
const float* recurrent_to_forget_weights_ptr,
const float* recurrent_to_cell_weights_ptr,
const float* recurrent_to_output_weights_ptr,
const float* cell_to_input_weights_ptr,
const float* cell_to_forget_weights_ptr,
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch) {
// Since we have already checked that weights are all there or none, we can // Since we have already checked that weights are all there or none, we can
// check the existense of only one to the get the condition. // check the existense of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr); const bool use_cifg = (input_to_input_weights_ptr == nullptr);
@ -160,6 +201,25 @@ void LstmStep(
input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
output_gate_scratch, /*result_stride=*/1); output_gate_scratch, /*result_stride=*/1);
// If auxiliary input is available then compute aux_input_weight * aux_input
if (aux_input_ptr_batch != nullptr) {
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
n_batch, input_gate_scratch, /*result_stride=*/1);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
n_batch, forget_gate_scratch, /*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
n_batch, cell_scratch, /*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
n_batch, output_gate_scratch, /*result_stride=*/1);
}
// For each batch and cell: compute recurrent_weight * output_state. // For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) { if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
@ -286,227 +346,362 @@ void LstmStep(
int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr_batch) { float* cell_state_ptr, float* output_ptr_batch) {
// Since we have already checked that weights are all there or none, we can LstmStepWithAuxInput(
// check the existense of only one to the get the condition. input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
const bool use_cifg = (input_to_input_weights_ptr == nullptr); input_to_forget_weights_ptr, input_to_forget_weights_scale,
const bool use_peephole = (cell_to_output_weights_ptr != nullptr); input_to_cell_weights_ptr, input_to_cell_weights_scale,
// Initialize scratch buffers with bias. input_to_output_weights_ptr, input_to_output_weights_scale,
if (!use_cifg) { /*aux_input_ptr_batch=*/nullptr,
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, /*aux_input_to_input_weights_ptr=*/nullptr,
input_gate_scratch); /*aux_input_to_input_weights_scale=*/0.0f,
} /*aux_input_to_forget_weights_ptr=*/nullptr,
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, /*aux_input_to_forget_weights_scale=*/0.0f,
forget_gate_scratch); /*aux_input_to_cell_weights_ptr=*/nullptr,
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, /*aux_input_to_cell_weights_scale=*/0.0f,
cell_scratch); /*aux_input_to_output_weights_ptr=*/nullptr,
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, /*aux_input_to_output_weights_scale=*/0.0f,
output_gate_scratch); recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
// Save quantization and matmul computation for all zero input. recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
float unused_min, unused_max; cell_to_input_weights_ptr, cell_to_input_weights_scale,
for (int b = 0; b < n_batch; ++b) { cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
const int offset = b * n_input; cell_to_output_weights_ptr, cell_to_output_weights_scale,
tensor_utils::SymmetricQuantizeFloats( input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
&unused_min, &unused_max, &scaling_factors[b]); projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
input_gate_scratch, forget_gate_scratch, cell_scratch,
output_gate_scratch, scaling_factors, product_scaling_factors,
recovered_cell_weights, quantized_input_ptr_batch,
/*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
output_ptr_batch);
} }
// For each batch and cell: compute input_weight * input.
if (!use_cifg) { void LstmStepWithAuxInput(
for (int b = 0; b < n_batch; ++b) { const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
product_scaling_factors[b] = float input_to_input_weights_scale,
scaling_factors[b] * input_to_input_weights_scale; const int8_t* input_to_forget_weights_ptr,
float input_to_forget_weights_scale,
const int8_t* input_to_cell_weights_ptr,
float input_to_cell_weights_scale,
const int8_t* input_to_output_weights_ptr,
float input_to_output_weights_scale, const float* aux_input_ptr_batch,
const int8_t* aux_input_to_input_weights_ptr,
float aux_input_to_input_weights_scale,
const int8_t* aux_input_to_forget_weights_ptr,
float aux_input_to_forget_weights_scale,
const int8_t* aux_input_to_cell_weights_ptr,
float aux_input_to_cell_weights_scale,
const int8_t* aux_input_to_output_weights_ptr,
float aux_input_to_output_weights_scale,
const int8_t* recurrent_to_input_weights_ptr,
float recurrent_to_input_weights_scale,
const int8_t* recurrent_to_forget_weights_ptr,
float recurrent_to_forget_weights_scale,
const int8_t* recurrent_to_cell_weights_ptr,
float recurrent_to_cell_weights_scale,
const int8_t* recurrent_to_output_weights_ptr,
float recurrent_to_output_weights_scale,
const int8_t* cell_to_input_weights_ptr,
float cell_to_input_weights_scale,
const int8_t* cell_to_forget_weights_ptr,
float cell_to_forget_weights_scale,
const int8_t* cell_to_output_weights_ptr,
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
float projection_weights_scale, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_output, float* input_gate_scratch, float* forget_gate_scratch,
float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
float* product_scaling_factors, float* recovered_cell_weights,
int8_t* quantized_input_ptr_batch,
int8_t* quantized_aux_input_ptr_batch,
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
float* output_state_ptr, float* cell_state_ptr,
float* output_ptr_batch) {
// Since we have already checked that weights are all there or none, we
// can check the existense of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
// Initialize scratch buffers with bias.
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
n_batch, input_gate_scratch);
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell,
input_to_input_weights_ptr, n_cell, n_input, n_batch, forget_gate_scratch);
quantized_input_ptr_batch, product_scaling_factors, n_batch, tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
input_gate_scratch, /*result_stride=*/1); cell_scratch);
} tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell,
n_batch, output_gate_scratch);
for (int b = 0; b < n_batch; ++b) { if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
product_scaling_factors[b] = // Save quantization and matmul computation for all zero input.
scaling_factors[b] * input_to_forget_weights_scale; float unused_min, unused_max;
} for (int b = 0; b < n_batch; ++b) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate( const int offset = b * n_input;
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, tensor_utils::SymmetricQuantizeFloats(
product_scaling_factors, n_batch, forget_gate_scratch, input_ptr_batch + offset, n_input,
/*result_stride=*/1); quantized_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
// For each batch and cell: compute input_weight * input.
if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input,
quantized_input_ptr_batch, product_scaling_factors, n_batch,
input_gate_scratch, /*result_stride=*/1);
}
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_cell_weights_scale; scaling_factors[b] * input_to_forget_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, input_to_forget_weights_ptr, n_cell, n_input,
product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); quantized_input_ptr_batch, product_scaling_factors, n_batch,
forget_gate_scratch,
/*result_stride=*/1);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * input_to_output_weights_scale; scaling_factors[b] * input_to_cell_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, input_to_cell_weights_ptr, n_cell, n_input,
product_scaling_factors, n_batch, output_gate_scratch, quantized_input_ptr_batch, product_scaling_factors, n_batch,
/*result_stride=*/1); cell_scratch, /*result_stride=*/1);
}
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { for (int b = 0; b < n_batch; ++b) {
// Save quantization and matmul computation for all zero input. product_scaling_factors[b] =
float unused_min, unused_max; scaling_factors[b] * input_to_output_weights_scale;
for (int b = 0; b < n_batch; ++b) { }
const int offset = b * n_output; tensor_utils::MatrixBatchVectorMultiplyAccumulate(
tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, input_to_output_weights_ptr, n_cell, n_input,
quantized_output_state_ptr + offset, quantized_input_ptr_batch, product_scaling_factors, n_batch,
&unused_min, &unused_max, output_gate_scratch,
&scaling_factors[b]); /*result_stride=*/1);
}
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_input_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
input_gate_scratch, /*result_stride=*/1);
}
for (int b = 0; b < n_batch; ++b) { if (aux_input_ptr_batch != nullptr &&
product_scaling_factors[b] = !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
scaling_factors[b] * recurrent_to_forget_weights_scale; // Save quantization and matmul computation for all zero input.
} float unused_min, unused_max;
tensor_utils::MatrixBatchVectorMultiplyAccumulate( for (int b = 0; b < n_batch; ++b) {
recurrent_to_forget_weights_ptr, n_cell, n_output, const int offset = b * n_input;
quantized_output_state_ptr, product_scaling_factors, n_batch, tensor_utils::SymmetricQuantizeFloats(
forget_gate_scratch, /*result_stride=*/1); aux_input_ptr_batch + offset, n_input,
quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
// For each batch and cell: compute input_weight * input.
if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_input,
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
input_gate_scratch, /*result_stride=*/1);
}
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_cell_weights_scale; scaling_factors[b] * aux_input_to_forget_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, aux_input_to_forget_weights_ptr, n_cell, n_input,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
cell_scratch, /*result_stride=*/1); forget_gate_scratch, /*result_stride=*/1);
for (int b = 0; b < n_batch; ++b) { for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_output_weights_scale; scaling_factors[b] * aux_input_to_cell_weights_scale;
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, aux_input_to_cell_weights_ptr, n_cell, n_input,
quantized_output_state_ptr, product_scaling_factors, n_batch, quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
output_gate_scratch, /*result_stride=*/1); cell_scratch, /*result_stride=*/1);
}
// Save quantization and matmul computation for all zero input. for (int b = 0; b < n_batch; ++b) {
bool is_cell_state_all_zeros = product_scaling_factors[b] =
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); scaling_factors[b] * aux_input_to_output_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_input,
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
output_gate_scratch, /*result_stride=*/1);
}
// For each batch and cell: update input gate. if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
if (!use_cifg) { // Save quantization and matmul computation for all zero input.
if (use_peephole && !is_cell_state_all_zeros) { float unused_min, unused_max;
tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, for (int b = 0; b < n_batch; ++b) {
cell_to_input_weights_scale, const int offset = b * n_output;
recovered_cell_weights); tensor_utils::SymmetricQuantizeFloats(
tensor_utils::VectorBatchVectorCwiseProductAccumulate( output_state_ptr + offset, n_output,
recovered_cell_weights, n_cell, cell_state_ptr, n_batch, quantized_output_state_ptr + offset, &unused_min, &unused_max,
input_gate_scratch); &scaling_factors[b]);
} }
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, // For each batch and cell: compute recurrent_weight * output_state.
input_gate_scratch); if (!use_cifg) {
} for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
input_gate_scratch, /*result_stride=*/1);
}
// For each batch and cell: update forget gate. for (int b = 0; b < n_batch; ++b) {
if (use_peephole && !is_cell_state_all_zeros) { product_scaling_factors[b] =
tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, scaling_factors[b] * recurrent_to_forget_weights_scale;
cell_to_forget_weights_scale, }
recovered_cell_weights); tensor_utils::MatrixBatchVectorMultiplyAccumulate(
tensor_utils::VectorBatchVectorCwiseProductAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output,
recovered_cell_weights, n_cell, cell_state_ptr, n_batch, quantized_output_state_ptr, product_scaling_factors, n_batch,
forget_gate_scratch); forget_gate_scratch, /*result_stride=*/1);
}
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell. for (int b = 0; b < n_batch; ++b) {
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, product_scaling_factors[b] =
n_batch * n_cell, cell_state_ptr); scaling_factors[b] * recurrent_to_cell_weights_scale;
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, }
params->activation, cell_scratch); tensor_utils::MatrixBatchVectorMultiplyAccumulate(
if (use_cifg) { recurrent_to_cell_weights_ptr, n_cell, n_output,
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, quantized_output_state_ptr, product_scaling_factors, n_batch,
forget_gate_scratch); cell_scratch, /*result_stride=*/1);
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
} else {
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
}
if (params->cell_clip > 0.0) {
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
params->cell_clip, cell_state_ptr);
}
is_cell_state_all_zeros = for (int b = 0; b < n_batch; ++b) {
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); product_scaling_factors[b] =
// For each batch and cell: update the output gate. scaling_factors[b] * recurrent_to_output_weights_scale;
if (use_peephole && !is_cell_state_all_zeros) { }
tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, tensor_utils::MatrixBatchVectorMultiplyAccumulate(
cell_to_output_weights_scale, recurrent_to_output_weights_ptr, n_cell, n_output,
recovered_cell_weights); quantized_output_state_ptr, product_scaling_factors, n_batch,
tensor_utils::VectorBatchVectorCwiseProductAccumulate( output_gate_scratch, /*result_stride=*/1);
recovered_cell_weights, n_cell, cell_state_ptr, n_batch, }
output_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
n_batch * n_cell, output_gate_scratch);
// For each batch: update the projection and output_state.
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
if (use_projection_weight) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_ptr_batch);
} else {
tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
}
if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
// Save quantization and matmul computation for all zero input. // Save quantization and matmul computation for all zero input.
float unused_min, unused_max; bool is_cell_state_all_zeros =
for (int b = 0; b < n_batch; ++b) { tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
const int offset = b * n_cell;
tensor_utils::SymmetricQuantizeFloats( // For each batch and cell: update input gate.
output_gate_scratch + offset, n_cell, if (!use_cifg) {
quantized_cell_state_ptr + offset, &unused_min, &unused_max, if (use_peephole && !is_cell_state_all_zeros) {
&scaling_factors[b]); tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
cell_to_input_weights_scale,
recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
} }
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] = // For each batch and cell: update forget gate.
scaling_factors[b] * projection_weights_scale; if (use_peephole && !is_cell_state_all_zeros) {
tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
cell_to_forget_weights_scale,
recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
} }
tensor_utils::MatrixBatchVectorMultiplyAccumulate( tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, forget_gate_scratch);
product_scaling_factors, n_batch, output_ptr_batch,
/*result_stride=*/1); // For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
cell_state_ptr, n_batch * n_cell,
cell_state_ptr);
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params->activation, cell_scratch);
if (use_cifg) {
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, forget_gate_scratch, n_batch * n_cell,
cell_state_ptr);
} else {
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
}
if (params->cell_clip > 0.0) {
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
params->cell_clip, cell_state_ptr);
}
is_cell_state_all_zeros =
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
// For each batch and cell: update the output gate.
if (use_peephole && !is_cell_state_all_zeros) {
tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
cell_to_output_weights_scale,
recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
n_batch * n_cell,
output_gate_scratch);
// For each batch: update the projection and output_state.
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
if (use_projection_weight) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_ptr_batch);
} else {
tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
}
if (!tensor_utils::IsZeroVector(output_gate_scratch,
n_batch * n_cell)) {
// Save quantization and matmul computation for all zero input.
float unused_min, unused_max;
for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_cell;
tensor_utils::SymmetricQuantizeFloats(
output_gate_scratch + offset, n_cell,
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * projection_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
quantized_cell_state_ptr, product_scaling_factors, n_batch,
output_ptr_batch,
/*result_stride=*/1);
}
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
params->proj_clip, output_ptr_batch);
}
} else {
tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
output_ptr_batch);
}
tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
output_state_ptr);
} }
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
params->proj_clip, output_ptr_batch);
}
} else {
tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
output_ptr_batch);
}
tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
output_state_ptr);
}
} // namespace kernel_utils } // namespace kernel_utils
} // namespace tflite } // namespace tflite

View File

@ -92,6 +92,31 @@ void LstmStep(
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch); float* output_ptr_batch);
// Same as above but includes an auxiliary input with the corresponding weights.
void LstmStepWithAuxInput(
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
const float* input_to_forget_weights_ptr,
const float* input_to_cell_weights_ptr,
const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
const float* aux_input_to_input_weights_ptr,
const float* aux_input_to_forget_weights_ptr,
const float* aux_input_to_cell_weights_ptr,
const float* aux_input_to_output_weights_ptr,
const float* recurrent_to_input_weights_ptr,
const float* recurrent_to_forget_weights_ptr,
const float* recurrent_to_cell_weights_ptr,
const float* recurrent_to_output_weights_ptr,
const float* cell_to_input_weights_ptr,
const float* cell_to_forget_weights_ptr,
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch);
// Same as above but with quantized weight matrices. In detail: // Same as above but with quantized weight matrices. In detail:
// Input of size 'n_batch * n_input': // Input of size 'n_batch * n_input':
// input_ptr_batch // input_ptr_batch
@ -175,6 +200,46 @@ void LstmStep(
int8_t* quantized_cell_state_ptr, float* output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr_batch); float* cell_state_ptr, float* output_ptr_batch);
void LstmStepWithAuxInput(
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
float input_to_input_weights_scale,
const int8_t* input_to_forget_weights_ptr,
float input_to_forget_weights_scale,
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
const int8_t* input_to_output_weights_ptr,
float input_to_output_weights_scale, const float* aux_input_ptr_batch,
const int8_t* aux_input_to_input_weights_ptr,
float aux_input_to_input_weights_scale,
const int8_t* aux_input_to_forget_weights_ptr,
float aux_input_to_forget_weights_scale,
const int8_t* aux_input_to_cell_weights_ptr,
float aux_input_to_cell_weights_scale,
const int8_t* aux_input_to_output_weights_ptr,
float aux_input_to_output_weights_scale,
const int8_t* recurrent_to_input_weights_ptr,
float recurrent_to_input_weights_scale,
const int8_t* recurrent_to_forget_weights_ptr,
float recurrent_to_forget_weights_scale,
const int8_t* recurrent_to_cell_weights_ptr,
float recurrent_to_cell_weights_scale,
const int8_t* recurrent_to_output_weights_ptr,
float recurrent_to_output_weights_scale,
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
const int8_t* cell_to_forget_weights_ptr,
float cell_to_forget_weights_scale,
const int8_t* cell_to_output_weights_ptr,
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
float projection_weights_scale, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_output, float* input_gate_scratch, float* forget_gate_scratch,
float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
float* product_scaling_factors, float* recovered_cell_weights,
int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch,
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch);
} // namespace kernel_utils } // namespace kernel_utils
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_