diff --git a/tensorflow/lite/kernels/cpu_backend_gemm.h b/tensorflow/lite/kernels/cpu_backend_gemm.h index a95c4d15a82..14ff571e7da 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm.h @@ -95,9 +95,26 @@ void Gemm(const MatrixParams& lhs_params, const LhsScalar* lhs_data, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm"); ValidateParams(lhs_params, rhs_params, dst_params, params); + // In some cases we want to unconditionally use ruy as the backend, overriding + // the `tflite_with_ruy` setting and the platform default. + bool must_use_ruy = false; if (context->use_caching()) { - // Dispatch to backend that supports caching of prepacked weights - // matrices. + // Only ruy supports caching of pre-packed matrices. Due to the large + // performance impact in the cases where it's typically used, this overrides + // the default. + must_use_ruy = true; + } + if (lhs_params.order != Order::kRowMajor || + rhs_params.order != Order::kColMajor || + dst_params.order != Order::kColMajor) { + // ruy supports all 2^3=8 combinations of storage orders with comparable + // performance. In ruy, it's only a runtime switch. In other backends + // (gemmlowp, Eigen), storage orders are template parameters, supporting + // all 8 combinations would be up to a 8-fold code size increase, so we + // prefer to force usage of ruy in these cases. + must_use_ruy = true; + } + if (must_use_ruy) { detail::GemmImplUsingRuy::Run(lhs_params, lhs_data, rhs_params, rhs_data, @@ -105,15 +122,18 @@ void Gemm(const MatrixParams& lhs_params, const LhsScalar* lhs_data, params, context); return; } - const bool do_custom_gemv = (dst_params.cols == 1); - if (do_custom_gemv) { - // GEMV case: try a custom fast GEMV path. + // If we did not choose to force usage of ruy above, then we may now consider + // using custom GEMV code for the matrix*vector cases. + const bool try_custom_gemv = (dst_params.cols == 1); + if (try_custom_gemv) { + // GEMV case: try a custom fast GEMV path. It will return true if it + // actually handled it. if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params, context)) { return; } } - ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM"); + // Generic case: dispatch to any backend as a general GEMM. GemmImpl::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params, context); diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_params.h b/tensorflow/lite/kernels/cpu_backend_gemm_params.h index 0040f40cd50..ef06d97331e 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_params.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_params.h @@ -236,17 +236,6 @@ void ValidateParams( (void)detail::ValidateTypes(); ValidateGemmParams(params); - // For now, Gemm only supports this particular combination of storage orders. - // Actually the generic ruy path already supports all combinations (with - // various performance penalties). On the other hand, gemmlowp and Eigen - // paths would require more source code and larger binary code to handle - // other combinations (because orders are template parameters in gemmlowp - // and Eigen). Since this is TFLite's own internal Gemm library, there is - // no point in supporting more than what TFlite currently uses, and that - // is for now this single combination. - TFLITE_DCHECK(lhs_params.order == Order::kRowMajor); - TFLITE_DCHECK(rhs_params.order == Order::kColMajor); - TFLITE_DCHECK(dst_params.order == Order::kColMajor); } } // namespace cpu_backend_gemm diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc index d79d1357696..521e7bb03fd 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc +++ b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc @@ -389,8 +389,13 @@ void TestSomeGemm(int rows, int depth, int cols, } MakeDeterministicPseudoRandomVector(rows * cols, &dst_data); + auto random_order = [&]() { + return random_engine() % 2 ? cpu_backend_gemm::Order::kRowMajor + : cpu_backend_gemm::Order::kColMajor; + }; MatrixParams lhs_params; - lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + lhs_params.order = + use_golden ? cpu_backend_gemm::Order::kRowMajor : random_order(); lhs_params.rows = rows; lhs_params.cols = depth; if (!std::is_floating_point::value) { @@ -401,7 +406,8 @@ void TestSomeGemm(int rows, int depth, int cols, } MatrixParams rhs_params; - rhs_params.order = cpu_backend_gemm::Order::kColMajor; + rhs_params.order = + use_golden ? cpu_backend_gemm::Order::kColMajor : random_order(); rhs_params.rows = depth; rhs_params.cols = cols; if (!std::is_floating_point::value) { @@ -412,7 +418,8 @@ void TestSomeGemm(int rows, int depth, int cols, } MatrixParams dst_params; - dst_params.order = cpu_backend_gemm::Order::kColMajor; + dst_params.order = + use_golden ? cpu_backend_gemm::Order::kColMajor : random_order(); dst_params.rows = rows; dst_params.cols = cols; if (!std::is_floating_point::value) {