Make default sparse FullyConnected kernel take the standard CSR format.
PiperOrigin-RevId: 297967339 Change-Id: I65e86496367e125c46729f828818cd53b4f91945
This commit is contained in:
parent
0d787d7756
commit
f94c8b283a
@ -77,6 +77,16 @@ TEST(DensifyOpTest, Float) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(dense_values));
|
||||
}
|
||||
|
||||
TEST(DensifyOpTest, Float3D) {
|
||||
std::initializer_list<float> dense_values = {6, 0, 9, 8, 0, 0,
|
||||
0, 0, 5, 0, 0, 7};
|
||||
std::initializer_list<float> sparse_values = {6, 9, 8, 5, 7};
|
||||
DensifyOpModel<float> m(TensorType_FLOAT32, {3, 2, 2}, dense_values);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetInput(), ElementsAreArray(sparse_values));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(dense_values));
|
||||
}
|
||||
|
||||
TEST(DensifyOpTest, Int8) {
|
||||
std::initializer_list<int8_t> dense_values = {6, 0, 9, 8, 0, 0,
|
||||
0, 0, 5, 0, 0, 7};
|
||||
|
@ -42,7 +42,7 @@ namespace fully_connected {
|
||||
|
||||
namespace {
|
||||
bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
|
||||
if (sparsity.dim_metadata[0].format == kTfLiteDimSparseCSR &&
|
||||
if (sparsity.dim_metadata[0].format == kTfLiteDimDense &&
|
||||
sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
|
||||
return true;
|
||||
}
|
||||
|
@ -40,8 +40,7 @@ inline void FullyConnectedSparseWeight(
|
||||
const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
|
||||
output_shape, output_dims_count - 1);
|
||||
const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
|
||||
const int* w0_segments = sparsity.dim_metadata[0].array_segments->data;
|
||||
const int* w0_indices = sparsity.dim_metadata[0].array_indices->data;
|
||||
const int w0_size = sparsity.dim_metadata[0].dense_size;
|
||||
const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
|
||||
const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
|
||||
|
||||
@ -50,9 +49,8 @@ inline void FullyConnectedSparseWeight(
|
||||
}
|
||||
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
for (int pw0 = w0_segments[0]; pw0 < w0_segments[1]; ++pw0) {
|
||||
int idx_0 = w0_indices[pw0];
|
||||
for (int pw1 = w1_segments[pw0]; pw1 < w1_segments[pw0 + 1]; ++pw1) {
|
||||
for (int idx_0 = 0; idx_0 < w0_size; ++idx_0) {
|
||||
for (int pw1 = w1_segments[idx_0]; pw1 < w1_segments[idx_0 + 1]; ++pw1) {
|
||||
int idx_1 = w1_indices[pw1];
|
||||
output_data[b * output_depth + idx_0] +=
|
||||
weights_data[pw1] * input_data[b * accum_depth + idx_1];
|
||||
|
@ -614,11 +614,12 @@ class SingleOpModel {
|
||||
std::vector<int> traversal_order(dims_count);
|
||||
std::vector<T> dense_data(data);
|
||||
|
||||
// Compress all dimensions and traverse them in the original order.
|
||||
// Compress only the last dimension and traverse in the original order.
|
||||
for (int i = 0; i < dims_count; i++) {
|
||||
format[i] = kTfLiteDimSparseCSR;
|
||||
format[i] = kTfLiteDimDense;
|
||||
traversal_order[i] = i;
|
||||
}
|
||||
format[dims_count - 1] = kTfLiteDimSparseCSR;
|
||||
|
||||
tflite::optimize::sparsity::FormatConverter<T> converter(
|
||||
shape, traversal_order, format);
|
||||
@ -630,21 +631,27 @@ class SingleOpModel {
|
||||
// Build sparsity parameter.
|
||||
std::vector<flatbuffers::Offset<DimensionMetadata>> fb_dim_metadata(
|
||||
dims_count);
|
||||
for (int i = 0; i < dims_count; i++) {
|
||||
for (int i = 0; i < dims_count - 1; i++) {
|
||||
const int metadata_idx = 2 * i;
|
||||
auto array_segments =
|
||||
CreateInt32Vector(builder_,
|
||||
builder_.CreateVector(dim_metadata[metadata_idx]))
|
||||
.Union();
|
||||
auto array_indices =
|
||||
CreateInt32Vector(
|
||||
builder_, builder_.CreateVector(dim_metadata[metadata_idx + 1]))
|
||||
.Union();
|
||||
fb_dim_metadata[i] = CreateDimensionMetadata(
|
||||
builder_, DimensionType_SPARSE_CSR, 0, SparseIndexVector_Int32Vector,
|
||||
array_segments, SparseIndexVector_Int32Vector, array_indices);
|
||||
builder_, DimensionType_DENSE, dim_metadata[metadata_idx][0]);
|
||||
}
|
||||
|
||||
// Parameters for the last compressed dimension.
|
||||
const int compressed_metadata_idx = 2 * (dims_count - 1);
|
||||
auto array_segments =
|
||||
CreateInt32Vector(builder_, builder_.CreateVector(
|
||||
dim_metadata[compressed_metadata_idx]))
|
||||
.Union();
|
||||
auto array_indices =
|
||||
CreateInt32Vector(
|
||||
builder_,
|
||||
builder_.CreateVector(dim_metadata[compressed_metadata_idx + 1]))
|
||||
.Union();
|
||||
fb_dim_metadata[dims_count - 1] = CreateDimensionMetadata(
|
||||
builder_, DimensionType_SPARSE_CSR, 0, SparseIndexVector_Int32Vector,
|
||||
array_segments, SparseIndexVector_Int32Vector, array_indices);
|
||||
|
||||
flatbuffers::Offset<SparsityParameters> s_param = CreateSparsityParameters(
|
||||
builder_, builder_.CreateVector(traversal_order), 0,
|
||||
builder_.CreateVector(fb_dim_metadata));
|
||||
|
Loading…
Reference in New Issue
Block a user