Make PortableReductionSumVector a template, reducing code duplication.
PiperOrigin-RevId: 351389401 Change-Id: Id1032673543d6b118e34ec573c5db1ffb1a5ce23
This commit is contained in:
parent
696cb254a1
commit
61a7ab7203
@ -726,38 +726,6 @@ void PortableVectorScalarMultiply(const int8_t* vector, const int v_size,
|
||||
}
|
||||
}
|
||||
|
||||
void PortableReductionSumVector(const float* input_vector, float* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
const float* input_vector_ptr = input_vector;
|
||||
for (int o = 0; o < output_size; o++) {
|
||||
for (int r = 0; r < reduction_size; r++) {
|
||||
output_vector[o] += *input_vector_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableReductionSumVector(const int32_t* input_vector,
|
||||
int32_t* output_vector, int output_size,
|
||||
int reduction_size) {
|
||||
const int32_t* input_vector_ptr = input_vector;
|
||||
for (int o = 0; o < output_size; o++) {
|
||||
for (int r = 0; r < reduction_size; r++) {
|
||||
output_vector[o] += *input_vector_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableReductionSumVector(const int8_t* input_vector,
|
||||
int32_t* output_vector, int output_size,
|
||||
int reduction_size) {
|
||||
const int8_t* input_vector_ptr = input_vector;
|
||||
for (int o = 0; o < output_size; o++) {
|
||||
for (int r = 0; r < reduction_size; r++) {
|
||||
output_vector[o] += *input_vector_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableMeanStddevNormalization(const float* input_vector,
|
||||
float* output_vector, int v_size,
|
||||
int n_batch) {
|
||||
|
@ -198,22 +198,21 @@ void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result);
|
||||
void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result);
|
||||
|
||||
// Reduce-sum on a float input vector:
|
||||
// input_vector: float pointer to input vector.
|
||||
// output_vector: float pointer to vector.
|
||||
// Reduce-sum on a vector:
|
||||
// input_vector: pointer to input vector.
|
||||
// output_vector: pointer to vector.
|
||||
// output_size: output vector size.
|
||||
// reduction_size: number of consecutive elements from input vector which are
|
||||
// added to get one element of output.
|
||||
void PortableReductionSumVector(const float* input_vector, float* output_vector,
|
||||
int output_size, int reduction_size);
|
||||
|
||||
void PortableReductionSumVector(const int32_t* input_vector,
|
||||
int32_t* output_vector, int output_size,
|
||||
int reduction_size);
|
||||
|
||||
void PortableReductionSumVector(const int8_t* input_vector,
|
||||
int32_t* output_vector, int output_size,
|
||||
int reduction_size);
|
||||
template <typename IN, typename OUT>
|
||||
void PortableReductionSumVector(const IN* input_vector, OUT* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
for (int o = 0; o < output_size; o++) {
|
||||
for (int r = 0; r < reduction_size; r++) {
|
||||
output_vector[o] += *input_vector++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Layer norm for each batch.
|
||||
void PortableMeanStddevNormalization(const float* input_vector,
|
||||
|
Loading…
x
Reference in New Issue
Block a user