Add sparse hybrid FullyConnected kernel.
PiperOrigin-RevId: 351611259 Change-Id: I0bb1bfbb04b8279baf81fb872df93fcfe2ef8cd6
This commit is contained in:
parent
caae51a72c
commit
d4d06d09d1
@ -53,6 +53,47 @@ bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
|
||||
static const int kDimMetadataSizeRandomSparse = 2;
|
||||
static const int kDimMetadataSizeBlockSparse = 3;
|
||||
|
||||
TfLiteStatus CreateLedgerTensor(const TfLiteSparsity* sparsity,
|
||||
TfLiteContext* context, TfLiteTensor* ledger) {
|
||||
TF_LITE_ENSURE(context, sparsity != nullptr);
|
||||
ledger->type = kTfLiteUInt8;
|
||||
ledger->allocation_type = kTfLiteArenaRwPersistent;
|
||||
TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
|
||||
ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
|
||||
sparsity->dim_metadata[1].array_segments->size - 1;
|
||||
return context->ResizeTensor(context, ledger, ledger_size);
|
||||
}
|
||||
|
||||
TfLiteStatus PopulateLedgerData(const TfLiteSparsity* sparsity,
|
||||
TfLiteContext* context, uint8_t* ledger_data) {
|
||||
TF_LITE_ENSURE(context, sparsity != nullptr);
|
||||
const auto* array_segments = sparsity->dim_metadata[1].array_segments;
|
||||
const auto* array_indices = sparsity->dim_metadata[1].array_indices;
|
||||
int output_data_ptr = 0;
|
||||
|
||||
for (int i = 0; i < array_segments->size - 1; i++) {
|
||||
int row_start = array_segments->data[i];
|
||||
int row_end = array_segments->data[i + 1];
|
||||
if (row_end - row_start > UINT8_MAX) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
// Copy num of non-zero blocks in row i.
|
||||
ledger_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
|
||||
output_data_ptr++;
|
||||
|
||||
for (int j = row_start; j < row_end; j++) {
|
||||
if (array_indices->data[j] > UINT8_MAX) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
// Copy indices of non-zero blocks in row i.
|
||||
ledger_data[output_data_ptr] =
|
||||
static_cast<uint8_t>(array_indices->data[j]);
|
||||
output_data_ptr++;
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// This file has four implementations of FullyConnected
|
||||
@ -74,6 +115,8 @@ struct OpData {
|
||||
// The index of the temporary tensor where the quantized inputs are cached.
|
||||
int scratch_tensor_index;
|
||||
bool compute_row_sums = false;
|
||||
// Only used for sparse hybrid fully connected kernels.
|
||||
bool ledger_initialized;
|
||||
};
|
||||
|
||||
constexpr int kInputTensor = 0;
|
||||
@ -134,7 +177,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
// Instead, we allocate a new object to carry information from Prepare() to
|
||||
// Eval().
|
||||
auto* op_data = new OpData();
|
||||
context->AddTensors(context, /*tensors_to_add=*/5,
|
||||
context->AddTensors(context, /*tensors_to_add=*/6,
|
||||
&op_data->scratch_tensor_index);
|
||||
return op_data;
|
||||
}
|
||||
@ -212,11 +255,18 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
// buffer to store the intermediate quantized values.
|
||||
// Additionally, we allocate a temporary buffer to store the accumulated
|
||||
// quantized values prior to multiplication by the scaling factor.
|
||||
if (input->type == kTfLiteFloat32 &&
|
||||
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8)) {
|
||||
const bool is_hybrid =
|
||||
(input->type == kTfLiteFloat32 &&
|
||||
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
|
||||
const bool is_sparse = filter->sparsity != nullptr;
|
||||
if (is_hybrid) {
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
data->compute_row_sums = true;
|
||||
node->temporaries = TfLiteIntArrayCreate(5);
|
||||
if (is_sparse) {
|
||||
node->temporaries = TfLiteIntArrayCreate(6);
|
||||
} else {
|
||||
node->temporaries = TfLiteIntArrayCreate(5);
|
||||
}
|
||||
node->temporaries->data[0] = data->scratch_tensor_index;
|
||||
|
||||
TfLiteTensor* input_quantized;
|
||||
@ -285,6 +335,16 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(context, row_sums, row_sums_size));
|
||||
}
|
||||
|
||||
if (is_sparse) {
|
||||
data->ledger_initialized = false;
|
||||
node->temporaries->data[5] = data->scratch_tensor_index + 5;
|
||||
TfLiteTensor* filter_ledger =
|
||||
&context->tensors[node->temporaries->data[5]];
|
||||
auto status =
|
||||
CreateLedgerTensor(filter->sparsity, context, filter_ledger);
|
||||
if (status != kTfLiteOk) return status;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize output.
|
||||
@ -386,6 +446,7 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
const int input_size = filter->dims->data[1];
|
||||
const int batch_size = total_input_size / filter->dims->data[1];
|
||||
const int num_units = filter->dims->data[0];
|
||||
const bool is_sparse = filter->sparsity != nullptr;
|
||||
|
||||
// Output = bias if bias tensor exists.
|
||||
if (bias) {
|
||||
@ -426,11 +487,24 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
// Compute output += weight * quantized_input
|
||||
int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
||||
batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
|
||||
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
if (is_sparse) {
|
||||
TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
|
||||
if (!data->ledger_initialized) {
|
||||
PopulateLedgerData(filter->sparsity, context,
|
||||
GetTensorData<uint8_t>(filter_ledger));
|
||||
data->ledger_initialized = true;
|
||||
}
|
||||
tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
|
||||
GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
|
||||
num_units, input_size, quant_data, scaling_factors_ptr, batch_size,
|
||||
GetTensorData<float>(output));
|
||||
} else {
|
||||
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
|
||||
filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
|
||||
batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
|
||||
input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
}
|
||||
|
||||
// Apply activation function to floats.
|
||||
tensor_utils::ApplyActivationToVector(
|
||||
|
@ -1145,7 +1145,8 @@ class SparseFullyConnectedOpModel : public SingleOpModel {
|
||||
int batches, const TensorData& input,
|
||||
const TensorData& weights,
|
||||
const std::vector<T>& weights_data,
|
||||
int num_threads = 1)
|
||||
int num_threads = 1,
|
||||
bool symmetric_quantize_weights = false)
|
||||
: batches_(batches), units_(units) {
|
||||
int total_input_size = 1;
|
||||
for (size_t i = 0; i < input.shape.size(); ++i) {
|
||||
@ -1154,7 +1155,8 @@ class SparseFullyConnectedOpModel : public SingleOpModel {
|
||||
input_size_ = total_input_size / batches_;
|
||||
|
||||
input_ = AddInput(input);
|
||||
weights_ = AddConstSparseInput(weights, weights_data);
|
||||
weights_ =
|
||||
AddConstSparseInput(weights, weights_data, symmetric_quantize_weights);
|
||||
|
||||
TensorData bias{input.type, {units_}};
|
||||
bias_ = AddInput(bias);
|
||||
@ -1355,6 +1357,62 @@ TEST_P(SparseFullyConnectedOpTest, Simple1x4TestMultiThreadedMoreBatches) {
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(SparseFullyConnectedOpTest, SparseHybrid1x16Test) {
|
||||
std::initializer_list<float> weight_data = {
|
||||
/* 1st row */
|
||||
1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
|
||||
14.14, 15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9,
|
||||
10.1, 11.11, 12.12, 13.13, 14.14, 15.15, 16.16,
|
||||
/* 2nd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
|
||||
-12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 3rd row */
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11,
|
||||
-12.12, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
/* 4th row */
|
||||
-1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
|
||||
-13.13, 14.14, -15.15, 16.16, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7,
|
||||
8.8, -9.9, 10.1, -11.11, 12.12, 0.0, 0.0, 0.0, 0.0};
|
||||
TensorData weight = {};
|
||||
weight.type = TensorType_FLOAT32;
|
||||
weight.shape = {4, 48};
|
||||
weight.traversal_order = {0, 1, 2};
|
||||
weight.format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
|
||||
weight.block_map = {1};
|
||||
weight.block_size = {16};
|
||||
SparseFullyConnectedOpModel<float> m(
|
||||
GetRegistration(),
|
||||
/*units=*/4, /*batches=*/2,
|
||||
/*input=*/{TensorType_FLOAT32, {2, 48}}, weight, weight_data,
|
||||
/*num_threads)=*/1, /*symmetric_quantize_weights=*/true);
|
||||
m.SetBias({1, 2, 3, 4});
|
||||
m.SetInput({
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0,
|
||||
1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, // b = 0
|
||||
2.5, 0.0, -2.1, 0.0, 3.0, 0.0, -1.3, 0.0, 1.3, 0.0,
|
||||
-1.1, 0.0, 2.0, 0.0, -1.7, 0.0, 1.9, 0.0, -1.5, 0.0,
|
||||
0.5, 0.0, -0.7, 0.0, 0.8, 0.0, -0.3, 0.0, 2.8, 0.0,
|
||||
-2.8, 0.0, 1.1, -2.3, 1.9, -1.9, 2.1, -0.5, 2.4, -0.1,
|
||||
1.0, -2.5, 0.7, -1.9, 0.2, 0.1, 0.2, 0.3, // b = 1
|
||||
});
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 4));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{0, 7.4715, 85.8359, 0, 5.9655, 3.0520, 1.9480, 0}, 1e-3)));
|
||||
}
|
||||
// TODO(b/148391360): Add tests for unsupported sparsity format.
|
||||
// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user