Consolidate redundant BatchVectorBatchVectorDotProduct implementations across portable, Neon, and SSE versions into one function.

PiperOrigin-RevId: 266989976
This commit is contained in:
A. Unique TensorFlower 2019-09-03 12:38:53 -07:00 committed by TensorFlower Gardener
parent 13c3302982
commit 99f8d44812
8 changed files with 11 additions and 71 deletions

View File

@ -1813,21 +1813,6 @@ float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
return result;
}
void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
const float* vector2, int v_size,
int n_batch, float* result,
int result_stride) {
float* result_ptr = result;
const float* vector1_ptr = vector1;
const float* vector2_ptr = vector2;
for (int b = 0; b < n_batch; b++) {
*result_ptr = NeonVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
vector1_ptr += v_size;
vector2_ptr += v_size;
result_ptr += result_stride;
}
}
void NeonReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size) {
const float* input_vector_ptr = input_vector;

View File

@ -175,14 +175,6 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
}
void BatchVectorBatchVectorDotProduct(const float* vector1,
const float* vector2, int v_size,
int n_batch, float* result,
int result_stride) {
NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
n_batch, result, result_stride);
}
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);

View File

@ -115,12 +115,6 @@ void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
float NeonVectorVectorDotProduct(const float* vector1, const float* vector2,
int v_size);
// Dot product of two batch vectors.
void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
const float* vector2, int v_size,
int n_batch, float* result,
int result_stride);
// Cwise product of a vector and a batch-vector.
void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
const float* batch_vector, int n_batch,

View File

@ -182,14 +182,6 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size);
}
void BatchVectorBatchVectorDotProduct(const float* vector1,
const float* vector2, int v_size,
int n_batch, float* result,
int result_stride) {
NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size,
n_batch, result, result_stride);
}
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);

View File

@ -439,22 +439,6 @@ float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
return result;
}
void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
const float* vector2, int v_size,
int n_batch, float* result,
int result_stride) {
float* result_ptr = result;
const float* vector1_ptr = vector1;
const float* vector2_ptr = vector2;
for (int b = 0; b < n_batch; b++) {
*result_ptr =
PortableVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size);
vector1_ptr += v_size;
vector2_ptr += v_size;
result_ptr += result_stride;
}
}
void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
const float* vector2,
int v_size, float* result) {

View File

@ -190,14 +190,6 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
return PortableVectorVectorDotProduct(vector1, vector2, v_size);
}
void BatchVectorBatchVectorDotProduct(const float* vector1,
const float* vector2, int v_size,
int n_batch, float* result,
int result_stride) {
PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch,
result, result_stride);
}
void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);

View File

@ -76,12 +76,6 @@ void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
int v_size);
// Dot product of two batch vectors.
void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
const float* vector2, int v_size,
int n_batch, float* result,
int result_stride);
// Cwise product of a vector and a batch-vector.
void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
const float* batch_vector,

View File

@ -183,10 +183,17 @@ float VectorVectorDotProduct(const float* vector1, const float* vector2,
// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize,
// ...
// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize]
void BatchVectorBatchVectorDotProduct(const float* vector1,
const float* vector2, int v_size,
int n_batch, float* result,
int result_stride);
template <typename T>
inline void BatchVectorBatchVectorDotProduct(const T* vector1, const T* vector2,
int v_size, int n_batch, T* result,
int result_stride) {
for (int b = 0; b < n_batch; b++) {
*result = VectorVectorDotProduct(vector1, vector2, v_size);
vector1 += v_size;
vector2 += v_size;
result += result_stride;
}
}
// Cwise product of a vector and a batch-vector.
void VectorBatchVectorCwiseProduct(const float* vector, int v_size,