Minor refactor: Index batches consistently with 'b'.

PiperOrigin-RevId: 289460257
Change-Id: Ib584b9eb73f2775f4a16380863bb089baf7fb680
This commit is contained in:
Robert David 2020-01-13 09:43:04 -08:00 committed by TensorFlower Gardener
parent 1a3bf55d13
commit 274ebd7a6b

View File

@ -320,36 +320,36 @@ inline void LstmStepFloat(
// n_output), we unroll batched operations.
if (use_projection_weight) {
if (use_projection_bias) {
for (int k = 0; k < n_batch; k++) {
for (int b = 0; b < n_batch; b++) {
std::copy_n(projection_bias_ptr, n_output,
output_ptr + k * output_batch_leading_dim);
output_ptr + b * output_batch_leading_dim);
}
} else {
for (int k = 0; k < n_batch; k++) {
std::fill_n(output_ptr + k * output_batch_leading_dim, n_output, 0.0f);
for (int b = 0; b < n_batch; b++) {
std::fill_n(output_ptr + b * output_batch_leading_dim, n_output, 0.0f);
}
}
for (int k = 0; k < n_batch; k++) {
for (int b = 0; b < n_batch; b++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
output_gate_scratch + k * n_cell,
/*n_batch=*/1, output_ptr + k * output_batch_leading_dim,
output_gate_scratch + b * n_cell,
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
/*result_stride=*/1);
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(output_ptr + k * output_batch_leading_dim,
tensor_utils::ClipVector(output_ptr + b * output_batch_leading_dim,
n_output, params->proj_clip,
output_ptr + k * output_batch_leading_dim);
output_ptr + b * output_batch_leading_dim);
}
}
} else {
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_gate_scratch + k * n_output, n_output,
output_ptr + k * output_batch_leading_dim);
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_gate_scratch + b * n_output, n_output,
output_ptr + b * output_batch_leading_dim);
}
}
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_ptr + k * output_batch_leading_dim, n_output,
output_state_ptr + k * n_output);
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_ptr + b * output_batch_leading_dim, n_output,
output_state_ptr + b * n_output);
}
}
@ -750,13 +750,13 @@ inline void LstmStepHybrid(
// n_output), we unroll the batched operations.
if (use_projection_weight) {
if (use_projection_bias) {
for (int k = 0; k < n_batch; k++) {
for (int b = 0; b < n_batch; b++) {
std::copy_n(projection_bias_ptr, n_output,
output_ptr + k * output_batch_leading_dim);
output_ptr + b * output_batch_leading_dim);
}
} else {
for (int k = 0; k < n_batch; k++) {
std::fill_n(output_ptr + k * output_batch_leading_dim, n_output, 0.0f);
for (int b = 0; b < n_batch; b++) {
std::fill_n(output_ptr + b * output_batch_leading_dim, n_output, 0.0f);
}
}
if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
@ -773,30 +773,30 @@ inline void LstmStepHybrid(
product_scaling_factors[b] =
scaling_factors[b] * projection_weights_scale;
}
for (int k = 0; k < n_batch; k++) {
for (int b = 0; b < n_batch; b++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
quantized_cell_state_ptr + k * n_cell, &product_scaling_factors[k],
/*n_batch=*/1, output_ptr + k * output_batch_leading_dim,
quantized_cell_state_ptr + b * n_cell, &product_scaling_factors[b],
/*n_batch=*/1, output_ptr + b * output_batch_leading_dim,
/*result_stride=*/1);
}
}
if (params->proj_clip > 0.0) {
for (int k = 0; k < n_batch; k++) {
tensor_utils::ClipVector(output_ptr + k * output_batch_leading_dim,
for (int b = 0; b < n_batch; b++) {
tensor_utils::ClipVector(output_ptr + b * output_batch_leading_dim,
n_output, params->proj_clip,
output_ptr + k * output_batch_leading_dim);
output_ptr + b * output_batch_leading_dim);
}
}
} else {
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_gate_scratch + k * n_output, n_output,
output_ptr + k * output_batch_leading_dim);
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_gate_scratch + b * n_output, n_output,
output_ptr + b * output_batch_leading_dim);
}
}
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_ptr + k * output_batch_leading_dim, n_output,
output_state_ptr + k * n_output);
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_ptr + b * output_batch_leading_dim, n_output,
output_state_ptr + b * n_output);
}
}