Templatize Gemm in a 'quantization flavor' distinguishing per-channel vs
uniform quantization. This isn't our preference design-wise, and is unnecessary as far as ruy is concerned, as in ruy this is just a runtime switch. But we want to avoid code size regressions also in the currently default case where gemmlowp not ruy is used, and in gemmlowp this is a compile-time switch. Moreover, the actual instantiations that TFLite needs are such that there is a real overall binary size regression unless this is templatized in this way. That is because TFLite uses 2 combinations of { input type, quantization flavor }: 1. { uint8, uniform quantization } 2. { int8, per-channel quantization } TFLite does not use other combinations such as {uint8, per-channel}, so if we force gemmlowp to instantiate this case, we regress binary size overall. PiperOrigin-RevId: 246070679
This commit is contained in:
parent
a5acee6458
commit
7114b99543
@ -36,60 +36,65 @@ namespace cpu_backend_gemm {
|
||||
*/
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar>
|
||||
struct GemmImpl
|
||||
: detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar> {};
|
||||
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
|
||||
DstScalar, quantization_flavor> {};
|
||||
|
||||
#ifndef TFLITE_WITH_RUY
|
||||
|
||||
/* Specializations using gemmlowp */
|
||||
|
||||
template <typename SrcScalar, typename DstScalar>
|
||||
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar>
|
||||
template <typename SrcScalar, typename DstScalar,
|
||||
QuantizationFlavor quantization_flavor>
|
||||
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
|
||||
quantization_flavor>
|
||||
: detail::GemmImplUsingGemmlowp<SrcScalar, SrcScalar, std::int32_t,
|
||||
DstScalar> {};
|
||||
DstScalar, quantization_flavor> {};
|
||||
|
||||
// When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile
|
||||
// outside of NEON. We avoid the compilation failure by subspecializing these
|
||||
// cases, rerouting it back to ruy.
|
||||
#ifndef GEMMLOWP_NEON
|
||||
template <typename SrcScalar>
|
||||
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t>
|
||||
: detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t,
|
||||
std::int8_t> {};
|
||||
template <typename SrcScalar, QuantizationFlavor quantization_flavor>
|
||||
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
|
||||
quantization_flavor>
|
||||
: detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
|
||||
quantization_flavor> {};
|
||||
|
||||
template <typename DstScalar>
|
||||
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar>
|
||||
template <typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar,
|
||||
quantization_flavor>
|
||||
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
|
||||
DstScalar> {};
|
||||
DstScalar, quantization_flavor> {};
|
||||
|
||||
template <>
|
||||
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t>
|
||||
template <QuantizationFlavor quantization_flavor>
|
||||
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
|
||||
quantization_flavor>
|
||||
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
|
||||
std::int8_t> {};
|
||||
std::int8_t, quantization_flavor> {};
|
||||
#endif // not GEMMLOWP_NEON
|
||||
|
||||
/* Specializations using Eigen */
|
||||
|
||||
template <>
|
||||
struct GemmImpl<float, float, float, float>
|
||||
: detail::GemmImplUsingEigen<float, float, float, float> {};
|
||||
struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
|
||||
: detail::GemmImplUsingEigen<float> {};
|
||||
|
||||
#endif // not TFLITE_WITH_RUY
|
||||
|
||||
/* Public entry point */
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar>
|
||||
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
|
||||
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||
const GemmParams<AccumScalar, DstScalar>& params,
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||
CpuBackendContext* context) {
|
||||
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
||||
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Run(
|
||||
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params,
|
||||
context);
|
||||
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||
dst_params, dst_data, params, context);
|
||||
}
|
||||
|
||||
} // namespace cpu_backend_gemm
|
||||
|
@ -24,13 +24,9 @@ namespace tflite {
|
||||
namespace cpu_backend_gemm {
|
||||
namespace detail {
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar>
|
||||
template <typename Scalar>
|
||||
struct GemmImplUsingEigen {
|
||||
static_assert(std::is_same<LhsScalar, float>::value, "");
|
||||
static_assert(std::is_same<RhsScalar, float>::value, "");
|
||||
static_assert(std::is_same<AccumScalar, float>::value, "");
|
||||
static_assert(std::is_same<DstScalar, float>::value, "");
|
||||
static_assert(std::is_same<Scalar, float>::value, "");
|
||||
static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
|
||||
const MatrixParams<float>& rhs_params, const float* rhs_data,
|
||||
const MatrixParams<float>& dst_params, float* dst_data,
|
||||
|
@ -59,94 +59,115 @@ struct GemmlowpBitDepthParams<std::int8_t> {
|
||||
using Type = gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams;
|
||||
};
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
struct GemmImplUsingGemmlowp {};
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar>
|
||||
struct GemmImplUsingGemmlowp {
|
||||
struct GemmImplUsingGemmlowp<
|
||||
LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
QuantizationFlavor::kIntegerWithUniformMultiplier> {
|
||||
static_assert(std::is_same<LhsScalar, RhsScalar>::value, "");
|
||||
static_assert(std::is_same<AccumScalar, std::int32_t>::value, "");
|
||||
using SrcScalar = LhsScalar;
|
||||
|
||||
static void Run(const MatrixParams<SrcScalar>& lhs_params,
|
||||
const SrcScalar* lhs_data,
|
||||
const MatrixParams<SrcScalar>& rhs_params,
|
||||
const SrcScalar* rhs_data,
|
||||
const MatrixParams<DstScalar>& dst_params,
|
||||
DstScalar* dst_data,
|
||||
const GemmParams<std::int32_t, DstScalar>& params,
|
||||
CpuBackendContext* context) {
|
||||
if (params.multiplier_exponent_perchannel) {
|
||||
// gemmlowp support for this per-channel path is limited to NEON.
|
||||
// We fall back to ruy outside of NEON.
|
||||
static void Run(
|
||||
const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data,
|
||||
const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data,
|
||||
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||
const GemmParams<std::int32_t, DstScalar,
|
||||
QuantizationFlavor::kIntegerWithUniformMultiplier>&
|
||||
params,
|
||||
CpuBackendContext* context) {
|
||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
|
||||
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
|
||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
|
||||
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
|
||||
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
|
||||
dst_data, dst_params.rows, dst_params.cols);
|
||||
|
||||
using ColVectorMap =
|
||||
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
|
||||
ColVectorMap bias_vector(params.bias, lhs_params.rows);
|
||||
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
||||
bias_addition_stage.bias_vector = bias_vector;
|
||||
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
|
||||
scale_stage.result_offset_after_shift = dst_params.zero_point;
|
||||
scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint;
|
||||
scale_stage.result_exponent = params.multiplier_exponent;
|
||||
using SaturatingCastStageType =
|
||||
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
|
||||
gemmlowp::OutputStageClamp clamp_stage;
|
||||
clamp_stage.min = params.clamp_min;
|
||||
clamp_stage.max = params.clamp_max;
|
||||
SaturatingCastStageType saturating_cast_stage;
|
||||
auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
|
||||
clamp_stage, saturating_cast_stage);
|
||||
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
|
||||
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
|
||||
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
|
||||
-lhs_params.zero_point, -rhs_params.zero_point, output_pipeline);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar>
|
||||
struct GemmImplUsingGemmlowp<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
QuantizationFlavor::kIntegerWithPerRowMultiplier> {
|
||||
static_assert(std::is_same<LhsScalar, RhsScalar>::value, "");
|
||||
static_assert(std::is_same<AccumScalar, std::int32_t>::value, "");
|
||||
using SrcScalar = LhsScalar;
|
||||
|
||||
static void Run(
|
||||
const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data,
|
||||
const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data,
|
||||
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||
const GemmParams<std::int32_t, DstScalar,
|
||||
QuantizationFlavor::kIntegerWithPerRowMultiplier>&
|
||||
params,
|
||||
CpuBackendContext* context) {
|
||||
// gemmlowp support for this per-channel path is limited to NEON.
|
||||
// We fall back to ruy outside of NEON.
|
||||
#ifdef GEMMLOWP_NEON
|
||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
|
||||
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
|
||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
|
||||
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
|
||||
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
|
||||
dst_data, dst_params.rows, dst_params.cols);
|
||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
|
||||
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
|
||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
|
||||
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
|
||||
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
|
||||
dst_data, dst_params.rows, dst_params.cols);
|
||||
|
||||
using ColVectorMap =
|
||||
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
|
||||
ColVectorMap bias_vector(params.bias, lhs_params.rows);
|
||||
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
||||
bias_addition_stage.bias_vector = bias_vector;
|
||||
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
|
||||
gemmlowp::VectorShape::Col>
|
||||
scale_stage;
|
||||
scale_stage.result_offset_after_shift = dst_params.zero_point;
|
||||
scale_stage.result_fixedpoint_multiplier = ColVectorMap(
|
||||
params.multiplier_fixedpoint_perchannel, dst_params.rows);
|
||||
scale_stage.result_exponent =
|
||||
ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows);
|
||||
using SaturatingCastStageType =
|
||||
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
|
||||
gemmlowp::OutputStageClamp clamp_stage;
|
||||
clamp_stage.min = params.clamp_min;
|
||||
clamp_stage.max = params.clamp_max;
|
||||
SaturatingCastStageType saturating_cast_stage;
|
||||
auto output_pipeline = std::make_tuple(
|
||||
bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage);
|
||||
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
|
||||
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
|
||||
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
|
||||
&gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
|
||||
output_pipeline);
|
||||
using ColVectorMap =
|
||||
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
|
||||
ColVectorMap bias_vector(params.bias, lhs_params.rows);
|
||||
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
||||
bias_addition_stage.bias_vector = bias_vector;
|
||||
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
|
||||
gemmlowp::VectorShape::Col>
|
||||
scale_stage;
|
||||
scale_stage.result_offset_after_shift = dst_params.zero_point;
|
||||
scale_stage.result_fixedpoint_multiplier =
|
||||
ColVectorMap(params.multiplier_fixedpoint_perchannel, dst_params.rows);
|
||||
scale_stage.result_exponent =
|
||||
ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows);
|
||||
using SaturatingCastStageType =
|
||||
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
|
||||
gemmlowp::OutputStageClamp clamp_stage;
|
||||
clamp_stage.min = params.clamp_min;
|
||||
clamp_stage.max = params.clamp_max;
|
||||
SaturatingCastStageType saturating_cast_stage;
|
||||
auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
|
||||
clamp_stage, saturating_cast_stage);
|
||||
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
|
||||
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
|
||||
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
|
||||
-lhs_params.zero_point, -rhs_params.zero_point, output_pipeline);
|
||||
#else
|
||||
GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Run(
|
||||
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data,
|
||||
params, context);
|
||||
GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
QuantizationFlavor::kIntegerWithPerRowMultiplier>::
|
||||
Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data,
|
||||
params, context);
|
||||
#endif
|
||||
} else {
|
||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
|
||||
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
|
||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
|
||||
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
|
||||
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
|
||||
dst_data, dst_params.rows, dst_params.cols);
|
||||
|
||||
using ColVectorMap =
|
||||
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
|
||||
ColVectorMap bias_vector(params.bias, lhs_params.rows);
|
||||
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
||||
bias_addition_stage.bias_vector = bias_vector;
|
||||
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
|
||||
scale_stage.result_offset_after_shift = dst_params.zero_point;
|
||||
scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint;
|
||||
scale_stage.result_exponent = params.multiplier_exponent;
|
||||
using SaturatingCastStageType =
|
||||
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
|
||||
gemmlowp::OutputStageClamp clamp_stage;
|
||||
clamp_stage.min = params.clamp_min;
|
||||
clamp_stage.max = params.clamp_max;
|
||||
SaturatingCastStageType saturating_cast_stage;
|
||||
auto output_pipeline = std::make_tuple(
|
||||
bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage);
|
||||
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
|
||||
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
|
||||
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
|
||||
&gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
|
||||
output_pipeline);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -49,12 +49,54 @@ struct MatrixParams {
|
||||
Scalar zero_point = 0;
|
||||
};
|
||||
|
||||
// Enumeration of broad categories of Gemm.
|
||||
//
|
||||
// The primary reason for this to exist is to allow Gemm to compile
|
||||
// only uniform-quantized or only per-channel-quantized code paths.
|
||||
// This is unneeded with ruy as the back-end, as this is only a runtime
|
||||
// difference in ruy, but with gemmlowp these really are separate code
|
||||
// paths and templatizing in a QuantizationFlavor is necessary to avoid
|
||||
// compiling unused gemmlowp code. Indeed, TFLite currently uses
|
||||
// uint8 with uniform quantization and int8 with per-channel quantization,
|
||||
// and does not use uint8 with per-channel. We want to avoid compiling
|
||||
// the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
|
||||
//
|
||||
// It's possible to drop this in the future if gemmlowp goes away and no
|
||||
// other then-relevant backend library handles quantized paths in a way that
|
||||
// requires knowing this at compile-time.
|
||||
enum class QuantizationFlavor {
|
||||
// Floating-point Gemm: the accumulators are not multiplied by any
|
||||
// 'multiplier'.
|
||||
kFloatingPoint,
|
||||
// Quantized Gemm using a single multiplier for all accumulators.
|
||||
kIntegerWithUniformMultiplier,
|
||||
// Quantized Gemm using a separate multipliers for accumulators of each
|
||||
// row of the destination matrix. This is what is called 'per-channel'
|
||||
// in GemmParams. Here we use the more specific 'per-row' terminology
|
||||
// to allow for the possibility of 'per-column' in the future, and to
|
||||
// allow for that to be a separate code path in some back-end such as
|
||||
// gemmlowp.
|
||||
kIntegerWithPerRowMultiplier
|
||||
};
|
||||
|
||||
// Additional parameters that Gemm needs, beyond what falls into
|
||||
// the MatrixParams that it takes. Compare to ruy::Spec.
|
||||
//
|
||||
// Decoupling AccumScalar from DstScalar (rather than deducing it from that)
|
||||
// is useful future-proofing. Think of a float16 path using float32 accum.
|
||||
template <typename AccumScalar, typename DstScalar>
|
||||
//
|
||||
// QuantizationFlavor is passed here even though it's technically not used
|
||||
// in this class. This is so that we retain the ability in the future to
|
||||
// specialize this class for quantization flavor, and this allows for
|
||||
// Gemm to be templatized in quantization_flavor via the GemmParams that it
|
||||
// takes, allowing for automatic template parameter deduction to take place,
|
||||
// so that most call sites don't need to specify a QuantizationFlavor
|
||||
// (only those that need perchannel quantization do).
|
||||
template <typename AccumScalar, typename DstScalar,
|
||||
QuantizationFlavor quantization_flavor =
|
||||
std::is_floating_point<AccumScalar>::value
|
||||
? QuantizationFlavor::kFloatingPoint
|
||||
: QuantizationFlavor::kIntegerWithUniformMultiplier>
|
||||
struct GemmParams {
|
||||
// Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
|
||||
// of the multiplier by which accumulators are multiplied before being casted
|
||||
@ -104,41 +146,72 @@ using FloatGemmParams = GemmParams<float, float>;
|
||||
// a release-build assertion. See b/131587258.
|
||||
|
||||
// Validates self-consistency of GemmParams.
|
||||
template <typename AccumScalar, typename DstScalar>
|
||||
void ValidateGemmParams(const GemmParams<AccumScalar, DstScalar>& params) {
|
||||
template <typename AccumScalar, typename DstScalar,
|
||||
QuantizationFlavor quantization_flavor>
|
||||
void ValidateGemmParams(
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
|
||||
// For now require a bias vector. Again, ruy does not rely on that requirement
|
||||
// but the gemmlowp and Eigen path would require more code to handle it,
|
||||
// and currently TFLite only uses the case where there is a bias vector.
|
||||
TFLITE_DCHECK(params.bias);
|
||||
// Guard consistency of the quantized multiplier fields.
|
||||
if (std::is_floating_point<AccumScalar>::value) {
|
||||
// Floating point case: must not have any quantized multipliers
|
||||
if (quantization_flavor == QuantizationFlavor::kFloatingPoint) {
|
||||
TFLITE_DCHECK(!params.multiplier_fixedpoint);
|
||||
TFLITE_DCHECK(!params.multiplier_exponent);
|
||||
TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
|
||||
TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
|
||||
} else {
|
||||
// Quantized case. Must have either uniform or perchannel multiplier,
|
||||
// not both.
|
||||
TFLITE_DCHECK((params.multiplier_fixedpoint == 0) !=
|
||||
(params.multiplier_fixedpoint_perchannel == nullptr));
|
||||
// Consistency of the two _perchannel fields.
|
||||
TFLITE_DCHECK((params.multiplier_exponent_perchannel == nullptr) ==
|
||||
(params.multiplier_fixedpoint_perchannel == nullptr));
|
||||
} else if (quantization_flavor ==
|
||||
QuantizationFlavor::kIntegerWithUniformMultiplier) {
|
||||
TFLITE_DCHECK(params.multiplier_fixedpoint);
|
||||
// Nothing to check about multiplier_exponent
|
||||
TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
|
||||
TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
|
||||
} else if (quantization_flavor ==
|
||||
QuantizationFlavor::kIntegerWithPerRowMultiplier) {
|
||||
TFLITE_DCHECK(!params.multiplier_fixedpoint);
|
||||
TFLITE_DCHECK(!params.multiplier_exponent);
|
||||
TFLITE_DCHECK(params.multiplier_fixedpoint_perchannel);
|
||||
TFLITE_DCHECK(params.multiplier_exponent_perchannel);
|
||||
}
|
||||
}
|
||||
|
||||
// Validates overall consistency of all the parameters taken by a Gemm call:
|
||||
// the 3 MatrixParams and the GemmParams. Even if currently these are
|
||||
// checked only separately, it's good to have this validation done in one
|
||||
// function taking all of these parameters at once, as in the future there
|
||||
// may be mutual consistency requirements.
|
||||
namespace detail {
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
struct ValidateTypes {
|
||||
// This generic implementation is for quantized flavors.
|
||||
// kFloatingPoint will be a specialization below.
|
||||
static_assert(!std::is_floating_point<LhsScalar>::value, "");
|
||||
static_assert(!std::is_floating_point<RhsScalar>::value, "");
|
||||
static_assert(!std::is_floating_point<AccumScalar>::value, "");
|
||||
// No requirement on DstScalar --- we might in the future allow it
|
||||
// to be floating point even in a quantized Gemm.
|
||||
};
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar>
|
||||
void ValidateParams(const MatrixParams<LhsScalar>& lhs_params,
|
||||
const MatrixParams<RhsScalar>& rhs_params,
|
||||
const MatrixParams<DstScalar>& dst_params,
|
||||
const GemmParams<AccumScalar, DstScalar>& params) {
|
||||
struct ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
QuantizationFlavor::kFloatingPoint> {
|
||||
static_assert(std::is_floating_point<LhsScalar>::value, "");
|
||||
static_assert(std::is_floating_point<RhsScalar>::value, "");
|
||||
static_assert(std::is_floating_point<AccumScalar>::value, "");
|
||||
static_assert(std::is_floating_point<DstScalar>::value, "");
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Validates overall consistency of all the parameters taken by a Gemm call:
|
||||
// the 3 MatrixParams and the GemmParams.
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
void ValidateParams(
|
||||
const MatrixParams<LhsScalar>& lhs_params,
|
||||
const MatrixParams<RhsScalar>& rhs_params,
|
||||
const MatrixParams<DstScalar>& dst_params,
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
|
||||
(void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||
quantization_flavor>();
|
||||
ValidateGemmParams(params);
|
||||
// For now, Gemm only supports this particular combination of storage orders.
|
||||
// Actually the generic ruy path already supports all combinations (with
|
||||
@ -151,6 +224,20 @@ void ValidateParams(const MatrixParams<LhsScalar>& lhs_params,
|
||||
TFLITE_DCHECK(lhs_params.order == Order::kRowMajor);
|
||||
TFLITE_DCHECK(rhs_params.order == Order::kColMajor);
|
||||
TFLITE_DCHECK(dst_params.order == Order::kColMajor);
|
||||
// Guard against the case when both LHS and RHS zero_point's are equal to
|
||||
// the minimum representable value. In that case, padding with zero_point
|
||||
// values will generate the bad case for fast int8 kernels on NEON
|
||||
// (pre-dotprod) which attempt to multiply-accumulate two pairs of int8
|
||||
// into a int16: this is safe except in the bad case -128*-128 + -128*-128.
|
||||
// Accordingly, this is banned by gemmlowp and ruy. However, they were not
|
||||
// guarding against that, which allowed a real bug to happen, b/131609283.
|
||||
// Checking this here lets callers of this cpu_backend_gemm library be
|
||||
// safe regardless of backends.
|
||||
if (quantization_flavor != QuantizationFlavor::kFloatingPoint) {
|
||||
TFLITE_DCHECK(
|
||||
lhs_params.zero_point != std::numeric_limits<LhsScalar>::lowest() ||
|
||||
rhs_params.zero_point != std::numeric_limits<RhsScalar>::lowest());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cpu_backend_gemm
|
||||
|
@ -61,16 +61,14 @@ void MakeRuySpec(const GemmParamsType& params, RuySpecType* ruy_spec) {
|
||||
}
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar>
|
||||
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
struct GemmImplUsingRuy {
|
||||
static void Run(const MatrixParams<LhsScalar>& lhs_params,
|
||||
const LhsScalar* lhs_data,
|
||||
const MatrixParams<RhsScalar>& rhs_params,
|
||||
const RhsScalar* rhs_data,
|
||||
const MatrixParams<DstScalar>& dst_params,
|
||||
DstScalar* dst_data,
|
||||
const GemmParams<AccumScalar, DstScalar>& params,
|
||||
CpuBackendContext* context) {
|
||||
static void Run(
|
||||
const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
|
||||
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||
CpuBackendContext* context) {
|
||||
ruy::Matrix<LhsScalar> ruy_lhs;
|
||||
ruy::Matrix<RhsScalar> ruy_rhs;
|
||||
ruy::Matrix<DstScalar> ruy_dst;
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/experimental/ruy/ruy.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
@ -34,6 +35,7 @@ namespace {
|
||||
using cpu_backend_gemm::Gemm;
|
||||
using cpu_backend_gemm::GemmParams;
|
||||
using cpu_backend_gemm::MatrixParams;
|
||||
using cpu_backend_gemm::QuantizationFlavor;
|
||||
|
||||
template <typename Scalar>
|
||||
std::string ToString(const std::vector<Scalar>& vector) {
|
||||
@ -125,9 +127,11 @@ void Clamp(const std::vector<Scalar>& src, Scalar clamp_min, Scalar clamp_max,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AccumScalar, typename DstScalar>
|
||||
void Clamp(const GemmParams<AccumScalar, DstScalar>& src, DstScalar clamp_min,
|
||||
DstScalar clamp_max, GemmParams<AccumScalar, DstScalar>* dst) {
|
||||
template <typename AccumScalar, typename DstScalar,
|
||||
QuantizationFlavor quantization_flavor>
|
||||
void Clamp(const GemmParams<AccumScalar, DstScalar, quantization_flavor>& src,
|
||||
DstScalar clamp_min, DstScalar clamp_max,
|
||||
GemmParams<AccumScalar, DstScalar, quantization_flavor>* dst) {
|
||||
*dst = src;
|
||||
dst->clamp_min = clamp_min;
|
||||
dst->clamp_max = clamp_max;
|
||||
@ -236,14 +240,14 @@ void CheckErrorForAccumulation(int accumulation_depth,
|
||||
}
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
typename DstScalar>
|
||||
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
void PerformGemmThenCompareResultsThenAgainWithClamping(
|
||||
const MatrixParams<LhsScalar>& lhs_params,
|
||||
const std::vector<LhsScalar>& lhs_data,
|
||||
const MatrixParams<RhsScalar>& rhs_params,
|
||||
const std::vector<RhsScalar>& rhs_data,
|
||||
const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
|
||||
const GemmParams<AccumScalar, DstScalar>& params,
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||
const std::vector<DstScalar>& expected,
|
||||
CpuBackendContext* cpu_backend_context) {
|
||||
const int accumulation_depth = lhs_params.cols;
|
||||
@ -253,7 +257,7 @@ void PerformGemmThenCompareResultsThenAgainWithClamping(
|
||||
expected);
|
||||
DstScalar expected_median = Median(expected);
|
||||
std::vector<DstScalar> expected_with_clamp;
|
||||
GemmParams<AccumScalar, DstScalar> params_with_clamp;
|
||||
GemmParams<AccumScalar, DstScalar, quantization_flavor> params_with_clamp;
|
||||
DstScalar clamp_min, clamp_max;
|
||||
|
||||
clamp_min = std::numeric_limits<DstScalar>::lowest();
|
||||
@ -453,9 +457,14 @@ void TestSomeGemm(int rows, int depth, int cols,
|
||||
rows, params.multiplier_fixedpoint);
|
||||
std::vector<int> multiplier_exponent_perchannel(rows,
|
||||
params.multiplier_exponent);
|
||||
GemmParams<AccumScalar, DstScalar> params_perchannel = params;
|
||||
params_perchannel.multiplier_fixedpoint = 0;
|
||||
params_perchannel.multiplier_exponent = 0;
|
||||
static constexpr QuantizationFlavor perchannel_flavor =
|
||||
std::is_floating_point<AccumScalar>::value
|
||||
? QuantizationFlavor::kFloatingPoint
|
||||
: QuantizationFlavor::kIntegerWithPerRowMultiplier;
|
||||
GemmParams<AccumScalar, DstScalar, perchannel_flavor> params_perchannel;
|
||||
params_perchannel.bias = params.bias;
|
||||
params_perchannel.clamp_min = params.clamp_min;
|
||||
params_perchannel.clamp_max = params.clamp_max;
|
||||
params_perchannel.multiplier_fixedpoint_perchannel =
|
||||
multiplier_fixedpoint_perchannel.data();
|
||||
params_perchannel.multiplier_exponent_perchannel =
|
||||
|
Loading…
x
Reference in New Issue
Block a user