Templatize Gemm in a 'quantization flavor' distinguishing per-channel vs

uniform quantization.

This isn't our preference design-wise, and is unnecessary as far as ruy is concerned, as in ruy this is just a runtime switch.

But we want to avoid code size regressions also in the currently default case where gemmlowp not ruy is used, and in gemmlowp this is a compile-time switch.

Moreover, the actual instantiations that TFLite needs are such that there is a real overall binary size regression unless this is templatized in this way. That is because TFLite uses 2 combinations of { input type, quantization flavor }:
 1. { uint8, uniform quantization }
 2. { int8, per-channel quantization }
TFLite does not use other combinations such as {uint8, per-channel}, so if we force gemmlowp to instantiate this case, we regress binary size overall.

PiperOrigin-RevId: 246070679
This commit is contained in:
Benoit Jacob 2019-04-30 20:06:00 -07:00 committed by TensorFlower Gardener
parent a5acee6458
commit 7114b99543
6 changed files with 263 additions and 147 deletions

View File

@ -36,60 +36,65 @@ namespace cpu_backend_gemm {
*/
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
struct GemmImpl
: detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar> {};
typename DstScalar, QuantizationFlavor quantization_flavor>
struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
DstScalar, quantization_flavor> {};
#ifndef TFLITE_WITH_RUY
/* Specializations using gemmlowp */
template <typename SrcScalar, typename DstScalar>
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar>
template <typename SrcScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
quantization_flavor>
: detail::GemmImplUsingGemmlowp<SrcScalar, SrcScalar, std::int32_t,
DstScalar> {};
DstScalar, quantization_flavor> {};
// When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile
// outside of NEON. We avoid the compilation failure by subspecializing these
// cases, rerouting it back to ruy.
#ifndef GEMMLOWP_NEON
template <typename SrcScalar>
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t>
: detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t,
std::int8_t> {};
template <typename SrcScalar, QuantizationFlavor quantization_flavor>
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
quantization_flavor>
: detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
quantization_flavor> {};
template <typename DstScalar>
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar>
template <typename DstScalar, QuantizationFlavor quantization_flavor>
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar,
quantization_flavor>
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
DstScalar> {};
DstScalar, quantization_flavor> {};
template <>
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t>
template <QuantizationFlavor quantization_flavor>
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
quantization_flavor>
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
std::int8_t> {};
std::int8_t, quantization_flavor> {};
#endif // not GEMMLOWP_NEON
/* Specializations using Eigen */
template <>
struct GemmImpl<float, float, float, float>
: detail::GemmImplUsingEigen<float, float, float, float> {};
struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
: detail::GemmImplUsingEigen<float> {};
#endif // not TFLITE_WITH_RUY
/* Public entry point */
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
typename DstScalar, QuantizationFlavor quantization_flavor>
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar>& params,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
CpuBackendContext* context) {
ValidateParams(lhs_params, rhs_params, dst_params, params);
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Run(
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params,
context);
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params, context);
}
} // namespace cpu_backend_gemm

View File

@ -24,13 +24,9 @@ namespace tflite {
namespace cpu_backend_gemm {
namespace detail {
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
template <typename Scalar>
struct GemmImplUsingEigen {
static_assert(std::is_same<LhsScalar, float>::value, "");
static_assert(std::is_same<RhsScalar, float>::value, "");
static_assert(std::is_same<AccumScalar, float>::value, "");
static_assert(std::is_same<DstScalar, float>::value, "");
static_assert(std::is_same<Scalar, float>::value, "");
static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
const MatrixParams<float>& rhs_params, const float* rhs_data,
const MatrixParams<float>& dst_params, float* dst_data,

View File

@ -59,94 +59,115 @@ struct GemmlowpBitDepthParams<std::int8_t> {
using Type = gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams;
};
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
struct GemmImplUsingGemmlowp {};
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
struct GemmImplUsingGemmlowp {
struct GemmImplUsingGemmlowp<
LhsScalar, RhsScalar, AccumScalar, DstScalar,
QuantizationFlavor::kIntegerWithUniformMultiplier> {
static_assert(std::is_same<LhsScalar, RhsScalar>::value, "");
static_assert(std::is_same<AccumScalar, std::int32_t>::value, "");
using SrcScalar = LhsScalar;
static void Run(const MatrixParams<SrcScalar>& lhs_params,
const SrcScalar* lhs_data,
const MatrixParams<SrcScalar>& rhs_params,
const SrcScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params,
DstScalar* dst_data,
const GemmParams<std::int32_t, DstScalar>& params,
CpuBackendContext* context) {
if (params.multiplier_exponent_perchannel) {
// gemmlowp support for this per-channel path is limited to NEON.
// We fall back to ruy outside of NEON.
static void Run(
const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data,
const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<std::int32_t, DstScalar,
QuantizationFlavor::kIntegerWithUniformMultiplier>&
params,
CpuBackendContext* context) {
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
dst_data, dst_params.rows, dst_params.cols);
using ColVectorMap =
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
ColVectorMap bias_vector(params.bias, lhs_params.rows);
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
bias_addition_stage.bias_vector = bias_vector;
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
scale_stage.result_offset_after_shift = dst_params.zero_point;
scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint;
scale_stage.result_exponent = params.multiplier_exponent;
using SaturatingCastStageType =
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = params.clamp_min;
clamp_stage.max = params.clamp_max;
SaturatingCastStageType saturating_cast_stage;
auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
clamp_stage, saturating_cast_stage);
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
-lhs_params.zero_point, -rhs_params.zero_point, output_pipeline);
}
};
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
struct GemmImplUsingGemmlowp<LhsScalar, RhsScalar, AccumScalar, DstScalar,
QuantizationFlavor::kIntegerWithPerRowMultiplier> {
static_assert(std::is_same<LhsScalar, RhsScalar>::value, "");
static_assert(std::is_same<AccumScalar, std::int32_t>::value, "");
using SrcScalar = LhsScalar;
static void Run(
const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data,
const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<std::int32_t, DstScalar,
QuantizationFlavor::kIntegerWithPerRowMultiplier>&
params,
CpuBackendContext* context) {
// gemmlowp support for this per-channel path is limited to NEON.
// We fall back to ruy outside of NEON.
#ifdef GEMMLOWP_NEON
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
dst_data, dst_params.rows, dst_params.cols);
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
dst_data, dst_params.rows, dst_params.cols);
using ColVectorMap =
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
ColVectorMap bias_vector(params.bias, lhs_params.rows);
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
bias_addition_stage.bias_vector = bias_vector;
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
gemmlowp::VectorShape::Col>
scale_stage;
scale_stage.result_offset_after_shift = dst_params.zero_point;
scale_stage.result_fixedpoint_multiplier = ColVectorMap(
params.multiplier_fixedpoint_perchannel, dst_params.rows);
scale_stage.result_exponent =
ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows);
using SaturatingCastStageType =
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = params.clamp_min;
clamp_stage.max = params.clamp_max;
SaturatingCastStageType saturating_cast_stage;
auto output_pipeline = std::make_tuple(
bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage);
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
&gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
output_pipeline);
using ColVectorMap =
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
ColVectorMap bias_vector(params.bias, lhs_params.rows);
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
bias_addition_stage.bias_vector = bias_vector;
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
gemmlowp::VectorShape::Col>
scale_stage;
scale_stage.result_offset_after_shift = dst_params.zero_point;
scale_stage.result_fixedpoint_multiplier =
ColVectorMap(params.multiplier_fixedpoint_perchannel, dst_params.rows);
scale_stage.result_exponent =
ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows);
using SaturatingCastStageType =
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = params.clamp_min;
clamp_stage.max = params.clamp_max;
SaturatingCastStageType saturating_cast_stage;
auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
clamp_stage, saturating_cast_stage);
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
-lhs_params.zero_point, -rhs_params.zero_point, output_pipeline);
#else
GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Run(
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data,
params, context);
GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
QuantizationFlavor::kIntegerWithPerRowMultiplier>::
Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data,
params, context);
#endif
} else {
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
dst_data, dst_params.rows, dst_params.cols);
using ColVectorMap =
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
ColVectorMap bias_vector(params.bias, lhs_params.rows);
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
bias_addition_stage.bias_vector = bias_vector;
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
scale_stage.result_offset_after_shift = dst_params.zero_point;
scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint;
scale_stage.result_exponent = params.multiplier_exponent;
using SaturatingCastStageType =
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = params.clamp_min;
clamp_stage.max = params.clamp_max;
SaturatingCastStageType saturating_cast_stage;
auto output_pipeline = std::make_tuple(
bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage);
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
&gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
output_pipeline);
}
}
};

View File

@ -49,12 +49,54 @@ struct MatrixParams {
Scalar zero_point = 0;
};
// Enumeration of broad categories of Gemm.
//
// The primary reason for this to exist is to allow Gemm to compile
// only uniform-quantized or only per-channel-quantized code paths.
// This is unneeded with ruy as the back-end, as this is only a runtime
// difference in ruy, but with gemmlowp these really are separate code
// paths and templatizing in a QuantizationFlavor is necessary to avoid
// compiling unused gemmlowp code. Indeed, TFLite currently uses
// uint8 with uniform quantization and int8 with per-channel quantization,
// and does not use uint8 with per-channel. We want to avoid compiling
// the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
//
// It's possible to drop this in the future if gemmlowp goes away and no
// other then-relevant backend library handles quantized paths in a way that
// requires knowing this at compile-time.
enum class QuantizationFlavor {
// Floating-point Gemm: the accumulators are not multiplied by any
// 'multiplier'.
kFloatingPoint,
// Quantized Gemm using a single multiplier for all accumulators.
kIntegerWithUniformMultiplier,
// Quantized Gemm using a separate multipliers for accumulators of each
// row of the destination matrix. This is what is called 'per-channel'
// in GemmParams. Here we use the more specific 'per-row' terminology
// to allow for the possibility of 'per-column' in the future, and to
// allow for that to be a separate code path in some back-end such as
// gemmlowp.
kIntegerWithPerRowMultiplier
};
// Additional parameters that Gemm needs, beyond what falls into
// the MatrixParams that it takes. Compare to ruy::Spec.
//
// Decoupling AccumScalar from DstScalar (rather than deducing it from that)
// is useful future-proofing. Think of a float16 path using float32 accum.
template <typename AccumScalar, typename DstScalar>
//
// QuantizationFlavor is passed here even though it's technically not used
// in this class. This is so that we retain the ability in the future to
// specialize this class for quantization flavor, and this allows for
// Gemm to be templatized in quantization_flavor via the GemmParams that it
// takes, allowing for automatic template parameter deduction to take place,
// so that most call sites don't need to specify a QuantizationFlavor
// (only those that need perchannel quantization do).
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor =
std::is_floating_point<AccumScalar>::value
? QuantizationFlavor::kFloatingPoint
: QuantizationFlavor::kIntegerWithUniformMultiplier>
struct GemmParams {
// Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
// of the multiplier by which accumulators are multiplied before being casted
@ -104,41 +146,72 @@ using FloatGemmParams = GemmParams<float, float>;
// a release-build assertion. See b/131587258.
// Validates self-consistency of GemmParams.
template <typename AccumScalar, typename DstScalar>
void ValidateGemmParams(const GemmParams<AccumScalar, DstScalar>& params) {
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
void ValidateGemmParams(
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
// For now require a bias vector. Again, ruy does not rely on that requirement
// but the gemmlowp and Eigen path would require more code to handle it,
// and currently TFLite only uses the case where there is a bias vector.
TFLITE_DCHECK(params.bias);
// Guard consistency of the quantized multiplier fields.
if (std::is_floating_point<AccumScalar>::value) {
// Floating point case: must not have any quantized multipliers
if (quantization_flavor == QuantizationFlavor::kFloatingPoint) {
TFLITE_DCHECK(!params.multiplier_fixedpoint);
TFLITE_DCHECK(!params.multiplier_exponent);
TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
} else {
// Quantized case. Must have either uniform or perchannel multiplier,
// not both.
TFLITE_DCHECK((params.multiplier_fixedpoint == 0) !=
(params.multiplier_fixedpoint_perchannel == nullptr));
// Consistency of the two _perchannel fields.
TFLITE_DCHECK((params.multiplier_exponent_perchannel == nullptr) ==
(params.multiplier_fixedpoint_perchannel == nullptr));
} else if (quantization_flavor ==
QuantizationFlavor::kIntegerWithUniformMultiplier) {
TFLITE_DCHECK(params.multiplier_fixedpoint);
// Nothing to check about multiplier_exponent
TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
} else if (quantization_flavor ==
QuantizationFlavor::kIntegerWithPerRowMultiplier) {
TFLITE_DCHECK(!params.multiplier_fixedpoint);
TFLITE_DCHECK(!params.multiplier_exponent);
TFLITE_DCHECK(params.multiplier_fixedpoint_perchannel);
TFLITE_DCHECK(params.multiplier_exponent_perchannel);
}
}
// Validates overall consistency of all the parameters taken by a Gemm call:
// the 3 MatrixParams and the GemmParams. Even if currently these are
// checked only separately, it's good to have this validation done in one
// function taking all of these parameters at once, as in the future there
// may be mutual consistency requirements.
namespace detail {
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
struct ValidateTypes {
// This generic implementation is for quantized flavors.
// kFloatingPoint will be a specialization below.
static_assert(!std::is_floating_point<LhsScalar>::value, "");
static_assert(!std::is_floating_point<RhsScalar>::value, "");
static_assert(!std::is_floating_point<AccumScalar>::value, "");
// No requirement on DstScalar --- we might in the future allow it
// to be floating point even in a quantized Gemm.
};
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
void ValidateParams(const MatrixParams<LhsScalar>& lhs_params,
const MatrixParams<RhsScalar>& rhs_params,
const MatrixParams<DstScalar>& dst_params,
const GemmParams<AccumScalar, DstScalar>& params) {
struct ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
QuantizationFlavor::kFloatingPoint> {
static_assert(std::is_floating_point<LhsScalar>::value, "");
static_assert(std::is_floating_point<RhsScalar>::value, "");
static_assert(std::is_floating_point<AccumScalar>::value, "");
static_assert(std::is_floating_point<DstScalar>::value, "");
};
} // namespace detail
// Validates overall consistency of all the parameters taken by a Gemm call:
// the 3 MatrixParams and the GemmParams.
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar, QuantizationFlavor quantization_flavor>
void ValidateParams(
const MatrixParams<LhsScalar>& lhs_params,
const MatrixParams<RhsScalar>& rhs_params,
const MatrixParams<DstScalar>& dst_params,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
(void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
quantization_flavor>();
ValidateGemmParams(params);
// For now, Gemm only supports this particular combination of storage orders.
// Actually the generic ruy path already supports all combinations (with
@ -151,6 +224,20 @@ void ValidateParams(const MatrixParams<LhsScalar>& lhs_params,
TFLITE_DCHECK(lhs_params.order == Order::kRowMajor);
TFLITE_DCHECK(rhs_params.order == Order::kColMajor);
TFLITE_DCHECK(dst_params.order == Order::kColMajor);
// Guard against the case when both LHS and RHS zero_point's are equal to
// the minimum representable value. In that case, padding with zero_point
// values will generate the bad case for fast int8 kernels on NEON
// (pre-dotprod) which attempt to multiply-accumulate two pairs of int8
// into a int16: this is safe except in the bad case -128*-128 + -128*-128.
// Accordingly, this is banned by gemmlowp and ruy. However, they were not
// guarding against that, which allowed a real bug to happen, b/131609283.
// Checking this here lets callers of this cpu_backend_gemm library be
// safe regardless of backends.
if (quantization_flavor != QuantizationFlavor::kFloatingPoint) {
TFLITE_DCHECK(
lhs_params.zero_point != std::numeric_limits<LhsScalar>::lowest() ||
rhs_params.zero_point != std::numeric_limits<RhsScalar>::lowest());
}
}
} // namespace cpu_backend_gemm

View File

@ -61,16 +61,14 @@ void MakeRuySpec(const GemmParamsType& params, RuySpecType* ruy_spec) {
}
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
typename DstScalar, QuantizationFlavor quantization_flavor>
struct GemmImplUsingRuy {
static void Run(const MatrixParams<LhsScalar>& lhs_params,
const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params,
const RhsScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params,
DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar>& params,
CpuBackendContext* context) {
static void Run(
const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
CpuBackendContext* context) {
ruy::Matrix<LhsScalar> ruy_lhs;
ruy::Matrix<RhsScalar> ruy_rhs;
ruy::Matrix<DstScalar> ruy_dst;

View File

@ -26,6 +26,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/lite/experimental/ruy/ruy.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
namespace tflite {
@ -34,6 +35,7 @@ namespace {
using cpu_backend_gemm::Gemm;
using cpu_backend_gemm::GemmParams;
using cpu_backend_gemm::MatrixParams;
using cpu_backend_gemm::QuantizationFlavor;
template <typename Scalar>
std::string ToString(const std::vector<Scalar>& vector) {
@ -125,9 +127,11 @@ void Clamp(const std::vector<Scalar>& src, Scalar clamp_min, Scalar clamp_max,
}
}
template <typename AccumScalar, typename DstScalar>
void Clamp(const GemmParams<AccumScalar, DstScalar>& src, DstScalar clamp_min,
DstScalar clamp_max, GemmParams<AccumScalar, DstScalar>* dst) {
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
void Clamp(const GemmParams<AccumScalar, DstScalar, quantization_flavor>& src,
DstScalar clamp_min, DstScalar clamp_max,
GemmParams<AccumScalar, DstScalar, quantization_flavor>* dst) {
*dst = src;
dst->clamp_min = clamp_min;
dst->clamp_max = clamp_max;
@ -236,14 +240,14 @@ void CheckErrorForAccumulation(int accumulation_depth,
}
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
typename DstScalar, QuantizationFlavor quantization_flavor>
void PerformGemmThenCompareResultsThenAgainWithClamping(
const MatrixParams<LhsScalar>& lhs_params,
const std::vector<LhsScalar>& lhs_data,
const MatrixParams<RhsScalar>& rhs_params,
const std::vector<RhsScalar>& rhs_data,
const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
const GemmParams<AccumScalar, DstScalar>& params,
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
const std::vector<DstScalar>& expected,
CpuBackendContext* cpu_backend_context) {
const int accumulation_depth = lhs_params.cols;
@ -253,7 +257,7 @@ void PerformGemmThenCompareResultsThenAgainWithClamping(
expected);
DstScalar expected_median = Median(expected);
std::vector<DstScalar> expected_with_clamp;
GemmParams<AccumScalar, DstScalar> params_with_clamp;
GemmParams<AccumScalar, DstScalar, quantization_flavor> params_with_clamp;
DstScalar clamp_min, clamp_max;
clamp_min = std::numeric_limits<DstScalar>::lowest();
@ -453,9 +457,14 @@ void TestSomeGemm(int rows, int depth, int cols,
rows, params.multiplier_fixedpoint);
std::vector<int> multiplier_exponent_perchannel(rows,
params.multiplier_exponent);
GemmParams<AccumScalar, DstScalar> params_perchannel = params;
params_perchannel.multiplier_fixedpoint = 0;
params_perchannel.multiplier_exponent = 0;
static constexpr QuantizationFlavor perchannel_flavor =
std::is_floating_point<AccumScalar>::value
? QuantizationFlavor::kFloatingPoint
: QuantizationFlavor::kIntegerWithPerRowMultiplier;
GemmParams<AccumScalar, DstScalar, perchannel_flavor> params_perchannel;
params_perchannel.bias = params.bias;
params_perchannel.clamp_min = params.clamp_min;
params_perchannel.clamp_max = params.clamp_max;
params_perchannel.multiplier_fixedpoint_perchannel =
multiplier_fixedpoint_perchannel.data();
params_perchannel.multiplier_exponent_perchannel =