LSTM: do projection to output_state instead of output. Because the two arrays are only different in stride (state has no stride), this allows us to do the projection in a batched manner.

Copy the result to the strided output after projection.

PiperOrigin-RevId: 316560275
Change-Id: I60c544d10a64437ece1fa75eea891af4b97df231
This commit is contained in:
Robert David 2020-06-15 15:42:09 -07:00 committed by TensorFlower Gardener
parent 1eeaa79d66
commit 1158611838
2 changed files with 46 additions and 82 deletions

View File

@ -391,40 +391,29 @@ inline void LstmStepFloat(
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
// For each batch: update the projection and output_state. Note that since
// the output batch rows may not be contiguous (output_batch_leading_dim !=
// n_output), we unroll batched operations.
// For each batch: update output_state.
if (use_projection_weight) {
if (use_projection_bias) {
for (int b = 0; b < n_batch; b++) {
std::copy_n(projection_bias_ptr, n_output,
output_ptr + b * output_batch_leading_dim);
}
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_state_ptr);
} else {
for (int b = 0; b < n_batch; b++) {
std::fill_n(output_ptr + b * output_batch_leading_dim, n_output, 0.0f);
}
std::fill_n(output_state_ptr, n_batch * n_output, 0.0f);
}
for (int b = 0; b < n_batch; b++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
output_gate_scratch + b * n_cell,
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim);
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(output_ptr + b * output_batch_leading_dim,
n_output, params->proj_clip,
output_ptr + b * output_batch_leading_dim);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
output_state_ptr);
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(output_state_ptr, n_batch * n_output,
params->proj_clip, output_state_ptr);
}
} else {
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_gate_scratch + b * n_output, n_output,
output_ptr + b * output_batch_leading_dim);
}
std::copy_n(output_gate_scratch, n_batch * n_output, output_state_ptr);
}
// Copy output_state to the output. Note that the output batch rows may not be
// contiguous (output_batch_leading_dim != n_output).
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_ptr + b * output_batch_leading_dim, n_output,
output_state_ptr + b * n_output);
std::copy_n(output_state_ptr + b * n_output, n_output,
output_ptr + b * output_batch_leading_dim);
}
}
// LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
@ -863,14 +852,10 @@ inline void LstmStepHybrid(
// n_output), we unroll the batched operations.
if (use_projection_weight) {
if (use_projection_bias) {
for (int b = 0; b < n_batch; b++) {
std::copy_n(projection_bias_ptr, n_output,
output_ptr + b * output_batch_leading_dim);
}
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_state_ptr);
} else {
for (int b = 0; b < n_batch; b++) {
std::fill_n(output_ptr + b * output_batch_leading_dim, n_output, 0.0f);
}
std::fill_n(output_state_ptr, n_batch * n_output, 0.0f);
}
if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
// Save quantization and matmul computation for all zero input.
@ -881,35 +866,25 @@ inline void LstmStepHybrid(
scaling_factors_scratch[b] =
scaling_factors[b] * projection_weights_scale;
}
for (int b = 0; b < n_batch; b++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
quantized_cell_state_ptr + b * n_cell, &scaling_factors_scratch[b],
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
/*per_channel_scale=*/nullptr,
asymmetric_quantize_inputs ? &zero_points[b] : nullptr,
accum_scratch_ptr, projection_weights_row_sums, compute_row_sums,
context);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
scaling_factors_scratch, n_batch, output_state_ptr,
/*per_channel_scale=*/nullptr,
asymmetric_quantize_inputs ? zero_points : nullptr, accum_scratch_ptr,
projection_weights_row_sums, compute_row_sums, context);
}
if (params->proj_clip > 0.0) {
for (int b = 0; b < n_batch; b++) {
tensor_utils::ClipVector(output_ptr + b * output_batch_leading_dim,
n_output, params->proj_clip,
output_ptr + b * output_batch_leading_dim);
}
tensor_utils::ClipVector(output_state_ptr, n_batch * n_output,
params->proj_clip, output_state_ptr);
}
} else {
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_gate_scratch + b * n_output, n_output,
output_ptr + b * output_batch_leading_dim);
}
std::copy_n(output_gate_scratch, n_batch * n_output, output_state_ptr);
}
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_ptr + b * output_batch_leading_dim, n_output,
output_state_ptr + b * n_output);
std::copy_n(output_state_ptr + b * n_output, n_output,
output_ptr + b * output_batch_leading_dim);
}
}
} // namespace
// Fully quantized lstm kernel for 16 bit gate matmul output.
//

View File

@ -249,40 +249,29 @@ inline void LstmStepWithAuxInput(
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
// For each batch: update the projection and output_state. Note that since
// the output batch rows may not be contiguous (output_batch_leading_dim !=
// n_output), we unroll batched operations.
// For each batch: update output_state.
if (use_projection_weight) {
if (use_projection_bias) {
for (int k = 0; k < n_batch; k++) {
std::copy_n(projection_bias_ptr, n_output,
output_ptr + k * output_batch_leading_dim);
}
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_state_ptr);
} else {
for (int k = 0; k < n_batch; k++) {
std::fill_n(output_ptr + k * output_batch_leading_dim, n_output, 0.0f);
}
std::fill_n(output_state_ptr, n_batch * n_output, 0.0f);
}
for (int k = 0; k < n_batch; k++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
output_gate_scratch + k * n_cell,
/*n_batch=*/1, output_ptr + k * output_batch_leading_dim);
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(output_ptr + k * output_batch_leading_dim,
n_output, params->proj_clip,
output_ptr + k * output_batch_leading_dim);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
output_state_ptr);
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(output_state_ptr, n_batch * n_output,
params->proj_clip, output_state_ptr);
}
} else {
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_gate_scratch + k * n_output, n_output,
output_ptr + k * output_batch_leading_dim);
}
std::copy_n(output_gate_scratch, n_batch * n_output, output_state_ptr);
}
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_ptr + k * output_batch_leading_dim, n_output,
output_state_ptr + k * n_output);
// Copy output_state to the output. Note that the output batch rows may not be
// contiguous (output_batch_leading_dim != n_output).
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_state_ptr + b * n_output, n_output,
output_ptr + b * output_batch_leading_dim);
}
}