diff --git a/tensorflow/lite/kernels/cpu_backend_gemm.h b/tensorflow/lite/kernels/cpu_backend_gemm.h index b165d1118aa..e004bb88241 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm.h @@ -36,60 +36,65 @@ namespace cpu_backend_gemm { */ template -struct GemmImpl - : detail::GemmImplUsingRuy {}; + typename DstScalar, QuantizationFlavor quantization_flavor> +struct GemmImpl : detail::GemmImplUsingRuy {}; #ifndef TFLITE_WITH_RUY /* Specializations using gemmlowp */ -template -struct GemmImpl +template +struct GemmImpl : detail::GemmImplUsingGemmlowp {}; + DstScalar, quantization_flavor> {}; // When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile // outside of NEON. We avoid the compilation failure by subspecializing these // cases, rerouting it back to ruy. #ifndef GEMMLOWP_NEON -template -struct GemmImpl - : detail::GemmImplUsingRuy {}; +template +struct GemmImpl + : detail::GemmImplUsingRuy {}; -template -struct GemmImpl +template +struct GemmImpl : detail::GemmImplUsingRuy {}; + DstScalar, quantization_flavor> {}; -template <> -struct GemmImpl +template +struct GemmImpl : detail::GemmImplUsingRuy {}; + std::int8_t, quantization_flavor> {}; #endif // not GEMMLOWP_NEON /* Specializations using Eigen */ template <> -struct GemmImpl - : detail::GemmImplUsingEigen {}; +struct GemmImpl + : detail::GemmImplUsingEigen {}; #endif // not TFLITE_WITH_RUY /* Public entry point */ template + typename DstScalar, QuantizationFlavor quantization_flavor> void Gemm(const MatrixParams& lhs_params, const LhsScalar* lhs_data, const MatrixParams& rhs_params, const RhsScalar* rhs_data, const MatrixParams& dst_params, DstScalar* dst_data, - const GemmParams& params, + const GemmParams& params, CpuBackendContext* context) { ValidateParams(lhs_params, rhs_params, dst_params, params); - GemmImpl::Run( - lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params, - context); + GemmImpl::Run(lhs_params, lhs_data, rhs_params, rhs_data, + dst_params, dst_data, params, context); } } // namespace cpu_backend_gemm diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_eigen.h b/tensorflow/lite/kernels/cpu_backend_gemm_eigen.h index 3cac7b084e2..678aef98597 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_eigen.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_eigen.h @@ -24,13 +24,9 @@ namespace tflite { namespace cpu_backend_gemm { namespace detail { -template +template struct GemmImplUsingEigen { - static_assert(std::is_same::value, ""); - static_assert(std::is_same::value, ""); - static_assert(std::is_same::value, ""); - static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); static void Run(const MatrixParams& lhs_params, const float* lhs_data, const MatrixParams& rhs_params, const float* rhs_data, const MatrixParams& dst_params, float* dst_data, diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h index c121e86be38..c39fc4b5f75 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h @@ -59,94 +59,115 @@ struct GemmlowpBitDepthParams { using Type = gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams; }; +template +struct GemmImplUsingGemmlowp {}; + template -struct GemmImplUsingGemmlowp { +struct GemmImplUsingGemmlowp< + LhsScalar, RhsScalar, AccumScalar, DstScalar, + QuantizationFlavor::kIntegerWithUniformMultiplier> { static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); using SrcScalar = LhsScalar; - static void Run(const MatrixParams& lhs_params, - const SrcScalar* lhs_data, - const MatrixParams& rhs_params, - const SrcScalar* rhs_data, - const MatrixParams& dst_params, - DstScalar* dst_data, - const GemmParams& params, - CpuBackendContext* context) { - if (params.multiplier_exponent_perchannel) { - // gemmlowp support for this per-channel path is limited to NEON. - // We fall back to ruy outside of NEON. + static void Run( + const MatrixParams& lhs_params, const SrcScalar* lhs_data, + const MatrixParams& rhs_params, const SrcScalar* rhs_data, + const MatrixParams& dst_params, DstScalar* dst_data, + const GemmParams& + params, + CpuBackendContext* context) { + gemmlowp::MatrixMap + gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols); + gemmlowp::MatrixMap + gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols); + gemmlowp::MatrixMap gemmlowp_dst( + dst_data, dst_params.rows, dst_params.cols); + + using ColVectorMap = + gemmlowp::VectorMap; + ColVectorMap bias_vector(params.bias, lhs_params.rows); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage; + scale_stage.result_offset_after_shift = dst_params.zero_point; + scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint; + scale_stage.result_exponent = params.multiplier_exponent; + using SaturatingCastStageType = + typename GemmlowpSaturatingCastStage::Type; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = params.clamp_min; + clamp_stage.max = params.clamp_max; + SaturatingCastStageType saturating_cast_stage; + auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage, + clamp_stage, saturating_cast_stage); + using BitDepthParams = typename GemmlowpBitDepthParams::Type; + gemmlowp::GemmWithOutputPipeline( + context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs_params.zero_point, -rhs_params.zero_point, output_pipeline); + } +}; + +template +struct GemmImplUsingGemmlowp { + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + using SrcScalar = LhsScalar; + + static void Run( + const MatrixParams& lhs_params, const SrcScalar* lhs_data, + const MatrixParams& rhs_params, const SrcScalar* rhs_data, + const MatrixParams& dst_params, DstScalar* dst_data, + const GemmParams& + params, + CpuBackendContext* context) { + // gemmlowp support for this per-channel path is limited to NEON. + // We fall back to ruy outside of NEON. #ifdef GEMMLOWP_NEON - gemmlowp::MatrixMap - gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols); - gemmlowp::MatrixMap - gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols); - gemmlowp::MatrixMap gemmlowp_dst( - dst_data, dst_params.rows, dst_params.cols); + gemmlowp::MatrixMap + gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols); + gemmlowp::MatrixMap + gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols); + gemmlowp::MatrixMap gemmlowp_dst( + dst_data, dst_params.rows, dst_params.cols); - using ColVectorMap = - gemmlowp::VectorMap; - ColVectorMap bias_vector(params.bias, lhs_params.rows); - gemmlowp::OutputStageBiasAddition bias_addition_stage; - bias_addition_stage.bias_vector = bias_vector; - gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC< - gemmlowp::VectorShape::Col> - scale_stage; - scale_stage.result_offset_after_shift = dst_params.zero_point; - scale_stage.result_fixedpoint_multiplier = ColVectorMap( - params.multiplier_fixedpoint_perchannel, dst_params.rows); - scale_stage.result_exponent = - ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows); - using SaturatingCastStageType = - typename GemmlowpSaturatingCastStage::Type; - gemmlowp::OutputStageClamp clamp_stage; - clamp_stage.min = params.clamp_min; - clamp_stage.max = params.clamp_max; - SaturatingCastStageType saturating_cast_stage; - auto output_pipeline = std::make_tuple( - bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage); - using BitDepthParams = typename GemmlowpBitDepthParams::Type; - gemmlowp::GemmWithOutputPipeline( - context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, - &gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point, - output_pipeline); + using ColVectorMap = + gemmlowp::VectorMap; + ColVectorMap bias_vector(params.bias, lhs_params.rows); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC< + gemmlowp::VectorShape::Col> + scale_stage; + scale_stage.result_offset_after_shift = dst_params.zero_point; + scale_stage.result_fixedpoint_multiplier = + ColVectorMap(params.multiplier_fixedpoint_perchannel, dst_params.rows); + scale_stage.result_exponent = + ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows); + using SaturatingCastStageType = + typename GemmlowpSaturatingCastStage::Type; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = params.clamp_min; + clamp_stage.max = params.clamp_max; + SaturatingCastStageType saturating_cast_stage; + auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage, + clamp_stage, saturating_cast_stage); + using BitDepthParams = typename GemmlowpBitDepthParams::Type; + gemmlowp::GemmWithOutputPipeline( + context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs_params.zero_point, -rhs_params.zero_point, output_pipeline); #else - GemmImplUsingRuy::Run( - lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, - params, context); + GemmImplUsingRuy:: + Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, + params, context); #endif - } else { - gemmlowp::MatrixMap - gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols); - gemmlowp::MatrixMap - gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols); - gemmlowp::MatrixMap gemmlowp_dst( - dst_data, dst_params.rows, dst_params.cols); - - using ColVectorMap = - gemmlowp::VectorMap; - ColVectorMap bias_vector(params.bias, lhs_params.rows); - gemmlowp::OutputStageBiasAddition bias_addition_stage; - bias_addition_stage.bias_vector = bias_vector; - gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage; - scale_stage.result_offset_after_shift = dst_params.zero_point; - scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint; - scale_stage.result_exponent = params.multiplier_exponent; - using SaturatingCastStageType = - typename GemmlowpSaturatingCastStage::Type; - gemmlowp::OutputStageClamp clamp_stage; - clamp_stage.min = params.clamp_min; - clamp_stage.max = params.clamp_max; - SaturatingCastStageType saturating_cast_stage; - auto output_pipeline = std::make_tuple( - bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage); - using BitDepthParams = typename GemmlowpBitDepthParams::Type; - gemmlowp::GemmWithOutputPipeline( - context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, - &gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point, - output_pipeline); - } } }; diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_params.h b/tensorflow/lite/kernels/cpu_backend_gemm_params.h index 58a47a91df4..58c636e7196 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_params.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_params.h @@ -49,12 +49,54 @@ struct MatrixParams { Scalar zero_point = 0; }; +// Enumeration of broad categories of Gemm. +// +// The primary reason for this to exist is to allow Gemm to compile +// only uniform-quantized or only per-channel-quantized code paths. +// This is unneeded with ruy as the back-end, as this is only a runtime +// difference in ruy, but with gemmlowp these really are separate code +// paths and templatizing in a QuantizationFlavor is necessary to avoid +// compiling unused gemmlowp code. Indeed, TFLite currently uses +// uint8 with uniform quantization and int8 with per-channel quantization, +// and does not use uint8 with per-channel. We want to avoid compiling +// the gemmlowp uint8 per-channel path when gemmlowp is the back-end. +// +// It's possible to drop this in the future if gemmlowp goes away and no +// other then-relevant backend library handles quantized paths in a way that +// requires knowing this at compile-time. +enum class QuantizationFlavor { + // Floating-point Gemm: the accumulators are not multiplied by any + // 'multiplier'. + kFloatingPoint, + // Quantized Gemm using a single multiplier for all accumulators. + kIntegerWithUniformMultiplier, + // Quantized Gemm using a separate multipliers for accumulators of each + // row of the destination matrix. This is what is called 'per-channel' + // in GemmParams. Here we use the more specific 'per-row' terminology + // to allow for the possibility of 'per-column' in the future, and to + // allow for that to be a separate code path in some back-end such as + // gemmlowp. + kIntegerWithPerRowMultiplier +}; + // Additional parameters that Gemm needs, beyond what falls into // the MatrixParams that it takes. Compare to ruy::Spec. // // Decoupling AccumScalar from DstScalar (rather than deducing it from that) // is useful future-proofing. Think of a float16 path using float32 accum. -template +// +// QuantizationFlavor is passed here even though it's technically not used +// in this class. This is so that we retain the ability in the future to +// specialize this class for quantization flavor, and this allows for +// Gemm to be templatized in quantization_flavor via the GemmParams that it +// takes, allowing for automatic template parameter deduction to take place, +// so that most call sites don't need to specify a QuantizationFlavor +// (only those that need perchannel quantization do). +template ::value + ? QuantizationFlavor::kFloatingPoint + : QuantizationFlavor::kIntegerWithUniformMultiplier> struct GemmParams { // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) // of the multiplier by which accumulators are multiplied before being casted @@ -104,41 +146,72 @@ using FloatGemmParams = GemmParams; // a release-build assertion. See b/131587258. // Validates self-consistency of GemmParams. -template -void ValidateGemmParams(const GemmParams& params) { +template +void ValidateGemmParams( + const GemmParams& params) { // For now require a bias vector. Again, ruy does not rely on that requirement // but the gemmlowp and Eigen path would require more code to handle it, // and currently TFLite only uses the case where there is a bias vector. TFLITE_DCHECK(params.bias); // Guard consistency of the quantized multiplier fields. - if (std::is_floating_point::value) { - // Floating point case: must not have any quantized multipliers + if (quantization_flavor == QuantizationFlavor::kFloatingPoint) { TFLITE_DCHECK(!params.multiplier_fixedpoint); TFLITE_DCHECK(!params.multiplier_exponent); TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel); TFLITE_DCHECK(!params.multiplier_exponent_perchannel); - } else { - // Quantized case. Must have either uniform or perchannel multiplier, - // not both. - TFLITE_DCHECK((params.multiplier_fixedpoint == 0) != - (params.multiplier_fixedpoint_perchannel == nullptr)); - // Consistency of the two _perchannel fields. - TFLITE_DCHECK((params.multiplier_exponent_perchannel == nullptr) == - (params.multiplier_fixedpoint_perchannel == nullptr)); + } else if (quantization_flavor == + QuantizationFlavor::kIntegerWithUniformMultiplier) { + TFLITE_DCHECK(params.multiplier_fixedpoint); + // Nothing to check about multiplier_exponent + TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel); + TFLITE_DCHECK(!params.multiplier_exponent_perchannel); + } else if (quantization_flavor == + QuantizationFlavor::kIntegerWithPerRowMultiplier) { + TFLITE_DCHECK(!params.multiplier_fixedpoint); + TFLITE_DCHECK(!params.multiplier_exponent); + TFLITE_DCHECK(params.multiplier_fixedpoint_perchannel); + TFLITE_DCHECK(params.multiplier_exponent_perchannel); } } -// Validates overall consistency of all the parameters taken by a Gemm call: -// the 3 MatrixParams and the GemmParams. Even if currently these are -// checked only separately, it's good to have this validation done in one -// function taking all of these parameters at once, as in the future there -// may be mutual consistency requirements. +namespace detail { + +template +struct ValidateTypes { + // This generic implementation is for quantized flavors. + // kFloatingPoint will be a specialization below. + static_assert(!std::is_floating_point::value, ""); + static_assert(!std::is_floating_point::value, ""); + static_assert(!std::is_floating_point::value, ""); + // No requirement on DstScalar --- we might in the future allow it + // to be floating point even in a quantized Gemm. +}; + template -void ValidateParams(const MatrixParams& lhs_params, - const MatrixParams& rhs_params, - const MatrixParams& dst_params, - const GemmParams& params) { +struct ValidateTypes { + static_assert(std::is_floating_point::value, ""); + static_assert(std::is_floating_point::value, ""); + static_assert(std::is_floating_point::value, ""); + static_assert(std::is_floating_point::value, ""); +}; + +} // namespace detail + +// Validates overall consistency of all the parameters taken by a Gemm call: +// the 3 MatrixParams and the GemmParams. +template +void ValidateParams( + const MatrixParams& lhs_params, + const MatrixParams& rhs_params, + const MatrixParams& dst_params, + const GemmParams& params) { + (void)detail::ValidateTypes(); ValidateGemmParams(params); // For now, Gemm only supports this particular combination of storage orders. // Actually the generic ruy path already supports all combinations (with @@ -151,6 +224,20 @@ void ValidateParams(const MatrixParams& lhs_params, TFLITE_DCHECK(lhs_params.order == Order::kRowMajor); TFLITE_DCHECK(rhs_params.order == Order::kColMajor); TFLITE_DCHECK(dst_params.order == Order::kColMajor); + // Guard against the case when both LHS and RHS zero_point's are equal to + // the minimum representable value. In that case, padding with zero_point + // values will generate the bad case for fast int8 kernels on NEON + // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8 + // into a int16: this is safe except in the bad case -128*-128 + -128*-128. + // Accordingly, this is banned by gemmlowp and ruy. However, they were not + // guarding against that, which allowed a real bug to happen, b/131609283. + // Checking this here lets callers of this cpu_backend_gemm library be + // safe regardless of backends. + if (quantization_flavor != QuantizationFlavor::kFloatingPoint) { + TFLITE_DCHECK( + lhs_params.zero_point != std::numeric_limits::lowest() || + rhs_params.zero_point != std::numeric_limits::lowest()); + } } } // namespace cpu_backend_gemm diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h index 43fb1e49533..abee370b982 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm_ruy.h @@ -61,16 +61,14 @@ void MakeRuySpec(const GemmParamsType& params, RuySpecType* ruy_spec) { } template + typename DstScalar, QuantizationFlavor quantization_flavor> struct GemmImplUsingRuy { - static void Run(const MatrixParams& lhs_params, - const LhsScalar* lhs_data, - const MatrixParams& rhs_params, - const RhsScalar* rhs_data, - const MatrixParams& dst_params, - DstScalar* dst_data, - const GemmParams& params, - CpuBackendContext* context) { + static void Run( + const MatrixParams& lhs_params, const LhsScalar* lhs_data, + const MatrixParams& rhs_params, const RhsScalar* rhs_data, + const MatrixParams& dst_params, DstScalar* dst_data, + const GemmParams& params, + CpuBackendContext* context) { ruy::Matrix ruy_lhs; ruy::Matrix ruy_rhs; ruy::Matrix ruy_dst; diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc index c763fa6483f..d65a8394144 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm_test.cc +++ b/tensorflow/lite/kernels/cpu_backend_gemm_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "tensorflow/lite/experimental/ruy/ruy.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" namespace tflite { @@ -34,6 +35,7 @@ namespace { using cpu_backend_gemm::Gemm; using cpu_backend_gemm::GemmParams; using cpu_backend_gemm::MatrixParams; +using cpu_backend_gemm::QuantizationFlavor; template std::string ToString(const std::vector& vector) { @@ -125,9 +127,11 @@ void Clamp(const std::vector& src, Scalar clamp_min, Scalar clamp_max, } } -template -void Clamp(const GemmParams& src, DstScalar clamp_min, - DstScalar clamp_max, GemmParams* dst) { +template +void Clamp(const GemmParams& src, + DstScalar clamp_min, DstScalar clamp_max, + GemmParams* dst) { *dst = src; dst->clamp_min = clamp_min; dst->clamp_max = clamp_max; @@ -236,14 +240,14 @@ void CheckErrorForAccumulation(int accumulation_depth, } template + typename DstScalar, QuantizationFlavor quantization_flavor> void PerformGemmThenCompareResultsThenAgainWithClamping( const MatrixParams& lhs_params, const std::vector& lhs_data, const MatrixParams& rhs_params, const std::vector& rhs_data, const MatrixParams& dst_params, std::vector* dst_data, - const GemmParams& params, + const GemmParams& params, const std::vector& expected, CpuBackendContext* cpu_backend_context) { const int accumulation_depth = lhs_params.cols; @@ -253,7 +257,7 @@ void PerformGemmThenCompareResultsThenAgainWithClamping( expected); DstScalar expected_median = Median(expected); std::vector expected_with_clamp; - GemmParams params_with_clamp; + GemmParams params_with_clamp; DstScalar clamp_min, clamp_max; clamp_min = std::numeric_limits::lowest(); @@ -453,9 +457,14 @@ void TestSomeGemm(int rows, int depth, int cols, rows, params.multiplier_fixedpoint); std::vector multiplier_exponent_perchannel(rows, params.multiplier_exponent); - GemmParams params_perchannel = params; - params_perchannel.multiplier_fixedpoint = 0; - params_perchannel.multiplier_exponent = 0; + static constexpr QuantizationFlavor perchannel_flavor = + std::is_floating_point::value + ? QuantizationFlavor::kFloatingPoint + : QuantizationFlavor::kIntegerWithPerRowMultiplier; + GemmParams params_perchannel; + params_perchannel.bias = params.bias; + params_perchannel.clamp_min = params.clamp_min; + params_perchannel.clamp_max = params.clamp_max; params_perchannel.multiplier_fixedpoint_perchannel = multiplier_fixedpoint_perchannel.data(); params_perchannel.multiplier_exponent_perchannel =