Use the separate matrix/vector scaling factor version of MatrixBatchVectorMultiplyAccumulate in projection too.
PiperOrigin-RevId: 316926050 Change-Id: I64e9febce8590231c78a90a0a5aec5da11996195
This commit is contained in:
parent
c8dd07ae28
commit
ab7cb8336c
@ -874,16 +874,13 @@ inline void LstmStepHybrid(
|
||||
tensor_utils::BatchQuantizeFloats(
|
||||
output_gate_scratch, n_batch, n_cell, quantized_cell_state_ptr,
|
||||
scaling_factors, zero_points, asymmetric_quantize_inputs);
|
||||
for (int b = 0; b < n_batch; ++b) {
|
||||
scaling_factors_scratch[b] =
|
||||
scaling_factors[b] * projection_weights_scale;
|
||||
}
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
|
||||
scaling_factors_scratch, n_batch, output_state_ptr,
|
||||
projection_weights_scale, scaling_factors, n_batch, output_state_ptr,
|
||||
/*per_channel_scale=*/nullptr,
|
||||
asymmetric_quantize_inputs ? zero_points : nullptr, accum_scratch_ptr,
|
||||
projection_weights_row_sums, compute_row_sums, context);
|
||||
projection_weights_row_sums, compute_row_sums,
|
||||
scaling_factors_scratch, context);
|
||||
}
|
||||
if (params->proj_clip > 0.0) {
|
||||
tensor_utils::ClipVector(output_state_ptr, n_batch * n_output,
|
||||
|
Loading…
Reference in New Issue
Block a user