Adapt to upcoming ruy::MulParams change where only the fields meaningful in each case will exist.
PiperOrigin-RevId: 318640066 Change-Id: Ic9bee87edde887859c9e6cbc53ebe3291978c77c
This commit is contained in:
parent
7585042327
commit
9008adb5e2
@ -361,6 +361,7 @@ cc_library(
|
||||
"@ruy//ruy",
|
||||
"@ruy//ruy:matrix",
|
||||
"@ruy//ruy:path",
|
||||
"@ruy//ruy:mul_params",
|
||||
"@ruy//ruy/profiler:instrumentation",
|
||||
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
|
||||
# is false, but putting these dependencies in a select() seems to
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
||||
|
||||
#include "ruy/matrix.h" // from @ruy
|
||||
#include "ruy/mul_params.h" // from @ruy
|
||||
#include "ruy/ruy.h" // from @ruy
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||
@ -57,23 +58,65 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmParamsType, typename RuySpecType>
|
||||
void MakeRuyMulParams(const GemmParamsType& params,
|
||||
RuySpecType* ruy_mul_params) {
|
||||
// This validation has already been performed by the Gemm API entry point,
|
||||
// but it doesn't hurt to test specifically this again here, where it's
|
||||
// being used.
|
||||
ValidateGemmParams(params);
|
||||
// Floating-point case.
|
||||
template <typename AccumScalar, typename DstScalar,
|
||||
QuantizationFlavor quantization_flavor>
|
||||
struct MakeRuyMulParamsImpl final {
|
||||
static void Run(
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||
ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
|
||||
static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint,
|
||||
"");
|
||||
ruy_mul_params->set_bias(params.bias);
|
||||
ruy_mul_params->set_clamp_min(params.clamp_min);
|
||||
ruy_mul_params->set_clamp_max(params.clamp_max);
|
||||
}
|
||||
};
|
||||
|
||||
ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
|
||||
ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
|
||||
ruy_mul_params->set_multiplier_fixedpoint_perchannel(
|
||||
params.multiplier_fixedpoint_perchannel);
|
||||
ruy_mul_params->set_multiplier_exponent_perchannel(
|
||||
params.multiplier_exponent_perchannel);
|
||||
ruy_mul_params->set_bias(params.bias);
|
||||
ruy_mul_params->set_clamp_min(params.clamp_min);
|
||||
ruy_mul_params->set_clamp_max(params.clamp_max);
|
||||
// Integer-quantized case with destination type narrower than int32
|
||||
template <typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||
struct MakeRuyMulParamsImpl<std::int32_t, DstScalar, quantization_flavor>
|
||||
final {
|
||||
static void Run(
|
||||
const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
|
||||
ruy::MulParams<std::int32_t, DstScalar>* ruy_mul_params) {
|
||||
static_assert(sizeof(DstScalar) < sizeof(std::int32_t), "");
|
||||
if (quantization_flavor ==
|
||||
QuantizationFlavor::kIntegerWithUniformMultiplier) {
|
||||
ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
|
||||
ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
|
||||
}
|
||||
if (quantization_flavor ==
|
||||
QuantizationFlavor::kIntegerWithPerRowMultiplier) {
|
||||
ruy_mul_params->set_multiplier_fixedpoint_perchannel(
|
||||
params.multiplier_fixedpoint_perchannel);
|
||||
ruy_mul_params->set_multiplier_exponent_perchannel(
|
||||
params.multiplier_exponent_perchannel);
|
||||
}
|
||||
ruy_mul_params->set_bias(params.bias);
|
||||
ruy_mul_params->set_clamp_min(params.clamp_min);
|
||||
ruy_mul_params->set_clamp_max(params.clamp_max);
|
||||
}
|
||||
};
|
||||
|
||||
// Raw-integer case with destination type int32.
|
||||
template <QuantizationFlavor quantization_flavor>
|
||||
struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t, quantization_flavor>
|
||||
final {
|
||||
static void Run(
|
||||
const GemmParams<std::int32_t, std::int32_t, quantization_flavor>& params,
|
||||
ruy::MulParams<std::int32_t, std::int32_t>* ruy_mul_params) {
|
||||
ruy_mul_params->set_bias(params.bias);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AccumScalar, typename DstScalar,
|
||||
QuantizationFlavor quantization_flavor>
|
||||
void MakeRuyMulParams(
|
||||
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||
ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
|
||||
MakeRuyMulParamsImpl<AccumScalar, DstScalar, quantization_flavor>::Run(
|
||||
params, ruy_mul_params);
|
||||
}
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||
|
Loading…
Reference in New Issue
Block a user