Let cpu_backend_gemm support all storage order combinations, unconditionally using ruy as the backend in combinations other than RowMajor*ColMajor->ColMajor, which were so far not supported. Ruy is different from other back-ends in that it supports all combinations as runtime parameters without a code size increase.
PiperOrigin-RevId: 323786939 Change-Id: Ib81abb5ca621a01cd8453a4a08b27601ad75c7dc
This commit is contained in:
parent
9f1aadb48f
commit
59d53b7425
@ -95,9 +95,26 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
|||||||
CpuBackendContext* context) {
|
CpuBackendContext* context) {
|
||||||
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
|
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
|
||||||
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
||||||
|
// In some cases we want to unconditionally use ruy as the backend, overriding
|
||||||
|
// the `tflite_with_ruy` setting and the platform default.
|
||||||
|
bool must_use_ruy = false;
|
||||||
if (context->use_caching()) {
|
if (context->use_caching()) {
|
||||||
// Dispatch to backend that supports caching of prepacked weights
|
// Only ruy supports caching of pre-packed matrices. Due to the large
|
||||||
// matrices.
|
// performance impact in the cases where it's typically used, this overrides
|
||||||
|
// the default.
|
||||||
|
must_use_ruy = true;
|
||||||
|
}
|
||||||
|
if (lhs_params.order != Order::kRowMajor ||
|
||||||
|
rhs_params.order != Order::kColMajor ||
|
||||||
|
dst_params.order != Order::kColMajor) {
|
||||||
|
// ruy supports all 2^3=8 combinations of storage orders with comparable
|
||||||
|
// performance. In ruy, it's only a runtime switch. In other backends
|
||||||
|
// (gemmlowp, Eigen), storage orders are template parameters, supporting
|
||||||
|
// all 8 combinations would be up to a 8-fold code size increase, so we
|
||||||
|
// prefer to force usage of ruy in these cases.
|
||||||
|
must_use_ruy = true;
|
||||||
|
}
|
||||||
|
if (must_use_ruy) {
|
||||||
detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
quantization_flavor>::Run(lhs_params, lhs_data,
|
quantization_flavor>::Run(lhs_params, lhs_data,
|
||||||
rhs_params, rhs_data,
|
rhs_params, rhs_data,
|
||||||
@ -105,15 +122,18 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
|||||||
params, context);
|
params, context);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const bool do_custom_gemv = (dst_params.cols == 1);
|
// If we did not choose to force usage of ruy above, then we may now consider
|
||||||
if (do_custom_gemv) {
|
// using custom GEMV code for the matrix*vector cases.
|
||||||
// GEMV case: try a custom fast GEMV path.
|
const bool try_custom_gemv = (dst_params.cols == 1);
|
||||||
|
if (try_custom_gemv) {
|
||||||
|
// GEMV case: try a custom fast GEMV path. It will return true if it
|
||||||
|
// actually handled it.
|
||||||
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
|
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||||
dst_params, dst_data, params, context)) {
|
dst_params, dst_data, params, context)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
|
// Generic case: dispatch to any backend as a general GEMM.
|
||||||
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||||
dst_params, dst_data, params, context);
|
dst_params, dst_data, params, context);
|
||||||
|
@ -236,17 +236,6 @@ void ValidateParams(
|
|||||||
(void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
(void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
quantization_flavor>();
|
quantization_flavor>();
|
||||||
ValidateGemmParams(params);
|
ValidateGemmParams(params);
|
||||||
// For now, Gemm only supports this particular combination of storage orders.
|
|
||||||
// Actually the generic ruy path already supports all combinations (with
|
|
||||||
// various performance penalties). On the other hand, gemmlowp and Eigen
|
|
||||||
// paths would require more source code and larger binary code to handle
|
|
||||||
// other combinations (because orders are template parameters in gemmlowp
|
|
||||||
// and Eigen). Since this is TFLite's own internal Gemm library, there is
|
|
||||||
// no point in supporting more than what TFlite currently uses, and that
|
|
||||||
// is for now this single combination.
|
|
||||||
TFLITE_DCHECK(lhs_params.order == Order::kRowMajor);
|
|
||||||
TFLITE_DCHECK(rhs_params.order == Order::kColMajor);
|
|
||||||
TFLITE_DCHECK(dst_params.order == Order::kColMajor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cpu_backend_gemm
|
} // namespace cpu_backend_gemm
|
||||||
|
@ -389,8 +389,13 @@ void TestSomeGemm(int rows, int depth, int cols,
|
|||||||
}
|
}
|
||||||
MakeDeterministicPseudoRandomVector(rows * cols, &dst_data);
|
MakeDeterministicPseudoRandomVector(rows * cols, &dst_data);
|
||||||
|
|
||||||
|
auto random_order = [&]() {
|
||||||
|
return random_engine() % 2 ? cpu_backend_gemm::Order::kRowMajor
|
||||||
|
: cpu_backend_gemm::Order::kColMajor;
|
||||||
|
};
|
||||||
MatrixParams<LhsScalar> lhs_params;
|
MatrixParams<LhsScalar> lhs_params;
|
||||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
lhs_params.order =
|
||||||
|
use_golden ? cpu_backend_gemm::Order::kRowMajor : random_order();
|
||||||
lhs_params.rows = rows;
|
lhs_params.rows = rows;
|
||||||
lhs_params.cols = depth;
|
lhs_params.cols = depth;
|
||||||
if (!std::is_floating_point<LhsScalar>::value) {
|
if (!std::is_floating_point<LhsScalar>::value) {
|
||||||
@ -401,7 +406,8 @@ void TestSomeGemm(int rows, int depth, int cols,
|
|||||||
}
|
}
|
||||||
|
|
||||||
MatrixParams<RhsScalar> rhs_params;
|
MatrixParams<RhsScalar> rhs_params;
|
||||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
rhs_params.order =
|
||||||
|
use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
|
||||||
rhs_params.rows = depth;
|
rhs_params.rows = depth;
|
||||||
rhs_params.cols = cols;
|
rhs_params.cols = cols;
|
||||||
if (!std::is_floating_point<RhsScalar>::value) {
|
if (!std::is_floating_point<RhsScalar>::value) {
|
||||||
@ -412,7 +418,8 @@ void TestSomeGemm(int rows, int depth, int cols,
|
|||||||
}
|
}
|
||||||
|
|
||||||
MatrixParams<DstScalar> dst_params;
|
MatrixParams<DstScalar> dst_params;
|
||||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
dst_params.order =
|
||||||
|
use_golden ? cpu_backend_gemm::Order::kColMajor : random_order();
|
||||||
dst_params.rows = rows;
|
dst_params.rows = rows;
|
||||||
dst_params.cols = cols;
|
dst_params.cols = cols;
|
||||||
if (!std::is_floating_point<DstScalar>::value) {
|
if (!std::is_floating_point<DstScalar>::value) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user