Minor refactor: Index batches consistently with 'b'.
PiperOrigin-RevId: 289460257 Change-Id: Ib584b9eb73f2775f4a16380863bb089baf7fb680
This commit is contained in:
parent
1a3bf55d13
commit
274ebd7a6b
@ -320,36 +320,36 @@ inline void LstmStepFloat(
|
||||
// n_output), we unroll batched operations.
|
||||
if (use_projection_weight) {
|
||||
if (use_projection_bias) {
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::copy_n(projection_bias_ptr, n_output,
|
||||
output_ptr + k * output_batch_leading_dim);
|
||||
output_ptr + b * output_batch_leading_dim);
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
std::fill_n(output_ptr + k * output_batch_leading_dim, n_output, 0.0f);
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::fill_n(output_ptr + b * output_batch_leading_dim, n_output, 0.0f);
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights_ptr, n_output, n_cell,
|
||||
output_gate_scratch + k * n_cell,
|
||||
/*n_batch=*/1, output_ptr + k * output_batch_leading_dim,
|
||||
output_gate_scratch + b * n_cell,
|
||||
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
|
||||
/*result_stride=*/1);
|
||||
if (params->proj_clip > 0.0) {
|
||||
tensor_utils::ClipVector(output_ptr + k * output_batch_leading_dim,
|
||||
tensor_utils::ClipVector(output_ptr + b * output_batch_leading_dim,
|
||||
n_output, params->proj_clip,
|
||||
output_ptr + k * output_batch_leading_dim);
|
||||
output_ptr + b * output_batch_leading_dim);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
std::copy_n(output_gate_scratch + k * n_output, n_output,
|
||||
output_ptr + k * output_batch_leading_dim);
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::copy_n(output_gate_scratch + b * n_output, n_output,
|
||||
output_ptr + b * output_batch_leading_dim);
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
std::copy_n(output_ptr + k * output_batch_leading_dim, n_output,
|
||||
output_state_ptr + k * n_output);
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::copy_n(output_ptr + b * output_batch_leading_dim, n_output,
|
||||
output_state_ptr + b * n_output);
|
||||
}
|
||||
}
|
||||
|
||||
@ -750,13 +750,13 @@ inline void LstmStepHybrid(
|
||||
// n_output), we unroll the batched operations.
|
||||
if (use_projection_weight) {
|
||||
if (use_projection_bias) {
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::copy_n(projection_bias_ptr, n_output,
|
||||
output_ptr + k * output_batch_leading_dim);
|
||||
output_ptr + b * output_batch_leading_dim);
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
std::fill_n(output_ptr + k * output_batch_leading_dim, n_output, 0.0f);
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::fill_n(output_ptr + b * output_batch_leading_dim, n_output, 0.0f);
|
||||
}
|
||||
}
|
||||
if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
|
||||
@ -773,30 +773,30 @@ inline void LstmStepHybrid(
|
||||
product_scaling_factors[b] =
|
||||
scaling_factors[b] * projection_weights_scale;
|
||||
}
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
projection_weights_ptr, n_output, n_cell,
|
||||
quantized_cell_state_ptr + k * n_cell, &product_scaling_factors[k],
|
||||
/*n_batch=*/1, output_ptr + k * output_batch_leading_dim,
|
||||
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
|
||||
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
|
||||
/*result_stride=*/1);
|
||||
}
|
||||
}
|
||||
if (params->proj_clip > 0.0) {
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
tensor_utils::ClipVector(output_ptr + k * output_batch_leading_dim,
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
tensor_utils::ClipVector(output_ptr + b * output_batch_leading_dim,
|
||||
n_output, params->proj_clip,
|
||||
output_ptr + k * output_batch_leading_dim);
|
||||
output_ptr + b * output_batch_leading_dim);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
std::copy_n(output_gate_scratch + k * n_output, n_output,
|
||||
output_ptr + k * output_batch_leading_dim);
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::copy_n(output_gate_scratch + b * n_output, n_output,
|
||||
output_ptr + b * output_batch_leading_dim);
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < n_batch; k++) {
|
||||
std::copy_n(output_ptr + k * output_batch_leading_dim, n_output,
|
||||
output_state_ptr + k * n_output);
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
std::copy_n(output_ptr + b * output_batch_leading_dim, n_output,
|
||||
output_state_ptr + b * n_output);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user