Consolidate redundant BatchVectorBatchVectorDotProduct implementations across portable, Neon, and SSE versions into one function.
PiperOrigin-RevId: 266989976
This commit is contained in:
parent
13c3302982
commit
99f8d44812
@ -1813,21 +1813,6 @@ float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
|
|
||||||
const float* vector2, int v_size,
|
|
||||||
int n_batch, float* result,
|
|
||||||
int result_stride) {
|
|
||||||
float* result_ptr = result;
|
|
||||||
const float* vector1_ptr = vector1;
|
|
||||||
const float* vector2_ptr = vector2;
|
|
||||||
for (int b = 0; b < n_batch; b++) {
|
|
||||||
*result_ptr = NeonVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
|
|
||||||
vector1_ptr += v_size;
|
|
||||||
vector2_ptr += v_size;
|
|
||||||
result_ptr += result_stride;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void NeonReductionSumVector(const float* input_vector, float* output_vector,
|
void NeonReductionSumVector(const float* input_vector, float* output_vector,
|
||||||
int output_size, int reduction_size) {
|
int output_size, int reduction_size) {
|
||||||
const float* input_vector_ptr = input_vector;
|
const float* input_vector_ptr = input_vector;
|
||||||
|
@ -175,14 +175,6 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
|||||||
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
|
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchVectorBatchVectorDotProduct(const float* vector1,
|
|
||||||
const float* vector2, int v_size,
|
|
||||||
int n_batch, float* result,
|
|
||||||
int result_stride) {
|
|
||||||
NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
|
|
||||||
n_batch, result, result_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
||||||
float* batch_vector) {
|
float* batch_vector) {
|
||||||
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
||||||
|
@ -115,12 +115,6 @@ void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
|
|||||||
float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
|
float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||||
int v_size);
|
int v_size);
|
||||||
|
|
||||||
// Dot product of two batch vectors.
|
|
||||||
void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
|
|
||||||
const float* vector2, int v_size,
|
|
||||||
int n_batch, float* result,
|
|
||||||
int result_stride);
|
|
||||||
|
|
||||||
// Cwise product of a vector and a batch-vector.
|
// Cwise product of a vector and a batch-vector.
|
||||||
void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
||||||
const float* batch_vector, int n_batch,
|
const float* batch_vector, int n_batch,
|
||||||
|
@ -182,14 +182,6 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
|||||||
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
|
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchVectorBatchVectorDotProduct(const float* vector1,
|
|
||||||
const float* vector2, int v_size,
|
|
||||||
int n_batch, float* result,
|
|
||||||
int result_stride) {
|
|
||||||
NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
|
|
||||||
n_batch, result, result_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
||||||
float* batch_vector) {
|
float* batch_vector) {
|
||||||
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
||||||
|
@ -439,22 +439,6 @@ float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
|
|
||||||
const float* vector2, int v_size,
|
|
||||||
int n_batch, float* result,
|
|
||||||
int result_stride) {
|
|
||||||
float* result_ptr = result;
|
|
||||||
const float* vector1_ptr = vector1;
|
|
||||||
const float* vector2_ptr = vector2;
|
|
||||||
for (int b = 0; b < n_batch; b++) {
|
|
||||||
*result_ptr =
|
|
||||||
PortableVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
|
|
||||||
vector1_ptr += v_size;
|
|
||||||
vector2_ptr += v_size;
|
|
||||||
result_ptr += result_stride;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
|
void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
|
||||||
const float* vector2,
|
const float* vector2,
|
||||||
int v_size, float* result) {
|
int v_size, float* result) {
|
||||||
|
@ -190,14 +190,6 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
|||||||
return PortableVectorVectorDotProduct(vector1, vector2, v_size);
|
return PortableVectorVectorDotProduct(vector1, vector2, v_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BatchVectorBatchVectorDotProduct(const float* vector1,
|
|
||||||
const float* vector2, int v_size,
|
|
||||||
int n_batch, float* result,
|
|
||||||
int result_stride) {
|
|
||||||
PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch,
|
|
||||||
result, result_stride);
|
|
||||||
}
|
|
||||||
|
|
||||||
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
|
||||||
float* batch_vector) {
|
float* batch_vector) {
|
||||||
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
|
||||||
|
@ -76,12 +76,6 @@ void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
|
|||||||
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
|
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||||
int v_size);
|
int v_size);
|
||||||
|
|
||||||
// Dot product of two batch vectors.
|
|
||||||
void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
|
|
||||||
const float* vector2, int v_size,
|
|
||||||
int n_batch, float* result,
|
|
||||||
int result_stride);
|
|
||||||
|
|
||||||
// Cwise product of a vector and a batch-vector.
|
// Cwise product of a vector and a batch-vector.
|
||||||
void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
||||||
const float* batch_vector,
|
const float* batch_vector,
|
||||||
|
@ -183,10 +183,17 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
|
|||||||
// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize,
|
// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize,
|
||||||
// ...
|
// ...
|
||||||
// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize]
|
// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize]
|
||||||
void BatchVectorBatchVectorDotProduct(const float* vector1,
|
template <typename T>
|
||||||
const float* vector2, int v_size,
|
inline void BatchVectorBatchVectorDotProduct(const T* vector1, const T* vector2,
|
||||||
int n_batch, float* result,
|
int v_size, int n_batch, T* result,
|
||||||
int result_stride);
|
int result_stride) {
|
||||||
|
for (int b = 0; b < n_batch; b++) {
|
||||||
|
*result = VectorVectorDotProduct(vector1, vector2, v_size);
|
||||||
|
vector1 += v_size;
|
||||||
|
vector2 += v_size;
|
||||||
|
result += result_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Cwise product of a vector and a batch-vector.
|
// Cwise product of a vector and a batch-vector.
|
||||||
void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user