Adapt to upcoming ruy::MulParams change where only the fields meaningful in each case will exist.

PiperOrigin-RevId: 318640066
Change-Id: Ic9bee87edde887859c9e6cbc53ebe3291978c77c
This commit is contained in:
Benoit Jacob 2020-06-27 12:07:56 -07:00 committed by TensorFlower Gardener
parent 7585042327
commit 9008adb5e2
2 changed files with 60 additions and 16 deletions

View File

@ -361,6 +361,7 @@ cc_library(
"@ruy//ruy",
"@ruy//ruy:matrix",
"@ruy//ruy:path",
"@ruy//ruy:mul_params",
"@ruy//ruy/profiler:instrumentation",
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
# is false, but putting these dependencies in a select() seems to

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
#include "ruy/matrix.h" // from @ruy
#include "ruy/mul_params.h" // from @ruy
#include "ruy/ruy.h" // from @ruy
#include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
@ -57,23 +58,65 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
}
}
template <typename GemmParamsType, typename RuySpecType>
void MakeRuyMulParams(const GemmParamsType& params,
RuySpecType* ruy_mul_params) {
// This validation has already been performed by the Gemm API entry point,
// but it doesn't hurt to test specifically this again here, where it's
// being used.
ValidateGemmParams(params);
// Floating-point case.
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
struct MakeRuyMulParamsImpl final {
static void Run(
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint,
"");
ruy_mul_params->set_bias(params.bias);
ruy_mul_params->set_clamp_min(params.clamp_min);
ruy_mul_params->set_clamp_max(params.clamp_max);
}
};
ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
ruy_mul_params->set_multiplier_fixedpoint_perchannel(
params.multiplier_fixedpoint_perchannel);
ruy_mul_params->set_multiplier_exponent_perchannel(
params.multiplier_exponent_perchannel);
ruy_mul_params->set_bias(params.bias);
ruy_mul_params->set_clamp_min(params.clamp_min);
ruy_mul_params->set_clamp_max(params.clamp_max);
// Integer-quantized case with destination type narrower than int32
template <typename DstScalar, QuantizationFlavor quantization_flavor>
struct MakeRuyMulParamsImpl<std::int32_t, DstScalar, quantization_flavor>
final {
static void Run(
const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
ruy::MulParams<std::int32_t, DstScalar>* ruy_mul_params) {
static_assert(sizeof(DstScalar) < sizeof(std::int32_t), "");
if (quantization_flavor ==
QuantizationFlavor::kIntegerWithUniformMultiplier) {
ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
}
if (quantization_flavor ==
QuantizationFlavor::kIntegerWithPerRowMultiplier) {
ruy_mul_params->set_multiplier_fixedpoint_perchannel(
params.multiplier_fixedpoint_perchannel);
ruy_mul_params->set_multiplier_exponent_perchannel(
params.multiplier_exponent_perchannel);
}
ruy_mul_params->set_bias(params.bias);
ruy_mul_params->set_clamp_min(params.clamp_min);
ruy_mul_params->set_clamp_max(params.clamp_max);
}
};
// Raw-integer case with destination type int32.
template <QuantizationFlavor quantization_flavor>
struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t, quantization_flavor>
final {
static void Run(
const GemmParams<std::int32_t, std::int32_t, quantization_flavor>& params,
ruy::MulParams<std::int32_t, std::int32_t>* ruy_mul_params) {
ruy_mul_params->set_bias(params.bias);
}
};
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
void MakeRuyMulParams(
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
MakeRuyMulParamsImpl<AccumScalar, DstScalar, quantization_flavor>::Run(
params, ruy_mul_params);
}
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,