TFLITE_WITH_RUY_GEMV uses CustomGEMV for float

PiperOrigin-RevId: 292930831
Change-Id: I3786ba562af1cf5f3a8e000af4abac65696bf3a3
This commit is contained in:
T.J. Alumbaugh 2020-02-03 09:02:06 -08:00 committed by TensorFlower Gardener
parent 21f9f51abc
commit 05a122df52

View File

@ -94,15 +94,19 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
CpuBackendContext* context) {
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
ValidateParams(lhs_params, rhs_params, dst_params, params);
#ifndef TFLITE_WITH_RUY_GEMV
if (dst_params.cols == 1) {
bool do_custom_gemv = dst_params.cols == 1;
#ifdef TFLITE_WITH_RUY_GEMV
// Prefer a Ruy GEMM to Custom GEMV unless we are doing float math.
// TODO(b/148692500): Add float GEMV kernels to Ruy.
do_custom_gemv = do_custom_gemv && std::is_floating_point<DstScalar>::value;
#endif
if (do_custom_gemv) {
// GEMV case: try a custom fast GEMV path.
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params, context)) {
return;
}
}
#endif
ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,