Use the new ruy API for caching constant matrices.
PiperOrigin-RevId: 309220053 Change-Id: I0d1f0040cd84aade4c9f9e84c240909865e818bc
This commit is contained in:
parent
09a1c84cfe
commit
3b3d320385
@ -338,6 +338,7 @@ cc_library(
|
||||
# Depend on ruy regardless of `tflite_with_ruy`. See the comment in
|
||||
# cpu_backend_gemm.h about why ruy is the generic path.
|
||||
"@ruy//ruy",
|
||||
"@ruy//ruy:matrix",
|
||||
"@ruy//ruy:path",
|
||||
"@ruy//ruy/profiler:instrumentation",
|
||||
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
|
||||
@ -525,6 +526,7 @@ cc_library(
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":cpu_backend_context",
|
||||
":cpu_backend_gemm",
|
||||
":cpu_backend_threadpool",
|
||||
":eigen_support",
|
||||
":kernel_util",
|
||||
|
@ -55,9 +55,6 @@ CpuBackendContext::CpuBackendContext()
|
||||
ruy_context_(new ruy::Context),
|
||||
gemmlowp_context_(new gemmlowp::GemmContext) {
|
||||
SetMaxNumThreads(kDefaultNumThreadpoolThreads);
|
||||
#ifdef TFLITE_WITH_RUY_GEMV
|
||||
ruy_context_->set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul);
|
||||
#endif
|
||||
}
|
||||
|
||||
CpuBackendContext::~CpuBackendContext() {}
|
||||
|
@ -29,6 +29,17 @@ namespace cpu_backend_gemm {
|
||||
// Matrix storage order: column-major or row-major.
|
||||
enum class Order { kColMajor, kRowMajor };
|
||||
|
||||
enum class CachePolicy : std::uint8_t {
|
||||
kNeverCache,
|
||||
kCacheIfLargeSpeedup,
|
||||
kAlwaysCache,
|
||||
};
|
||||
|
||||
inline CachePolicy DefaultCachePolicy(bool is_constant_data) {
|
||||
return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup
|
||||
: CachePolicy::kNeverCache;
|
||||
}
|
||||
|
||||
// MatrixParams encapsulates the parameters that Gemm needs about each
|
||||
// matrix, besides the buffer data pointer.
|
||||
// Compare to ruy::Matrix, which also encapsulates the data pointer.
|
||||
@ -47,10 +58,13 @@ struct MatrixParams {
|
||||
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
|
||||
// When Scalar is floating-point, this must be 0.
|
||||
Scalar zero_point = 0;
|
||||
// Indicate whether the underlying data will remain unchanged for
|
||||
// some period of time. Defaults to false, but should be set to true
|
||||
// for unchanging data (e.g. weights buffers in many cases)
|
||||
bool cacheable = false;
|
||||
// When the data pointed to by this matrix is constant data, so that it is
|
||||
// valid to assume that equality of pointers implies equality of data,
|
||||
// a CachePolicy may be used instead of the default kNeverCache,
|
||||
// which will enable ruy to take advantage of this constancy of the data to
|
||||
// cache the packing work, which can be a large speedup in matrix*vector
|
||||
// and other narrow shapes.
|
||||
CachePolicy cache_policy = CachePolicy::kNeverCache;
|
||||
};
|
||||
|
||||
// Enumeration of broad categories of Gemm.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
||||
|
||||
#include "ruy/matrix.h" // from @ruy
|
||||
#include "ruy/path.h" // from @ruy
|
||||
#include "ruy/ruy.h" // from @ruy
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
@ -25,6 +26,20 @@ namespace tflite {
|
||||
namespace cpu_backend_gemm {
|
||||
namespace detail {
|
||||
|
||||
inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
|
||||
switch (cache_policy) {
|
||||
case CachePolicy::kNeverCache:
|
||||
return ruy::CachePolicy::kNeverCache;
|
||||
case CachePolicy::kCacheIfLargeSpeedup:
|
||||
return ruy::CachePolicy::kCacheIfLargeSpeedup;
|
||||
case CachePolicy::kAlwaysCache:
|
||||
return ruy::CachePolicy::kAlwaysCache;
|
||||
default:
|
||||
TFLITE_DCHECK(false);
|
||||
return ruy::CachePolicy::kNeverCache;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Scalar, typename DataPointer>
|
||||
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
||||
ruy::Matrix<Scalar>* dst) {
|
||||
@ -38,7 +53,7 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
||||
dst->set_data(data_ptr);
|
||||
dst->set_zero_point(params.zero_point);
|
||||
#ifdef TFLITE_WITH_RUY_GEMV
|
||||
dst->set_cacheable(params.cacheable);
|
||||
dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
|
@ -1042,7 +1042,7 @@ void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias,
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.rows = n_output;
|
||||
lhs_params.cols = n_input;
|
||||
lhs_params.cacheable = true;
|
||||
lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
|
||||
|
||||
MatrixParams<int8_t> rhs_params;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
|
@ -286,13 +286,15 @@ inline void FullyConnected(
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.rows = input_rows;
|
||||
rhs_params.cols = input_shape.FlatSize() / input_rows;
|
||||
rhs_params.cacheable = params.rhs_cacheable;
|
||||
rhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
|
||||
TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
|
||||
cpu_backend_gemm::MatrixParams<float> lhs_params;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.cols = weights_shape.Dims(dims_count - 1);
|
||||
lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
|
||||
lhs_params.cacheable = params.lhs_cacheable;
|
||||
lhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<float> dst_params;
|
||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
|
||||
@ -345,13 +347,15 @@ inline void FullyConnected(
|
||||
lhs_params.cols = filter_cols;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.zero_point = -filter_offset;
|
||||
lhs_params.cacheable = params.lhs_cacheable;
|
||||
lhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
|
||||
rhs_params.rows = filter_cols;
|
||||
rhs_params.cols = batches;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.zero_point = -input_offset;
|
||||
rhs_params.cacheable = params.rhs_cacheable;
|
||||
rhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<uint8> dst_params;
|
||||
dst_params.rows = filter_rows;
|
||||
dst_params.cols = batches;
|
||||
@ -404,13 +408,15 @@ inline void FullyConnected(
|
||||
lhs_params.cols = accum_depth;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.zero_point = -filter_offset;
|
||||
lhs_params.cacheable = params.lhs_cacheable;
|
||||
lhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
|
||||
rhs_params.rows = accum_depth;
|
||||
rhs_params.cols = batches;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.zero_point = -input_offset;
|
||||
rhs_params.cacheable = params.rhs_cacheable;
|
||||
rhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<int16> dst_params;
|
||||
dst_params.rows = output_depth;
|
||||
dst_params.cols = batches;
|
||||
|
Loading…
Reference in New Issue
Block a user