Make default sparse FullyConnected kernel take the standard CSR format.

PiperOrigin-RevId: 297967339
Change-Id: I65e86496367e125c46729f828818cd53b4f91945
This commit is contained in:
Yunlu Li 2020-02-28 17:01:14 -08:00 committed by TensorFlower Gardener
parent 0d787d7756
commit f94c8b283a
4 changed files with 34 additions and 19 deletions

View File

@ -77,6 +77,16 @@ TEST(DensifyOpTest, Float) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray(dense_values));
}
TEST(DensifyOpTest, Float3D) {
std::initializer_list<float> dense_values = {6, 0, 9, 8, 0, 0,
0, 0, 5, 0, 0, 7};
std::initializer_list<float> sparse_values = {6, 9, 8, 5, 7};
DensifyOpModel<float> m(TensorType_FLOAT32, {3, 2, 2}, dense_values);
m.Invoke();
EXPECT_THAT(m.GetInput(), ElementsAreArray(sparse_values));
EXPECT_THAT(m.GetOutput(), ElementsAreArray(dense_values));
}
TEST(DensifyOpTest, Int8) {
std::initializer_list<int8_t> dense_values = {6, 0, 9, 8, 0, 0,
0, 0, 5, 0, 0, 7};

View File

@ -42,7 +42,7 @@ namespace fully_connected {
namespace {
bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
if (sparsity.dim_metadata[0].format == kTfLiteDimSparseCSR &&
if (sparsity.dim_metadata[0].format == kTfLiteDimDense &&
sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
return true;
}

View File

@ -40,8 +40,7 @@ inline void FullyConnectedSparseWeight(
const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
output_shape, output_dims_count - 1);
const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
const int* w0_segments = sparsity.dim_metadata[0].array_segments->data;
const int* w0_indices = sparsity.dim_metadata[0].array_indices->data;
const int w0_size = sparsity.dim_metadata[0].dense_size;
const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
@ -50,9 +49,8 @@ inline void FullyConnectedSparseWeight(
}
for (int b = 0; b < batches; ++b) {
for (int pw0 = w0_segments[0]; pw0 < w0_segments[1]; ++pw0) {
int idx_0 = w0_indices[pw0];
for (int pw1 = w1_segments[pw0]; pw1 < w1_segments[pw0 + 1]; ++pw1) {
for (int idx_0 = 0; idx_0 < w0_size; ++idx_0) {
for (int pw1 = w1_segments[idx_0]; pw1 < w1_segments[idx_0 + 1]; ++pw1) {
int idx_1 = w1_indices[pw1];
output_data[b * output_depth + idx_0] +=
weights_data[pw1] * input_data[b * accum_depth + idx_1];

View File

@ -614,11 +614,12 @@ class SingleOpModel {
std::vector<int> traversal_order(dims_count);
std::vector<T> dense_data(data);
// Compress all dimensions and traverse them in the original order.
// Compress only the last dimension and traverse in the original order.
for (int i = 0; i < dims_count; i++) {
format[i] = kTfLiteDimSparseCSR;
format[i] = kTfLiteDimDense;
traversal_order[i] = i;
}
format[dims_count - 1] = kTfLiteDimSparseCSR;
tflite::optimize::sparsity::FormatConverter<T> converter(
shape, traversal_order, format);
@ -630,21 +631,27 @@ class SingleOpModel {
// Build sparsity parameter.
std::vector<flatbuffers::Offset<DimensionMetadata>> fb_dim_metadata(
dims_count);
for (int i = 0; i < dims_count; i++) {
for (int i = 0; i < dims_count - 1; i++) {
const int metadata_idx = 2 * i;
auto array_segments =
CreateInt32Vector(builder_,
builder_.CreateVector(dim_metadata[metadata_idx]))
.Union();
auto array_indices =
CreateInt32Vector(
builder_, builder_.CreateVector(dim_metadata[metadata_idx + 1]))
.Union();
fb_dim_metadata[i] = CreateDimensionMetadata(
builder_, DimensionType_SPARSE_CSR, 0, SparseIndexVector_Int32Vector,
array_segments, SparseIndexVector_Int32Vector, array_indices);
builder_, DimensionType_DENSE, dim_metadata[metadata_idx][0]);
}
// Parameters for the last compressed dimension.
const int compressed_metadata_idx = 2 * (dims_count - 1);
auto array_segments =
CreateInt32Vector(builder_, builder_.CreateVector(
dim_metadata[compressed_metadata_idx]))
.Union();
auto array_indices =
CreateInt32Vector(
builder_,
builder_.CreateVector(dim_metadata[compressed_metadata_idx + 1]))
.Union();
fb_dim_metadata[dims_count - 1] = CreateDimensionMetadata(
builder_, DimensionType_SPARSE_CSR, 0, SparseIndexVector_Int32Vector,
array_segments, SparseIndexVector_Int32Vector, array_indices);
flatbuffers::Offset<SparsityParameters> s_param = CreateSparsityParameters(
builder_, builder_.CreateVector(traversal_order), 0,
builder_.CreateVector(fb_dim_metadata));