diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index ae49d798fa8..5b6fe4b5b21 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -338,6 +338,7 @@ cc_library( # Depend on ruy regardless of `tflite_with_ruy`. See the comment in # cpu_backend_gemm.h about why ruy is the generic path. "@ruy//ruy", + "@ruy//ruy:matrix", "@ruy//ruy:path", "@ruy//ruy/profiler:instrumentation", # We only need to depend on gemmlowp and Eigen when tflite_with_ruy @@ -525,6 +526,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":cpu_backend_context", + ":cpu_backend_gemm", ":cpu_backend_threadpool", ":eigen_support", ":kernel_util", diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc index 0fa4175973a..d6de9bf8d61 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.cc +++ b/tensorflow/lite/kernels/cpu_backend_context.cc @@ -55,9 +55,6 @@ CpuBackendContext::CpuBackendContext() ruy_context_(new ruy::Context), gemmlowp_context_(new gemmlowp::GemmContext) { SetMaxNumThreads(kDefaultNumThreadpoolThreads); -#ifdef TFLITE_WITH_RUY_GEMV - ruy_context_->set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul); -#endif } CpuBackendContext::~CpuBackendContext() {} diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_params.h b/tensorflow/lite/kernels/cpu_backend_gemm_params.h index 66700ea9cdf..0040f40cd50 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_params.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_params.h @@ -29,6 +29,17 @@ namespace cpu_backend_gemm { // Matrix storage order: column-major or row-major. enum class Order { kColMajor, kRowMajor }; +enum class CachePolicy : std::uint8_t { + kNeverCache, + kCacheIfLargeSpeedup, + kAlwaysCache, +}; + +inline CachePolicy DefaultCachePolicy(bool is_constant_data) { + return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup + : CachePolicy::kNeverCache; +} + // MatrixParams encapsulates the parameters that Gemm needs about each // matrix, besides the buffer data pointer. // Compare to ruy::Matrix, which also encapsulates the data pointer. @@ -47,10 +58,13 @@ struct MatrixParams { // The zero_point, i.e. which Scalar value is to be interpreted as zero. // When Scalar is floating-point, this must be 0. Scalar zero_point = 0; - // Indicate whether the underlying data will remain unchanged for - // some period of time. Defaults to false, but should be set to true - // for unchanging data (e.g. weights buffers in many cases) - bool cacheable = false; + // When the data pointed to by this matrix is constant data, so that it is + // valid to assume that equality of pointers implies equality of data, + // a CachePolicy may be used instead of the default kNeverCache, + // which will enable ruy to take advantage of this constancy of the data to + // cache the packing work, which can be a large speedup in matrix*vector + // and other narrow shapes. + CachePolicy cache_policy = CachePolicy::kNeverCache; }; // Enumeration of broad categories of Gemm. diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h index 9023e93d83d..b441628a67b 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ +#include "ruy/matrix.h" // from @ruy #include "ruy/path.h" // from @ruy #include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" @@ -25,6 +26,20 @@ namespace tflite { namespace cpu_backend_gemm { namespace detail { +inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) { + switch (cache_policy) { + case CachePolicy::kNeverCache: + return ruy::CachePolicy::kNeverCache; + case CachePolicy::kCacheIfLargeSpeedup: + return ruy::CachePolicy::kCacheIfLargeSpeedup; + case CachePolicy::kAlwaysCache: + return ruy::CachePolicy::kAlwaysCache; + default: + TFLITE_DCHECK(false); + return ruy::CachePolicy::kNeverCache; + } +} + template void MakeRuyMatrix(const MatrixParams& params, DataPointer data_ptr, ruy::Matrix* dst) { @@ -38,7 +53,7 @@ void MakeRuyMatrix(const MatrixParams& params, DataPointer data_ptr, dst->set_data(data_ptr); dst->set_zero_point(params.zero_point); #ifdef TFLITE_WITH_RUY_GEMV - dst->set_cacheable(params.cacheable); + dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy)); #endif } diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 5faf13303d8..62a4ede9a06 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 32584fcd027..07f3117dac7 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -1042,7 +1042,7 @@ void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias, lhs_params.order = cpu_backend_gemm::Order::kRowMajor; lhs_params.rows = n_output; lhs_params.cols = n_input; - lhs_params.cacheable = true; + lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup; MatrixParams rhs_params; rhs_params.order = cpu_backend_gemm::Order::kColMajor; diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 6e1f805f7f4..5f183de7269 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -286,13 +286,15 @@ inline void FullyConnected( rhs_params.order = cpu_backend_gemm::Order::kColMajor; rhs_params.rows = input_rows; rhs_params.cols = input_shape.FlatSize() / input_rows; - rhs_params.cacheable = params.rhs_cacheable; + rhs_params.cache_policy = + cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable); TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols); cpu_backend_gemm::MatrixParams lhs_params; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; lhs_params.cols = weights_shape.Dims(dims_count - 1); lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1); - lhs_params.cacheable = params.lhs_cacheable; + lhs_params.cache_policy = + cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable); cpu_backend_gemm::MatrixParams dst_params; dst_params.order = cpu_backend_gemm::Order::kColMajor; dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1); @@ -345,13 +347,15 @@ inline void FullyConnected( lhs_params.cols = filter_cols; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; lhs_params.zero_point = -filter_offset; - lhs_params.cacheable = params.lhs_cacheable; + lhs_params.cache_policy = + cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable); cpu_backend_gemm::MatrixParams rhs_params; rhs_params.rows = filter_cols; rhs_params.cols = batches; rhs_params.order = cpu_backend_gemm::Order::kColMajor; rhs_params.zero_point = -input_offset; - rhs_params.cacheable = params.rhs_cacheable; + rhs_params.cache_policy = + cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable); cpu_backend_gemm::MatrixParams dst_params; dst_params.rows = filter_rows; dst_params.cols = batches; @@ -404,13 +408,15 @@ inline void FullyConnected( lhs_params.cols = accum_depth; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; lhs_params.zero_point = -filter_offset; - lhs_params.cacheable = params.lhs_cacheable; + lhs_params.cache_policy = + cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable); cpu_backend_gemm::MatrixParams rhs_params; rhs_params.rows = accum_depth; rhs_params.cols = batches; rhs_params.order = cpu_backend_gemm::Order::kColMajor; rhs_params.zero_point = -input_offset; - rhs_params.cacheable = params.rhs_cacheable; + rhs_params.cache_policy = + cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable); cpu_backend_gemm::MatrixParams dst_params; dst_params.rows = output_depth; dst_params.cols = batches;