TFLITE_WITH_RUY_GEMV uses CustomGEMV for float
PiperOrigin-RevId: 292930831 Change-Id: I3786ba562af1cf5f3a8e000af4abac65696bf3a3
This commit is contained in:
parent
21f9f51abc
commit
05a122df52
@ -94,15 +94,19 @@ void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
|||||||
CpuBackendContext* context) {
|
CpuBackendContext* context) {
|
||||||
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
|
ruy::profiler::ScopeLabel label("cpu_backend_gemm::Gemm");
|
||||||
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
||||||
#ifndef TFLITE_WITH_RUY_GEMV
|
bool do_custom_gemv = dst_params.cols == 1;
|
||||||
if (dst_params.cols == 1) {
|
#ifdef TFLITE_WITH_RUY_GEMV
|
||||||
|
// Prefer a Ruy GEMM to Custom GEMV unless we are doing float math.
|
||||||
|
// TODO(b/148692500): Add float GEMV kernels to Ruy.
|
||||||
|
do_custom_gemv = do_custom_gemv && std::is_floating_point<DstScalar>::value;
|
||||||
|
#endif
|
||||||
|
if (do_custom_gemv) {
|
||||||
// GEMV case: try a custom fast GEMV path.
|
// GEMV case: try a custom fast GEMV path.
|
||||||
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
|
if (detail::CustomGemv(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||||
dst_params, dst_data, params, context)) {
|
dst_params, dst_data, params, context)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
|
ruy::profiler::ScopeLabel label2("cpu_backend_gemm::Gemm: general GEMM");
|
||||||
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user