diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 9d3e5929d82..cb00d73adac 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -361,6 +361,7 @@ cc_library( "@ruy//ruy", "@ruy//ruy:matrix", "@ruy//ruy:path", + "@ruy//ruy:mul_params", "@ruy//ruy/profiler:instrumentation", # We only need to depend on gemmlowp and Eigen when tflite_with_ruy # is false, but putting these dependencies in a select() seems to diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h index 07ae2ff08b7..6a818834d30 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ #include "ruy/matrix.h" // from @ruy +#include "ruy/mul_params.h" // from @ruy #include "ruy/ruy.h" // from @ruy #include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" @@ -57,23 +58,65 @@ void MakeRuyMatrix(const MatrixParams& params, DataPointer data_ptr, } } -template -void MakeRuyMulParams(const GemmParamsType& params, - RuySpecType* ruy_mul_params) { - // This validation has already been performed by the Gemm API entry point, - // but it doesn't hurt to test specifically this again here, where it's - // being used. - ValidateGemmParams(params); +// Floating-point case. +template +struct MakeRuyMulParamsImpl final { + static void Run( + const GemmParams& params, + ruy::MulParams* ruy_mul_params) { + static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint, + ""); + ruy_mul_params->set_bias(params.bias); + ruy_mul_params->set_clamp_min(params.clamp_min); + ruy_mul_params->set_clamp_max(params.clamp_max); + } +}; - ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint); - ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent); - ruy_mul_params->set_multiplier_fixedpoint_perchannel( - params.multiplier_fixedpoint_perchannel); - ruy_mul_params->set_multiplier_exponent_perchannel( - params.multiplier_exponent_perchannel); - ruy_mul_params->set_bias(params.bias); - ruy_mul_params->set_clamp_min(params.clamp_min); - ruy_mul_params->set_clamp_max(params.clamp_max); +// Integer-quantized case with destination type narrower than int32 +template +struct MakeRuyMulParamsImpl + final { + static void Run( + const GemmParams& params, + ruy::MulParams* ruy_mul_params) { + static_assert(sizeof(DstScalar) < sizeof(std::int32_t), ""); + if (quantization_flavor == + QuantizationFlavor::kIntegerWithUniformMultiplier) { + ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint); + ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent); + } + if (quantization_flavor == + QuantizationFlavor::kIntegerWithPerRowMultiplier) { + ruy_mul_params->set_multiplier_fixedpoint_perchannel( + params.multiplier_fixedpoint_perchannel); + ruy_mul_params->set_multiplier_exponent_perchannel( + params.multiplier_exponent_perchannel); + } + ruy_mul_params->set_bias(params.bias); + ruy_mul_params->set_clamp_min(params.clamp_min); + ruy_mul_params->set_clamp_max(params.clamp_max); + } +}; + +// Raw-integer case with destination type int32. +template +struct MakeRuyMulParamsImpl + final { + static void Run( + const GemmParams& params, + ruy::MulParams* ruy_mul_params) { + ruy_mul_params->set_bias(params.bias); + } +}; + +template +void MakeRuyMulParams( + const GemmParams& params, + ruy::MulParams* ruy_mul_params) { + MakeRuyMulParamsImpl::Run( + params, ruy_mul_params); } template