Add optimized MatrixBatchVectorMultiplyAccumulate for asymmetric inputs for sse
PiperOrigin-RevId: 312035618 Change-Id: I5ae85ae9b0b646d2fe1e665c25aae6b99622dd2b
This commit is contained in:
parent
344f898250
commit
76853076b3
@ -1466,16 +1466,20 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
int i = 0;
|
||||
int32_t* scratch_ptr = scratch;
|
||||
for (; i <= total_size - 8; i += 8, result += 8) {
|
||||
float batch_scaling_factor0 = scaling_factors[i / m_rows];
|
||||
float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
|
||||
if (per_channel_scale) {
|
||||
batch_scaling_factor0 *= per_channel_scale[i % m_rows];
|
||||
batch_scaling_factor1 *= per_channel_scale[(i + 4) % m_rows];
|
||||
}
|
||||
const float batch_scaling_factor0 = scaling_factors[i / m_rows];
|
||||
const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows];
|
||||
const int batch_input_offset0 = -input_offset[i / m_rows];
|
||||
const int batch_input_offset1 = -input_offset[(i + 4) / m_rows];
|
||||
const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
|
||||
const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
|
||||
float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0);
|
||||
float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1);
|
||||
if (per_channel_scale) {
|
||||
const float32x4_t per_channel_scale0 =
|
||||
vld1q_f32(&per_channel_scale[i % m_rows]);
|
||||
const float32x4_t per_channel_scale1 =
|
||||
vld1q_f32(&per_channel_scale[(i + 4) % m_rows]);
|
||||
scaling_factor0 = vmulq_f32(scaling_factor0, per_channel_scale0);
|
||||
scaling_factor1 = vmulq_f32(scaling_factor1, per_channel_scale1);
|
||||
}
|
||||
const int32x4_t input_offset0 = vdupq_n_s32(batch_input_offset0);
|
||||
const int32x4_t input_offset1 = vdupq_n_s32(batch_input_offset1);
|
||||
const int32x4_t row_sum0 = vld1q_s32(row_sums + (i % m_rows));
|
||||
@ -1498,7 +1502,10 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
|
||||
scratch_ptr += i;
|
||||
for (; i < total_size; i++) {
|
||||
const float batch_scaling_factor = scaling_factors[i / m_rows];
|
||||
float batch_scaling_factor = scaling_factors[i / m_rows];
|
||||
if (per_channel_scale) {
|
||||
batch_scaling_factor *= per_channel_scale[i % m_rows];
|
||||
}
|
||||
const int32_t zero_point = input_offset[i / m_rows];
|
||||
int32_t dotprod = *(scratch_ptr++);
|
||||
dotprod -= row_sums[i % m_rows] * zero_point;
|
||||
@ -1514,16 +1521,6 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
per_channel_scale, input_offset, row_sums);
|
||||
}
|
||||
|
||||
void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset) {
|
||||
NeonMatrixBatchVectorMultiplyAccumulateImpl(
|
||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
|
||||
per_channel_scale, input_offset, nullptr);
|
||||
}
|
||||
|
||||
inline int64x2x2_t MulAdd(int32x4_t acc, int32x4_t lhs, int32x4_t rhs) {
|
||||
int64x2x2_t result;
|
||||
const int64x2_t lhs_low = vmovl_s32(vget_low_s32(lhs));
|
||||
|
@ -55,16 +55,6 @@ void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
|
||||
vectors, scaling_factors, n_batch, scratch, result, context);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset) {
|
||||
NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
|
||||
vectors, scaling_factors, n_batch, result, per_channel_scale,
|
||||
input_offset);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
|
@ -62,12 +62,6 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context);
|
||||
|
||||
void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset);
|
||||
|
||||
void NeonApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights,
|
||||
const int32_t* bias, int32_t layer_norm_scale_a,
|
||||
int32_t layer_norm_scale_b, int32_t variance_limit,
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -89,18 +90,24 @@ float GetFloatVectorElement(__m128 v) {
|
||||
|
||||
} // namespace
|
||||
|
||||
void SseMatrixBatchVectorMultiplyAccumulate(
|
||||
void SseMatrixBatchVectorMultiplyAccumulateImpl(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors,
|
||||
const float* __restrict__ scaling_factors, int n_batch,
|
||||
float* __restrict__ result) {
|
||||
float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, const int32_t* row_sums) {
|
||||
for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
|
||||
const float batch_scaling_factor = scaling_factors[batch];
|
||||
const int32_t batch_offset = input_offset ? input_offset[batch] : 0;
|
||||
// Compute dot-product for every column.
|
||||
for (std::intptr_t row = 0; row < m_rows; ++row) {
|
||||
// Get the address of the first element of the row.
|
||||
const int8_t* __restrict__ row_ptr = matrix + row * m_cols;
|
||||
|
||||
const float row_scale =
|
||||
per_channel_scale ? per_channel_scale[row] * batch_scaling_factor
|
||||
: batch_scaling_factor;
|
||||
const int32_t row_offset =
|
||||
row_sums && batch_offset ? batch_offset * row_sums[row] : 0;
|
||||
// Initialize the dot product sum for the row to 0.
|
||||
__m128i dotprod_32x4 = _mm_setzero_si128();
|
||||
std::intptr_t col = 0;
|
||||
@ -152,8 +159,10 @@ void SseMatrixBatchVectorMultiplyAccumulate(
|
||||
for (; col < m_cols; ++col) {
|
||||
sum += row_ptr[col] * vectors[col];
|
||||
} // for col
|
||||
|
||||
*result += sum * batch_scaling_factor;
|
||||
if (row_offset) {
|
||||
sum -= row_offset;
|
||||
}
|
||||
*result += sum * row_scale;
|
||||
++result;
|
||||
} // for row
|
||||
|
||||
@ -165,56 +174,30 @@ void SseMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors,
|
||||
const float* __restrict__ scaling_factors, int n_batch,
|
||||
float* __restrict__ result, const float* __restrict__ per_channel_scale,
|
||||
const int32_t* __restrict__ input_offset) {
|
||||
if (input_offset == nullptr) {
|
||||
SseMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
||||
scaling_factors, n_batch, result);
|
||||
return;
|
||||
}
|
||||
static constexpr std::intptr_t kBlockSize = 16;
|
||||
for (std::intptr_t batch = 0; batch < n_batch; ++batch) {
|
||||
const float batch_scaling_factor = scaling_factors[batch];
|
||||
for (std::intptr_t row = 0; row < m_rows; ++row) {
|
||||
const int8_t* __restrict__ row_ptr = matrix + row * m_cols;
|
||||
float scale = batch_scaling_factor;
|
||||
if (per_channel_scale != nullptr) {
|
||||
scale *= per_channel_scale[row];
|
||||
}
|
||||
__m128i dotprod_32x4 = _mm_setzero_si128();
|
||||
__m128i row_sum_16x8 = _mm_setzero_si128();
|
||||
std::intptr_t col = 0;
|
||||
for (; col < (m_cols & ~(kBlockSize - 1)); col += kBlockSize) {
|
||||
const __m128i vec_8x16 =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col));
|
||||
const __m128i row_8x16 =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
|
||||
// dotprod += vec · row
|
||||
dotprod_32x4 =
|
||||
_mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16));
|
||||
float* __restrict__ result) {
|
||||
SseMatrixBatchVectorMultiplyAccumulateImpl(
|
||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
|
||||
/*per_channel_scale=*/nullptr, /*input_offset=*/nullptr,
|
||||
/*row_sums=*/nullptr);
|
||||
}
|
||||
|
||||
// Pairwise add 16x 8-bit values; equivalently, multipy-add with 1.
|
||||
// Result is 8x 16-bit values.
|
||||
const __m128i row_16x8 = _mm_maddubs_epi16(_mm_set1_epi8(1), row_8x16);
|
||||
row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8);
|
||||
} // for col
|
||||
// Pairwise add 8x 16-bit values; equivalently, multipy-add with 1.
|
||||
// Result is 4x 32-bit values.
|
||||
const __m128i row_sum_32x4 =
|
||||
_mm_madd_epi16(row_sum_16x8, _mm_set1_epi16(1));
|
||||
int32_t sum = ReduceInt32x4(dotprod_32x4);
|
||||
int32_t row_sum = ReduceInt32x4(row_sum_32x4);
|
||||
// Postamble loop.
|
||||
for (; col < m_cols; ++col) {
|
||||
sum += row_ptr[col] * vectors[col];
|
||||
row_sum += row_ptr[col];
|
||||
} // for col
|
||||
sum -= row_sum * input_offset[batch];
|
||||
*result += sum * scale;
|
||||
++result;
|
||||
} // for row
|
||||
vectors += m_cols;
|
||||
} // for batch
|
||||
void SseMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors,
|
||||
const float* __restrict__ scaling_factors, int n_batch,
|
||||
float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context) {
|
||||
if ((input_offset != nullptr) && (!compute_row_sums || *compute_row_sums)) {
|
||||
memset(row_sums, 0, sizeof(int32_t) * m_rows);
|
||||
SseReductionSumVector(matrix, row_sums, m_rows, m_cols);
|
||||
if (compute_row_sums) {
|
||||
*compute_row_sums = false;
|
||||
}
|
||||
}
|
||||
SseMatrixBatchVectorMultiplyAccumulateImpl(
|
||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
|
||||
per_channel_scale, input_offset, row_sums);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -347,6 +330,44 @@ void SseSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
} // for batch
|
||||
}
|
||||
|
||||
void SseReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
|
||||
const int output_size, const int reduction_size) {
|
||||
static constexpr std::intptr_t kBlockSize = 16;
|
||||
for (std::intptr_t row = 0; row < output_size; ++row) {
|
||||
const int8_t* __restrict__ row_ptr = input_vector + row * reduction_size;
|
||||
__m128i row_sum_16x8 = _mm_setzero_si128();
|
||||
std::intptr_t col = 0;
|
||||
for (; col < (reduction_size & ~(kBlockSize - 1)); col += kBlockSize) {
|
||||
const __m128i row_8x16 =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col));
|
||||
const __m128i row_16x8 = _mm_maddubs_epi16(_mm_set1_epi8(1), row_8x16);
|
||||
row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8);
|
||||
} // for col
|
||||
#ifdef __SSE4_1__
|
||||
// Postamble for 8x 8-bit inputs.
|
||||
if (col < (reduction_size & ~7)) {
|
||||
// _mm_loadu_si64 not supported in gcc versions < 9, breaks kokoro build.
|
||||
const __m128i row_16x8 = _mm_cvtepi8_epi16(
|
||||
_mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col)));
|
||||
// dotprod += vec · row
|
||||
row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8);
|
||||
col += 8;
|
||||
}
|
||||
#endif
|
||||
const __m128i row_sum_32x4 =
|
||||
_mm_madd_epi16(row_sum_16x8, _mm_set1_epi16(1));
|
||||
int32_t row_sum = ReduceInt32x4(row_sum_32x4);
|
||||
#if defined(__SSE4_1__) && defined(__clang__)
|
||||
// SSE 4.1: Don't try to unroll and vectorize this, already done above.
|
||||
#pragma clang loop unroll(disable) vectorize(disable)
|
||||
#endif
|
||||
for (; col < reduction_size; col++) {
|
||||
row_sum += *(row_ptr + col);
|
||||
}
|
||||
*(output_vector + row) += row_sum;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensor_utils
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -59,10 +59,9 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result,
|
||||
per_channel_scale, input_offset, scratch, row_sums, compute_row_sums,
|
||||
context);
|
||||
SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
|
||||
vectors, scaling_factors, n_batch, result, per_channel_scale,
|
||||
input_offset, scratch, row_sums, compute_row_sums, context);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
@ -75,17 +74,6 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
vectors, scaling_factors, n_batch, result);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors,
|
||||
const float* __restrict__ scaling_factors, int n_batch,
|
||||
float* __restrict__ result, const float* __restrict__ per_channel_scale,
|
||||
const int32_t* __restrict__ input_offset) {
|
||||
SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols,
|
||||
vectors, scaling_factors, n_batch, result, per_channel_scale,
|
||||
input_offset);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
@ -315,8 +303,8 @@ void ReductionSumVector(const int32_t* input_vector, int32_t* output_vector,
|
||||
|
||||
void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
|
||||
reduction_size);
|
||||
SSE_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size,
|
||||
reduction_size);
|
||||
}
|
||||
|
||||
void MeanStddevNormalization(const float* input_vector, float* output_vector,
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define __restrict__ __restrict
|
||||
#endif
|
||||
@ -38,8 +40,9 @@ void SseMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors,
|
||||
const float* __restrict__ scaling_factors, int n_batch,
|
||||
float* __restrict__ result, const float* __restrict__ per_channel_scale,
|
||||
const int32_t* __restrict__ input_offset);
|
||||
float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset, int32_t* scratch, int32_t* row_sums,
|
||||
bool* compute_row_sums, CpuBackendContext* context);
|
||||
|
||||
// Matrix multiplication for quantized values using symmetric quantization.
|
||||
// Sparse version.
|
||||
@ -49,6 +52,9 @@ void SseSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ scaling_factors, int n_batch,
|
||||
float* __restrict__ result);
|
||||
|
||||
void SseReductionSumVector(const int8_t* input_vector, int32_t* output_vector,
|
||||
const int output_size, const int reduction_size);
|
||||
|
||||
#endif // __SSSE3__
|
||||
|
||||
} // namespace tensor_utils
|
||||
|
@ -161,35 +161,6 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
} // for batch
|
||||
}
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset) {
|
||||
for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
|
||||
const float batch_scaling_factor = scaling_factors[batch];
|
||||
const float batch_offset = input_offset[batch];
|
||||
const int8_t* row_ptr = matrix;
|
||||
for (int row = 0; row < m_rows; ++row) {
|
||||
int32_t dotprod = 0;
|
||||
float scale = batch_scaling_factor;
|
||||
if (per_channel_scale) {
|
||||
scale *= per_channel_scale[row];
|
||||
}
|
||||
#if defined(__GNUC__)
|
||||
// Prefetch the row to cache.
|
||||
__builtin_prefetch(row_ptr, 0 /* prefetch for read */,
|
||||
3 /* temporal locality */);
|
||||
#endif
|
||||
for (int col = 0; col < m_cols; ++col, ++row_ptr) {
|
||||
dotprod += (*row_ptr) * (vectors[col] - batch_offset);
|
||||
} // for col
|
||||
*result += dotprod * scale;
|
||||
++result;
|
||||
} // for row
|
||||
} // for batch
|
||||
}
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
|
@ -98,16 +98,6 @@ void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix,
|
||||
scaling_factors, n_batch, result);
|
||||
}
|
||||
|
||||
void MatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset) {
|
||||
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors,
|
||||
scaling_factors, n_batch, result,
|
||||
per_channel_scale, input_offset);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
|
@ -83,12 +83,6 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
int n_batch, int32_t* scratch, float* __restrict__ result,
|
||||
CpuBackendContext* context);
|
||||
|
||||
void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* __restrict__ matrix, const int m_rows, const int m_cols,
|
||||
const int8_t* __restrict__ vectors, const float* scaling_factors,
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset);
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
|
@ -1136,11 +1136,15 @@ std::vector<float> TestPerChannelDotprodMatrixBatchVectorMultiply(
|
||||
bool is_per_channel = true) {
|
||||
MatrixVectorData data =
|
||||
SetupMatrixVectorData(rows, cols, batch, negative, is_per_channel);
|
||||
|
||||
std::vector<int32_t> scratch(rows * batch);
|
||||
std::vector<int32_t> row_sums(rows);
|
||||
bool compute_row_sums = true;
|
||||
CpuBackendContext context;
|
||||
MatrixBatchVectorMultiplyAccumulate(
|
||||
data.matrix.data(), rows, cols, data.vectors.data(),
|
||||
data.scale_factors.data(), batch, &data.results[0],
|
||||
data.per_channel_scales.data(), data.input_offsets.data());
|
||||
data.per_channel_scales.data(), data.input_offsets.data(), scratch.data(),
|
||||
row_sums.data(), &compute_row_sums, &context);
|
||||
return data.results;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user