Internal change.
PiperOrigin-RevId: 209828735
This commit is contained in:
parent
c21e14a133
commit
5022fc95aa
@ -127,6 +127,47 @@ void LstmStep(
|
|||||||
float* cell_state_ptr, float* input_gate_scratch,
|
float* cell_state_ptr, float* input_gate_scratch,
|
||||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||||
float* output_ptr_batch) {
|
float* output_ptr_batch) {
|
||||||
|
LstmStepWithAuxInput(
|
||||||
|
input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
|
||||||
|
input_to_cell_weights_ptr, input_to_output_weights_ptr,
|
||||||
|
/*aux_input_ptr_batch=*/nullptr,
|
||||||
|
/*aux_input_to_input_weights_ptr=*/nullptr,
|
||||||
|
/*aux_input_to_forget_weights_ptr=*/nullptr,
|
||||||
|
/*aux_input_to_cell_weights_ptr=*/nullptr,
|
||||||
|
/*aux_input_to_output_weights_ptr=*/nullptr,
|
||||||
|
recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
|
||||||
|
recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
|
||||||
|
cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
|
||||||
|
cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
|
||||||
|
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
|
||||||
|
projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
|
||||||
|
output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
|
||||||
|
cell_scratch, output_gate_scratch, output_ptr_batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LstmStepWithAuxInput(
|
||||||
|
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
|
||||||
|
const float* input_to_forget_weights_ptr,
|
||||||
|
const float* input_to_cell_weights_ptr,
|
||||||
|
const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
|
||||||
|
const float* aux_input_to_input_weights_ptr,
|
||||||
|
const float* aux_input_to_forget_weights_ptr,
|
||||||
|
const float* aux_input_to_cell_weights_ptr,
|
||||||
|
const float* aux_input_to_output_weights_ptr,
|
||||||
|
const float* recurrent_to_input_weights_ptr,
|
||||||
|
const float* recurrent_to_forget_weights_ptr,
|
||||||
|
const float* recurrent_to_cell_weights_ptr,
|
||||||
|
const float* recurrent_to_output_weights_ptr,
|
||||||
|
const float* cell_to_input_weights_ptr,
|
||||||
|
const float* cell_to_forget_weights_ptr,
|
||||||
|
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
|
||||||
|
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||||
|
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
|
||||||
|
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||||
|
int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
|
||||||
|
float* cell_state_ptr, float* input_gate_scratch,
|
||||||
|
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||||
|
float* output_ptr_batch) {
|
||||||
// Since we have already checked that weights are all there or none, we can
|
// Since we have already checked that weights are all there or none, we can
|
||||||
// check the existense of only one to the get the condition.
|
// check the existense of only one to the get the condition.
|
||||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||||
@ -160,6 +201,25 @@ void LstmStep(
|
|||||||
input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
|
||||||
output_gate_scratch, /*result_stride=*/1);
|
output_gate_scratch, /*result_stride=*/1);
|
||||||
|
|
||||||
|
// If auxiliary input is available then compute aux_input_weight * aux_input
|
||||||
|
if (aux_input_ptr_batch != nullptr) {
|
||||||
|
if (!use_cifg) {
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
aux_input_to_input_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
|
||||||
|
n_batch, input_gate_scratch, /*result_stride=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
aux_input_to_forget_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
|
||||||
|
n_batch, forget_gate_scratch, /*result_stride=*/1);
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
aux_input_to_cell_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
|
||||||
|
n_batch, cell_scratch, /*result_stride=*/1);
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
aux_input_to_output_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
|
||||||
|
n_batch, output_gate_scratch, /*result_stride=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
// For each batch and cell: compute recurrent_weight * output_state.
|
// For each batch and cell: compute recurrent_weight * output_state.
|
||||||
if (!use_cifg) {
|
if (!use_cifg) {
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
@ -286,227 +346,362 @@ void LstmStep(
|
|||||||
int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
|
||||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||||
float* cell_state_ptr, float* output_ptr_batch) {
|
float* cell_state_ptr, float* output_ptr_batch) {
|
||||||
// Since we have already checked that weights are all there or none, we can
|
LstmStepWithAuxInput(
|
||||||
// check the existense of only one to the get the condition.
|
input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
|
||||||
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
input_to_forget_weights_ptr, input_to_forget_weights_scale,
|
||||||
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
input_to_cell_weights_ptr, input_to_cell_weights_scale,
|
||||||
// Initialize scratch buffers with bias.
|
input_to_output_weights_ptr, input_to_output_weights_scale,
|
||||||
if (!use_cifg) {
|
/*aux_input_ptr_batch=*/nullptr,
|
||||||
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
|
/*aux_input_to_input_weights_ptr=*/nullptr,
|
||||||
input_gate_scratch);
|
/*aux_input_to_input_weights_scale=*/0.0f,
|
||||||
}
|
/*aux_input_to_forget_weights_ptr=*/nullptr,
|
||||||
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
|
/*aux_input_to_forget_weights_scale=*/0.0f,
|
||||||
forget_gate_scratch);
|
/*aux_input_to_cell_weights_ptr=*/nullptr,
|
||||||
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
/*aux_input_to_cell_weights_scale=*/0.0f,
|
||||||
cell_scratch);
|
/*aux_input_to_output_weights_ptr=*/nullptr,
|
||||||
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
|
/*aux_input_to_output_weights_scale=*/0.0f,
|
||||||
output_gate_scratch);
|
recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
|
||||||
|
recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
|
||||||
if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
|
recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
|
||||||
// Save quantization and matmul computation for all zero input.
|
recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
|
||||||
float unused_min, unused_max;
|
cell_to_input_weights_ptr, cell_to_input_weights_scale,
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
|
||||||
const int offset = b * n_input;
|
cell_to_output_weights_ptr, cell_to_output_weights_scale,
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
|
||||||
input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
|
output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
|
||||||
&unused_min, &unused_max, &scaling_factors[b]);
|
projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
|
||||||
|
input_gate_scratch, forget_gate_scratch, cell_scratch,
|
||||||
|
output_gate_scratch, scaling_factors, product_scaling_factors,
|
||||||
|
recovered_cell_weights, quantized_input_ptr_batch,
|
||||||
|
/*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
|
||||||
|
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
|
||||||
|
output_ptr_batch);
|
||||||
}
|
}
|
||||||
// For each batch and cell: compute input_weight * input.
|
|
||||||
if (!use_cifg) {
|
void LstmStepWithAuxInput(
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
|
||||||
product_scaling_factors[b] =
|
float input_to_input_weights_scale,
|
||||||
scaling_factors[b] * input_to_input_weights_scale;
|
const int8_t* input_to_forget_weights_ptr,
|
||||||
|
float input_to_forget_weights_scale,
|
||||||
|
const int8_t* input_to_cell_weights_ptr,
|
||||||
|
float input_to_cell_weights_scale,
|
||||||
|
const int8_t* input_to_output_weights_ptr,
|
||||||
|
float input_to_output_weights_scale, const float* aux_input_ptr_batch,
|
||||||
|
const int8_t* aux_input_to_input_weights_ptr,
|
||||||
|
float aux_input_to_input_weights_scale,
|
||||||
|
const int8_t* aux_input_to_forget_weights_ptr,
|
||||||
|
float aux_input_to_forget_weights_scale,
|
||||||
|
const int8_t* aux_input_to_cell_weights_ptr,
|
||||||
|
float aux_input_to_cell_weights_scale,
|
||||||
|
const int8_t* aux_input_to_output_weights_ptr,
|
||||||
|
float aux_input_to_output_weights_scale,
|
||||||
|
const int8_t* recurrent_to_input_weights_ptr,
|
||||||
|
float recurrent_to_input_weights_scale,
|
||||||
|
const int8_t* recurrent_to_forget_weights_ptr,
|
||||||
|
float recurrent_to_forget_weights_scale,
|
||||||
|
const int8_t* recurrent_to_cell_weights_ptr,
|
||||||
|
float recurrent_to_cell_weights_scale,
|
||||||
|
const int8_t* recurrent_to_output_weights_ptr,
|
||||||
|
float recurrent_to_output_weights_scale,
|
||||||
|
const int8_t* cell_to_input_weights_ptr,
|
||||||
|
float cell_to_input_weights_scale,
|
||||||
|
const int8_t* cell_to_forget_weights_ptr,
|
||||||
|
float cell_to_forget_weights_scale,
|
||||||
|
const int8_t* cell_to_output_weights_ptr,
|
||||||
|
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
|
||||||
|
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||||
|
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
|
||||||
|
float projection_weights_scale, const float* projection_bias_ptr,
|
||||||
|
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||||
|
int n_output, float* input_gate_scratch, float* forget_gate_scratch,
|
||||||
|
float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
|
||||||
|
float* product_scaling_factors, float* recovered_cell_weights,
|
||||||
|
int8_t* quantized_input_ptr_batch,
|
||||||
|
int8_t* quantized_aux_input_ptr_batch,
|
||||||
|
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
|
||||||
|
float* output_state_ptr, float* cell_state_ptr,
|
||||||
|
float* output_ptr_batch) {
|
||||||
|
// Since we have already checked that weights are all there or none, we
|
||||||
|
// can check the existense of only one to the get the condition.
|
||||||
|
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
|
||||||
|
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
|
||||||
|
// Initialize scratch buffers with bias.
|
||||||
|
if (!use_cifg) {
|
||||||
|
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
|
||||||
|
n_batch, input_gate_scratch);
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell,
|
||||||
input_to_input_weights_ptr, n_cell, n_input,
|
n_batch, forget_gate_scratch);
|
||||||
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
|
||||||
input_gate_scratch, /*result_stride=*/1);
|
cell_scratch);
|
||||||
}
|
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell,
|
||||||
|
n_batch, output_gate_scratch);
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
|
||||||
product_scaling_factors[b] =
|
// Save quantization and matmul computation for all zero input.
|
||||||
scaling_factors[b] * input_to_forget_weights_scale;
|
float unused_min, unused_max;
|
||||||
}
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
const int offset = b * n_input;
|
||||||
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
product_scaling_factors, n_batch, forget_gate_scratch,
|
input_ptr_batch + offset, n_input,
|
||||||
/*result_stride=*/1);
|
quantized_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
|
// For each batch and cell: compute input_weight * input.
|
||||||
|
if (!use_cifg) {
|
||||||
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
|
product_scaling_factors[b] =
|
||||||
|
scaling_factors[b] * input_to_input_weights_scale;
|
||||||
|
}
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
input_to_input_weights_ptr, n_cell, n_input,
|
||||||
|
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
||||||
|
input_gate_scratch, /*result_stride=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_cell_weights_scale;
|
scaling_factors[b] * input_to_forget_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
|
input_to_forget_weights_ptr, n_cell, n_input,
|
||||||
product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
|
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
||||||
|
forget_gate_scratch,
|
||||||
|
/*result_stride=*/1);
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * input_to_output_weights_scale;
|
scaling_factors[b] * input_to_cell_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
|
input_to_cell_weights_ptr, n_cell, n_input,
|
||||||
product_scaling_factors, n_batch, output_gate_scratch,
|
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
||||||
/*result_stride=*/1);
|
cell_scratch, /*result_stride=*/1);
|
||||||
}
|
|
||||||
|
|
||||||
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
// Save quantization and matmul computation for all zero input.
|
product_scaling_factors[b] =
|
||||||
float unused_min, unused_max;
|
scaling_factors[b] * input_to_output_weights_scale;
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
}
|
||||||
const int offset = b * n_output;
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
|
input_to_output_weights_ptr, n_cell, n_input,
|
||||||
quantized_output_state_ptr + offset,
|
quantized_input_ptr_batch, product_scaling_factors, n_batch,
|
||||||
&unused_min, &unused_max,
|
output_gate_scratch,
|
||||||
&scaling_factors[b]);
|
/*result_stride=*/1);
|
||||||
}
|
|
||||||
// For each batch and cell: compute recurrent_weight * output_state.
|
|
||||||
if (!use_cifg) {
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
|
||||||
product_scaling_factors[b] =
|
|
||||||
scaling_factors[b] * recurrent_to_input_weights_scale;
|
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
|
||||||
recurrent_to_input_weights_ptr, n_cell, n_output,
|
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
|
||||||
input_gate_scratch, /*result_stride=*/1);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
if (aux_input_ptr_batch != nullptr &&
|
||||||
product_scaling_factors[b] =
|
!tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
|
||||||
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
// Save quantization and matmul computation for all zero input.
|
||||||
}
|
float unused_min, unused_max;
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
const int offset = b * n_input;
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
forget_gate_scratch, /*result_stride=*/1);
|
aux_input_ptr_batch + offset, n_input,
|
||||||
|
quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
|
// For each batch and cell: compute input_weight * input.
|
||||||
|
if (!use_cifg) {
|
||||||
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
|
product_scaling_factors[b] =
|
||||||
|
scaling_factors[b] * aux_input_to_input_weights_scale;
|
||||||
|
}
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
aux_input_to_input_weights_ptr, n_cell, n_input,
|
||||||
|
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||||
|
input_gate_scratch, /*result_stride=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
scaling_factors[b] * aux_input_to_forget_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
aux_input_to_forget_weights_ptr, n_cell, n_input,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||||
cell_scratch, /*result_stride=*/1);
|
forget_gate_scratch, /*result_stride=*/1);
|
||||||
|
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
product_scaling_factors[b] =
|
product_scaling_factors[b] =
|
||||||
scaling_factors[b] * recurrent_to_output_weights_scale;
|
scaling_factors[b] * aux_input_to_cell_weights_scale;
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
recurrent_to_output_weights_ptr, n_cell, n_output,
|
aux_input_to_cell_weights_ptr, n_cell, n_input,
|
||||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||||
output_gate_scratch, /*result_stride=*/1);
|
cell_scratch, /*result_stride=*/1);
|
||||||
}
|
|
||||||
|
|
||||||
// Save quantization and matmul computation for all zero input.
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
bool is_cell_state_all_zeros =
|
product_scaling_factors[b] =
|
||||||
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
|
scaling_factors[b] * aux_input_to_output_weights_scale;
|
||||||
|
}
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
aux_input_to_output_weights_ptr, n_cell, n_input,
|
||||||
|
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
|
||||||
|
output_gate_scratch, /*result_stride=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
// For each batch and cell: update input gate.
|
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
|
||||||
if (!use_cifg) {
|
// Save quantization and matmul computation for all zero input.
|
||||||
if (use_peephole && !is_cell_state_all_zeros) {
|
float unused_min, unused_max;
|
||||||
tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
cell_to_input_weights_scale,
|
const int offset = b * n_output;
|
||||||
recovered_cell_weights);
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
output_state_ptr + offset, n_output,
|
||||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
quantized_output_state_ptr + offset, &unused_min, &unused_max,
|
||||||
input_gate_scratch);
|
&scaling_factors[b]);
|
||||||
}
|
}
|
||||||
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
// For each batch and cell: compute recurrent_weight * output_state.
|
||||||
input_gate_scratch);
|
if (!use_cifg) {
|
||||||
}
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
|
product_scaling_factors[b] =
|
||||||
|
scaling_factors[b] * recurrent_to_input_weights_scale;
|
||||||
|
}
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
recurrent_to_input_weights_ptr, n_cell, n_output,
|
||||||
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
|
input_gate_scratch, /*result_stride=*/1);
|
||||||
|
}
|
||||||
|
|
||||||
// For each batch and cell: update forget gate.
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
if (use_peephole && !is_cell_state_all_zeros) {
|
product_scaling_factors[b] =
|
||||||
tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
|
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
||||||
cell_to_forget_weights_scale,
|
}
|
||||||
recovered_cell_weights);
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
||||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
forget_gate_scratch);
|
forget_gate_scratch, /*result_stride=*/1);
|
||||||
}
|
|
||||||
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
|
||||||
forget_gate_scratch);
|
|
||||||
|
|
||||||
// For each batch and cell: update the cell.
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
|
product_scaling_factors[b] =
|
||||||
n_batch * n_cell, cell_state_ptr);
|
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
||||||
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
}
|
||||||
params->activation, cell_scratch);
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
if (use_cifg) {
|
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
||||||
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
forget_gate_scratch);
|
cell_scratch, /*result_stride=*/1);
|
||||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
|
||||||
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
|
||||||
} else {
|
|
||||||
tensor_utils::VectorVectorCwiseProductAccumulate(
|
|
||||||
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
|
||||||
}
|
|
||||||
if (params->cell_clip > 0.0) {
|
|
||||||
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
|
||||||
params->cell_clip, cell_state_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
is_cell_state_all_zeros =
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
|
product_scaling_factors[b] =
|
||||||
// For each batch and cell: update the output gate.
|
scaling_factors[b] * recurrent_to_output_weights_scale;
|
||||||
if (use_peephole && !is_cell_state_all_zeros) {
|
}
|
||||||
tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
cell_to_output_weights_scale,
|
recurrent_to_output_weights_ptr, n_cell, n_output,
|
||||||
recovered_cell_weights);
|
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||||
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
output_gate_scratch, /*result_stride=*/1);
|
||||||
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
}
|
||||||
output_gate_scratch);
|
|
||||||
}
|
|
||||||
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
|
||||||
output_gate_scratch);
|
|
||||||
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
|
||||||
params->activation, cell_scratch);
|
|
||||||
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
|
||||||
n_batch * n_cell, output_gate_scratch);
|
|
||||||
|
|
||||||
// For each batch: update the projection and output_state.
|
|
||||||
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
|
||||||
const bool use_projection_bias = (projection_bias_ptr != nullptr);
|
|
||||||
if (use_projection_weight) {
|
|
||||||
if (use_projection_bias) {
|
|
||||||
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
|
|
||||||
n_batch, output_ptr_batch);
|
|
||||||
} else {
|
|
||||||
tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
|
|
||||||
}
|
|
||||||
if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
|
|
||||||
// Save quantization and matmul computation for all zero input.
|
// Save quantization and matmul computation for all zero input.
|
||||||
float unused_min, unused_max;
|
bool is_cell_state_all_zeros =
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
|
||||||
const int offset = b * n_cell;
|
|
||||||
tensor_utils::SymmetricQuantizeFloats(
|
// For each batch and cell: update input gate.
|
||||||
output_gate_scratch + offset, n_cell,
|
if (!use_cifg) {
|
||||||
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
|
if (use_peephole && !is_cell_state_all_zeros) {
|
||||||
&scaling_factors[b]);
|
tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
|
||||||
|
cell_to_input_weights_scale,
|
||||||
|
recovered_cell_weights);
|
||||||
|
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||||
|
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||||
|
input_gate_scratch);
|
||||||
|
}
|
||||||
|
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
|
||||||
|
input_gate_scratch);
|
||||||
}
|
}
|
||||||
for (int b = 0; b < n_batch; ++b) {
|
|
||||||
product_scaling_factors[b] =
|
// For each batch and cell: update forget gate.
|
||||||
scaling_factors[b] * projection_weights_scale;
|
if (use_peephole && !is_cell_state_all_zeros) {
|
||||||
|
tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
|
||||||
|
cell_to_forget_weights_scale,
|
||||||
|
recovered_cell_weights);
|
||||||
|
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||||
|
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||||
|
forget_gate_scratch);
|
||||||
}
|
}
|
||||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
|
||||||
projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
|
forget_gate_scratch);
|
||||||
product_scaling_factors, n_batch, output_ptr_batch,
|
|
||||||
/*result_stride=*/1);
|
// For each batch and cell: update the cell.
|
||||||
|
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
|
||||||
|
cell_state_ptr, n_batch * n_cell,
|
||||||
|
cell_state_ptr);
|
||||||
|
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
|
||||||
|
params->activation, cell_scratch);
|
||||||
|
if (use_cifg) {
|
||||||
|
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
|
||||||
|
forget_gate_scratch);
|
||||||
|
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||||
|
cell_scratch, forget_gate_scratch, n_batch * n_cell,
|
||||||
|
cell_state_ptr);
|
||||||
|
} else {
|
||||||
|
tensor_utils::VectorVectorCwiseProductAccumulate(
|
||||||
|
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
|
||||||
|
}
|
||||||
|
if (params->cell_clip > 0.0) {
|
||||||
|
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
|
||||||
|
params->cell_clip, cell_state_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
is_cell_state_all_zeros =
|
||||||
|
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
|
||||||
|
// For each batch and cell: update the output gate.
|
||||||
|
if (use_peephole && !is_cell_state_all_zeros) {
|
||||||
|
tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
|
||||||
|
cell_to_output_weights_scale,
|
||||||
|
recovered_cell_weights);
|
||||||
|
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
|
||||||
|
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
|
||||||
|
output_gate_scratch);
|
||||||
|
}
|
||||||
|
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
|
||||||
|
output_gate_scratch);
|
||||||
|
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
|
||||||
|
params->activation, cell_scratch);
|
||||||
|
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
|
||||||
|
n_batch * n_cell,
|
||||||
|
output_gate_scratch);
|
||||||
|
|
||||||
|
// For each batch: update the projection and output_state.
|
||||||
|
const bool use_projection_weight = (projection_weights_ptr != nullptr);
|
||||||
|
const bool use_projection_bias = (projection_bias_ptr != nullptr);
|
||||||
|
if (use_projection_weight) {
|
||||||
|
if (use_projection_bias) {
|
||||||
|
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
|
||||||
|
n_batch, output_ptr_batch);
|
||||||
|
} else {
|
||||||
|
tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
|
||||||
|
}
|
||||||
|
if (!tensor_utils::IsZeroVector(output_gate_scratch,
|
||||||
|
n_batch * n_cell)) {
|
||||||
|
// Save quantization and matmul computation for all zero input.
|
||||||
|
float unused_min, unused_max;
|
||||||
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
|
const int offset = b * n_cell;
|
||||||
|
tensor_utils::SymmetricQuantizeFloats(
|
||||||
|
output_gate_scratch + offset, n_cell,
|
||||||
|
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
|
||||||
|
&scaling_factors[b]);
|
||||||
|
}
|
||||||
|
for (int b = 0; b < n_batch; ++b) {
|
||||||
|
product_scaling_factors[b] =
|
||||||
|
scaling_factors[b] * projection_weights_scale;
|
||||||
|
}
|
||||||
|
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||||
|
projection_weights_ptr, n_output, n_cell,
|
||||||
|
quantized_cell_state_ptr, product_scaling_factors, n_batch,
|
||||||
|
output_ptr_batch,
|
||||||
|
/*result_stride=*/1);
|
||||||
|
}
|
||||||
|
if (params->proj_clip > 0.0) {
|
||||||
|
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
|
||||||
|
params->proj_clip, output_ptr_batch);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
|
||||||
|
output_ptr_batch);
|
||||||
|
}
|
||||||
|
tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
|
||||||
|
output_state_ptr);
|
||||||
}
|
}
|
||||||
if (params->proj_clip > 0.0) {
|
|
||||||
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
|
|
||||||
params->proj_clip, output_ptr_batch);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
|
|
||||||
output_ptr_batch);
|
|
||||||
}
|
|
||||||
tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
|
|
||||||
output_state_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace kernel_utils
|
} // namespace kernel_utils
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -92,6 +92,31 @@ void LstmStep(
|
|||||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||||
float* output_ptr_batch);
|
float* output_ptr_batch);
|
||||||
|
|
||||||
|
// Same as above but includes an auxiliary input with the corresponding weights.
|
||||||
|
void LstmStepWithAuxInput(
|
||||||
|
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
|
||||||
|
const float* input_to_forget_weights_ptr,
|
||||||
|
const float* input_to_cell_weights_ptr,
|
||||||
|
const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
|
||||||
|
const float* aux_input_to_input_weights_ptr,
|
||||||
|
const float* aux_input_to_forget_weights_ptr,
|
||||||
|
const float* aux_input_to_cell_weights_ptr,
|
||||||
|
const float* aux_input_to_output_weights_ptr,
|
||||||
|
const float* recurrent_to_input_weights_ptr,
|
||||||
|
const float* recurrent_to_forget_weights_ptr,
|
||||||
|
const float* recurrent_to_cell_weights_ptr,
|
||||||
|
const float* recurrent_to_output_weights_ptr,
|
||||||
|
const float* cell_to_input_weights_ptr,
|
||||||
|
const float* cell_to_forget_weights_ptr,
|
||||||
|
const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
|
||||||
|
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||||
|
const float* output_gate_bias_ptr, const float* projection_weights_ptr,
|
||||||
|
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
|
||||||
|
int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
|
||||||
|
float* cell_state_ptr, float* input_gate_scratch,
|
||||||
|
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||||
|
float* output_ptr_batch);
|
||||||
|
|
||||||
// Same as above but with quantized weight matrices. In detail:
|
// Same as above but with quantized weight matrices. In detail:
|
||||||
// Input of size 'n_batch * n_input':
|
// Input of size 'n_batch * n_input':
|
||||||
// input_ptr_batch
|
// input_ptr_batch
|
||||||
@ -175,6 +200,46 @@ void LstmStep(
|
|||||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||||
float* cell_state_ptr, float* output_ptr_batch);
|
float* cell_state_ptr, float* output_ptr_batch);
|
||||||
|
|
||||||
|
void LstmStepWithAuxInput(
|
||||||
|
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
|
||||||
|
float input_to_input_weights_scale,
|
||||||
|
const int8_t* input_to_forget_weights_ptr,
|
||||||
|
float input_to_forget_weights_scale,
|
||||||
|
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
|
||||||
|
const int8_t* input_to_output_weights_ptr,
|
||||||
|
float input_to_output_weights_scale, const float* aux_input_ptr_batch,
|
||||||
|
const int8_t* aux_input_to_input_weights_ptr,
|
||||||
|
float aux_input_to_input_weights_scale,
|
||||||
|
const int8_t* aux_input_to_forget_weights_ptr,
|
||||||
|
float aux_input_to_forget_weights_scale,
|
||||||
|
const int8_t* aux_input_to_cell_weights_ptr,
|
||||||
|
float aux_input_to_cell_weights_scale,
|
||||||
|
const int8_t* aux_input_to_output_weights_ptr,
|
||||||
|
float aux_input_to_output_weights_scale,
|
||||||
|
const int8_t* recurrent_to_input_weights_ptr,
|
||||||
|
float recurrent_to_input_weights_scale,
|
||||||
|
const int8_t* recurrent_to_forget_weights_ptr,
|
||||||
|
float recurrent_to_forget_weights_scale,
|
||||||
|
const int8_t* recurrent_to_cell_weights_ptr,
|
||||||
|
float recurrent_to_cell_weights_scale,
|
||||||
|
const int8_t* recurrent_to_output_weights_ptr,
|
||||||
|
float recurrent_to_output_weights_scale,
|
||||||
|
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
|
||||||
|
const int8_t* cell_to_forget_weights_ptr,
|
||||||
|
float cell_to_forget_weights_scale,
|
||||||
|
const int8_t* cell_to_output_weights_ptr,
|
||||||
|
float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
|
||||||
|
const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
|
||||||
|
const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
|
||||||
|
float projection_weights_scale, const float* projection_bias_ptr,
|
||||||
|
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
|
||||||
|
int n_output, float* input_gate_scratch, float* forget_gate_scratch,
|
||||||
|
float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
|
||||||
|
float* product_scaling_factors, float* recovered_cell_weights,
|
||||||
|
int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch,
|
||||||
|
int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
|
||||||
|
float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch);
|
||||||
|
|
||||||
} // namespace kernel_utils
|
} // namespace kernel_utils
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
|
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user