Adapt to upcoming ruy::MulParams change where only the fields meaningful in each case will exist.

PiperOrigin-RevId: 318640066
Change-Id: Ic9bee87edde887859c9e6cbc53ebe3291978c77c
This commit is contained in:
Benoit Jacob 2020-06-27 12:07:56 -07:00 committed by TensorFlower Gardener
parent 7585042327
commit 9008adb5e2
2 changed files with 60 additions and 16 deletions
tensorflow/lite/kernels

View File

@ -361,6 +361,7 @@ cc_library(
"@ruy//ruy", "@ruy//ruy",
"@ruy//ruy:matrix", "@ruy//ruy:matrix",
"@ruy//ruy:path", "@ruy//ruy:path",
"@ruy//ruy:mul_params",
"@ruy//ruy/profiler:instrumentation", "@ruy//ruy/profiler:instrumentation",
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy # We only need to depend on gemmlowp and Eigen when tflite_with_ruy
# is false, but putting these dependencies in a select() seems to # is false, but putting these dependencies in a select() seems to

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_ #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
#include "ruy/matrix.h" // from @ruy #include "ruy/matrix.h" // from @ruy
#include "ruy/mul_params.h" // from @ruy
#include "ruy/ruy.h" // from @ruy #include "ruy/ruy.h" // from @ruy
#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
@ -57,23 +58,65 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
} }
} }
template <typename GemmParamsType, typename RuySpecType> // Floating-point case.
void MakeRuyMulParams(const GemmParamsType& params, template <typename AccumScalar, typename DstScalar,
RuySpecType* ruy_mul_params) { QuantizationFlavor quantization_flavor>
// This validation has already been performed by the Gemm API entry point, struct MakeRuyMulParamsImpl final {
// but it doesn't hurt to test specifically this again here, where it's static void Run(
// being used. const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
ValidateGemmParams(params); ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint,
"");
ruy_mul_params->set_bias(params.bias);
ruy_mul_params->set_clamp_min(params.clamp_min);
ruy_mul_params->set_clamp_max(params.clamp_max);
}
};
ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint); // Integer-quantized case with destination type narrower than int32
ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent); template <typename DstScalar, QuantizationFlavor quantization_flavor>
ruy_mul_params->set_multiplier_fixedpoint_perchannel( struct MakeRuyMulParamsImpl<std::int32_t, DstScalar, quantization_flavor>
params.multiplier_fixedpoint_perchannel); final {
ruy_mul_params->set_multiplier_exponent_perchannel( static void Run(
params.multiplier_exponent_perchannel); const GemmParams<std::int32_t, DstScalar, quantization_flavor>& params,
ruy_mul_params->set_bias(params.bias); ruy::MulParams<std::int32_t, DstScalar>* ruy_mul_params) {
ruy_mul_params->set_clamp_min(params.clamp_min); static_assert(sizeof(DstScalar) < sizeof(std::int32_t), "");
ruy_mul_params->set_clamp_max(params.clamp_max); if (quantization_flavor ==
QuantizationFlavor::kIntegerWithUniformMultiplier) {
ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
}
if (quantization_flavor ==
QuantizationFlavor::kIntegerWithPerRowMultiplier) {
ruy_mul_params->set_multiplier_fixedpoint_perchannel(
params.multiplier_fixedpoint_perchannel);
ruy_mul_params->set_multiplier_exponent_perchannel(
params.multiplier_exponent_perchannel);
}
ruy_mul_params->set_bias(params.bias);
ruy_mul_params->set_clamp_min(params.clamp_min);
ruy_mul_params->set_clamp_max(params.clamp_max);
}
};
// Raw-integer case with destination type int32.
template <QuantizationFlavor quantization_flavor>
struct MakeRuyMulParamsImpl<std::int32_t, std::int32_t, quantization_flavor>
final {
static void Run(
const GemmParams<std::int32_t, std::int32_t, quantization_flavor>& params,
ruy::MulParams<std::int32_t, std::int32_t>* ruy_mul_params) {
ruy_mul_params->set_bias(params.bias);
}
};
template <typename AccumScalar, typename DstScalar,
QuantizationFlavor quantization_flavor>
void MakeRuyMulParams(
const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params,
ruy::MulParams<AccumScalar, DstScalar>* ruy_mul_params) {
MakeRuyMulParamsImpl<AccumScalar, DstScalar, quantization_flavor>::Run(
params, ruy_mul_params);
} }
template <typename LhsScalar, typename RhsScalar, typename AccumScalar, template <typename LhsScalar, typename RhsScalar, typename AccumScalar,