Sparse fully connected kernel with 1x4 block config.
PiperOrigin-RevId: 310222893 Change-Id: Ia0e7d4172f06ee62043d641f3d24409139ece712
This commit is contained in:
parent
84135ce118
commit
3e78c70559
@ -50,6 +50,10 @@ bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static const int kDimMetadataSizeRandomSparse = 2;
|
||||
static const int kDimMetadataSizeBlockSparse = 3;
|
||||
|
||||
} // namespace
|
||||
|
||||
// This file has four implementations of FullyConnected
|
||||
@ -652,15 +656,33 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
const auto& sparsity = *filter->sparsity;
|
||||
if (!SupportedSparsityFormat(sparsity)) {
|
||||
context->ReportError(context,
|
||||
"Unsupported sparse fully-connected weight format.");
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Unsupported sparse fully-connected weight format.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
if (sparsity.dim_metadata_size == kDimMetadataSizeRandomSparse) {
|
||||
// Random sparse.
|
||||
optimized_ops::FullyConnectedSparseWeight(
|
||||
sparsity, op_params, GetTensorShape(input),
|
||||
GetTensorData<float>(input), GetTensorShape(filter),
|
||||
GetTensorData<float>(filter), GetTensorShape(bias),
|
||||
GetTensorData<float>(bias), GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
} else if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
|
||||
sparsity.dim_metadata[2].dense_size == 4) {
|
||||
// Block sparse with block size of 1x4.
|
||||
optimized_ops::FullyConnectedSparseWeight1x4(
|
||||
sparsity, op_params, GetTensorShape(input),
|
||||
GetTensorData<float>(input), GetTensorShape(filter),
|
||||
GetTensorData<float>(filter), GetTensorShape(bias),
|
||||
GetTensorData<float>(bias), GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
} else {
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Unsupported sparse fully-connected weight format.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
optimized_ops::FullyConnectedSparseWeight(
|
||||
sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||
GetTensorShape(output), GetTensorData<float>(output));
|
||||
} else if (kernel_type == kLegacyPie) {
|
||||
return EvalPie(context, node, params, data, input, filter, bias, output);
|
||||
} else {
|
||||
|
@ -1243,6 +1243,35 @@ TEST_P(SparseFullyConnectedOpTest, SimpleTest2) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
|
||||
}
|
||||
|
||||
TEST_P(SparseFullyConnectedOpTest, Simple1x4Test) {
|
||||
std::initializer_list<float> weight_data = {
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 0
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 1
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 2
|
||||
};
|
||||
TensorData weight = {};
|
||||
weight.type = TensorType_FLOAT32;
|
||||
weight.shape = {3, 12};
|
||||
weight.traversal_order = {0, 1, 2};
|
||||
weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
|
||||
weight.block_map = {1};
|
||||
weight.block_size = {4};
|
||||
SparseFullyConnectedOpModel<float> m(GetRegistration(),
|
||||
/*units=*/3, /*batches=*/2,
|
||||
/*input=*/{TensorType_FLOAT32, {2, 12}},
|
||||
weight, weight_data);
|
||||
m.SetBias({1, 2, 3});
|
||||
|
||||
m.SetInput({
|
||||
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, 11, 12, // b = 0
|
||||
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, -11, 12, // b = 1
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, 81, 82, 83));
|
||||
}
|
||||
// TODO(b/148391360): Add tests for unsupported sparsity format.
|
||||
// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
|
||||
|
||||
|
@ -1954,6 +1954,36 @@ void NeonCwiseClipping(int8_t* input, const int8_t clipping_value,
|
||||
}
|
||||
}
|
||||
|
||||
void NeonSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
|
||||
const int kBlockSize = 4;
|
||||
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
|
||||
|
||||
for (int batch = 0; batch < n_batch; batch++) {
|
||||
const float* matrix_ptr = matrix;
|
||||
for (int row = 0; row < m_rows; row++) {
|
||||
float32x4_t acc_32x4 = vmovq_n_f32(0.0);
|
||||
const float* vector_in_batch = vector + batch * m_cols;
|
||||
|
||||
for (int i = segments[row]; i < segments[row + 1]; i++) {
|
||||
const int block_start_index = indices[i] * kBlockSize;
|
||||
const float* vector_block_in_batch_ptr =
|
||||
vector_in_batch + block_start_index;
|
||||
|
||||
// Load 4 float values from the vector and matrix row.
|
||||
float32x4_t vector_f32x4 = vld1q_f32(vector_block_in_batch_ptr);
|
||||
float32x4_t matrix_f32x4 = vld1q_f32(matrix_ptr);
|
||||
// Multiply the vector and matrix row and add to accumulator.
|
||||
acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
|
||||
matrix_ptr += kBlockSize;
|
||||
}
|
||||
result[batch * m_rows + row] += AccumulateNeonLane(acc_32x4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NeonSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
|
@ -76,6 +76,14 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
input_offset, scratch, row_sums, compute_row_sums, context);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
|
||||
NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x4, matrix,
|
||||
segments, indices, m_rows, m_cols, vector, n_batch, result);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
|
@ -111,6 +111,11 @@ void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
|
||||
int32_t n_row, int32_t n_col,
|
||||
int32_t* output);
|
||||
|
||||
void NeonSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result);
|
||||
|
||||
// Multiply a matrix by a batch vector, and store results in a batch-size
|
||||
// vector. Sparse version.
|
||||
void NeonSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/cppmath.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -68,6 +69,42 @@ inline void FullyConnectedSparseWeight(
|
||||
}
|
||||
}
|
||||
|
||||
inline void FullyConnectedSparseWeight1x4(
|
||||
const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
|
||||
const RuntimeShape& input_shape, const float* input_data,
|
||||
const RuntimeShape& weights_shape, const float* weights_data,
|
||||
const RuntimeShape& bias_shape, const float* bias_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
const float output_activation_min = params.float_activation_min;
|
||||
const float output_activation_max = params.float_activation_max;
|
||||
|
||||
const int output_elements = output_shape.FlatSize();
|
||||
const int output_dims_count = output_shape.DimensionsCount();
|
||||
const int weights_dims_count = weights_shape.DimensionsCount();
|
||||
const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
|
||||
const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
|
||||
output_shape, output_dims_count - 1);
|
||||
const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
|
||||
const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
|
||||
|
||||
for (int i = 0; i < output_elements; ++i) {
|
||||
output_data[i] = 0.f;
|
||||
}
|
||||
|
||||
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
weights_data, w1_segments, w1_indices, weights_shape.Dims(0),
|
||||
weights_shape.Dims(1), input_data, batches, output_data);
|
||||
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
for (int i = 0; i < output_depth; ++i) {
|
||||
float total = output_data[b * output_depth + i];
|
||||
float bias_value = bias_data[i];
|
||||
output_data[b * output_depth + i] = ActivationFunctionWithMinMax(
|
||||
total + bias_value, output_activation_min, output_activation_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
|
||||
|
@ -86,6 +86,14 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
input_offset);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
|
||||
NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x4, matrix,
|
||||
segments, indices, m_rows, m_cols, vector, n_batch, result);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
|
@ -234,6 +234,30 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
} // for batch
|
||||
}
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
|
||||
const int kBlockSize = 4;
|
||||
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
|
||||
for (int batch = 0; batch < n_batch; batch++) {
|
||||
const float* matrix_ptr = matrix;
|
||||
for (int row = 0; row < m_rows; row++) {
|
||||
float dot_prod = 0.0f;
|
||||
const float* vector_in_batch = vector + batch * m_cols;
|
||||
for (int i = segments[row]; i < segments[row + 1]; i++) {
|
||||
const int block_start_index = indices[i] * kBlockSize;
|
||||
const float* vector_block_in_batch_ptr =
|
||||
vector_in_batch + block_start_index;
|
||||
for (int c = 0; c < kBlockSize; c++) {
|
||||
dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++;
|
||||
}
|
||||
}
|
||||
result[batch * m_rows + row] += dot_prod;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
|
@ -109,6 +109,14 @@ void MatrixBatchVectorMultiplyAccumulate(
|
||||
per_channel_scale, input_offset);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
|
||||
PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
matrix, segments, indices, m_rows, m_cols, vector, n_batch, result);
|
||||
}
|
||||
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
|
@ -85,6 +85,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
|
||||
int n_batch, float* __restrict__ result, const float* per_channel_scale,
|
||||
const int32_t* input_offset);
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result);
|
||||
|
||||
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
|
||||
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
|
||||
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,
|
||||
|
@ -65,6 +65,15 @@ void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
|
||||
int m_cols, const float* vector,
|
||||
int n_batch, float* result);
|
||||
|
||||
// Same as the function above, but the matrix is a sparse tensor with block
|
||||
// pattern 1x4.
|
||||
// This function assumes that m_cols is a multiple of the block size (4 in this
|
||||
// case) so that there's no incomplete block.
|
||||
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
|
||||
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
|
||||
const int32_t* __restrict__ indices, int m_rows, int m_cols,
|
||||
const float* __restrict__ vector, int n_batch, float* __restrict__ result);
|
||||
|
||||
// Same as the function above, but the matrix is stored in block compressed
|
||||
// sparse row format with block pattern 1x16 which consists of two arrays:
|
||||
// 1. A matrix array stores non-zero blocks of the matrix in row major.
|
||||
|
Loading…
x
Reference in New Issue
Block a user