Use the new ruy API for caching constant matrices.

PiperOrigin-RevId: 309220053
Change-Id: I0d1f0040cd84aade4c9f9e84c240909865e818bc
This commit is contained in:
Benoit Jacob 2020-04-30 07:17:56 -07:00 committed by TensorFlower Gardener
parent 09a1c84cfe
commit 3b3d320385
7 changed files with 50 additions and 15 deletions

View File

@ -338,6 +338,7 @@ cc_library(
# Depend on ruy regardless of `tflite_with_ruy`. See the comment in
# cpu_backend_gemm.h about why ruy is the generic path.
"@ruy//ruy",
"@ruy//ruy:matrix",
"@ruy//ruy:path",
"@ruy//ruy/profiler:instrumentation",
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
@ -525,6 +526,7 @@ cc_library(
visibility = ["//visibility:private"],
deps = [
":cpu_backend_context",
":cpu_backend_gemm",
":cpu_backend_threadpool",
":eigen_support",
":kernel_util",

View File

@ -55,9 +55,6 @@ CpuBackendContext::CpuBackendContext()
ruy_context_(new ruy::Context),
gemmlowp_context_(new gemmlowp::GemmContext) {
SetMaxNumThreads(kDefaultNumThreadpoolThreads);
#ifdef TFLITE_WITH_RUY_GEMV
ruy_context_->set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul);
#endif
}
CpuBackendContext::~CpuBackendContext() {}

View File

@ -29,6 +29,17 @@ namespace cpu_backend_gemm {
// Matrix storage order: column-major or row-major.
enum class Order { kColMajor, kRowMajor };
enum class CachePolicy : std::uint8_t {
kNeverCache,
kCacheIfLargeSpeedup,
kAlwaysCache,
};
inline CachePolicy DefaultCachePolicy(bool is_constant_data) {
return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup
: CachePolicy::kNeverCache;
}
// MatrixParams encapsulates the parameters that Gemm needs about each
// matrix, besides the buffer data pointer.
// Compare to ruy::Matrix, which also encapsulates the data pointer.
@ -47,10 +58,13 @@ struct MatrixParams {
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
// When Scalar is floating-point, this must be 0.
Scalar zero_point = 0;
// Indicate whether the underlying data will remain unchanged for
// some period of time. Defaults to false, but should be set to true
// for unchanging data (e.g. weights buffers in many cases)
bool cacheable = false;
// When the data pointed to by this matrix is constant data, so that it is
// valid to assume that equality of pointers implies equality of data,
// a CachePolicy may be used instead of the default kNeverCache,
// which will enable ruy to take advantage of this constancy of the data to
// cache the packing work, which can be a large speedup in matrix*vector
// and other narrow shapes.
CachePolicy cache_policy = CachePolicy::kNeverCache;
};
// Enumeration of broad categories of Gemm.

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
#include "ruy/matrix.h" // from @ruy
#include "ruy/path.h" // from @ruy
#include "ruy/ruy.h" // from @ruy
#include "tensorflow/lite/kernels/cpu_backend_context.h"
@ -25,6 +26,20 @@ namespace tflite {
namespace cpu_backend_gemm {
namespace detail {
inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
switch (cache_policy) {
case CachePolicy::kNeverCache:
return ruy::CachePolicy::kNeverCache;
case CachePolicy::kCacheIfLargeSpeedup:
return ruy::CachePolicy::kCacheIfLargeSpeedup;
case CachePolicy::kAlwaysCache:
return ruy::CachePolicy::kAlwaysCache;
default:
TFLITE_DCHECK(false);
return ruy::CachePolicy::kNeverCache;
}
}
template <typename Scalar, typename DataPointer>
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
ruy::Matrix<Scalar>* dst) {
@ -38,7 +53,7 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
dst->set_data(data_ptr);
dst->set_zero_point(params.zero_point);
#ifdef TFLITE_WITH_RUY_GEMV
dst->set_cacheable(params.cacheable);
dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
#endif
}

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"

View File

@ -1042,7 +1042,7 @@ void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias,
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.rows = n_output;
lhs_params.cols = n_input;
lhs_params.cacheable = true;
lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
MatrixParams<int8_t> rhs_params;
rhs_params.order = cpu_backend_gemm::Order::kColMajor;

View File

@ -286,13 +286,15 @@ inline void FullyConnected(
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
rhs_params.rows = input_rows;
rhs_params.cols = input_shape.FlatSize() / input_rows;
rhs_params.cacheable = params.rhs_cacheable;
rhs_params.cache_policy =
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
cpu_backend_gemm::MatrixParams<float> lhs_params;
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.cols = weights_shape.Dims(dims_count - 1);
lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
lhs_params.cacheable = params.lhs_cacheable;
lhs_params.cache_policy =
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
cpu_backend_gemm::MatrixParams<float> dst_params;
dst_params.order = cpu_backend_gemm::Order::kColMajor;
dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
@ -345,13 +347,15 @@ inline void FullyConnected(
lhs_params.cols = filter_cols;
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.zero_point = -filter_offset;
lhs_params.cacheable = params.lhs_cacheable;
lhs_params.cache_policy =
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
rhs_params.rows = filter_cols;
rhs_params.cols = batches;
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
rhs_params.zero_point = -input_offset;
rhs_params.cacheable = params.rhs_cacheable;
rhs_params.cache_policy =
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
cpu_backend_gemm::MatrixParams<uint8> dst_params;
dst_params.rows = filter_rows;
dst_params.cols = batches;
@ -404,13 +408,15 @@ inline void FullyConnected(
lhs_params.cols = accum_depth;
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.zero_point = -filter_offset;
lhs_params.cacheable = params.lhs_cacheable;
lhs_params.cache_policy =
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
rhs_params.rows = accum_depth;
rhs_params.cols = batches;
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
rhs_params.zero_point = -input_offset;
rhs_params.cacheable = params.rhs_cacheable;
rhs_params.cache_policy =
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
cpu_backend_gemm::MatrixParams<int16> dst_params;
dst_params.rows = output_depth;
dst_params.cols = batches;