Let cpu_backend_gemm support all storage order combinations, unconditionally using ruy as the backend in combinations other than RowMajor*ColMajor->ColMajor, which were so far not supported. Ruy is different from other back-ends in that it supports all combinations as runtime parameters without a code size increase.

PiperOrigin-RevId: 323786939
Change-Id: Ib81abb5ca621a01cd8453a4a08b27601ad75c7dc
This commit is contained in:
Benoit Jacob 2020-07-29 07:57:51 -07:00 committed by TensorFlower Gardener
parent 9f1aadb48f
commit 59d53b7425
3 changed files with 36 additions and 20 deletions

View File

@ -95,9 +95,26 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
CpuBackendContext* context) { CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm"); ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
ValidateParams(lhs_params, rhs_params, dst_params, params); ValidateParams(lhs_params, rhs_params, dst_params, params);
// In some cases we want to unconditionally use ruy as the backend, overriding
// the `tflite_with_ruy` setting and the platform default.
bool must_use_ruy = false;
if (context->use_caching()) { if (context->use_caching()) {
// Dispatch to backend that supports caching of prepacked weights // Only ruy supports caching of pre-packed matrices. Due to the large
// matrices. // performance impact in the cases where it's typically used, this overrides
// the default.
must_use_ruy = true;
}
if (lhs_params.order != Order::kRowMajor ||
rhs_params.order != Order::kColMajor ||
dst_params.order != Order::kColMajor) {
// ruy supports all 2^3=8 combinations of storage orders with comparable
// performance. In ruy, it's only a runtime switch. In other backends
// (gemmlowp, Eigen), storage orders are template parameters, supporting
// all 8 combinations would be up to a 8-fold code size increase, so we
// prefer to force usage of ruy in these cases.
must_use_ruy = true;
}
if (must_use_ruy) {
detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar, detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>::Run(lhs_params, lhs_data, quantization_flavor>::Run(lhs_params, lhs_data,
rhs_params, rhs_data, rhs_params, rhs_data,
@ -105,15 +122,18 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
params, context); params, context);
return; return;
} }
const bool do_custom_gemv = (dst_params.cols == 1); // If we did not choose to force usage of ruy above, then we may now consider
if (do_custom_gemv) { // using custom GEMV code for the matrix*vector cases.
// GEMV case: try a custom fast GEMV path. const bool try_custom_gemv = (dst_params.cols == 1);
if (try_custom_gemv) {
// GEMV case: try a custom fast GEMV path. It will return true if it
// actually handled it.
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data, if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params, context)) { dst_params, dst_data, params, context)) {
return; return;
} }
} }
ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM"); // Generic case: dispatch to any backend as a general GEMM.
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data, quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params, context); dst_params, dst_data, params, context);

View File

@ -236,17 +236,6 @@ void ValidateParams(
(void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar, (void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>(); quantization_flavor>();
ValidateGemmParams(params); ValidateGemmParams(params);
// For now, Gemm only supports this particular combination of storage orders.
// Actually the generic ruy path already supports all combinations (with
// various performance penalties). On the other hand, gemmlowp and Eigen
// paths would require more source code and larger binary code to handle
// other combinations (because orders are template parameters in gemmlowp
// and Eigen). Since this is TFLite's own internal Gemm library, there is
// no point in supporting more than what TFlite currently uses, and that
// is for now this single combination.
TFLITE_DCHECK(lhs_params.order == Order::kRowMajor);
TFLITE_DCHECK(rhs_params.order == Order::kColMajor);
TFLITE_DCHECK(dst_params.order == Order::kColMajor);
} }
} // namespace cpu_backend_gemm } // namespace cpu_backend_gemm

View File

@ -389,8 +389,13 @@ void TestSomeGemm(int rows, int depth, int cols,
} }
MakeDeterministicPseudoRandomVector(rows * cols, &dst_data); MakeDeterministicPseudoRandomVector(rows * cols, &dst_data);
auto random_order = [&]() {
return random_engine() % 2 ? cpu_backend_gemm::Order::kRowMajor
: cpu_backend_gemm::Order::kColMajor;
};
MatrixParams<LhsScalar> lhs_params; MatrixParams<LhsScalar> lhs_params;
lhs_params.order = cpu_backend_gemm::Order::kRowMajor; lhs_params.order =
use_golden ? cpu_backend_gemm::Order::kRowMajor : random_order();
lhs_params.rows = rows; lhs_params.rows = rows;
lhs_params.cols = depth; lhs_params.cols = depth;
if (!std::is_floating_point<LhsScalar>::value) { if (!std::is_floating_point<LhsScalar>::value) {
@ -401,7 +406,8 @@ void TestSomeGemm(int rows, int depth, int cols,
} }
MatrixParams<RhsScalar> rhs_params; MatrixParams<RhsScalar> rhs_params;
rhs_params.order = cpu_backend_gemm::Order::kColMajor; rhs_params.order =
use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
rhs_params.rows = depth; rhs_params.rows = depth;
rhs_params.cols = cols; rhs_params.cols = cols;
if (!std::is_floating_point<RhsScalar>::value) { if (!std::is_floating_point<RhsScalar>::value) {
@ -412,7 +418,8 @@ void TestSomeGemm(int rows, int depth, int cols,
} }
MatrixParams<DstScalar> dst_params; MatrixParams<DstScalar> dst_params;
dst_params.order = cpu_backend_gemm::Order::kColMajor; dst_params.order =
use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
dst_params.rows = rows; dst_params.rows = rows;
dst_params.cols = cols; dst_params.cols = cols;
if (!std::is_floating_point<DstScalar>::value) { if (!std::is_floating_point<DstScalar>::value) {