From 630a18b20f90b34c73768feed67176b48ec9c254 Mon Sep 17 00:00:00 2001 From: Yunlu Li Date: Wed, 17 Feb 2021 18:21:12 -0800 Subject: [PATCH] Fix data race in sparse FC kernel. PiperOrigin-RevId: 358077976 Change-Id: I25c53bfd443a90c9db1b34300d794bcc96961c7d --- tensorflow/lite/kernels/fully_connected.cc | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 2a5041c88a7..12448cc6145 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -504,11 +504,6 @@ void EvalHybridImpl(TfLiteContext* context, TfLiteNode* node, GetTensorData(accum_scratch) + thread_start * output_depth; if (is_sparse) { TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]]; - if (!data->ledger_initialized) { - PopulateLedgerData(filter->sparsity, context, - GetTensorData(filter_ledger)); - data->ledger_initialized = true; - } tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( GetTensorData(filter), GetTensorData(filter_ledger), output_depth, input_depth, quant_data, scaling_factors_ptr, batch_size, @@ -592,6 +587,16 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, const auto& output_shape = GetTensorShape(output); CpuBackendContext* cpu_backend_context = CpuBackendContext::GetFromContext(context); + const bool is_sparse = filter->sparsity != nullptr; + if (is_sparse) { + TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]]; + if (!data->ledger_initialized) { + PopulateLedgerData(filter->sparsity, context, + GetTensorData(filter_ledger)); + data->ledger_initialized = true; + } + } + const int max_threads = cpu_backend_context->max_num_threads(); const int batches = FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);