From 3e78c70559d6b04a58db77e7f5563016315384e7 Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Wed, 6 May 2020 14:02:31 -0700 Subject: [PATCH] Sparse fully connected kernel with 1x4 block config. PiperOrigin-RevId: 310222893 Change-Id: Ia0e7d4172f06ee62043d641f3d24409139ece712 --- tensorflow/lite/kernels/fully_connected.cc | 36 ++++++++++++++---- .../lite/kernels/fully_connected_test.cc | 29 +++++++++++++++ .../internal/optimized/neon_tensor_utils.cc | 30 +++++++++++++++ .../internal/optimized/neon_tensor_utils.h | 8 ++++ .../optimized/neon_tensor_utils_impl.h | 5 +++ .../optimized/sparse_ops/fully_connected.h | 37 +++++++++++++++++++ .../internal/optimized/sse_tensor_utils.h | 8 ++++ .../reference/portable_tensor_utils.cc | 24 ++++++++++++ .../reference/portable_tensor_utils.h | 8 ++++ .../reference/portable_tensor_utils_impl.h | 5 +++ .../lite/kernels/internal/tensor_utils.h | 9 +++++ 11 files changed, 192 insertions(+), 7 deletions(-) diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 62a4ede9a06..1cd1b14e7a8 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -50,6 +50,10 @@ bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) { return false; } + +static const int kDimMetadataSizeRandomSparse = 2; +static const int kDimMetadataSizeBlockSparse = 3; + } // namespace // This file has four implementations of FullyConnected @@ -652,15 +656,33 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, const auto& sparsity = *filter->sparsity; if (!SupportedSparsityFormat(sparsity)) { - context->ReportError(context, - "Unsupported sparse fully-connected weight format."); + TF_LITE_KERNEL_LOG(context, + "Unsupported sparse fully-connected weight format."); + return kTfLiteError; + } + + if (sparsity.dim_metadata_size == kDimMetadataSizeRandomSparse) { + // Random sparse. + optimized_ops::FullyConnectedSparseWeight( + sparsity, op_params, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); + } else if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse && + sparsity.dim_metadata[2].dense_size == 4) { + // Block sparse with block size of 1x4. + optimized_ops::FullyConnectedSparseWeight1x4( + sparsity, op_params, GetTensorShape(input), + GetTensorData(input), GetTensorShape(filter), + GetTensorData(filter), GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output)); + } else { + TF_LITE_KERNEL_LOG(context, + "Unsupported sparse fully-connected weight format."); return kTfLiteError; } - optimized_ops::FullyConnectedSparseWeight( - sparsity, op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(filter), GetTensorData(filter), - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output)); } else if (kernel_type == kLegacyPie) { return EvalPie(context, node, params, data, input, filter, bias, output); } else { diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index a600af82145..34d68cf0b0d 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -1243,6 +1243,35 @@ TEST_P(SparseFullyConnectedOpTest, SimpleTest2) { EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9)); } +TEST_P(SparseFullyConnectedOpTest, Simple1x4Test) { + std::initializer_list weight_data = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 2 + }; + TensorData weight = {}; + weight.type = TensorType_FLOAT32; + weight.shape = {3, 12}; + weight.traversal_order = {0, 1, 2}; + weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR}; + weight.block_map = {1}; + weight.block_size = {4}; + SparseFullyConnectedOpModel m(GetRegistration(), + /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 12}}, + weight, weight_data); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, 11, 12, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, -11, 12, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, 81, 82, 83)); +} // TODO(b/148391360): Add tests for unsupported sparsity format. // TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat) diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 07f3117dac7..4d8c20074d5 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -1954,6 +1954,36 @@ void NeonCwiseClipping(int8_t* input, const int8_t clipping_value, } } +void NeonSparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result) { + const int kBlockSize = 4; + TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0); + + for (int batch = 0; batch < n_batch; batch++) { + const float* matrix_ptr = matrix; + for (int row = 0; row < m_rows; row++) { + float32x4_t acc_32x4 = vmovq_n_f32(0.0); + const float* vector_in_batch = vector + batch * m_cols; + + for (int i = segments[row]; i < segments[row + 1]; i++) { + const int block_start_index = indices[i] * kBlockSize; + const float* vector_block_in_batch_ptr = + vector_in_batch + block_start_index; + + // Load 4 float values from the vector and matrix row. + float32x4_t vector_f32x4 = vld1q_f32(vector_block_in_batch_ptr); + float32x4_t matrix_f32x4 = vld1q_f32(matrix_ptr); + // Multiply the vector and matrix row and add to accumulator. + acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4); + matrix_ptr += kBlockSize; + } + result[batch * m_rows + row] += AccumulateNeonLane(acc_32x4); + } + } +} + void NeonSparseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ matrix, const uint8_t* __restrict__ ledger, int m_rows, int m_cols, const float* __restrict__ vector, int n_batch, diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h index 98ae3c976df..b978bf5f3bb 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -76,6 +76,14 @@ void MatrixBatchVectorMultiplyAccumulate( input_offset, scratch, row_sums, compute_row_sums, context); } +void SparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result) { + NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x4, matrix, + segments, indices, m_rows, m_cols, vector, n_batch, result); +} + void SparseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ matrix, const uint8_t* __restrict__ ledger, int m_rows, int m_cols, const float* __restrict__ vector, int n_batch, diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h index 059accb0222..1b043390c22 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h @@ -111,6 +111,11 @@ void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar, int32_t n_row, int32_t n_col, int32_t* output); +void NeonSparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result); + // Multiply a matrix by a batch vector, and store results in a batch-size // vector. Sparse version. void NeonSparseMatrixBatchVectorMultiplyAccumulate( diff --git a/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h index f7e54e144ce..750e63e152f 100644 --- a/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -68,6 +69,42 @@ inline void FullyConnectedSparseWeight( } } +inline void FullyConnectedSparseWeight1x4( + const TfLiteSparsity& sparsity, const FullyConnectedParams& params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& weights_shape, const float* weights_data, + const RuntimeShape& bias_shape, const float* bias_data, + const RuntimeShape& output_shape, float* output_data) { + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; + + const int output_elements = output_shape.FlatSize(); + const int output_dims_count = output_shape.DimensionsCount(); + const int weights_dims_count = weights_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1); + const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2, + output_shape, output_dims_count - 1); + const int* w1_segments = sparsity.dim_metadata[1].array_segments->data; + const int* w1_indices = sparsity.dim_metadata[1].array_indices->data; + + for (int i = 0; i < output_elements; ++i) { + output_data[i] = 0.f; + } + + tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate1x4( + weights_data, w1_segments, w1_indices, weights_shape.Dims(0), + weights_shape.Dims(1), input_data, batches, output_data); + + for (int b = 0; b < batches; ++b) { + for (int i = 0; i < output_depth; ++i) { + float total = output_data[b * output_depth + i]; + float bias_value = bias_data[i]; + output_data[b * output_depth + i] = ActivationFunctionWithMinMax( + total + bias_value, output_activation_min, output_activation_max); + } + } +} + } // namespace optimized_ops } // namespace tflite #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_ diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h index 1d0d2273e93..986e70a7823 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h @@ -86,6 +86,14 @@ void MatrixBatchVectorMultiplyAccumulate( input_offset); } +void SparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result) { + NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x4, matrix, + segments, indices, m_rows, m_cols, vector, n_batch, result); +} + void SparseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ matrix, const uint8_t* __restrict__ ledger, int m_rows, int m_cols, const float* __restrict__ vector, int n_batch, diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index 19c74973aeb..22e37d5af71 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -234,6 +234,30 @@ void PortableMatrixBatchVectorMultiplyAccumulate( } // for batch } +void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result) { + const int kBlockSize = 4; + TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0); + for (int batch = 0; batch < n_batch; batch++) { + const float* matrix_ptr = matrix; + for (int row = 0; row < m_rows; row++) { + float dot_prod = 0.0f; + const float* vector_in_batch = vector + batch * m_cols; + for (int i = segments[row]; i < segments[row + 1]; i++) { + const int block_start_index = indices[i] * kBlockSize; + const float* vector_block_in_batch_ptr = + vector_in_batch + block_start_index; + for (int c = 0; c < kBlockSize; c++) { + dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++; + } + } + result[batch * m_rows + row] += dot_prod; + } + } +} + void PortableSparseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ matrix, const uint8_t* __restrict__ ledger, int m_rows, int m_cols, const float* __restrict__ vector, int n_batch, diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h index c4f886f6a5c..9a365074513 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h @@ -109,6 +109,14 @@ void MatrixBatchVectorMultiplyAccumulate( per_channel_scale, input_offset); } +void SparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result) { + PortableSparseMatrixBatchVectorMultiplyAccumulate1x4( + matrix, segments, indices, m_rows, m_cols, vector, n_batch, result); +} + void SparseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ matrix, const uint8_t* __restrict__ ledger, int m_rows, int m_cols, const float* __restrict__ vector, int n_batch, diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h index 04fedc327d0..d8bd70f3722 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h @@ -85,6 +85,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate( int n_batch, float* __restrict__ result, const float* per_channel_scale, const int32_t* input_offset); +void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result); + void PortableSparseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ matrix, const uint8_t* __restrict__ ledger, int m_rows, int m_cols, const float* __restrict__ vector, int n_batch, diff --git a/tensorflow/lite/kernels/internal/tensor_utils.h b/tensorflow/lite/kernels/internal/tensor_utils.h index e8edf2f59f2..1929c2e2ff4 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/lite/kernels/internal/tensor_utils.h @@ -65,6 +65,15 @@ void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, int m_cols, const float* vector, int n_batch, float* result); +// Same as the function above, but the matrix is a sparse tensor with block +// pattern 1x4. +// This function assumes that m_cols is a multiple of the block size (4 in this +// case) so that there's no incomplete block. +void SparseMatrixBatchVectorMultiplyAccumulate1x4( + const float* __restrict__ matrix, const int32_t* __restrict__ segments, + const int32_t* __restrict__ indices, int m_rows, int m_cols, + const float* __restrict__ vector, int n_batch, float* __restrict__ result); + // Same as the function above, but the matrix is stored in block compressed // sparse row format with block pattern 1x16 which consists of two arrays: // 1. A matrix array stores non-zero blocks of the matrix in row major.