From 8a9c9b6af71303fb695bcf5ea216fa5bb8d96c50 Mon Sep 17 00:00:00 2001
From: Yunlu Li <yunluli@google.com>
Date: Wed, 5 Feb 2020 14:21:41 -0800
Subject: [PATCH] Add one example of sparse FullyConnected kernel.

PiperOrigin-RevId: 293450452
Change-Id: I4ad38bc8f871291a6db11b80b44660f5ae9e6421
---
 tensorflow/lite/kernels/fully_connected.cc    |  62 ++++++++++
 tensorflow/lite/kernels/fully_connected.h     |   2 +
 .../lite/kernels/fully_connected_test.cc      | 114 ++++++++++++++++++
 tensorflow/lite/kernels/internal/BUILD        |   2 +
 .../optimized/sparse_ops/fully_connected.h    |  75 ++++++++++++
 .../reference/sparse_ops/fully_connected.h    |  46 +++++++
 6 files changed, 301 insertions(+)
 create mode 100644 tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h
 create mode 100644 tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h

diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc
index 36dab796a28..fe39971f303 100644
--- a/tensorflow/lite/kernels/fully_connected.cc
+++ b/tensorflow/lite/kernels/fully_connected.cc
@@ -23,10 +23,12 @@ limitations under the License.
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h"
 #include "tensorflow/lite/kernels/internal/tensor.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
@@ -38,11 +40,24 @@ namespace ops {
 namespace builtin {
 namespace fully_connected {
 
+namespace {
+bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
+  if (sparsity.dim_metadata[0].format == kTfLiteDimSparseCSR &&
+      sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
+    return true;
+  }
+
+  return false;
+}
+}  // namespace
+
 // This file has four implementations of FullyConnected
 enum KernelType {
   kReference,
   kGenericOptimized,
   kLegacyPie,  // Legacy path used by the PIE team and related clients.
+  kSparseReference,
+  kSparseOptimized,
 };
 
 struct OpData {
@@ -574,11 +589,41 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
     FullyConnectedParams op_params;
     op_params.float_activation_min = output_activation_min;
     op_params.float_activation_max = output_activation_max;
+
     reference_ops::FullyConnected(
         op_params, GetTensorShape(input), GetTensorData<float>(input),
         GetTensorShape(filter), GetTensorData<float>(filter),
         GetTensorShape(bias), GetTensorData<float>(bias),
         GetTensorShape(output), GetTensorData<float>(output));
+  } else if (kernel_type == kSparseReference) {
+    FullyConnectedParams op_params;
+    op_params.float_activation_min = output_activation_min;
+    op_params.float_activation_max = output_activation_max;
+    TF_LITE_ENSURE(context, filter->sparsity != nullptr);
+
+    const auto& sparsity = *filter->sparsity;
+    reference_ops::FullyConnectedSparseWeight(
+        sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
+        GetTensorShape(filter), GetTensorData<float>(filter),
+        GetTensorShape(bias), GetTensorData<float>(bias),
+        GetTensorShape(output), GetTensorData<float>(output));
+  } else if (kernel_type == kSparseOptimized) {
+    FullyConnectedParams op_params;
+    op_params.float_activation_min = output_activation_min;
+    op_params.float_activation_max = output_activation_max;
+    TF_LITE_ENSURE(context, filter->sparsity != nullptr);
+
+    const auto& sparsity = *filter->sparsity;
+    if (!SupportedSparsityFormat(sparsity)) {
+      context->ReportError(context,
+                           "Unsupported sparse fully-connected weight format.");
+      return kTfLiteError;
+    }
+    optimized_ops::FullyConnectedSparseWeight(
+        sparsity, op_params, GetTensorShape(input), GetTensorData<float>(input),
+        GetTensorShape(filter), GetTensorData<float>(filter),
+        GetTensorShape(bias), GetTensorData<float>(bias),
+        GetTensorShape(output), GetTensorData<float>(output));
   } else if (kernel_type == kLegacyPie) {
     return EvalPie(context, node, params, data, input, filter, bias, output);
   } else {
@@ -653,6 +698,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 
 }  // namespace fully_connected
 
+// TODO(b/147449640): Clean up sparse registrations after conversion is done.
+TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF() {
+  static TfLiteRegistration r = {
+      fully_connected::Init, fully_connected::Free,
+      fully_connected::Prepare<fully_connected::kSparseReference>,
+      fully_connected::Eval<fully_connected::kSparseReference>};
+  return &r;
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT() {
+  static TfLiteRegistration r = {
+      fully_connected::Init, fully_connected::Free,
+      fully_connected::Prepare<fully_connected::kSparseOptimized>,
+      fully_connected::Eval<fully_connected::kSparseOptimized>};
+  return &r;
+}
+
 TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
   static TfLiteRegistration r = {
       fully_connected::Init, fully_connected::Free,
diff --git a/tensorflow/lite/kernels/fully_connected.h b/tensorflow/lite/kernels/fully_connected.h
index a9a6cc70a80..badc9e7a91c 100644
--- a/tensorflow/lite/kernels/fully_connected.h
+++ b/tensorflow/lite/kernels/fully_connected.h
@@ -29,6 +29,8 @@ namespace builtin {
 TfLiteRegistration* Register_FULLY_CONNECTED_REF();
 TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT();
 TfLiteRegistration* Register_FULLY_CONNECTED_PIE();
+TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_REF();
+TfLiteRegistration* Register_FULLY_CONNECTED_SPARSE_OPT();
 }  // namespace builtin
 }  // namespace ops
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc
index a4b49c59efb..21273cd7a3a 100644
--- a/tensorflow/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/lite/kernels/fully_connected_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
 
 #include "tensorflow/lite/kernels/fully_connected.h"
 
+#include <initializer_list>
 #include <iomanip>
 #include <random>
 #include <vector>
@@ -356,6 +357,11 @@ const auto kKernelMapNoPie = new std::map<string, TfLiteRegistration*>({
     {"GenericOptimized", ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT()},
 });
 
+const auto kKernelMapSparse = new std::map<string, TfLiteRegistration*>({
+    {"SparseReference", ops::builtin::Register_FULLY_CONNECTED_SPARSE_REF()},
+    {"SparseOptimized", ops::builtin::Register_FULLY_CONNECTED_SPARSE_OPT()},
+});
+
 class QuantizedFullyConnectedOpTest : public SingleOpTest {
  protected:
   const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
@@ -1061,5 +1067,113 @@ TEST_P(FloatFullyConnectedOpTest, BlackBoxTest) {
   }
 }
 
+template <typename T>
+class SparseFullyConnectedOpModel : public SingleOpModel {
+ public:
+  SparseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
+                              int batches, const TensorData& input,
+                              std::initializer_list<int> weights_shape,
+                              std::initializer_list<T> weights_data)
+      : batches_(batches), units_(units) {
+    int total_input_size = 1;
+    for (size_t i = 0; i < input.shape.size(); ++i) {
+      total_input_size *= input.shape[i];
+    }
+    input_size_ = total_input_size / batches_;
+
+    input_ = AddInput(input);
+    weights_ = AddConstSparseInput(input.type, weights_shape, weights_data);
+
+    TensorData bias{input.type, {units_}};
+    bias_ = AddInput(bias);
+
+    output_ = AddOutput({input.type});
+
+    SetBuiltinOp(
+        BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
+        CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
+            .Union());
+    resolver_ = absl::make_unique<SingleOpResolver>(
+        BuiltinOperator_FULLY_CONNECTED, registration);
+    BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
+  }
+  void SetBias(const std::vector<T>& data) { PopulateTensor(bias_, data); }
+  void SetInput(const std::vector<T>& data) { PopulateTensor(input_, data); }
+  std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+  int input_size() { return input_size_; }
+  int num_units() { return units_; }
+  int num_batches() { return batches_; }
+
+ protected:
+  int input_;
+  int weights_;
+  int bias_;
+  int output_;
+
+  int batches_;
+  int units_;
+  int input_size_;
+};
+
+class SparseFullyConnectedOpTest : public SingleOpTest {
+ protected:
+  const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+    return *kKernelMapSparse;
+  }
+};
+
+TEST_P(SparseFullyConnectedOpTest, SimpleTest) {
+  std::initializer_list<int> weight_shape = {3, 10};
+  std::initializer_list<float> weight_data = {
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
+      1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
+  };
+  SparseFullyConnectedOpModel<float> m(
+      GetRegistration(), /*units=*/3, /*batches=*/2,
+      /*input=*/{TensorType_FLOAT32, {2, 10}}, weight_shape, weight_data);
+  m.SetBias({1, 2, 3});
+
+  m.SetInput({
+      1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
+      1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
+  });
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 3));
+  EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
+}
+
+TEST_P(SparseFullyConnectedOpTest, SimpleTest2) {
+  std::initializer_list<int> weight_shape = {1, 2};
+  std::initializer_list<float> weight_data = {
+      2, 4  // u = 0
+  };
+  SparseFullyConnectedOpModel<float> m(
+      GetRegistration(), /*units=*/1, /*batches=*/2,
+      /*input=*/{TensorType_FLOAT32, {2, 2}}, weight_shape, weight_data);
+  m.SetBias({1});
+
+  m.SetInput({
+      1, 2,  // b = 0
+      2, 1   // b = 1
+  });
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1));
+  EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
+}
+
+// TODO(b/148391360): Add tests for unsupported sparsity format.
+// TEST_P(SparseFullyConnectedOpTest, TestUnsupportedSparsityFormat)
+
+INSTANTIATE_TEST_SUITE_P(
+    SparseFullyConnectedOpTest, SparseFullyConnectedOpTest,
+    ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMapSparse)));
+
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 0c063ed6338..03bdfd1cf36 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -232,6 +232,7 @@ cc_library(
         "optimized/integer_ops/softmax.h",
         "optimized/integer_ops/transpose_conv.h",
         "optimized/optimized_ops.h",
+        "optimized/sparse_ops/fully_connected.h",
     ],
     copts = tflite_copts(),
     deps = [
@@ -456,6 +457,7 @@ cc_library(
         "reference/requantize.h",
         "reference/round.h",
         "reference/softmax.h",
+        "reference/sparse_ops/fully_connected.h",
         "reference/strided_slice.h",
         "reference/svdf.h",
     ],
diff --git a/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h
new file mode 100644
index 00000000000..9a05b61c006
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h
@@ -0,0 +1,75 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/lite/kernels/internal/round.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+inline void FullyConnectedSparseWeight(
+    const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
+    const RuntimeShape& input_shape, const float* input_data,
+    const RuntimeShape& weights_shape, const float* weights_data,
+    const RuntimeShape& bias_shape, const float* bias_data,
+    const RuntimeShape& output_shape, float* output_data) {
+  const float output_activation_min = params.float_activation_min;
+  const float output_activation_max = params.float_activation_max;
+
+  const int output_elements = output_shape.FlatSize();
+  const int output_dims_count = output_shape.DimensionsCount();
+  const int weights_dims_count = weights_shape.DimensionsCount();
+  const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
+  const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
+                                       output_shape, output_dims_count - 1);
+  const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
+  const int* w0_segments = sparsity.dim_metadata[0].array_segments->data;
+  const int* w0_indices = sparsity.dim_metadata[0].array_indices->data;
+  const int* w1_segments = sparsity.dim_metadata[1].array_segments->data;
+  const int* w1_indices = sparsity.dim_metadata[1].array_indices->data;
+
+  for (int i = 0; i < output_elements; ++i) {
+    output_data[i] = 0.f;
+  }
+
+  for (int b = 0; b < batches; ++b) {
+    for (int pw0 = w0_segments[0]; pw0 < w0_segments[1]; ++pw0) {
+      int idx_0 = w0_indices[pw0];
+      for (int pw1 = w1_segments[pw0]; pw1 < w1_segments[pw0 + 1]; ++pw1) {
+        int idx_1 = w1_indices[pw1];
+        output_data[b * output_depth + idx_0] +=
+            weights_data[pw1] * input_data[b * accum_depth + idx_1];
+      }
+    }
+  }
+
+  for (int b = 0; b < batches; ++b) {
+    for (int i = 0; i < output_depth; ++i) {
+      float total = output_data[b * output_depth + i];
+      float bias_value = bias_data[i];
+      output_data[b * output_depth + i] = ActivationFunctionWithMinMax(
+          total + bias_value, output_activation_min, output_activation_max);
+    }
+  }
+}
+
+}  // namespace optimized_ops
+}  // namespace tflite
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_SPARSE_OPS_FULLY_CONNECTED_H_
diff --git a/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h b/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h
new file mode 100644
index 00000000000..0f8a248d61c
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h
@@ -0,0 +1,46 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_
+
+#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
+
+namespace tflite {
+namespace reference_ops {
+
+// Convert weights to dense format and run dense fully connected.
+inline void FullyConnectedSparseWeight(
+    const TfLiteSparsity& sparsity, const FullyConnectedParams& params,
+    const RuntimeShape& input_shape, const float* input_data,
+    const RuntimeShape& weights_shape, const float* weights_data,
+    const RuntimeShape& bias_shape, const float* bias_data,
+    const RuntimeShape& output_shape, float* output_data) {
+  std::vector<int> weights_shape_vector(weights_shape.DimensionsCount());
+  for (int i = 0; i < weights_shape.DimensionsCount(); i++) {
+    weights_shape_vector[i] = weights_shape.Dims(i);
+  }
+  tflite::optimize::sparsity::FormatConverter<float> converter(
+      weights_shape_vector, sparsity);
+  converter.SparseToDense(weights_data);
+  const std::vector<float> dense_weights_data = converter.GetData();
+  FullyConnected(params, input_shape, input_data, weights_shape,
+                 dense_weights_data.data(), bias_shape, bias_data, output_shape,
+                 output_data);
+}
+
+}  // namespace reference_ops
+}  // namespace tflite
+#endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPARSE_OPS_FULLY_CONNECTED_H_