Adapt to upcoming ruy::MulParams change where only the fields meaningful in each case will exist.
PiperOrigin-RevId: 318640066 Change-Id: Ic9bee87edde887859c9e6cbc53ebe3291978c77c
This commit is contained in:
parent
7585042327
commit
9008adb5e2
tensorflow/lite/kernels
@ -361,6 +361,7 @@ cc_library(
|
|||||||
"@ruy//ruy",
|
"@ruy//ruy",
|
||||||
"@ruy//ruy:matrix",
|
"@ruy//ruy:matrix",
|
||||||
"@ruy//ruy:path",
|
"@ruy//ruy:path",
|
||||||
|
"@ruy//ruy:mul_params",
|
||||||
"@ruy//ruy/profiler:instrumentation",
|
"@ruy//ruy/profiler:instrumentation",
|
||||||
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
|
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
|
||||||
# is false, but putting these dependencies in a select() seems to
|
# is false, but putting these dependencies in a select() seems to
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
||||||
|
|
||||||
#include "ruy/matrix.h" // from @ruy
|
#include "ruy/matrix.h" // from @ruy
|
||||||
|
#include "ruy/mul_params.h" // from @ruy
|
||||||
#include "ruy/ruy.h" // from @ruy
|
#include "ruy/ruy.h" // from @ruy
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||||
@ -57,24 +58,66 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename GemmParamsType, typename RuySpecType>
|
// Floating-point case.
|
||||||
void MakeRuyMulParams(const GemmParamsType& params,
|
template <typename AccumScalar, typename DstScalar,
|
||||||
RuySpecType* ruy_mul_params) {
|
QuantizationFlavor quantization_flavor>
|
||||||
// This validation has already been performed by the Gemm API entry point,
|
struct MakeRuyMulParamsImpl final {
|
||||||
// but it doesn't hurt to test specifically this again here, where it's
|
static void Run(
|
||||||
// being used.
|
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||||
ValidateGemmParams(params);
|
ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
|
||||||
|
static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint,
|
||||||
|
"");
|
||||||
|
ruy_mul_params->set_bias(params.bias);
|
||||||
|
ruy_mul_params->set_clamp_min(params.clamp_min);
|
||||||
|
ruy_mul_params->set_clamp_max(params.clamp_max);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Integer-quantized case with destination type narrower than int32
|
||||||
|
template <typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
|
struct MakeRuyMulParamsImpl<std::int32_t, DstScalar, quantization_flavor>
|
||||||
|
final {
|
||||||
|
static void Run(
|
||||||
|
const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
|
||||||
|
ruy::MulParams<std::int32_t, DstScalar>* ruy_mul_params) {
|
||||||
|
static_assert(sizeof(DstScalar) < sizeof(std::int32_t), "");
|
||||||
|
if (quantization_flavor ==
|
||||||
|
QuantizationFlavor::kIntegerWithUniformMultiplier) {
|
||||||
ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
|
ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
|
||||||
ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
|
ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
|
||||||
|
}
|
||||||
|
if (quantization_flavor ==
|
||||||
|
QuantizationFlavor::kIntegerWithPerRowMultiplier) {
|
||||||
ruy_mul_params->set_multiplier_fixedpoint_perchannel(
|
ruy_mul_params->set_multiplier_fixedpoint_perchannel(
|
||||||
params.multiplier_fixedpoint_perchannel);
|
params.multiplier_fixedpoint_perchannel);
|
||||||
ruy_mul_params->set_multiplier_exponent_perchannel(
|
ruy_mul_params->set_multiplier_exponent_perchannel(
|
||||||
params.multiplier_exponent_perchannel);
|
params.multiplier_exponent_perchannel);
|
||||||
|
}
|
||||||
ruy_mul_params->set_bias(params.bias);
|
ruy_mul_params->set_bias(params.bias);
|
||||||
ruy_mul_params->set_clamp_min(params.clamp_min);
|
ruy_mul_params->set_clamp_min(params.clamp_min);
|
||||||
ruy_mul_params->set_clamp_max(params.clamp_max);
|
ruy_mul_params->set_clamp_max(params.clamp_max);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Raw-integer case with destination type int32.
|
||||||
|
template <QuantizationFlavor quantization_flavor>
|
||||||
|
struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t, quantization_flavor>
|
||||||
|
final {
|
||||||
|
static void Run(
|
||||||
|
const GemmParams<std::int32_t, std::int32_t, quantization_flavor>& params,
|
||||||
|
ruy::MulParams<std::int32_t, std::int32_t>* ruy_mul_params) {
|
||||||
|
ruy_mul_params->set_bias(params.bias);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename AccumScalar, typename DstScalar,
|
||||||
|
QuantizationFlavor quantization_flavor>
|
||||||
|
void MakeRuyMulParams(
|
||||||
|
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
|
||||||
|
ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
|
||||||
|
MakeRuyMulParamsImpl<AccumScalar, DstScalar, quantization_flavor>::Run(
|
||||||
|
params, ruy_mul_params);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
typename DstScalar, QuantizationFlavor quantization_flavor>
|
typename DstScalar, QuantizationFlavor quantization_flavor>
|
||||||
|
Loading…
Reference in New Issue
Block a user