Change ReductionSumVector to overwrite output instead of accumulating.

Almost all of the calls were preceded with a memset/fill_n, this is not needed any more. The few places (in SVDF calls) where this was used to add bias, it was replaced with VectorBatchVectorAdd after reduction.

PiperOrigin-RevId: 351474602
Change-Id: I5477464b98040d99d8e7c06244c829055bb67b5e
This commit is contained in:
Robert David 2021-01-12 16:38:45 -08:00 committed by TensorFlower Gardener
parent f4a9a6b097
commit ed36c75165
10 changed files with 35 additions and 74 deletions

View File

@ -140,26 +140,6 @@ void RnnBatchStep(
compute_row_sums);
}
void ComputeMatrixSums(int32_t* input_row_sums, int32_t* aux_input_row_sums,
int32_t* recurrent_row_sums, int32_t* row_sums,
const float* aux_input_ptr_batch, int num_units,
int input_size, int aux_input_size,
const int8_t* input_weights_ptr,
const int8_t* aux_input_weights_ptr,
const int8_t* recurrent_weights_ptr) {
memset(input_row_sums, 0, sizeof(int32_t) * num_units);
tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums, num_units,
input_size);
if (aux_input_ptr_batch) {
memset(aux_input_row_sums, 0, sizeof(int32_t) * num_units);
tensor_utils::ReductionSumVector(aux_input_weights_ptr, aux_input_row_sums,
num_units, aux_input_size);
}
memset(recurrent_row_sums, 0, sizeof(int32_t) * num_units);
tensor_utils::ReductionSumVector(recurrent_weights_ptr, recurrent_row_sums,
num_units, num_units);
}
void RnnBatchStep(
const float* input_ptr_batch, const int8_t* input_weights_ptr,
float input_weights_scale, const float* aux_input_ptr_batch,
@ -187,10 +167,15 @@ void RnnBatchStep(
}
recurrent_row_sums = aux_input_row_sums + num_units;
if (*compute_row_sums) {
ComputeMatrixSums(input_row_sums, aux_input_row_sums, recurrent_row_sums,
row_sums, aux_input_ptr_batch, num_units, input_size,
aux_input_size, input_weights_ptr,
aux_input_weights_ptr, recurrent_weights_ptr);
tensor_utils::ReductionSumVector(input_weights_ptr, input_row_sums,
num_units, input_size);
if (aux_input_ptr_batch) {
tensor_utils::ReductionSumVector(aux_input_weights_ptr,
aux_input_row_sums, num_units,
aux_input_size);
}
tensor_utils::ReductionSumVector(
recurrent_weights_ptr, recurrent_row_sums, num_units, num_units);
*compute_row_sums = false;
}
}

View File

@ -180,7 +180,6 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) {
num_weights_matrices *= extended_lhs_shape.Dims(i);
}
memset(row_sums, 0, sizeof(int32_t) * lhs_rows * num_weights_matrices);
tensor_utils::ReductionSumVector(
lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth);
if (compute_row_sums) {

View File

@ -1276,7 +1276,6 @@ void NeonMatrixBatchVectorMultiplyAccumulateImpl(
int32_t* row_sums_ptr = row_sums;
if (row_sums == nullptr) {
row_sums_ptr = static_cast<int32_t*>(malloc(sizeof(int32_t) * m_rows));
memset(row_sums_ptr, 0, sizeof(int32_t) * m_rows);
NeonReductionSumVector(matrix, row_sums_ptr, m_rows, m_cols);
}
@ -1385,7 +1384,6 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
}
if (compute_row_sums == nullptr || *compute_row_sums) {
memset(row_sums, 0, sizeof(int32_t) * m_rows);
NeonReductionSumVector(matrix, row_sums, m_rows, m_cols);
if (compute_row_sums) {
*compute_row_sums = false;
@ -2454,7 +2452,6 @@ float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
void NeonReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size) {
const float* input_vector_ptr = input_vector;
for (int o = 0; o < output_size; o++) {
// If v_size is not divisible by the vector size, then we need to process
// the final few elements sequentially. postamble_start shows the start
@ -2464,16 +2461,16 @@ void NeonReductionSumVector(const float* input_vector, float* output_vector,
float32x4_t sum_f32x4 = vmovq_n_f32(0.0);
int r = 0;
for (; r < postamble_start; r += kFloatValuesPerNeonVector) {
float32x4_t v1_f32x4 = vld1q_f32(input_vector_ptr + r);
float32x4_t v1_f32x4 = vld1q_f32(input_vector + r);
sum_f32x4 = vaddq_f32(sum_f32x4, v1_f32x4);
}
output_vector[o] += AccumulateNeonLane(sum_f32x4);
input_vector_ptr += postamble_start;
float sum = AccumulateNeonLane(sum_f32x4);
// Postamble loop.
for (; r < reduction_size; r++) {
output_vector[o] += *input_vector_ptr++;
sum += input_vector[r];
}
output_vector[o] = sum;
input_vector += reduction_size;
}
}
@ -2484,24 +2481,23 @@ void NeonReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
const int postamble_start =
reduction_size & ~((kWeightsPerNeonLane >> 1) - 1);
for (int o = 0; o < output_size; ++o) {
// Get the address of the first element of the row.
int8_t* row_ptr = (int8_t*)input_vector + o * reduction_size; // NOLINT
int32x4_t sum_32x4 = vmovq_n_s32(0);
int r = 0;
for (; r < postamble_half_start; r += kWeightsPerNeonLane) {
const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr + r));
const int8x16_t s2_8x16 = vld1q_s8(input_vector + r);
sum_32x4 = vpadalq_s16(sum_32x4, vpaddlq_s8(s2_8x16));
}
if (r < postamble_start) {
const int8x8_t s2_8x8 = vld1_s8((const int8_t*)(row_ptr + r));
const int8x8_t s2_8x8 = vld1_s8(input_vector + r);
sum_32x4 = vpadalq_s16(sum_32x4, vmovl_s8(s2_8x8));
r += (kWeightsPerNeonLane >> 1);
}
int32_t sum = AccumulateNeonLane(sum_32x4);
for (; r < reduction_size; ++r) {
sum += row_ptr[r];
sum += input_vector[r];
}
output_vector[o] += sum;
output_vector[o] = sum;
input_vector += reduction_size;
}
}

View File

@ -1476,7 +1476,6 @@ inline void HybridConvPerChannel(
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
TFLITE_DCHECK_EQ(scratch_shape.FlatSize(), output_shape.FlatSize());
if (!compute_row_sums || *compute_row_sums) {
memset(row_sums, 0, sizeof(int32_t) * filter_rows);
tensor_utils::ReductionSumVector(filter_data, row_sums, filter_rows,
filter_cols);
if (compute_row_sums) {

View File

@ -274,7 +274,6 @@ void SseMatrixBatchVectorMultiplyAccumulate(
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
bool* compute_row_sums, CpuBackendContext* context) {
if ((input_offset != nullptr) && (!compute_row_sums || *compute_row_sums)) {
memset(row_sums, 0, sizeof(int32_t) * m_rows);
SseReductionSumVector(matrix, row_sums, m_rows, m_cols);
if (compute_row_sums) {
*compute_row_sums = false;
@ -447,9 +446,9 @@ void SseReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
#pragma clang loop unroll(disable) vectorize(disable)
#endif
for (; col < reduction_size; col++) {
row_sum += *(row_ptr + col);
row_sum += row_ptr[col];
}
*(output_vector + row) += row_sum;
output_vector[row] = row_sum;
}
}

View File

@ -165,7 +165,6 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) {
num_weights_matrices *= extended_lhs_shape.Dims(i);
}
memset(row_sums, 0, sizeof(int32_t) * lhs_rows * num_weights_matrices);
tensor_utils::ReductionSumVector(
lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth);
if (compute_row_sums) {

View File

@ -172,7 +172,6 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
return;
}
if (!compute_row_sums || *compute_row_sums) {
memset(row_sums, 0, sizeof(int32_t) * m_rows);
PortableReductionSumVector(matrix, row_sums, m_rows, m_cols);
if (compute_row_sums) {
*compute_row_sums = false;

View File

@ -204,9 +204,12 @@ template <typename IN, typename OUT>
void PortableReductionSumVector(const IN* input_vector, OUT* output_vector,
int output_size, int reduction_size) {
for (int o = 0; o < output_size; o++) {
OUT result = 0;
for (int r = 0; r < reduction_size; r++) {
output_vector[o] += *input_vector++;
result += input_vector[r];
}
output_vector[o] = result;
input_vector += reduction_size;
}
}

View File

@ -49,17 +49,14 @@ static inline void ApplyTimeWeightsBiasAndActivation(
scratch_ptr_batch);
}
// Initialize output with bias if provided.
if (bias_ptr) {
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr);
} else {
std::fill_n(output_ptr, batch_size * num_units, 0.0f);
}
// Reduction sum.
tensor_utils::ReductionSumVector(scratch_ptr, output_ptr,
batch_size * num_units, rank);
// Add bias if provided.
if (bias_ptr) {
tensor_utils::VectorBatchVectorAdd(bias_ptr, num_units, batch_size,
output_ptr);
}
// Apply activation.
tensor_utils::ApplyActivationToVector(output_ptr, batch_size * num_units,
@ -131,16 +128,14 @@ inline void EvalIntegerSVDF(
// Reduce, add bias, rescale, activation.
{
// Add bias.
if (bias_data) {
tensor_utils::VectorBatchVectorAssign(bias_data, n_unit, n_batch,
output_temp_data);
} else {
std::fill_n(output_temp_data, n_batch * n_unit, 0);
}
// Reduce.
tensor_utils::ReductionSumVector(scratch_data, output_temp_data,
n_batch * n_unit, n_rank);
// Add bias.
if (bias_data) {
tensor_utils::VectorBatchVectorAdd(bias_data, n_unit, n_batch,
output_temp_data);
}
// Rescale.
const int32_t output_max = std::numeric_limits<int8_t>::max();
const int32_t output_min = std::numeric_limits<int8_t>::min();

View File

@ -63,61 +63,48 @@ void ComputeRowSums(
const float* aux_input_ptr) {
// Compute the row sums for dequantization
if (!use_cifg) {
std::fill_n(input_to_input_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(input_to_input_weights_ptr,
input_to_input_row_sums, n_cell, n_input);
}
std::fill_n(input_to_forget_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(input_to_forget_weights_ptr,
input_to_forget_row_sums, n_cell, n_input);
std::fill_n(input_to_cell_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(input_to_cell_weights_ptr,
input_to_cell_row_sums, n_cell, n_input);
std::fill_n(input_to_output_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(input_to_output_weights_ptr,
input_to_output_row_sums, n_cell, n_input);
if (aux_input_ptr) {
if (!use_cifg) {
std::fill_n(aux_input_to_input_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
aux_input_to_input_row_sums, n_cell,
n_aux_input);
}
std::fill_n(aux_input_to_forget_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
aux_input_to_forget_row_sums, n_cell,
n_aux_input);
std::fill_n(aux_input_to_cell_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
aux_input_to_cell_row_sums, n_cell,
n_aux_input);
std::fill_n(aux_input_to_output_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
aux_input_to_output_row_sums, n_cell,
n_aux_input);
}
if (!use_cifg) {
std::fill_n(recurrent_to_input_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
recurrent_to_input_row_sums, n_cell,
n_output);
}
std::fill_n(recurrent_to_forget_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
recurrent_to_forget_row_sums, n_cell,
n_output);
std::fill_n(recurrent_to_cell_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
recurrent_to_cell_row_sums, n_cell,
n_output);
std::fill_n(recurrent_to_output_row_sums, n_cell, 0);
tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
recurrent_to_output_row_sums, n_cell,
n_output);
if (projection_weights_ptr != nullptr) {
std::fill_n(projection_weights_row_sums, n_output, 0);
tensor_utils::ReductionSumVector(
projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
}