From 57b89f0dfa8abc5da12d6e47a45cd9c8a2f86653 Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Thu, 25 Feb 2021 22:29:44 -0800 Subject: [PATCH] Add multi-threaded hybrid-sparse FC kernel. PiperOrigin-RevId: 359690999 Change-Id: Iddb23feff57ffca306795f91ee244b19bc41ad3a --- tensorflow/lite/kernels/fully_connected.cc | 221 +++++++++++++++--- .../lite/kernels/fully_connected_test.cc | 115 ++++++++- 2 files changed, 308 insertions(+), 28 deletions(-) diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 51cc86c09d0..6da963bd440 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -431,13 +431,13 @@ TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, - TfLiteFullyConnectedParams* params, OpData* data, - const TfLiteTensor* input, const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* input_quantized, - TfLiteTensor* scaling_factors, - TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, - TfLiteTensor* input_offsets, TfLiteTensor* output) { +TfLiteStatus EvalHybridDense( + TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors, + TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, + TfLiteTensor* input_offsets, TfLiteTensor* output) { int total_input_size = 1; for (int i = 0; i < input->dims->size; i++) { total_input_size *= input->dims->data[i]; @@ -446,7 +446,6 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, const int input_size = filter->dims->data[1]; const int batch_size = total_input_size / filter->dims->data[1]; const int num_units = filter->dims->data[0]; - const bool is_sparse = filter->sparsity != nullptr; // Output = bias if bias tensor exists. if (bias) { @@ -487,24 +486,11 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, // Compute output += weight * quantized_input int32_t* scratch = GetTensorData(accum_scratch); - if (is_sparse) { - TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]]; - if (!data->ledger_initialized) { - PopulateLedgerData(filter->sparsity, context, - GetTensorData(filter_ledger)); - data->ledger_initialized = true; - } - tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( - GetTensorData(filter), GetTensorData(filter_ledger), - num_units, input_size, quant_data, scaling_factors_ptr, batch_size, - GetTensorData(output)); - } else { - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - filter_data, num_units, input_size, quant_data, scaling_factors_ptr, - batch_size, GetTensorData(output), /*per_channel_scale=*/nullptr, - input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums, - CpuBackendContext::GetFromContext(context)); - } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + filter_data, num_units, input_size, quant_data, scaling_factors_ptr, + batch_size, GetTensorData(output), /*per_channel_scale=*/nullptr, + input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums, + CpuBackendContext::GetFromContext(context)); // Apply activation function to floats. tensor_utils::ApplyActivationToVector( @@ -513,6 +499,189 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } +void EvalSparseHybridImpl(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, int thread_start, + int thread_end, TfLiteTensor* input_quantized, + TfLiteTensor* scaling_factors, + TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, + TfLiteTensor* input_offsets, TfLiteTensor* output) { + ruy::profiler::ScopeLabel label("FullyConnected"); + ruy::profiler::ScopeLabel inner_label("Sparse Hybrid Kernel"); + const auto& input_shape = GetTensorShape(input); + const auto& output_shape = GetTensorShape(output); + const auto& filter_shape = GetTensorShape(filter); + const int input_dims_count = input_shape.DimensionsCount(); + const int output_dims_count = output_shape.DimensionsCount(); + const int filter_dims_count = filter_shape.DimensionsCount(); + const int batch_size = thread_end - thread_start; + const int input_depth = MatchingDim(filter_shape, filter_dims_count - 1, + input_shape, input_dims_count - 1); + const int output_depth = MatchingDim(filter_shape, filter_dims_count - 2, + output_shape, output_dims_count - 1); + const int per_thread_input_size = batch_size * input_depth; + + const float* per_thread_input = + GetTensorData(input) + thread_start * input_depth; + float* per_thread_output = + GetTensorData(output) + thread_start * output_depth; + + // Output = bias if bias tensor exists. + if (bias) { + tensor_utils::VectorBatchVectorAssign(GetTensorData(bias), + output_depth, batch_size, + per_thread_output); + } else { + std::fill_n(per_thread_output, batch_size * output_depth, 0.0f); + } + + // Save matrix multiplication computation for all zero input. + if (tensor_utils::IsZeroVector(per_thread_input, per_thread_input_size)) { + tensor_utils::ApplyActivationToVector( + per_thread_output, batch_size * output_depth, params->activation, + per_thread_output); + return; + } + + // Quantize input from float to uint8 + quantization params (scaling factor). + float* scaling_factors_ptr = + GetTensorData(scaling_factors) + thread_start; + int32_t* input_offset_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs) { + input_offset_ptr = GetTensorData(input_offsets) + thread_start; + row_sums_ptr = GetTensorData(row_sums); + } + int8_t* quant_data = + GetTensorData(input_quantized) + thread_start * input_depth; + tensor_utils::BatchQuantizeFloats(per_thread_input, batch_size, input_depth, + quant_data, scaling_factors_ptr, + input_offset_ptr, + params->asymmetric_quantize_inputs); + for (int b = 0; b < batch_size; ++b) { + // Incorporate scaling of the filter. + scaling_factors_ptr[b] *= filter->params.scale; + } + + // Compute output += weight * quantized_input + TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]]; + tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( + GetTensorData(filter), GetTensorData(filter_ledger), + output_depth, input_depth, quant_data, scaling_factors_ptr, batch_size, + per_thread_output); + + // Apply activation function to floats. + tensor_utils::ApplyActivationToVector(per_thread_output, + batch_size * output_depth, + params->activation, per_thread_output); +} + +struct SparseHybridFullyConnectedTask : cpu_backend_threadpool::Task { + SparseHybridFullyConnectedTask( + TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, const int thread_start, const int thread_end, + TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors, + TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, + TfLiteTensor* input_offsets, TfLiteTensor* output) + : context(context), + node(node), + params(params), + data(data), + input(input), + filter(filter), + bias(bias), + thread_start(thread_start), + thread_end(thread_end), + input_quantized(input_quantized), + scaling_factors(scaling_factors), + accum_scratch(accum_scratch), + row_sums(row_sums), + input_offsets(input_offsets), + output(output) {} + + void Run() override { + EvalSparseHybridImpl(context, node, params, data, input, filter, bias, + thread_start, thread_end, input_quantized, + scaling_factors, accum_scratch, row_sums, + input_offsets, output); + } + + private: + TfLiteContext* context; + TfLiteNode* node; + TfLiteFullyConnectedParams* params; + OpData* data; + const TfLiteTensor* input; + const TfLiteTensor* filter; + const TfLiteTensor* bias; + const int thread_start; + const int thread_end; + TfLiteTensor* input_quantized; + TfLiteTensor* scaling_factors; + TfLiteTensor* accum_scratch; + TfLiteTensor* row_sums; + TfLiteTensor* input_offsets; + TfLiteTensor* output; +}; + +TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* input_quantized, + TfLiteTensor* scaling_factors, + TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, + TfLiteTensor* input_offsets, TfLiteTensor* output) { + const auto& output_shape = GetTensorShape(output); + CpuBackendContext* cpu_backend_context = + CpuBackendContext::GetFromContext(context); + const bool is_dense = filter->sparsity == nullptr; + if (is_dense) { + return EvalHybridDense(context, node, params, data, input, filter, bias, + input_quantized, scaling_factors, accum_scratch, + row_sums, input_offsets, output); + } + + TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]]; + if (!data->ledger_initialized) { + PopulateLedgerData(filter->sparsity, context, + GetTensorData(filter_ledger)); + data->ledger_initialized = true; + } + + // The multi-threaded kernel slices the workload along the batch dimension. If + // there's not enough batches of data, the number of threads used is equal to + // the batch size. + // TODO(b/173442777): If needed, we can improve this later with slicing along + // the row dimension of the weight. + const int max_threads = cpu_backend_context->max_num_threads(); + const int batches = + FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1); + const int thread_count = std::max(1, std::min(batches, max_threads)); + + std::vector tasks; + tasks.reserve(thread_count); + int thread_start = 0; + for (int i = 0; i < thread_count; ++i) { + // This makes sure the workload is relatively balanced when batches is not + // a multiple of thread_count. The first mod(batches, thread_count) tasks + // need to process one more batch than the rest. + int thread_end = thread_start + batches / thread_count; + if (i < batches % thread_count) thread_end++; + + tasks.emplace_back(context, node, params, data, input, filter, bias, + thread_start, thread_end, input_quantized, + scaling_factors, accum_scratch, row_sums, input_offsets, + output); + thread_start = thread_end; + } + cpu_backend_threadpool::Execute(tasks.size(), tasks.data(), + cpu_backend_context); + return kTfLiteOk; +} + namespace { template void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index ba480308870..0adc80a6dab 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -296,7 +296,8 @@ class HybridFullyConnectedOpModel : public SingleOpModel { HybridFullyConnectedOpModel(int units, int batches, const TensorData& input, const TensorData& weights, const TensorData& output = {TensorType_FLOAT32}, - bool asymmetric_inputs = false) + bool asymmetric_inputs = false, + int num_threads = 1) : batches_(batches), units_(units) { int total_input_size = 1; for (size_t i = 0; i < input.shape.size(); ++i) { @@ -322,7 +323,9 @@ class HybridFullyConnectedOpModel : public SingleOpModel { resolver_ = absl::make_unique( BuiltinOperator_FULLY_CONNECTED, ops::builtin::Register_FULLY_CONNECTED_PIE()); - BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}, + num_threads, /*allow_fp32_relax_to_fp16=*/false, + /*apply_delegate=*/false); } void SetBias(const std::vector& f) { PopulateTensor(bias_, f); } void SetWeights(const std::vector& data) { @@ -879,6 +882,44 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) { /*max_abs_error=*/1.3f))); } +TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8MultiThreaded) { + for (int num_threads = 1; num_threads <= 4; ++num_threads) { + HybridFullyConnectedOpModel m( + /*units=*/3, /*batches=*/4, + /*input=*/{TensorType_FLOAT32, {4, 10}}, + /*weights=*/ + {TensorType_INT8, {3, 10}, 0, 0, 10.0 / 127.0, 0}, + /*output=*/{TensorType_FLOAT32}, /*asymmetric_inputs=*/false, + /*num_threads=*/num_threads); // Hybrid + + m.SetSignedWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 2 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 3 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 3)); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 24, 25, 26, // + 58, 59, 60, // + 24, 25, 26, // + 58, 59, 60, // + }, + /*max_abs_error=*/1.3f))); + } +} + TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedUint8) { HybridFullyConnectedOpModel m( /*units=*/3, /*batches=*/2, @@ -1413,6 +1454,76 @@ TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16Test) { ElementsAreArray(ArrayFloatNear( {0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0}, 1e-3))); } + +TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16TestMultiThreaded) { + std::initializer_list weight_data = { + /* 1st row */ + 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13, + 14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, + 10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, + /* 2nd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11, + -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 3rd row */ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, + -12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + /* 4th row */ + -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12, + -13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, + 8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0}; + TensorData weight = {}; + weight.type = TensorType_FLOAT32; + weight.shape = {4, 48}; + weight.traversal_order = {0, 1, 2}; + weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR}; + weight.block_map = {1}; + weight.block_size = {16}; + for (int num_threads = 1; num_threads <= 4; ++num_threads) { + SparseFullyConnectedOpModel m( + GetRegistration(), + /*units=*/4, /*batches=*/4, + /*input=*/{TensorType_FLOAT32, {4, 48}}, weight, weight_data, + /*num_threads)=*/num_threads, /*symmetric_quantize_weights=*/true); + m.SetBias({1, 2, 3, 4}); + m.SetInput({ + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 0 + 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, + -1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, + 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0, + -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, + 1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 1 + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, + 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 2 + 2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0, + -1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0, + 0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0, + -2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1, + 1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 3 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAre(4, 4)); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0, 0, + 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0}, + 1e-3))); + } +} // TODO(b/148391360): Add tests for unsupported sparsity format. // TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)