Add __restrict__ keywords to (Portable)MeanStddevNormalization input and output parameters.
PiperOrigin-RevId: 352891183 Change-Id: I6b812ccdf76caecda13cc3484934619a3d63add8
This commit is contained in:
parent
b65ecec1bc
commit
a60a884eab
@ -290,8 +290,9 @@ void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
|
|||||||
reduction_size);
|
reduction_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
void MeanStddevNormalization(const float* __restrict__ input_vector,
|
||||||
int v_size, int n_batch) {
|
float* __restrict__ output_vector, int v_size,
|
||||||
|
int n_batch) {
|
||||||
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,8 +300,9 @@ void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
|
|||||||
reduction_size);
|
reduction_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
void MeanStddevNormalization(const float* __restrict__ input_vector,
|
||||||
int v_size, int n_batch) {
|
float* __restrict__ output_vector, int v_size,
|
||||||
|
int n_batch) {
|
||||||
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -715,9 +715,9 @@ void PortableVectorScalarMultiply(const int8_t* vector, const int v_size,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PortableMeanStddevNormalization(const float* input_vector,
|
void PortableMeanStddevNormalization(const float* __restrict__ input_vector,
|
||||||
float* output_vector, int v_size,
|
float* __restrict__ output_vector,
|
||||||
int n_batch) {
|
int v_size, int n_batch) {
|
||||||
for (int batch = 0; batch < n_batch; ++batch) {
|
for (int batch = 0; batch < n_batch; ++batch) {
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for (int i = 0; i < v_size; ++i) {
|
for (int i = 0; i < v_size; ++i) {
|
||||||
|
@ -294,8 +294,9 @@ void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
|
|||||||
reduction_size);
|
reduction_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
void MeanStddevNormalization(const float* __restrict__ input_vector,
|
||||||
int v_size, int n_batch) {
|
float* __restrict__ output_vector, int v_size,
|
||||||
|
int n_batch) {
|
||||||
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,9 +214,9 @@ void PortableReductionSumVector(const IN* input_vector, OUT* output_vector,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Layer norm for each batch.
|
// Layer norm for each batch.
|
||||||
void PortableMeanStddevNormalization(const float* input_vector,
|
void PortableMeanStddevNormalization(const float* __restrict__ input_vector,
|
||||||
float* output_vector, int v_size,
|
float* __restrict__ output_vector,
|
||||||
int n_batch);
|
int v_size, int n_batch);
|
||||||
|
|
||||||
// Saturate Add.
|
// Saturate Add.
|
||||||
void PortableTwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
|
void PortableTwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
|
||||||
|
@ -630,8 +630,9 @@ void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
|
|||||||
int output_size, int reduction_size);
|
int output_size, int reduction_size);
|
||||||
|
|
||||||
// Layer norm for each batch.
|
// Layer norm for each batch.
|
||||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
void MeanStddevNormalization(const float* __restrict__ input_vector,
|
||||||
int v_size, int n_batch);
|
float* __restrict__ output_vector, int v_size,
|
||||||
|
int n_batch);
|
||||||
|
|
||||||
// Saturate Add with rescale on both inputs.
|
// Saturate Add with rescale on both inputs.
|
||||||
void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
|
void TwoGateSaturatingAdd(const int8_t* input, int8_t input_zp,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user