Sparse fully connected kernel with 1x4 block config.

PiperOrigin-RevId: 310222893
Change-Id: Ia0e7d4172f06ee62043d641f3d24409139ece712
This commit is contained in:
Yunlu Li 2020-05-06 14:02:31 -07:00 committed by TensorFlower Gardener
parent 84135ce118
commit 3e78c70559
11 changed files with 192 additions and 7 deletions

View File

@ -50,6 +50,10 @@ bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
return false;
}
static const int kDimMetadataSizeRandomSparse = 2;
static const int kDimMetadataSizeBlockSparse = 3;
} // namespace
// This file has four implementations of FullyConnected
@ -652,15 +656,33 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
const auto& sparsity = *filter->sparsity;
if (!SupportedSparsityFormat(sparsity)) {
context->ReportError(context,
"Unsupported sparse fully-connected weight format.");
TF_LITE_KERNEL_LOG(context,
"Unsupported sparse fully-connected weight format.");
return kTfLiteError;
}
if (sparsity.dim_metadata_size == kDimMetadataSizeRandomSparse) {
// Random sparse.
optimized_ops::FullyConnectedSparseWeight(
sparsity, op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output));
} else if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
sparsity.dim_metadata[2].dense_size == 4) {
// Block sparse with block size of 1x4.
optimized_ops::FullyConnectedSparseWeight1x4(
sparsity, op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
GetTensorData<float>(bias), GetTensorShape(output),
GetTensorData<float>(output));
} else {
TF_LITE_KERNEL_LOG(context,
"Unsupported sparse fully-connected weight format.");
return kTfLiteError;
}
optimized_ops::FullyConnectedSparseWeight(
sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(filter), GetTensorData<float>(filter),
GetTensorShape(bias), GetTensorData<float>(bias),
GetTensorShape(output), GetTensorData<float>(output));
} else if (kernel_type == kLegacyPie) {
return EvalPie(context, node, params, data, input, filter, bias, output);
} else {

View File

@ -1243,6 +1243,35 @@ TEST_P(SparseFullyConnectedOpTest, SimpleTest2) {
EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
}
TEST_P(SparseFullyConnectedOpTest, Simple1x4Test) {
std::initializer_list<float> weight_data = {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 1
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, // u = 2
};
TensorData weight = {};
weight.type = TensorType_FLOAT32;
weight.shape = {3, 12};
weight.traversal_order = {0, 1, 2};
weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
weight.block_map = {1};
weight.block_size = {4};
SparseFullyConnectedOpModel<float> m(GetRegistration(),
/*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 12}},
weight, weight_data);
m.SetBias({1, 2, 3});
m.SetInput({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, 11, 12, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, -11, 12, // b = 1
});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
EXPECT_THAT(m.GetOutput(), ElementsAre(289, 290, 291, 81, 82, 83));
}
// TODO(b/148391360): Add tests for unsupported sparsity format.
// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)

View File

@ -1954,6 +1954,36 @@ void NeonCwiseClipping(int8_t* input, const int8_t clipping_value,
}
}
void NeonSparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
const int kBlockSize = 4;
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
for (int batch = 0; batch < n_batch; batch++) {
const float* matrix_ptr = matrix;
for (int row = 0; row < m_rows; row++) {
float32x4_t acc_32x4 = vmovq_n_f32(0.0);
const float* vector_in_batch = vector + batch * m_cols;
for (int i = segments[row]; i < segments[row + 1]; i++) {
const int block_start_index = indices[i] * kBlockSize;
const float* vector_block_in_batch_ptr =
vector_in_batch + block_start_index;
// Load 4 float values from the vector and matrix row.
float32x4_t vector_f32x4 = vld1q_f32(vector_block_in_batch_ptr);
float32x4_t matrix_f32x4 = vld1q_f32(matrix_ptr);
// Multiply the vector and matrix row and add to accumulator.
acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4);
matrix_ptr += kBlockSize;
}
result[batch * m_rows + row] += AccumulateNeonLane(acc_32x4);
}
}
}
void NeonSparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,

View File

@ -76,6 +76,14 @@ void MatrixBatchVectorMultiplyAccumulate(
input_offset, scratch, row_sums, compute_row_sums, context);
}
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x4, matrix,
segments, indices, m_rows, m_cols, vector, n_batch, result);
}
void SparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,

View File

@ -111,6 +111,11 @@ void NeonMatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar,
int32_t n_row, int32_t n_col,
int32_t* output);
void NeonSparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result);
// Multiply a matrix by a batch vector, and store results in a batch-size
// vector. Sparse version.
void NeonSparseMatrixBatchVectorMultiplyAccumulate(

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
@ -68,6 +69,42 @@ inline void FullyConnectedSparseWeight(
}
}
inline void FullyConnectedSparseWeight1x4(
const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& weights_shape, const float* weights_data,
const RuntimeShape& bias_shape, const float* bias_data,
const RuntimeShape& output_shape, float* output_data) {
const float output_activation_min = params.float_activation_min;
const float output_activation_max = params.float_activation_max;
const int output_elements = output_shape.FlatSize();
const int output_dims_count = output_shape.DimensionsCount();
const int weights_dims_count = weights_shape.DimensionsCount();
const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
output_shape, output_dims_count - 1);
const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
for (int i = 0; i < output_elements; ++i) {
output_data[i] = 0.f;
}
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate1x4(
weights_data, w1_segments, w1_indices, weights_shape.Dims(0),
weights_shape.Dims(1), input_data, batches, output_data);
for (int b = 0; b < batches; ++b) {
for (int i = 0; i < output_depth; ++i) {
float total = output_data[b * output_depth + i];
float bias_value = bias_data[i];
output_data[b * output_depth + i] = ActivationFunctionWithMinMax(
total + bias_value, output_activation_min, output_activation_max);
}
}
}
} // namespace optimized_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_

View File

@ -86,6 +86,14 @@ void MatrixBatchVectorMultiplyAccumulate(
input_offset);
}
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
NEON_OR_PORTABLE(SparseMatrixBatchVectorMultiplyAccumulate1x4, matrix,
segments, indices, m_rows, m_cols, vector, n_batch, result);
}
void SparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,

View File

@ -234,6 +234,30 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
} // for batch
}
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
const int kBlockSize = 4;
TFLITE_DCHECK_EQ(m_cols % kBlockSize, 0);
for (int batch = 0; batch < n_batch; batch++) {
const float* matrix_ptr = matrix;
for (int row = 0; row < m_rows; row++) {
float dot_prod = 0.0f;
const float* vector_in_batch = vector + batch * m_cols;
for (int i = segments[row]; i < segments[row + 1]; i++) {
const int block_start_index = indices[i] * kBlockSize;
const float* vector_block_in_batch_ptr =
vector_in_batch + block_start_index;
for (int c = 0; c < kBlockSize; c++) {
dot_prod += *matrix_ptr++ * *vector_block_in_batch_ptr++;
}
}
result[batch * m_rows + row] += dot_prod;
}
}
}
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,

View File

@ -109,6 +109,14 @@ void MatrixBatchVectorMultiplyAccumulate(
per_channel_scale, input_offset);
}
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result) {
PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
matrix, segments, indices, m_rows, m_cols, vector, n_batch, result);
}
void SparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,

View File

@ -85,6 +85,11 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
int n_batch, float* __restrict__ result, const float* per_channel_scale,
const int32_t* input_offset);
void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result);
void PortableSparseMatrixBatchVectorMultiplyAccumulate(
const float* __restrict__ matrix, const uint8_t* __restrict__ ledger,
int m_rows, int m_cols, const float* __restrict__ vector, int n_batch,

View File

@ -65,6 +65,15 @@ void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
int m_cols, const float* vector,
int n_batch, float* result);
// Same as the function above, but the matrix is a sparse tensor with block
// pattern 1x4.
// This function assumes that m_cols is a multiple of the block size (4 in this
// case) so that there's no incomplete block.
void SparseMatrixBatchVectorMultiplyAccumulate1x4(
const float* __restrict__ matrix, const int32_t* __restrict__ segments,
const int32_t* __restrict__ indices, int m_rows, int m_cols,
const float* __restrict__ vector, int n_batch, float* __restrict__ result);
// Same as the function above, but the matrix is stored in block compressed
// sparse row format with block pattern 1x16 which consists of two arrays:
// 1. A matrix array stores non-zero blocks of the matrix in row major.