Let cpu_backend_gemm support all storage order combinations, unconditionally using ruy as the backend in combinations other than RowMajor*ColMajor->ColMajor, which were so far not supported. Ruy is different from other back-ends in that it supports all combinations as runtime parameters without a code size increase.
PiperOrigin-RevId: 323786939 Change-Id: Ib81abb5ca621a01cd8453a4a08b27601ad75c7dc
This commit is contained in:
parent
9f1aadb48f
commit
59d53b7425
@ -95,9 +95,26 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||
CpuBackendContext* context) {
|
||||
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
|
||||
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
||||
// In some cases we want to unconditionally use ruy as the backend, overriding
|
||||
// the `tflite_with_ruy` setting and the platform default.
|
||||
bool must_use_ruy = false;
|
||||
if (context->use_caching()) {
|
||||
// Dispatch to backend that supports caching of prepacked weights
|
||||
// matrices.
|
||||
// Only ruy supports caching of pre-packed matrices. Due to the large
|
||||
// performance impact in the cases where it's typically used, this overrides
|
||||
// the default.
|
||||
must_use_ruy = true;
|
||||
}
|
||||
if (lhs_params.order != Order::kRowMajor ||
|
||||
rhs_params.order != Order::kColMajor ||
|
||||
dst_params.order != Order::kColMajor) {
|
||||
// ruy supports all 2^3=8 combinations of storage orders with comparable
|
||||
// performance. In ruy, it's only a runtime switch. In other backends
|
||||
// (gemmlowp, Eigen), storage orders are template parameters, supporting
|
||||
// all 8 combinations would be up to a 8-fold code size increase, so we
|
||||
// prefer to force usage of ruy in these cases.
|
||||
must_use_ruy = true;
|
||||
}
|
||||
if (must_use_ruy) {
|
||||
detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>::Run(lhs_params, lhs_data,
|
||||
rhs_params, rhs_data,
|
||||
@ -105,15 +122,18 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||
params, context);
|
||||
return;
|
||||
}
|
||||
const bool do_custom_gemv = (dst_params.cols == 1);
|
||||
if (do_custom_gemv) {
|
||||
// GEMV case: try a custom fast GEMV path.
|
||||
// If we did not choose to force usage of ruy above, then we may now consider
|
||||
// using custom GEMV code for the matrix*vector cases.
|
||||
const bool try_custom_gemv = (dst_params.cols == 1);
|
||||
if (try_custom_gemv) {
|
||||
// GEMV case: try a custom fast GEMV path. It will return true if it
|
||||
// actually handled it.
|
||||
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||
dst_params, dst_data, params, context)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
|
||||
// Generic case: dispatch to any backend as a general GEMM.
|
||||
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||
dst_params, dst_data, params, context);
|
||||
|
@ -236,17 +236,6 @@ void ValidateParams(
|
||||
(void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>();
|
||||
ValidateGemmParams(params);
|
||||
// For now, Gemm only supports this particular combination of storage orders.
|
||||
// Actually the generic ruy path already supports all combinations (with
|
||||
// various performance penalties). On the other hand, gemmlowp and Eigen
|
||||
// paths would require more source code and larger binary code to handle
|
||||
// other combinations (because orders are template parameters in gemmlowp
|
||||
// and Eigen). Since this is TFLite's own internal Gemm library, there is
|
||||
// no point in supporting more than what TFlite currently uses, and that
|
||||
// is for now this single combination.
|
||||
TFLITE_DCHECK(lhs_params.order == Order::kRowMajor);
|
||||
TFLITE_DCHECK(rhs_params.order == Order::kColMajor);
|
||||
TFLITE_DCHECK(dst_params.order == Order::kColMajor);
|
||||
}
|
||||
|
||||
} // namespace cpu_backend_gemm
|
||||
|
@ -389,8 +389,13 @@ void TestSomeGemm(int rows, int depth, int cols,
|
||||
}
|
||||
MakeDeterministicPseudoRandomVector(rows * cols, &dst_data);
|
||||
|
||||
auto random_order = [&]() {
|
||||
return random_engine() % 2 ? cpu_backend_gemm::Order::kRowMajor
|
||||
: cpu_backend_gemm::Order::kColMajor;
|
||||
};
|
||||
MatrixParams<LhsScalar> lhs_params;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.order =
|
||||
use_golden ? cpu_backend_gemm::Order::kRowMajor : random_order();
|
||||
lhs_params.rows = rows;
|
||||
lhs_params.cols = depth;
|
||||
if (!std::is_floating_point<LhsScalar>::value) {
|
||||
@ -401,7 +406,8 @@ void TestSomeGemm(int rows, int depth, int cols,
|
||||
}
|
||||
|
||||
MatrixParams<RhsScalar> rhs_params;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.order =
|
||||
use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
|
||||
rhs_params.rows = depth;
|
||||
rhs_params.cols = cols;
|
||||
if (!std::is_floating_point<RhsScalar>::value) {
|
||||
@ -412,7 +418,8 @@ void TestSomeGemm(int rows, int depth, int cols,
|
||||
}
|
||||
|
||||
MatrixParams<DstScalar> dst_params;
|
||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
dst_params.order =
|
||||
use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
|
||||
dst_params.rows = rows;
|
||||
dst_params.cols = cols;
|
||||
if (!std::is_floating_point<DstScalar>::value) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user