TFLITE_WITH_RUY_GEMV uses CustomGEMV for float
PiperOrigin-RevId: 292930831 Change-Id: I3786ba562af1cf5f3a8e000af4abac65696bf3a3
This commit is contained in:
parent
21f9f51abc
commit
05a122df52
@ -94,15 +94,19 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||
CpuBackendContext* context) {
|
||||
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
|
||||
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
||||
#ifndef TFLITE_WITH_RUY_GEMV
|
||||
if (dst_params.cols == 1) {
|
||||
bool do_custom_gemv = dst_params.cols == 1;
|
||||
#ifdef TFLITE_WITH_RUY_GEMV
|
||||
// Prefer a Ruy GEMM to Custom GEMV unless we are doing float math.
|
||||
// TODO(b/148692500): Add float GEMV kernels to Ruy.
|
||||
do_custom_gemv = do_custom_gemv && std::is_floating_point<DstScalar>::value;
|
||||
#endif
|
||||
if (do_custom_gemv) {
|
||||
// GEMV case: try a custom fast GEMV path.
|
||||
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||
dst_params, dst_data, params, context)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
|
||||
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||
|
Loading…
Reference in New Issue
Block a user