Use the new ruy API for caching constant matrices.
PiperOrigin-RevId: 309220053 Change-Id: I0d1f0040cd84aade4c9f9e84c240909865e818bc
This commit is contained in:
parent
09a1c84cfe
commit
3b3d320385
@ -338,6 +338,7 @@ cc_library(
|
|||||||
# Depend on ruy regardless of `tflite_with_ruy`. See the comment in
|
# Depend on ruy regardless of `tflite_with_ruy`. See the comment in
|
||||||
# cpu_backend_gemm.h about why ruy is the generic path.
|
# cpu_backend_gemm.h about why ruy is the generic path.
|
||||||
"@ruy//ruy",
|
"@ruy//ruy",
|
||||||
|
"@ruy//ruy:matrix",
|
||||||
"@ruy//ruy:path",
|
"@ruy//ruy:path",
|
||||||
"@ruy//ruy/profiler:instrumentation",
|
"@ruy//ruy/profiler:instrumentation",
|
||||||
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
|
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
|
||||||
@ -525,6 +526,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
deps = [
|
deps = [
|
||||||
":cpu_backend_context",
|
":cpu_backend_context",
|
||||||
|
":cpu_backend_gemm",
|
||||||
":cpu_backend_threadpool",
|
":cpu_backend_threadpool",
|
||||||
":eigen_support",
|
":eigen_support",
|
||||||
":kernel_util",
|
":kernel_util",
|
||||||
|
@ -55,9 +55,6 @@ CpuBackendContext::CpuBackendContext()
|
|||||||
ruy_context_(new ruy::Context),
|
ruy_context_(new ruy::Context),
|
||||||
gemmlowp_context_(new gemmlowp::GemmContext) {
|
gemmlowp_context_(new gemmlowp::GemmContext) {
|
||||||
SetMaxNumThreads(kDefaultNumThreadpoolThreads);
|
SetMaxNumThreads(kDefaultNumThreadpoolThreads);
|
||||||
#ifdef TFLITE_WITH_RUY_GEMV
|
|
||||||
ruy_context_->set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CpuBackendContext::~CpuBackendContext() {}
|
CpuBackendContext::~CpuBackendContext() {}
|
||||||
|
@ -29,6 +29,17 @@ namespace cpu_backend_gemm {
|
|||||||
// Matrix storage order: column-major or row-major.
|
// Matrix storage order: column-major or row-major.
|
||||||
enum class Order { kColMajor, kRowMajor };
|
enum class Order { kColMajor, kRowMajor };
|
||||||
|
|
||||||
|
enum class CachePolicy : std::uint8_t {
|
||||||
|
kNeverCache,
|
||||||
|
kCacheIfLargeSpeedup,
|
||||||
|
kAlwaysCache,
|
||||||
|
};
|
||||||
|
|
||||||
|
inline CachePolicy DefaultCachePolicy(bool is_constant_data) {
|
||||||
|
return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup
|
||||||
|
: CachePolicy::kNeverCache;
|
||||||
|
}
|
||||||
|
|
||||||
// MatrixParams encapsulates the parameters that Gemm needs about each
|
// MatrixParams encapsulates the parameters that Gemm needs about each
|
||||||
// matrix, besides the buffer data pointer.
|
// matrix, besides the buffer data pointer.
|
||||||
// Compare to ruy::Matrix, which also encapsulates the data pointer.
|
// Compare to ruy::Matrix, which also encapsulates the data pointer.
|
||||||
@ -47,10 +58,13 @@ struct MatrixParams {
|
|||||||
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
|
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
|
||||||
// When Scalar is floating-point, this must be 0.
|
// When Scalar is floating-point, this must be 0.
|
||||||
Scalar zero_point = 0;
|
Scalar zero_point = 0;
|
||||||
// Indicate whether the underlying data will remain unchanged for
|
// When the data pointed to by this matrix is constant data, so that it is
|
||||||
// some period of time. Defaults to false, but should be set to true
|
// valid to assume that equality of pointers implies equality of data,
|
||||||
// for unchanging data (e.g. weights buffers in many cases)
|
// a CachePolicy may be used instead of the default kNeverCache,
|
||||||
bool cacheable = false;
|
// which will enable ruy to take advantage of this constancy of the data to
|
||||||
|
// cache the packing work, which can be a large speedup in matrix*vector
|
||||||
|
// and other narrow shapes.
|
||||||
|
CachePolicy cache_policy = CachePolicy::kNeverCache;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Enumeration of broad categories of Gemm.
|
// Enumeration of broad categories of Gemm.
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
||||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
||||||
|
|
||||||
|
#include "ruy/matrix.h" // from @ruy
|
||||||
#include "ruy/path.h" // from @ruy
|
#include "ruy/path.h" // from @ruy
|
||||||
#include "ruy/ruy.h" // from @ruy
|
#include "ruy/ruy.h" // from @ruy
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
@ -25,6 +26,20 @@ namespace tflite {
|
|||||||
namespace cpu_backend_gemm {
|
namespace cpu_backend_gemm {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
|
inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
|
||||||
|
switch (cache_policy) {
|
||||||
|
case CachePolicy::kNeverCache:
|
||||||
|
return ruy::CachePolicy::kNeverCache;
|
||||||
|
case CachePolicy::kCacheIfLargeSpeedup:
|
||||||
|
return ruy::CachePolicy::kCacheIfLargeSpeedup;
|
||||||
|
case CachePolicy::kAlwaysCache:
|
||||||
|
return ruy::CachePolicy::kAlwaysCache;
|
||||||
|
default:
|
||||||
|
TFLITE_DCHECK(false);
|
||||||
|
return ruy::CachePolicy::kNeverCache;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Scalar, typename DataPointer>
|
template <typename Scalar, typename DataPointer>
|
||||||
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
||||||
ruy::Matrix<Scalar>* dst) {
|
ruy::Matrix<Scalar>* dst) {
|
||||||
@ -38,7 +53,7 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
|||||||
dst->set_data(data_ptr);
|
dst->set_data(data_ptr);
|
||||||
dst->set_zero_point(params.zero_point);
|
dst->set_zero_point(params.zero_point);
|
||||||
#ifdef TFLITE_WITH_RUY_GEMV
|
#ifdef TFLITE_WITH_RUY_GEMV
|
||||||
dst->set_cacheable(params.cacheable);
|
dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
|
#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
|
||||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||||
|
@ -1042,7 +1042,7 @@ void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias,
|
|||||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||||
lhs_params.rows = n_output;
|
lhs_params.rows = n_output;
|
||||||
lhs_params.cols = n_input;
|
lhs_params.cols = n_input;
|
||||||
lhs_params.cacheable = true;
|
lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
|
||||||
|
|
||||||
MatrixParams<int8_t> rhs_params;
|
MatrixParams<int8_t> rhs_params;
|
||||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
|
@ -286,13 +286,15 @@ inline void FullyConnected(
|
|||||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
rhs_params.rows = input_rows;
|
rhs_params.rows = input_rows;
|
||||||
rhs_params.cols = input_shape.FlatSize() / input_rows;
|
rhs_params.cols = input_shape.FlatSize() / input_rows;
|
||||||
rhs_params.cacheable = params.rhs_cacheable;
|
rhs_params.cache_policy =
|
||||||
|
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
|
||||||
TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
|
TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
|
||||||
cpu_backend_gemm::MatrixParams<float> lhs_params;
|
cpu_backend_gemm::MatrixParams<float> lhs_params;
|
||||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||||
lhs_params.cols = weights_shape.Dims(dims_count - 1);
|
lhs_params.cols = weights_shape.Dims(dims_count - 1);
|
||||||
lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
|
lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
|
||||||
lhs_params.cacheable = params.lhs_cacheable;
|
lhs_params.cache_policy =
|
||||||
|
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
|
||||||
cpu_backend_gemm::MatrixParams<float> dst_params;
|
cpu_backend_gemm::MatrixParams<float> dst_params;
|
||||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
|
dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
|
||||||
@ -345,13 +347,15 @@ inline void FullyConnected(
|
|||||||
lhs_params.cols = filter_cols;
|
lhs_params.cols = filter_cols;
|
||||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||||
lhs_params.zero_point = -filter_offset;
|
lhs_params.zero_point = -filter_offset;
|
||||||
lhs_params.cacheable = params.lhs_cacheable;
|
lhs_params.cache_policy =
|
||||||
|
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
|
||||||
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
|
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
|
||||||
rhs_params.rows = filter_cols;
|
rhs_params.rows = filter_cols;
|
||||||
rhs_params.cols = batches;
|
rhs_params.cols = batches;
|
||||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
rhs_params.zero_point = -input_offset;
|
rhs_params.zero_point = -input_offset;
|
||||||
rhs_params.cacheable = params.rhs_cacheable;
|
rhs_params.cache_policy =
|
||||||
|
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
|
||||||
cpu_backend_gemm::MatrixParams<uint8> dst_params;
|
cpu_backend_gemm::MatrixParams<uint8> dst_params;
|
||||||
dst_params.rows = filter_rows;
|
dst_params.rows = filter_rows;
|
||||||
dst_params.cols = batches;
|
dst_params.cols = batches;
|
||||||
@ -404,13 +408,15 @@ inline void FullyConnected(
|
|||||||
lhs_params.cols = accum_depth;
|
lhs_params.cols = accum_depth;
|
||||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||||
lhs_params.zero_point = -filter_offset;
|
lhs_params.zero_point = -filter_offset;
|
||||||
lhs_params.cacheable = params.lhs_cacheable;
|
lhs_params.cache_policy =
|
||||||
|
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
|
||||||
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
|
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
|
||||||
rhs_params.rows = accum_depth;
|
rhs_params.rows = accum_depth;
|
||||||
rhs_params.cols = batches;
|
rhs_params.cols = batches;
|
||||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
rhs_params.zero_point = -input_offset;
|
rhs_params.zero_point = -input_offset;
|
||||||
rhs_params.cacheable = params.rhs_cacheable;
|
rhs_params.cache_policy =
|
||||||
|
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
|
||||||
cpu_backend_gemm::MatrixParams<int16> dst_params;
|
cpu_backend_gemm::MatrixParams<int16> dst_params;
|
||||||
dst_params.rows = output_depth;
|
dst_params.rows = output_depth;
|
||||||
dst_params.cols = batches;
|
dst_params.cols = batches;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user