Add sparse hybrid FullyConnected kernel.

PiperOrigin-RevId: 351611259
Change-Id: I0bb1bfbb04b8279baf81fb872df93fcfe2ef8cd6
This commit is contained in:
Yunlu Li 2021-01-13 10:10:18 -08:00 committed by TensorFlower Gardener
parent caae51a72c
commit d4d06d09d1
2 changed files with 143 additions and 11 deletions

View File

@ -53,6 +53,47 @@ bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
static const int kDimMetadataSizeRandomSparse = 2;
static const int kDimMetadataSizeBlockSparse = 3;
TfLiteStatus CreateLedgerTensor(const TfLiteSparsity* sparsity,
TfLiteContext* context, TfLiteTensor* ledger) {
TF_LITE_ENSURE(context, sparsity != nullptr);
ledger->type = kTfLiteUInt8;
ledger->allocation_type = kTfLiteArenaRwPersistent;
TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
sparsity->dim_metadata[1].array_segments->size - 1;
return context->ResizeTensor(context, ledger, ledger_size);
}
TfLiteStatus PopulateLedgerData(const TfLiteSparsity* sparsity,
TfLiteContext* context, uint8_t* ledger_data) {
TF_LITE_ENSURE(context, sparsity != nullptr);
const auto* array_segments = sparsity->dim_metadata[1].array_segments;
const auto* array_indices = sparsity->dim_metadata[1].array_indices;
int output_data_ptr = 0;
for (int i = 0; i < array_segments->size - 1; i++) {
int row_start = array_segments->data[i];
int row_end = array_segments->data[i + 1];
if (row_end - row_start > UINT8_MAX) {
return kTfLiteError;
}
// Copy num of non-zero blocks in row i.
ledger_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
output_data_ptr++;
for (int j = row_start; j < row_end; j++) {
if (array_indices->data[j] > UINT8_MAX) {
return kTfLiteError;
}
// Copy indices of non-zero blocks in row i.
ledger_data[output_data_ptr] =
static_cast<uint8_t>(array_indices->data[j]);
output_data_ptr++;
}
}
return kTfLiteOk;
}
} // namespace
// This file has four implementations of FullyConnected
@ -74,6 +115,8 @@ struct OpData {
// The index of the temporary tensor where the quantized inputs are cached.
int scratch_tensor_index;
bool compute_row_sums = false;
// Only used for sparse hybrid fully connected kernels.
bool ledger_initialized;
};
constexpr int kInputTensor = 0;
@ -134,7 +177,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
auto* op_data = new OpData();
context->AddTensors(context, /*tensors_to_add=*/5,
context->AddTensors(context, /*tensors_to_add=*/6,
&op_data->scratch_tensor_index);
return op_data;
}
@ -212,11 +255,18 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
// buffer to store the intermediate quantized values.
// Additionally, we allocate a temporary buffer to store the accumulated
// quantized values prior to multiplication by the scaling factor.
if (input->type == kTfLiteFloat32 &&
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
const bool is_hybrid =
(input->type == kTfLiteFloat32 &&
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
const bool is_sparse = filter->sparsity != nullptr;
if (is_hybrid) {
TfLiteIntArrayFree(node->temporaries);
data->compute_row_sums = true;
node->temporaries = TfLiteIntArrayCreate(5);
if (is_sparse) {
node->temporaries = TfLiteIntArrayCreate(6);
} else {
node->temporaries = TfLiteIntArrayCreate(5);
}
node->temporaries->data[0] = data->scratch_tensor_index;
TfLiteTensor* input_quantized;
@ -285,6 +335,16 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, row_sums, row_sums_size));
}
if (is_sparse) {
data->ledger_initialized = false;
node->temporaries->data[5] = data->scratch_tensor_index + 5;
TfLiteTensor* filter_ledger =
&context->tensors[node->temporaries->data[5]];
auto status =
CreateLedgerTensor(filter->sparsity, context, filter_ledger);
if (status != kTfLiteOk) return status;
}
}
// Resize output.
@ -386,6 +446,7 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
const int input_size = filter->dims->data[1];
const int batch_size = total_input_size / filter->dims->data[1];
const int num_units = filter->dims->data[0];
const bool is_sparse = filter->sparsity != nullptr;
// Output = bias if bias tensor exists.
if (bias) {
@ -426,11 +487,24 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
// Compute output += weight * quantized_input
int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
if (is_sparse) {
TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
if (!data->ledger_initialized) {
PopulateLedgerData(filter->sparsity, context,
GetTensorData<uint8_t>(filter_ledger));
data->ledger_initialized = true;
}
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
num_units, input_size, quant_data, scaling_factors_ptr, batch_size,
GetTensorData<float>(output));
} else {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
CpuBackendContext::GetFromContext(context));
}
// Apply activation function to floats.
tensor_utils::ApplyActivationToVector(

View File

@ -1145,7 +1145,8 @@ class SparseFullyConnectedOpModel : public SingleOpModel {
int batches, const TensorData& input,
const TensorData& weights,
const std::vector<T>& weights_data,
int num_threads = 1)
int num_threads = 1,
bool symmetric_quantize_weights = false)
: batches_(batches), units_(units) {
int total_input_size = 1;
for (size_t i = 0; i < input.shape.size(); ++i) {
@ -1154,7 +1155,8 @@ class SparseFullyConnectedOpModel : public SingleOpModel {
input_size_ = total_input_size / batches_;
input_ = AddInput(input);
weights_ = AddConstSparseInput(weights, weights_data);
weights_ =
AddConstSparseInput(weights, weights_data, symmetric_quantize_weights);
TensorData bias{input.type, {units_}};
bias_ = AddInput(bias);
@ -1355,6 +1357,62 @@ TEST_P(SparseFullyConnectedOpTest, Simple1x4TestMultiThreadedMoreBatches) {
));
}
}
TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16Test) {
std::initializer_list<float> weight_data = {
/* 1st row */
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
/* 2nd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
-12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 3rd row */
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
-12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
/* 4th row */
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
TensorData weight = {};
weight.type = TensorType_FLOAT32;
weight.shape = {4, 48};
weight.traversal_order = {0, 1, 2};
weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
weight.block_map = {1};
weight.block_size = {16};
SparseFullyConnectedOpModel<float> m(
GetRegistration(),
/*units=*/4, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 48}}, weight, weight_data,
/*num_threads)=*/1, /*symmetric_quantize_weights=*/true);
m.SetBias({1, 2, 3, 4});
m.SetInput({
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 0
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0,
-1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0,
0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0,
-2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 1
});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 4));
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0}, 1e-3)));
}
// TODO(b/148391360): Add tests for unsupported sparsity format.
// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)