Templatize Gemm in a 'quantization flavor' distinguishing per-channel vs
uniform quantization. This isn't our preference design-wise, and is unnecessary as far as ruy is concerned, as in ruy this is just a runtime switch. But we want to avoid code size regressions also in the currently default case where gemmlowp not ruy is used, and in gemmlowp this is a compile-time switch. Moreover, the actual instantiations that TFLite needs are such that there is a real overall binary size regression unless this is templatized in this way. That is because TFLite uses 2 combinations of { input type, quantization flavor }: 1. { uint8, uniform quantization } 2. { int8, per-channel quantization } TFLite does not use other combinations such as {uint8, per-channel}, so if we force gemmlowp to instantiate this case, we regress binary size overall. PiperOrigin-RevId: 246070679
This commit is contained in:
parent
a5acee6458
commit
7114b99543
@ -36,60 +36,65 @@ namespace cpu_backend_gemm {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
typename DstScalar>
|
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
struct GemmImpl
|
struct GemmImpl : detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar,
|
||||||
: detail::GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar> {};
|
DstScalar, quantization_flavor> {};
|
||||||
|
|
||||||
#ifndef TFLITE_WITH_RUY
|
#ifndef TFLITE_WITH_RUY
|
||||||
|
|
||||||
/* Specializations using gemmlowp */
|
/* Specializations using gemmlowp */
|
||||||
|
|
||||||
template <typename SrcScalar, typename DstScalar>
|
template <typename SrcScalar, typename DstScalar,
|
||||||
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar>
|
QuantizationFlavor quantization_flavor>
|
||||||
|
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, DstScalar,
|
||||||
|
quantization_flavor>
|
||||||
: detail::GemmImplUsingGemmlowp<SrcScalar, SrcScalar, std::int32_t,
|
: detail::GemmImplUsingGemmlowp<SrcScalar, SrcScalar, std::int32_t,
|
||||||
DstScalar> {};
|
DstScalar, quantization_flavor> {};
|
||||||
|
|
||||||
// When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile
|
// When SrcScalar=int8 or DstScalar=int8, gemmlowp fails to compile
|
||||||
// outside of NEON. We avoid the compilation failure by subspecializing these
|
// outside of NEON. We avoid the compilation failure by subspecializing these
|
||||||
// cases, rerouting it back to ruy.
|
// cases, rerouting it back to ruy.
|
||||||
#ifndef GEMMLOWP_NEON
|
#ifndef GEMMLOWP_NEON
|
||||||
template <typename SrcScalar>
|
template <typename SrcScalar, QuantizationFlavor quantization_flavor>
|
||||||
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t>
|
struct GemmImpl<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
|
||||||
: detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t,
|
quantization_flavor>
|
||||||
std::int8_t> {};
|
: detail::GemmImplUsingRuy<SrcScalar, SrcScalar, std::int32_t, std::int8_t,
|
||||||
|
quantization_flavor> {};
|
||||||
|
|
||||||
template <typename DstScalar>
|
template <typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar>
|
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, DstScalar,
|
||||||
|
quantization_flavor>
|
||||||
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
|
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
|
||||||
DstScalar> {};
|
DstScalar, quantization_flavor> {};
|
||||||
|
|
||||||
template <>
|
template <QuantizationFlavor quantization_flavor>
|
||||||
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t>
|
struct GemmImpl<std::int8_t, std::int8_t, std::int32_t, std::int8_t,
|
||||||
|
quantization_flavor>
|
||||||
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
|
: detail::GemmImplUsingRuy<std::int8_t, std::int8_t, std::int32_t,
|
||||||
std::int8_t> {};
|
std::int8_t, quantization_flavor> {};
|
||||||
#endif // not GEMMLOWP_NEON
|
#endif // not GEMMLOWP_NEON
|
||||||
|
|
||||||
/* Specializations using Eigen */
|
/* Specializations using Eigen */
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct GemmImpl<float, float, float, float>
|
struct GemmImpl<float, float, float, float, QuantizationFlavor::kFloatingPoint>
|
||||||
: detail::GemmImplUsingEigen<float, float, float, float> {};
|
: detail::GemmImplUsingEigen<float> {};
|
||||||
|
|
||||||
#endif // not TFLITE_WITH_RUY
|
#endif // not TFLITE_WITH_RUY
|
||||||
|
|
||||||
/* Public entry point */
|
/* Public entry point */
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
typename DstScalar>
|
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||||
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
|
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
|
||||||
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||||
const GemmParams<AccumScalar, DstScalar>& params,
|
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||||
CpuBackendContext* context) {
|
CpuBackendContext* context) {
|
||||||
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
ValidateParams(lhs_params, rhs_params, dst_params, params);
|
||||||
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Run(
|
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params,
|
quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data,
|
||||||
context);
|
dst_params, dst_data, params, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cpu_backend_gemm
|
} // namespace cpu_backend_gemm
|
||||||
|
@ -24,13 +24,9 @@ namespace tflite {
|
|||||||
namespace cpu_backend_gemm {
|
namespace cpu_backend_gemm {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
template <typename Scalar>
|
||||||
typename DstScalar>
|
|
||||||
struct GemmImplUsingEigen {
|
struct GemmImplUsingEigen {
|
||||||
static_assert(std::is_same<LhsScalar, float>::value, "");
|
static_assert(std::is_same<Scalar, float>::value, "");
|
||||||
static_assert(std::is_same<RhsScalar, float>::value, "");
|
|
||||||
static_assert(std::is_same<AccumScalar, float>::value, "");
|
|
||||||
static_assert(std::is_same<DstScalar, float>::value, "");
|
|
||||||
static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
|
static void Run(const MatrixParams<float>& lhs_params, const float* lhs_data,
|
||||||
const MatrixParams<float>& rhs_params, const float* rhs_data,
|
const MatrixParams<float>& rhs_params, const float* rhs_data,
|
||||||
const MatrixParams<float>& dst_params, float* dst_data,
|
const MatrixParams<float>& dst_params, float* dst_data,
|
||||||
|
@ -59,94 +59,115 @@ struct GemmlowpBitDepthParams<std::int8_t> {
|
|||||||
using Type = gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams;
|
using Type = gemmlowp::SignedL8R8WithLhsNonzeroBitDepthParams;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
|
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
|
struct GemmImplUsingGemmlowp {};
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
typename DstScalar>
|
typename DstScalar>
|
||||||
struct GemmImplUsingGemmlowp {
|
struct GemmImplUsingGemmlowp<
|
||||||
|
LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
|
QuantizationFlavor::kIntegerWithUniformMultiplier> {
|
||||||
static_assert(std::is_same<LhsScalar, RhsScalar>::value, "");
|
static_assert(std::is_same<LhsScalar, RhsScalar>::value, "");
|
||||||
static_assert(std::is_same<AccumScalar, std::int32_t>::value, "");
|
static_assert(std::is_same<AccumScalar, std::int32_t>::value, "");
|
||||||
using SrcScalar = LhsScalar;
|
using SrcScalar = LhsScalar;
|
||||||
|
|
||||||
static void Run(const MatrixParams<SrcScalar>& lhs_params,
|
static void Run(
|
||||||
const SrcScalar* lhs_data,
|
const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data,
|
||||||
const MatrixParams<SrcScalar>& rhs_params,
|
const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data,
|
||||||
const SrcScalar* rhs_data,
|
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||||
const MatrixParams<DstScalar>& dst_params,
|
const GemmParams<std::int32_t, DstScalar,
|
||||||
DstScalar* dst_data,
|
QuantizationFlavor::kIntegerWithUniformMultiplier>&
|
||||||
const GemmParams<std::int32_t, DstScalar>& params,
|
params,
|
||||||
CpuBackendContext* context) {
|
CpuBackendContext* context) {
|
||||||
if (params.multiplier_exponent_perchannel) {
|
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
|
||||||
// gemmlowp support for this per-channel path is limited to NEON.
|
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
|
||||||
// We fall back to ruy outside of NEON.
|
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
|
||||||
|
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
|
||||||
|
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
|
||||||
|
dst_data, dst_params.rows, dst_params.cols);
|
||||||
|
|
||||||
|
using ColVectorMap =
|
||||||
|
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
|
||||||
|
ColVectorMap bias_vector(params.bias, lhs_params.rows);
|
||||||
|
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
||||||
|
bias_addition_stage.bias_vector = bias_vector;
|
||||||
|
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
|
||||||
|
scale_stage.result_offset_after_shift = dst_params.zero_point;
|
||||||
|
scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint;
|
||||||
|
scale_stage.result_exponent = params.multiplier_exponent;
|
||||||
|
using SaturatingCastStageType =
|
||||||
|
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
|
||||||
|
gemmlowp::OutputStageClamp clamp_stage;
|
||||||
|
clamp_stage.min = params.clamp_min;
|
||||||
|
clamp_stage.max = params.clamp_max;
|
||||||
|
SaturatingCastStageType saturating_cast_stage;
|
||||||
|
auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
|
||||||
|
clamp_stage, saturating_cast_stage);
|
||||||
|
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
|
||||||
|
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
|
||||||
|
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
|
||||||
|
-lhs_params.zero_point, -rhs_params.zero_point, output_pipeline);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
|
typename DstScalar>
|
||||||
|
struct GemmImplUsingGemmlowp<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
|
QuantizationFlavor::kIntegerWithPerRowMultiplier> {
|
||||||
|
static_assert(std::is_same<LhsScalar, RhsScalar>::value, "");
|
||||||
|
static_assert(std::is_same<AccumScalar, std::int32_t>::value, "");
|
||||||
|
using SrcScalar = LhsScalar;
|
||||||
|
|
||||||
|
static void Run(
|
||||||
|
const MatrixParams<SrcScalar>& lhs_params, const SrcScalar* lhs_data,
|
||||||
|
const MatrixParams<SrcScalar>& rhs_params, const SrcScalar* rhs_data,
|
||||||
|
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||||
|
const GemmParams<std::int32_t, DstScalar,
|
||||||
|
QuantizationFlavor::kIntegerWithPerRowMultiplier>&
|
||||||
|
params,
|
||||||
|
CpuBackendContext* context) {
|
||||||
|
// gemmlowp support for this per-channel path is limited to NEON.
|
||||||
|
// We fall back to ruy outside of NEON.
|
||||||
#ifdef GEMMLOWP_NEON
|
#ifdef GEMMLOWP_NEON
|
||||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
|
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
|
||||||
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
|
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
|
||||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
|
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
|
||||||
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
|
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
|
||||||
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
|
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
|
||||||
dst_data, dst_params.rows, dst_params.cols);
|
dst_data, dst_params.rows, dst_params.cols);
|
||||||
|
|
||||||
using ColVectorMap =
|
using ColVectorMap =
|
||||||
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
|
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
|
||||||
ColVectorMap bias_vector(params.bias, lhs_params.rows);
|
ColVectorMap bias_vector(params.bias, lhs_params.rows);
|
||||||
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
||||||
bias_addition_stage.bias_vector = bias_vector;
|
bias_addition_stage.bias_vector = bias_vector;
|
||||||
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
|
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC<
|
||||||
gemmlowp::VectorShape::Col>
|
gemmlowp::VectorShape::Col>
|
||||||
scale_stage;
|
scale_stage;
|
||||||
scale_stage.result_offset_after_shift = dst_params.zero_point;
|
scale_stage.result_offset_after_shift = dst_params.zero_point;
|
||||||
scale_stage.result_fixedpoint_multiplier = ColVectorMap(
|
scale_stage.result_fixedpoint_multiplier =
|
||||||
params.multiplier_fixedpoint_perchannel, dst_params.rows);
|
ColVectorMap(params.multiplier_fixedpoint_perchannel, dst_params.rows);
|
||||||
scale_stage.result_exponent =
|
scale_stage.result_exponent =
|
||||||
ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows);
|
ColVectorMap(params.multiplier_exponent_perchannel, dst_params.rows);
|
||||||
using SaturatingCastStageType =
|
using SaturatingCastStageType =
|
||||||
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
|
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
|
||||||
gemmlowp::OutputStageClamp clamp_stage;
|
gemmlowp::OutputStageClamp clamp_stage;
|
||||||
clamp_stage.min = params.clamp_min;
|
clamp_stage.min = params.clamp_min;
|
||||||
clamp_stage.max = params.clamp_max;
|
clamp_stage.max = params.clamp_max;
|
||||||
SaturatingCastStageType saturating_cast_stage;
|
SaturatingCastStageType saturating_cast_stage;
|
||||||
auto output_pipeline = std::make_tuple(
|
auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
|
||||||
bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage);
|
clamp_stage, saturating_cast_stage);
|
||||||
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
|
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
|
||||||
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
|
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
|
||||||
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
|
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst,
|
||||||
&gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
|
-lhs_params.zero_point, -rhs_params.zero_point, output_pipeline);
|
||||||
output_pipeline);
|
|
||||||
#else
|
#else
|
||||||
GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Run(
|
GemmImplUsingRuy<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data,
|
QuantizationFlavor::kIntegerWithPerRowMultiplier>::
|
||||||
params, context);
|
Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data,
|
||||||
|
params, context);
|
||||||
#endif
|
#endif
|
||||||
} else {
|
|
||||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::RowMajor>
|
|
||||||
gemmlowp_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
|
|
||||||
gemmlowp::MatrixMap<const SrcScalar, gemmlowp::MapOrder::ColMajor>
|
|
||||||
gemmlowp_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
|
|
||||||
gemmlowp::MatrixMap<DstScalar, gemmlowp::MapOrder::ColMajor> gemmlowp_dst(
|
|
||||||
dst_data, dst_params.rows, dst_params.cols);
|
|
||||||
|
|
||||||
using ColVectorMap =
|
|
||||||
gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>;
|
|
||||||
ColVectorMap bias_vector(params.bias, lhs_params.rows);
|
|
||||||
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
|
|
||||||
bias_addition_stage.bias_vector = bias_vector;
|
|
||||||
gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
|
|
||||||
scale_stage.result_offset_after_shift = dst_params.zero_point;
|
|
||||||
scale_stage.result_fixedpoint_multiplier = params.multiplier_fixedpoint;
|
|
||||||
scale_stage.result_exponent = params.multiplier_exponent;
|
|
||||||
using SaturatingCastStageType =
|
|
||||||
typename GemmlowpSaturatingCastStage<DstScalar>::Type;
|
|
||||||
gemmlowp::OutputStageClamp clamp_stage;
|
|
||||||
clamp_stage.min = params.clamp_min;
|
|
||||||
clamp_stage.max = params.clamp_max;
|
|
||||||
SaturatingCastStageType saturating_cast_stage;
|
|
||||||
auto output_pipeline = std::make_tuple(
|
|
||||||
bias_addition_stage, scale_stage, clamp_stage, saturating_cast_stage);
|
|
||||||
using BitDepthParams = typename GemmlowpBitDepthParams<SrcScalar>::Type;
|
|
||||||
gemmlowp::GemmWithOutputPipeline<SrcScalar, DstScalar, BitDepthParams>(
|
|
||||||
context->gemmlowp_context(), gemmlowp_lhs, gemmlowp_rhs,
|
|
||||||
&gemmlowp_dst, -lhs_params.zero_point, -rhs_params.zero_point,
|
|
||||||
output_pipeline);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -49,12 +49,54 @@ struct MatrixParams {
|
|||||||
Scalar zero_point = 0;
|
Scalar zero_point = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Enumeration of broad categories of Gemm.
|
||||||
|
//
|
||||||
|
// The primary reason for this to exist is to allow Gemm to compile
|
||||||
|
// only uniform-quantized or only per-channel-quantized code paths.
|
||||||
|
// This is unneeded with ruy as the back-end, as this is only a runtime
|
||||||
|
// difference in ruy, but with gemmlowp these really are separate code
|
||||||
|
// paths and templatizing in a QuantizationFlavor is necessary to avoid
|
||||||
|
// compiling unused gemmlowp code. Indeed, TFLite currently uses
|
||||||
|
// uint8 with uniform quantization and int8 with per-channel quantization,
|
||||||
|
// and does not use uint8 with per-channel. We want to avoid compiling
|
||||||
|
// the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
|
||||||
|
//
|
||||||
|
// It's possible to drop this in the future if gemmlowp goes away and no
|
||||||
|
// other then-relevant backend library handles quantized paths in a way that
|
||||||
|
// requires knowing this at compile-time.
|
||||||
|
enum class QuantizationFlavor {
|
||||||
|
// Floating-point Gemm: the accumulators are not multiplied by any
|
||||||
|
// 'multiplier'.
|
||||||
|
kFloatingPoint,
|
||||||
|
// Quantized Gemm using a single multiplier for all accumulators.
|
||||||
|
kIntegerWithUniformMultiplier,
|
||||||
|
// Quantized Gemm using a separate multipliers for accumulators of each
|
||||||
|
// row of the destination matrix. This is what is called 'per-channel'
|
||||||
|
// in GemmParams. Here we use the more specific 'per-row' terminology
|
||||||
|
// to allow for the possibility of 'per-column' in the future, and to
|
||||||
|
// allow for that to be a separate code path in some back-end such as
|
||||||
|
// gemmlowp.
|
||||||
|
kIntegerWithPerRowMultiplier
|
||||||
|
};
|
||||||
|
|
||||||
// Additional parameters that Gemm needs, beyond what falls into
|
// Additional parameters that Gemm needs, beyond what falls into
|
||||||
// the MatrixParams that it takes. Compare to ruy::Spec.
|
// the MatrixParams that it takes. Compare to ruy::Spec.
|
||||||
//
|
//
|
||||||
// Decoupling AccumScalar from DstScalar (rather than deducing it from that)
|
// Decoupling AccumScalar from DstScalar (rather than deducing it from that)
|
||||||
// is useful future-proofing. Think of a float16 path using float32 accum.
|
// is useful future-proofing. Think of a float16 path using float32 accum.
|
||||||
template <typename AccumScalar, typename DstScalar>
|
//
|
||||||
|
// QuantizationFlavor is passed here even though it's technically not used
|
||||||
|
// in this class. This is so that we retain the ability in the future to
|
||||||
|
// specialize this class for quantization flavor, and this allows for
|
||||||
|
// Gemm to be templatized in quantization_flavor via the GemmParams that it
|
||||||
|
// takes, allowing for automatic template parameter deduction to take place,
|
||||||
|
// so that most call sites don't need to specify a QuantizationFlavor
|
||||||
|
// (only those that need perchannel quantization do).
|
||||||
|
template <typename AccumScalar, typename DstScalar,
|
||||||
|
QuantizationFlavor quantization_flavor =
|
||||||
|
std::is_floating_point<AccumScalar>::value
|
||||||
|
? QuantizationFlavor::kFloatingPoint
|
||||||
|
: QuantizationFlavor::kIntegerWithUniformMultiplier>
|
||||||
struct GemmParams {
|
struct GemmParams {
|
||||||
// Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
|
// Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
|
||||||
// of the multiplier by which accumulators are multiplied before being casted
|
// of the multiplier by which accumulators are multiplied before being casted
|
||||||
@ -104,41 +146,72 @@ using FloatGemmParams = GemmParams<float, float>;
|
|||||||
// a release-build assertion. See b/131587258.
|
// a release-build assertion. See b/131587258.
|
||||||
|
|
||||||
// Validates self-consistency of GemmParams.
|
// Validates self-consistency of GemmParams.
|
||||||
template <typename AccumScalar, typename DstScalar>
|
template <typename AccumScalar, typename DstScalar,
|
||||||
void ValidateGemmParams(const GemmParams<AccumScalar, DstScalar>& params) {
|
QuantizationFlavor quantization_flavor>
|
||||||
|
void ValidateGemmParams(
|
||||||
|
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
|
||||||
// For now require a bias vector. Again, ruy does not rely on that requirement
|
// For now require a bias vector. Again, ruy does not rely on that requirement
|
||||||
// but the gemmlowp and Eigen path would require more code to handle it,
|
// but the gemmlowp and Eigen path would require more code to handle it,
|
||||||
// and currently TFLite only uses the case where there is a bias vector.
|
// and currently TFLite only uses the case where there is a bias vector.
|
||||||
TFLITE_DCHECK(params.bias);
|
TFLITE_DCHECK(params.bias);
|
||||||
// Guard consistency of the quantized multiplier fields.
|
// Guard consistency of the quantized multiplier fields.
|
||||||
if (std::is_floating_point<AccumScalar>::value) {
|
if (quantization_flavor == QuantizationFlavor::kFloatingPoint) {
|
||||||
// Floating point case: must not have any quantized multipliers
|
|
||||||
TFLITE_DCHECK(!params.multiplier_fixedpoint);
|
TFLITE_DCHECK(!params.multiplier_fixedpoint);
|
||||||
TFLITE_DCHECK(!params.multiplier_exponent);
|
TFLITE_DCHECK(!params.multiplier_exponent);
|
||||||
TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
|
TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
|
||||||
TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
|
TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
|
||||||
} else {
|
} else if (quantization_flavor ==
|
||||||
// Quantized case. Must have either uniform or perchannel multiplier,
|
QuantizationFlavor::kIntegerWithUniformMultiplier) {
|
||||||
// not both.
|
TFLITE_DCHECK(params.multiplier_fixedpoint);
|
||||||
TFLITE_DCHECK((params.multiplier_fixedpoint == 0) !=
|
// Nothing to check about multiplier_exponent
|
||||||
(params.multiplier_fixedpoint_perchannel == nullptr));
|
TFLITE_DCHECK(!params.multiplier_fixedpoint_perchannel);
|
||||||
// Consistency of the two _perchannel fields.
|
TFLITE_DCHECK(!params.multiplier_exponent_perchannel);
|
||||||
TFLITE_DCHECK((params.multiplier_exponent_perchannel == nullptr) ==
|
} else if (quantization_flavor ==
|
||||||
(params.multiplier_fixedpoint_perchannel == nullptr));
|
QuantizationFlavor::kIntegerWithPerRowMultiplier) {
|
||||||
|
TFLITE_DCHECK(!params.multiplier_fixedpoint);
|
||||||
|
TFLITE_DCHECK(!params.multiplier_exponent);
|
||||||
|
TFLITE_DCHECK(params.multiplier_fixedpoint_perchannel);
|
||||||
|
TFLITE_DCHECK(params.multiplier_exponent_perchannel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validates overall consistency of all the parameters taken by a Gemm call:
|
namespace detail {
|
||||||
// the 3 MatrixParams and the GemmParams. Even if currently these are
|
|
||||||
// checked only separately, it's good to have this validation done in one
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
// function taking all of these parameters at once, as in the future there
|
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
// may be mutual consistency requirements.
|
struct ValidateTypes {
|
||||||
|
// This generic implementation is for quantized flavors.
|
||||||
|
// kFloatingPoint will be a specialization below.
|
||||||
|
static_assert(!std::is_floating_point<LhsScalar>::value, "");
|
||||||
|
static_assert(!std::is_floating_point<RhsScalar>::value, "");
|
||||||
|
static_assert(!std::is_floating_point<AccumScalar>::value, "");
|
||||||
|
// No requirement on DstScalar --- we might in the future allow it
|
||||||
|
// to be floating point even in a quantized Gemm.
|
||||||
|
};
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
typename DstScalar>
|
typename DstScalar>
|
||||||
void ValidateParams(const MatrixParams<LhsScalar>& lhs_params,
|
struct ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
const MatrixParams<RhsScalar>& rhs_params,
|
QuantizationFlavor::kFloatingPoint> {
|
||||||
const MatrixParams<DstScalar>& dst_params,
|
static_assert(std::is_floating_point<LhsScalar>::value, "");
|
||||||
const GemmParams<AccumScalar, DstScalar>& params) {
|
static_assert(std::is_floating_point<RhsScalar>::value, "");
|
||||||
|
static_assert(std::is_floating_point<AccumScalar>::value, "");
|
||||||
|
static_assert(std::is_floating_point<DstScalar>::value, "");
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
// Validates overall consistency of all the parameters taken by a Gemm call:
|
||||||
|
// the 3 MatrixParams and the GemmParams.
|
||||||
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
|
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
|
void ValidateParams(
|
||||||
|
const MatrixParams<LhsScalar>& lhs_params,
|
||||||
|
const MatrixParams<RhsScalar>& rhs_params,
|
||||||
|
const MatrixParams<DstScalar>& dst_params,
|
||||||
|
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {
|
||||||
|
(void)detail::ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar,
|
||||||
|
quantization_flavor>();
|
||||||
ValidateGemmParams(params);
|
ValidateGemmParams(params);
|
||||||
// For now, Gemm only supports this particular combination of storage orders.
|
// For now, Gemm only supports this particular combination of storage orders.
|
||||||
// Actually the generic ruy path already supports all combinations (with
|
// Actually the generic ruy path already supports all combinations (with
|
||||||
@ -151,6 +224,20 @@ void ValidateParams(const MatrixParams<LhsScalar>& lhs_params,
|
|||||||
TFLITE_DCHECK(lhs_params.order == Order::kRowMajor);
|
TFLITE_DCHECK(lhs_params.order == Order::kRowMajor);
|
||||||
TFLITE_DCHECK(rhs_params.order == Order::kColMajor);
|
TFLITE_DCHECK(rhs_params.order == Order::kColMajor);
|
||||||
TFLITE_DCHECK(dst_params.order == Order::kColMajor);
|
TFLITE_DCHECK(dst_params.order == Order::kColMajor);
|
||||||
|
// Guard against the case when both LHS and RHS zero_point's are equal to
|
||||||
|
// the minimum representable value. In that case, padding with zero_point
|
||||||
|
// values will generate the bad case for fast int8 kernels on NEON
|
||||||
|
// (pre-dotprod) which attempt to multiply-accumulate two pairs of int8
|
||||||
|
// into a int16: this is safe except in the bad case -128*-128 + -128*-128.
|
||||||
|
// Accordingly, this is banned by gemmlowp and ruy. However, they were not
|
||||||
|
// guarding against that, which allowed a real bug to happen, b/131609283.
|
||||||
|
// Checking this here lets callers of this cpu_backend_gemm library be
|
||||||
|
// safe regardless of backends.
|
||||||
|
if (quantization_flavor != QuantizationFlavor::kFloatingPoint) {
|
||||||
|
TFLITE_DCHECK(
|
||||||
|
lhs_params.zero_point != std::numeric_limits<LhsScalar>::lowest() ||
|
||||||
|
rhs_params.zero_point != std::numeric_limits<RhsScalar>::lowest());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cpu_backend_gemm
|
} // namespace cpu_backend_gemm
|
||||||
|
@ -61,16 +61,14 @@ void MakeRuySpec(const GemmParamsType& params, RuySpecType* ruy_spec) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
typename DstScalar>
|
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
struct GemmImplUsingRuy {
|
struct GemmImplUsingRuy {
|
||||||
static void Run(const MatrixParams<LhsScalar>& lhs_params,
|
static void Run(
|
||||||
const LhsScalar* lhs_data,
|
const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||||
const MatrixParams<RhsScalar>& rhs_params,
|
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
|
||||||
const RhsScalar* rhs_data,
|
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||||
const MatrixParams<DstScalar>& dst_params,
|
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||||
DstScalar* dst_data,
|
CpuBackendContext* context) {
|
||||||
const GemmParams<AccumScalar, DstScalar>& params,
|
|
||||||
CpuBackendContext* context) {
|
|
||||||
ruy::Matrix<LhsScalar> ruy_lhs;
|
ruy::Matrix<LhsScalar> ruy_lhs;
|
||||||
ruy::Matrix<RhsScalar> ruy_rhs;
|
ruy::Matrix<RhsScalar> ruy_rhs;
|
||||||
ruy::Matrix<DstScalar> ruy_dst;
|
ruy::Matrix<DstScalar> ruy_dst;
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "tensorflow/lite/experimental/ruy/ruy.h"
|
#include "tensorflow/lite/experimental/ruy/ruy.h"
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ namespace {
|
|||||||
using cpu_backend_gemm::Gemm;
|
using cpu_backend_gemm::Gemm;
|
||||||
using cpu_backend_gemm::GemmParams;
|
using cpu_backend_gemm::GemmParams;
|
||||||
using cpu_backend_gemm::MatrixParams;
|
using cpu_backend_gemm::MatrixParams;
|
||||||
|
using cpu_backend_gemm::QuantizationFlavor;
|
||||||
|
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
std::string ToString(const std::vector<Scalar>& vector) {
|
std::string ToString(const std::vector<Scalar>& vector) {
|
||||||
@ -125,9 +127,11 @@ void Clamp(const std::vector<Scalar>& src, Scalar clamp_min, Scalar clamp_max,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AccumScalar, typename DstScalar>
|
template <typename AccumScalar, typename DstScalar,
|
||||||
void Clamp(const GemmParams<AccumScalar, DstScalar>& src, DstScalar clamp_min,
|
QuantizationFlavor quantization_flavor>
|
||||||
DstScalar clamp_max, GemmParams<AccumScalar, DstScalar>* dst) {
|
void Clamp(const GemmParams<AccumScalar, DstScalar, quantization_flavor>& src,
|
||||||
|
DstScalar clamp_min, DstScalar clamp_max,
|
||||||
|
GemmParams<AccumScalar, DstScalar, quantization_flavor>* dst) {
|
||||||
*dst = src;
|
*dst = src;
|
||||||
dst->clamp_min = clamp_min;
|
dst->clamp_min = clamp_min;
|
||||||
dst->clamp_max = clamp_max;
|
dst->clamp_max = clamp_max;
|
||||||
@ -236,14 +240,14 @@ void CheckErrorForAccumulation(int accumulation_depth,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
typename DstScalar>
|
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
void PerformGemmThenCompareResultsThenAgainWithClamping(
|
void PerformGemmThenCompareResultsThenAgainWithClamping(
|
||||||
const MatrixParams<LhsScalar>& lhs_params,
|
const MatrixParams<LhsScalar>& lhs_params,
|
||||||
const std::vector<LhsScalar>& lhs_data,
|
const std::vector<LhsScalar>& lhs_data,
|
||||||
const MatrixParams<RhsScalar>& rhs_params,
|
const MatrixParams<RhsScalar>& rhs_params,
|
||||||
const std::vector<RhsScalar>& rhs_data,
|
const std::vector<RhsScalar>& rhs_data,
|
||||||
const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
|
const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
|
||||||
const GemmParams<AccumScalar, DstScalar>& params,
|
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||||
const std::vector<DstScalar>& expected,
|
const std::vector<DstScalar>& expected,
|
||||||
CpuBackendContext* cpu_backend_context) {
|
CpuBackendContext* cpu_backend_context) {
|
||||||
const int accumulation_depth = lhs_params.cols;
|
const int accumulation_depth = lhs_params.cols;
|
||||||
@ -253,7 +257,7 @@ void PerformGemmThenCompareResultsThenAgainWithClamping(
|
|||||||
expected);
|
expected);
|
||||||
DstScalar expected_median = Median(expected);
|
DstScalar expected_median = Median(expected);
|
||||||
std::vector<DstScalar> expected_with_clamp;
|
std::vector<DstScalar> expected_with_clamp;
|
||||||
GemmParams<AccumScalar, DstScalar> params_with_clamp;
|
GemmParams<AccumScalar, DstScalar, quantization_flavor> params_with_clamp;
|
||||||
DstScalar clamp_min, clamp_max;
|
DstScalar clamp_min, clamp_max;
|
||||||
|
|
||||||
clamp_min = std::numeric_limits<DstScalar>::lowest();
|
clamp_min = std::numeric_limits<DstScalar>::lowest();
|
||||||
@ -453,9 +457,14 @@ void TestSomeGemm(int rows, int depth, int cols,
|
|||||||
rows, params.multiplier_fixedpoint);
|
rows, params.multiplier_fixedpoint);
|
||||||
std::vector<int> multiplier_exponent_perchannel(rows,
|
std::vector<int> multiplier_exponent_perchannel(rows,
|
||||||
params.multiplier_exponent);
|
params.multiplier_exponent);
|
||||||
GemmParams<AccumScalar, DstScalar> params_perchannel = params;
|
static constexpr QuantizationFlavor perchannel_flavor =
|
||||||
params_perchannel.multiplier_fixedpoint = 0;
|
std::is_floating_point<AccumScalar>::value
|
||||||
params_perchannel.multiplier_exponent = 0;
|
? QuantizationFlavor::kFloatingPoint
|
||||||
|
: QuantizationFlavor::kIntegerWithPerRowMultiplier;
|
||||||
|
GemmParams<AccumScalar, DstScalar, perchannel_flavor> params_perchannel;
|
||||||
|
params_perchannel.bias = params.bias;
|
||||||
|
params_perchannel.clamp_min = params.clamp_min;
|
||||||
|
params_perchannel.clamp_max = params.clamp_max;
|
||||||
params_perchannel.multiplier_fixedpoint_perchannel =
|
params_perchannel.multiplier_fixedpoint_perchannel =
|
||||||
multiplier_fixedpoint_perchannel.data();
|
multiplier_fixedpoint_perchannel.data();
|
||||||
params_perchannel.multiplier_exponent_perchannel =
|
params_perchannel.multiplier_exponent_perchannel =
|
||||||
|
Loading…
x
Reference in New Issue
Block a user