Add a version of MatrixBatchVectorMultiplyAccumulate with separate scaling parameters for matrix and vectors.

This greatly simplifies Hybrid LSTM code.

PiperOrigin-RevId: 315928673
Change-Id: Ib912dac1e3b88a973f369a579864265851f5679f
This commit is contained in:
Robert David 2020-06-11 10:35:35 -07:00 committed by TensorFlower Gardener
parent 9033264944
commit d394de5e90
2 changed files with 73 additions and 93 deletions

View File

@ -162,6 +162,27 @@ void MatrixBatchVectorMultiplyAccumulate(
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context);
// Same as the function above, but provides separate scaling factor for the
// matrix and the vectors. The scaling factors are multiplied in the
// scaling_factor_scratch buffer.
inline void MatrixBatchVectorMultiplyAccumulate(
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
const int8_t* __restrict__ vectors, const float matrix_scaling_factor,
const float* vector_scaling_factors, int n_batch,
float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, float* scaling_factor_scratch,
CpuBackendContext* context) {
for (int b = 0; b < n_batch; ++b) {
scaling_factor_scratch[b] =
vector_scaling_factors[b] * matrix_scaling_factor;
}
MatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
scaling_factor_scratch, n_batch, result,
per_channel_scale, input_offset, scratch,
row_sums, compute_row_sums, context);
}
// Same as the function above, but the matrix is stored in block compressed
// sparse row format with block pattern 1x16 which consists of two arrays:
// 1. A matrix array stores non-zero blocks of the matrix in row major.

View File

@ -537,7 +537,7 @@ inline void LstmStepHybrid(
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
int output_batch_leading_dim, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* scaling_factors, float* product_scaling_factors,
float* scaling_factors, float* scaling_factors_scratch,
float* recovered_cell_weights, int8_t* quantized_input_ptr,
int8_t* quantized_aux_input_ptr, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
@ -646,49 +646,34 @@ inline void LstmStepHybrid(
quantized_input_ptr, scaling_factors,
zero_points, asymmetric_quantize_inputs);
if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, input_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_input_row_sums, compute_row_sums, context);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_forget_weights_scale;
input_to_input_weights_scale, scaling_factors, n_batch,
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, input_to_input_row_sums, compute_row_sums,
scaling_factors_scratch, context);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, forget_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_forget_row_sums, compute_row_sums, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_cell_weights_scale;
}
input_to_forget_weights_scale, scaling_factors, n_batch,
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, input_to_forget_row_sums, compute_row_sums,
scaling_factors_scratch, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, cell_scratch,
input_to_cell_weights_scale, scaling_factors, n_batch, cell_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_cell_row_sums, compute_row_sums, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_output_weights_scale;
}
input_to_cell_row_sums, compute_row_sums, scaling_factors_scratch,
context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr,
product_scaling_factors, n_batch, output_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
input_to_output_row_sums, compute_row_sums, context);
input_to_output_weights_scale, scaling_factors, n_batch,
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, input_to_output_row_sums, compute_row_sums,
scaling_factors_scratch, context);
}
// For each batch and cell: compute aux_input_weight * aux_input.
@ -700,49 +685,36 @@ inline void LstmStepHybrid(
zero_points, asymmetric_quantize_inputs);
if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch,
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, aux_input_to_input_row_sums, compute_row_sums,
context);
quantized_aux_input_ptr, aux_input_to_input_weights_scale,
scaling_factors, n_batch, input_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_to_input_row_sums, compute_row_sums,
scaling_factors_scratch, context);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_forget_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch,
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, aux_input_to_forget_row_sums, compute_row_sums,
quantized_aux_input_ptr, aux_input_to_forget_weights_scale,
scaling_factors, n_batch, forget_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_cell_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch, cell_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_to_cell_row_sums, compute_row_sums, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_output_weights_scale;
}
quantized_aux_input_ptr, aux_input_to_cell_weights_scale,
scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr,
zero_points, accum_scratch_ptr, aux_input_to_cell_row_sums,
compute_row_sums, scaling_factors_scratch, context);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
quantized_aux_input_ptr, product_scaling_factors, n_batch,
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, aux_input_to_output_row_sums, compute_row_sums,
quantized_aux_input_ptr, aux_input_to_output_weights_scale,
scaling_factors, n_batch, output_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
aux_input_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
context);
}
@ -753,49 +725,36 @@ inline void LstmStepHybrid(
scaling_factors, zero_points, asymmetric_quantize_inputs);
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
input_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, recurrent_to_input_row_sums, compute_row_sums,
context);
quantized_output_state_ptr, recurrent_to_input_weights_scale,
scaling_factors, n_batch, input_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
recurrent_to_input_row_sums, compute_row_sums,
scaling_factors_scratch, context);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_forget_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
forget_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, recurrent_to_forget_row_sums, compute_row_sums,
quantized_output_state_ptr, recurrent_to_forget_weights_scale,
scaling_factors, n_batch, forget_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
recurrent_to_forget_row_sums, compute_row_sums, scaling_factors_scratch,
context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_cell_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
cell_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, recurrent_to_cell_row_sums, compute_row_sums,
context);
quantized_output_state_ptr, recurrent_to_cell_weights_scale,
scaling_factors, n_batch, cell_scratch, /*per_channel_scale=*/nullptr,
zero_points, accum_scratch_ptr, recurrent_to_cell_row_sums,
compute_row_sums, scaling_factors_scratch, context);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_output_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
output_gate_scratch, /*per_channel_scale=*/nullptr, zero_points,
accum_scratch_ptr, recurrent_to_output_row_sums, compute_row_sums,
quantized_output_state_ptr, recurrent_to_output_weights_scale,
scaling_factors, n_batch, output_gate_scratch,
/*per_channel_scale=*/nullptr, zero_points, accum_scratch_ptr,
recurrent_to_output_row_sums, compute_row_sums, scaling_factors_scratch,
context);
}
@ -919,13 +878,13 @@ inline void LstmStepHybrid(
output_gate_scratch, n_batch, n_cell, quantized_cell_state_ptr,
scaling_factors, zero_points, asymmetric_quantize_inputs);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors_scratch[b] =
scaling_factors[b] * projection_weights_scale;
}
for (int b = 0; b < n_batch; b++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
quantized_cell_state_ptr + b * n_cell, &scaling_factors_scratch[b],
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
/*per_channel_scale=*/nullptr,
asymmetric_quantize_inputs ? &zero_points[b] : nullptr,