Add a version of MatrixBatchVectorMultiplyAccumulate with separate scaling parameters for matrix and vectors.
This greatly simplifies Hybrid LSTM code. PiperOrigin-RevId: 315928673 Change-Id: Ib912dac1e3b88a973f369a579864265851f5679f
This commit is contained in:
parent
9033264944
commit
d394de5e90
@ -162,6 +162,27 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context);
|
||||
|
||||
// Same as the function above, but provides separate scaling factor for the
|
||||
// matrix and the vectors. The scaling factors are multiplied in the
|
||||
// scaling_factor_scratch buffer.
|
||||
inline void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float matrix_scaling_factor,
|
||||
const float* vector_scaling_factors, int n_batch,
|
||||
float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, float* scaling_factor_scratch,
|
||||
CpuBackendContext* context) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
scaling_factor_scratch[b] =
|
||||
vector_scaling_factors[b] * matrix_scaling_factor;
|
||||
}
|
||||
MatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
||||
scaling_factor_scratch, n_batch, result,
|
||||
per_channel_scale, input_offset, scratch,
|
||||
row_sums, compute_row_sums, context);
|
||||
}
|
||||
|
||||
// Same as the function above, but the matrix is stored in block compressed
|
||||
// sparse row format with block pattern 1x16 which consists of two arrays:
|
||||
// 1. A matrix array stores non-zero blocks of the matrix in row major.
|
||||
|
@ -537,7 +537,7 @@ inline void LstmStepHybrid(
|
||||
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
|
||||
int output_batch_leading_dim, float* input_gate_scratch,
|
||||
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
|
||||
float* scaling_factors, float* product_scaling_factors,
|
||||
float* scaling_factors, float* scaling_factors_scratch,
|
||||
float* recovered_cell_weights, int8_t* quantized_input_ptr,
|
||||
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
|
||||
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
|
||||
@ -646,49 +646,34 @@ inline void LstmStepHybrid(
|
||||
quantized_input_ptr, scaling_factors,
|
||||
zero_points, asymmetric_quantize_inputs);
|
||||
if (!use_cifg) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch, input_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
input_to_input_row_sums, compute_row_sums, context);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_forget_weights_scale;
|
||||
input_to_input_weights_scale, scaling_factors, n_batch,
|
||||
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, input_to_input_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
}
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch, forget_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
input_to_forget_row_sums, compute_row_sums, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_cell_weights_scale;
|
||||
}
|
||||
input_to_forget_weights_scale, scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, input_to_forget_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch, cell_scratch,
|
||||
input_to_cell_weights_scale, scaling_factors, n_batch, cell_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
input_to_cell_row_sums, compute_row_sums, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * input_to_output_weights_scale;
|
||||
}
|
||||
input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
|
||||
product_scaling_factors, n_batch, output_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
input_to_output_row_sums, compute_row_sums, context);
|
||||
input_to_output_weights_scale, scaling_factors, n_batch,
|
||||
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, input_to_output_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
}
|
||||
|
||||
// For each batch and cell: compute aux_input_weight * aux_input.
|
||||
@ -700,49 +685,36 @@ inline void LstmStepHybrid(
|
||||
zero_points, asymmetric_quantize_inputs);
|
||||
|
||||
if (!use_cifg) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums,
|
||||
context);
|
||||
quantized_aux_input_ptr, aux_input_to_input_weights_scale,
|
||||
scaling_factors, n_batch, input_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
aux_input_to_input_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums,
|
||||
quantized_aux_input_ptr, aux_input_to_forget_weights_scale,
|
||||
scaling_factors, n_batch, forget_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
aux_input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
aux_input_to_cell_row_sums, compute_row_sums, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * aux_input_to_output_weights_scale;
|
||||
}
|
||||
quantized_aux_input_ptr, aux_input_to_cell_weights_scale,
|
||||
scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr,
|
||||
zero_points, accum_scratch_ptr, aux_input_to_cell_row_sums,
|
||||
compute_row_sums, scaling_factors_scratch, context);
|
||||
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
|
||||
quantized_aux_input_ptr, product_scaling_factors, n_batch,
|
||||
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums,
|
||||
quantized_aux_input_ptr, aux_input_to_output_weights_scale,
|
||||
scaling_factors, n_batch, output_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
aux_input_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
}
|
||||
|
||||
@ -753,49 +725,36 @@ inline void LstmStepHybrid(
|
||||
scaling_factors, zero_points, asymmetric_quantize_inputs);
|
||||
// For each batch and cell: compute recurrent_weight * output_state.
|
||||
if (!use_cifg) {
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_input_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_input_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums,
|
||||
context);
|
||||
quantized_output_state_ptr, recurrent_to_input_weights_scale,
|
||||
scaling_factors, n_batch, input_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
recurrent_to_input_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
}
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_forget_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_forget_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums,
|
||||
quantized_output_state_ptr, recurrent_to_forget_weights_scale,
|
||||
scaling_factors, n_batch, forget_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
recurrent_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_cell_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_cell_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
cell_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums,
|
||||
context);
|
||||
quantized_output_state_ptr, recurrent_to_cell_weights_scale,
|
||||
scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr,
|
||||
zero_points, accum_scratch_ptr, recurrent_to_cell_row_sums,
|
||||
compute_row_sums, scaling_factors_scratch, context);
|
||||
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * recurrent_to_output_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
recurrent_to_output_weights_ptr, n_cell, n_output,
|
||||
quantized_output_state_ptr, product_scaling_factors, n_batch,
|
||||
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
|
||||
accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums,
|
||||
quantized_output_state_ptr, recurrent_to_output_weights_scale,
|
||||
scaling_factors, n_batch, output_gate_scratch,
|
||||
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
|
||||
recurrent_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
|
||||
context);
|
||||
}
|
||||
|
||||
@ -919,13 +878,13 @@ inline void LstmStepHybrid(
|
||||
output_gate_scratch, n_batch, n_cell, quantized_cell_state_ptr,
|
||||
scaling_factors, zero_points, asymmetric_quantize_inputs);
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors_scratch[b] =
|
||||
scaling_factors[b] * projection_weights_scale;
|
||||
}
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights_ptr, n_output, n_cell,
|
||||
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
|
||||
quantized_cell_state_ptr + b * n_cell, &scaling_factors_scratch[b],
|
||||
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
|
||||
/*per_channel_scale=*/nullptr,
|
||||
asymmetric_quantize_inputs ? &zero_points[b] : nullptr,
|
||||
|
Loading…
x
Reference in New Issue
Block a user