diff --git a/tensorflow/lite/kernels/cpu_backend_gemm.h b/tensorflow/lite/kernels/cpu_backend_gemm.h index 6ebbcb8c21e..16ccc14557f 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm.h @@ -94,15 +94,19 @@ void Gemm(const MatrixParams& lhs_params, const LhsScalar* lhs_data, CpuBackendContext* context) { ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm"); ValidateParams(lhs_params, rhs_params, dst_params, params); -#ifndef TFLITE_WITH_RUY_GEMV - if (dst_params.cols == 1) { + bool do_custom_gemv = dst_params.cols == 1; +#ifdef TFLITE_WITH_RUY_GEMV + // Prefer a Ruy GEMM to Custom GEMV unless we are doing float math. + // TODO(b/148692500): Add float GEMV kernels to Ruy. + do_custom_gemv = do_custom_gemv && std::is_floating_point::value; +#endif + if (do_custom_gemv) { // GEMV case: try a custom fast GEMV path. if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params, context)) { return; } } -#endif ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM"); GemmImpl::Run(lhs_params, lhs_data, rhs_params, rhs_data,