From 3b3d32038525798fdb9c0ac9ef55e64cbd4dcb84 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <benoitjacob@google.com>
Date: Thu, 30 Apr 2020 07:17:56 -0700
Subject: [PATCH] Use the new ruy API for caching constant matrices.

PiperOrigin-RevId: 309220053
Change-Id: I0d1f0040cd84aade4c9f9e84c240909865e818bc
---
 tensorflow/lite/kernels/BUILD                 |  2 ++
 .../lite/kernels/cpu_backend_context.cc       |  3 ---
 .../lite/kernels/cpu_backend_gemm_params.h    | 22 +++++++++++++++----
 .../lite/kernels/cpu_backend_gemm_ruy.h       | 17 +++++++++++++-
 tensorflow/lite/kernels/fully_connected.cc    |  1 +
 .../internal/optimized/neon_tensor_utils.cc   |  2 +-
 .../internal/optimized/optimized_ops.h        | 18 ++++++++++-----
 7 files changed, 50 insertions(+), 15 deletions(-)

diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index ae49d798fa8..5b6fe4b5b21 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -338,6 +338,7 @@ cc_library(
         # Depend on ruy regardless of `tflite_with_ruy`. See the comment in
         # cpu_backend_gemm.h about why ruy is the generic path.
         "@ruy//ruy",
+        "@ruy//ruy:matrix",
         "@ruy//ruy:path",
         "@ruy//ruy/profiler:instrumentation",
         # We only need to depend on gemmlowp and Eigen when tflite_with_ruy
@@ -525,6 +526,7 @@ cc_library(
     visibility = ["//visibility:private"],
     deps = [
         ":cpu_backend_context",
+        ":cpu_backend_gemm",
         ":cpu_backend_threadpool",
         ":eigen_support",
         ":kernel_util",
diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc
index 0fa4175973a..d6de9bf8d61 100644
--- a/tensorflow/lite/kernels/cpu_backend_context.cc
+++ b/tensorflow/lite/kernels/cpu_backend_context.cc
@@ -55,9 +55,6 @@ CpuBackendContext::CpuBackendContext()
       ruy_context_(new ruy::Context),
       gemmlowp_context_(new gemmlowp::GemmContext) {
   SetMaxNumThreads(kDefaultNumThreadpoolThreads);
-#ifdef TFLITE_WITH_RUY_GEMV
-  ruy_context_->set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul);
-#endif
 }
 
 CpuBackendContext::~CpuBackendContext() {}
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_params.h b/tensorflow/lite/kernels/cpu_backend_gemm_params.h
index 66700ea9cdf..0040f40cd50 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_params.h
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_params.h
@@ -29,6 +29,17 @@ namespace cpu_backend_gemm {
 // Matrix storage order: column-major or row-major.
 enum class Order { kColMajor, kRowMajor };
 
+enum class CachePolicy : std::uint8_t {
+  kNeverCache,
+  kCacheIfLargeSpeedup,
+  kAlwaysCache,
+};
+
+inline CachePolicy DefaultCachePolicy(bool is_constant_data) {
+  return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup
+                          : CachePolicy::kNeverCache;
+}
+
 // MatrixParams encapsulates the parameters that Gemm needs about each
 // matrix, besides the buffer data pointer.
 // Compare to ruy::Matrix, which also encapsulates the data pointer.
@@ -47,10 +58,13 @@ struct MatrixParams {
   // The zero_point, i.e. which Scalar value is to be interpreted as zero.
   // When Scalar is floating-point, this must be 0.
   Scalar zero_point = 0;
-  // Indicate whether the underlying data will remain unchanged for
-  // some period of time. Defaults to false, but should be set to true
-  // for unchanging data (e.g. weights buffers in many cases)
-  bool cacheable = false;
+  // When the data pointed to by this matrix is constant data, so that it is
+  // valid to assume that equality of pointers implies equality of data,
+  // a CachePolicy may be used instead of the default kNeverCache,
+  // which will enable ruy to take advantage of this constancy of the data to
+  // cache the packing work, which can be a large speedup in matrix*vector
+  // and other narrow shapes.
+  CachePolicy cache_policy = CachePolicy::kNeverCache;
 };
 
 // Enumeration of broad categories of Gemm.
diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h
index 9023e93d83d..b441628a67b 100644
--- a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h
+++ b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h
@@ -16,6 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
 
+#include "ruy/matrix.h"  // from @ruy
 #include "ruy/path.h"  // from @ruy
 #include "ruy/ruy.h"  // from @ruy
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
@@ -25,6 +26,20 @@ namespace tflite {
 namespace cpu_backend_gemm {
 namespace detail {
 
+inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
+  switch (cache_policy) {
+    case CachePolicy::kNeverCache:
+      return ruy::CachePolicy::kNeverCache;
+    case CachePolicy::kCacheIfLargeSpeedup:
+      return ruy::CachePolicy::kCacheIfLargeSpeedup;
+    case CachePolicy::kAlwaysCache:
+      return ruy::CachePolicy::kAlwaysCache;
+    default:
+      TFLITE_DCHECK(false);
+      return ruy::CachePolicy::kNeverCache;
+  }
+}
+
 template <typename Scalar, typename DataPointer>
 void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
                    ruy::Matrix<Scalar>* dst) {
@@ -38,7 +53,7 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
   dst->set_data(data_ptr);
   dst->set_zero_point(params.zero_point);
 #ifdef TFLITE_WITH_RUY_GEMV
-  dst->set_cacheable(params.cacheable);
+  dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
 #endif
 }
 
diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc
index 5faf13303d8..62a4ede9a06 100644
--- a/tensorflow/lite/kernels/fully_connected.cc
+++ b/tensorflow/lite/kernels/fully_connected.cc
@@ -22,6 +22,7 @@ limitations under the License.
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/cpu_backend_context.h"
+#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
 #include "tensorflow/lite/kernels/internal/quantization_util.h"
diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 32584fcd027..07f3117dac7 100644
--- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -1042,7 +1042,7 @@ void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias,
   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
   lhs_params.rows = n_output;
   lhs_params.cols = n_input;
-  lhs_params.cacheable = true;
+  lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
 
   MatrixParams<int8_t> rhs_params;
   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index 6e1f805f7f4..5f183de7269 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -286,13 +286,15 @@ inline void FullyConnected(
   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
   rhs_params.rows = input_rows;
   rhs_params.cols = input_shape.FlatSize() / input_rows;
-  rhs_params.cacheable = params.rhs_cacheable;
+  rhs_params.cache_policy =
+      cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
   TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
   cpu_backend_gemm::MatrixParams<float> lhs_params;
   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
   lhs_params.cols = weights_shape.Dims(dims_count - 1);
   lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
-  lhs_params.cacheable = params.lhs_cacheable;
+  lhs_params.cache_policy =
+      cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
   cpu_backend_gemm::MatrixParams<float> dst_params;
   dst_params.order = cpu_backend_gemm::Order::kColMajor;
   dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
@@ -345,13 +347,15 @@ inline void FullyConnected(
   lhs_params.cols = filter_cols;
   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
   lhs_params.zero_point = -filter_offset;
-  lhs_params.cacheable = params.lhs_cacheable;
+  lhs_params.cache_policy =
+      cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
   rhs_params.rows = filter_cols;
   rhs_params.cols = batches;
   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
   rhs_params.zero_point = -input_offset;
-  rhs_params.cacheable = params.rhs_cacheable;
+  rhs_params.cache_policy =
+      cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
   cpu_backend_gemm::MatrixParams<uint8> dst_params;
   dst_params.rows = filter_rows;
   dst_params.cols = batches;
@@ -404,13 +408,15 @@ inline void FullyConnected(
   lhs_params.cols = accum_depth;
   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
   lhs_params.zero_point = -filter_offset;
-  lhs_params.cacheable = params.lhs_cacheable;
+  lhs_params.cache_policy =
+      cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
   cpu_backend_gemm::MatrixParams<uint8> rhs_params;
   rhs_params.rows = accum_depth;
   rhs_params.cols = batches;
   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
   rhs_params.zero_point = -input_offset;
-  rhs_params.cacheable = params.rhs_cacheable;
+  rhs_params.cache_policy =
+      cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
   cpu_backend_gemm::MatrixParams<int16> dst_params;
   dst_params.rows = output_depth;
   dst_params.cols = batches;