Consolidate redundant BatchVectorBatchVectorDotProduct implementations across portable, Neon, and SSE versions into one function.
PiperOrigin-RevId: 266989976
This commit is contained in:
parent
13c3302982
commit
99f8d44812
@ -1813,21 +1813,6 @@ float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
return result;
|
||||
}
|
||||
|
||||
void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
|
||||
const float* vector2, int v_size,
|
||||
int n_batch, float* result,
|
||||
int result_stride) {
|
||||
float* result_ptr = result;
|
||||
const float* vector1_ptr = vector1;
|
||||
const float* vector2_ptr = vector2;
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
*result_ptr = NeonVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
|
||||
vector1_ptr += v_size;
|
||||
vector2_ptr += v_size;
|
||||
result_ptr += result_stride;
|
||||
}
|
||||
}
|
||||
|
||||
void NeonReductionSumVector(const float* input_vector, float* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
const float* input_vector_ptr = input_vector;
|
||||
|
@ -175,14 +175,6 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
|
||||
}
|
||||
|
||||
void BatchVectorBatchVectorDotProduct(const float* vector1,
|
||||
const float* vector2, int v_size,
|
||||
int n_batch, float* result,
|
||||
int result_stride) {
|
||||
NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
|
||||
n_batch, result, result_stride);
|
||||
}
|
||||
|
||||
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
||||
float* batch_vector) {
|
||||
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
||||
|
@ -115,12 +115,6 @@ void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
|
||||
float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
int v_size);
|
||||
|
||||
// Dot product of two batch vectors.
|
||||
void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
|
||||
const float* vector2, int v_size,
|
||||
int n_batch, float* result,
|
||||
int result_stride);
|
||||
|
||||
// Cwise product of a vector and a batch-vector.
|
||||
void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
||||
const float* batch_vector, int n_batch,
|
||||
|
@ -182,14 +182,6 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
|
||||
}
|
||||
|
||||
void BatchVectorBatchVectorDotProduct(const float* vector1,
|
||||
const float* vector2, int v_size,
|
||||
int n_batch, float* result,
|
||||
int result_stride) {
|
||||
NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
|
||||
n_batch, result, result_stride);
|
||||
}
|
||||
|
||||
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
||||
float* batch_vector) {
|
||||
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
||||
|
@ -439,22 +439,6 @@ float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
return result;
|
||||
}
|
||||
|
||||
void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
|
||||
const float* vector2, int v_size,
|
||||
int n_batch, float* result,
|
||||
int result_stride) {
|
||||
float* result_ptr = result;
|
||||
const float* vector1_ptr = vector1;
|
||||
const float* vector2_ptr = vector2;
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
*result_ptr =
|
||||
PortableVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
|
||||
vector1_ptr += v_size;
|
||||
vector2_ptr += v_size;
|
||||
result_ptr += result_stride;
|
||||
}
|
||||
}
|
||||
|
||||
void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
|
||||
const float* vector2,
|
||||
int v_size, float* result) {
|
||||
|
@ -190,14 +190,6 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
return PortableVectorVectorDotProduct(vector1, vector2, v_size);
|
||||
}
|
||||
|
||||
void BatchVectorBatchVectorDotProduct(const float* vector1,
|
||||
const float* vector2, int v_size,
|
||||
int n_batch, float* result,
|
||||
int result_stride) {
|
||||
PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch,
|
||||
result, result_stride);
|
||||
}
|
||||
|
||||
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
||||
float* batch_vector) {
|
||||
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
||||
|
@ -76,12 +76,6 @@ void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
|
||||
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
int v_size);
|
||||
|
||||
// Dot product of two batch vectors.
|
||||
void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
|
||||
const float* vector2, int v_size,
|
||||
int n_batch, float* result,
|
||||
int result_stride);
|
||||
|
||||
// Cwise product of a vector and a batch-vector.
|
||||
void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
||||
const float* batch_vector,
|
||||
|
@ -183,10 +183,17 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize,
|
||||
// ...
|
||||
// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize]
|
||||
void BatchVectorBatchVectorDotProduct(const float* vector1,
|
||||
const float* vector2, int v_size,
|
||||
int n_batch, float* result,
|
||||
int result_stride);
|
||||
template <typename T>
|
||||
inline void BatchVectorBatchVectorDotProduct(const T* vector1, const T* vector2,
|
||||
int v_size, int n_batch, T* result,
|
||||
int result_stride) {
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
*result = VectorVectorDotProduct(vector1, vector2, v_size);
|
||||
vector1 += v_size;
|
||||
vector2 += v_size;
|
||||
result += result_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Cwise product of a vector and a batch-vector.
|
||||
void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
||||
|
Loading…
Reference in New Issue
Block a user