Make PortableReductionSumVector a template, reducing code duplication.
PiperOrigin-RevId: 351389401 Change-Id: Id1032673543d6b118e34ec573c5db1ffb1a5ce23
This commit is contained in:
parent
696cb254a1
commit
61a7ab7203
@ -726,38 +726,6 @@ void PortableVectorScalarMultiply(const int8_t* vector, const int v_size,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PortableReductionSumVector(const float* input_vector, float* output_vector,
|
|
||||||
int output_size, int reduction_size) {
|
|
||||||
const float* input_vector_ptr = input_vector;
|
|
||||||
for (int o = 0; o < output_size; o++) {
|
|
||||||
for (int r = 0; r < reduction_size; r++) {
|
|
||||||
output_vector[o] += *input_vector_ptr++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PortableReductionSumVector(const int32_t* input_vector,
|
|
||||||
int32_t* output_vector, int output_size,
|
|
||||||
int reduction_size) {
|
|
||||||
const int32_t* input_vector_ptr = input_vector;
|
|
||||||
for (int o = 0; o < output_size; o++) {
|
|
||||||
for (int r = 0; r < reduction_size; r++) {
|
|
||||||
output_vector[o] += *input_vector_ptr++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PortableReductionSumVector(const int8_t* input_vector,
|
|
||||||
int32_t* output_vector, int output_size,
|
|
||||||
int reduction_size) {
|
|
||||||
const int8_t* input_vector_ptr = input_vector;
|
|
||||||
for (int o = 0; o < output_size; o++) {
|
|
||||||
for (int r = 0; r < reduction_size; r++) {
|
|
||||||
output_vector[o] += *input_vector_ptr++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PortableMeanStddevNormalization(const float* input_vector,
|
void PortableMeanStddevNormalization(const float* input_vector,
|
||||||
float* output_vector, int v_size,
|
float* output_vector, int v_size,
|
||||||
int n_batch) {
|
int n_batch) {
|
||||||
|
|||||||
@ -198,22 +198,21 @@ void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result);
|
|||||||
void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||||
float* result);
|
float* result);
|
||||||
|
|
||||||
// Reduce-sum on a float input vector:
|
// Reduce-sum on a vector:
|
||||||
// input_vector: float pointer to input vector.
|
// input_vector: pointer to input vector.
|
||||||
// output_vector: float pointer to vector.
|
// output_vector: pointer to vector.
|
||||||
// output_size: output vector size.
|
// output_size: output vector size.
|
||||||
// reduction_size: number of consecutive elements from input vector which are
|
// reduction_size: number of consecutive elements from input vector which are
|
||||||
// added to get one element of output.
|
// added to get one element of output.
|
||||||
void PortableReductionSumVector(const float* input_vector, float* output_vector,
|
template <typename IN, typename OUT>
|
||||||
int output_size, int reduction_size);
|
void PortableReductionSumVector(const IN* input_vector, OUT* output_vector,
|
||||||
|
int output_size, int reduction_size) {
|
||||||
void PortableReductionSumVector(const int32_t* input_vector,
|
for (int o = 0; o < output_size; o++) {
|
||||||
int32_t* output_vector, int output_size,
|
for (int r = 0; r < reduction_size; r++) {
|
||||||
int reduction_size);
|
output_vector[o] += *input_vector++;
|
||||||
|
}
|
||||||
void PortableReductionSumVector(const int8_t* input_vector,
|
}
|
||||||
int32_t* output_vector, int output_size,
|
}
|
||||||
int reduction_size);
|
|
||||||
|
|
||||||
// Layer norm for each batch.
|
// Layer norm for each batch.
|
||||||
void PortableMeanStddevNormalization(const float* input_vector,
|
void PortableMeanStddevNormalization(const float* input_vector,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user