Fix data race in sparse FC kernel.

PiperOrigin-RevId: 358077976
Change-Id: I25c53bfd443a90c9db1b34300d794bcc96961c7d
This commit is contained in:
Yunlu Li 2021-02-17 18:21:12 -08:00 committed by TensorFlower Gardener
parent 425283d7f4
commit 630a18b20f

View File

@ -504,11 +504,6 @@ void EvalHybridImpl(TfLiteContext* context, TfLiteNode* node,
GetTensorData<int32_t>(accum_scratch) + thread_start * output_depth;
if (is_sparse) {
TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
if (!data->ledger_initialized) {
PopulateLedgerData(filter->sparsity, context,
GetTensorData<uint8_t>(filter_ledger));
data->ledger_initialized = true;
}
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
output_depth, input_depth, quant_data, scaling_factors_ptr, batch_size,
@ -592,6 +587,16 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
const auto& output_shape = GetTensorShape(output);
CpuBackendContext* cpu_backend_context =
CpuBackendContext::GetFromContext(context);
const bool is_sparse = filter->sparsity != nullptr;
if (is_sparse) {
TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
if (!data->ledger_initialized) {
PopulateLedgerData(filter->sparsity, context,
GetTensorData<uint8_t>(filter_ledger));
data->ledger_initialized = true;
}
}
const int max_threads = cpu_backend_context->max_num_threads();
const int batches =
FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);