Introduce a :cpu_backend_gemm library allowing to perform all

types of GEMM that TFLite needs, given a CpuBackendContext.

Ruy provides a generic implementation. A subsequent CL
will introduce a config_setting, `with_ruy`, false by default.
When with_ruy is false, the generic ruy-based implementation
will be overridden by a gemmlowp/eigen implementation.

The rationale for having ruy as the generic implementation is
that the current pre-ruy landscape is that some paths have
spotty support, with the new int8/perchannel stuff only supported
in gemmlowp on NEON, forcing call sites across TFLite to have
complicated #if logic. Thanks to ruy providing a generic
implementation in all cases, such call sites won't have to
use #if's to avoid build issues. Call sites may still want to avoid
a code size penalty from generic paths which they don't call, but
they will be able to get that by letting the compiler elide
statically unused branches.

PiperOrigin-RevId: 245451095
This commit is contained in:
Benoit Jacob 2019-04-26 11:03:38 -07:00 committed by TensorFlower Gardener
parent ba4db912ec
commit 71ad025fc4
7 changed files with 841 additions and 6 deletions

View File

@ -256,7 +256,11 @@ cc_library(
"dispatch.h", "dispatch.h",
"impl.h", "impl.h",
], ],
hdrs = ["ruy.h"], hdrs = [
"matrix.h",
"path.h",
"ruy.h",
],
visibility = ruy_visibility(), visibility = ruy_visibility(),
deps = [ deps = [
":allocator", ":allocator",
@ -265,10 +269,8 @@ cc_library(
":common", ":common",
":context", ":context",
":kernel", ":kernel",
":matrix",
":opt_set", ":opt_set",
":pack", ":pack",
":path",
":size_util", ":size_util",
":spec", ":spec",
":thread_pool", ":thread_pool",

View File

@ -3,4 +3,6 @@ Control of ruy visibility
""" """
def ruy_visibility(): def ruy_visibility():
return [] return [
"//tensorflow/lite/kernels:__subpackages__",
]

View File

@ -90,10 +90,40 @@ cc_library(
copts = tflite_copts(), copts = tflite_copts(),
deps = [ deps = [
":op_macros", ":op_macros",
# For now this unconditionally depends on both ruy and gemmlowp.
# See the comment inside class CpuBackendContext on the
# gemmlowp_context_ and ruy_context_ members.
"//tensorflow/lite/experimental/ruy:context",
"@gemmlowp", "@gemmlowp",
], ],
) )
cc_library(
name = "cpu_backend_gemm",
hdrs = [
"cpu_backend_gemm.h",
],
deps = [
":op_macros",
":cpu_backend_context",
# Depend on ruy regardless of `with_ruy`. See the comment in
# cpu_backend_gemm.h about why ruy is the generic path.
"//tensorflow/lite/experimental/ruy",
],
)
cc_test(
name = "cpu_backend_gemm_test",
srcs = ["cpu_backend_gemm_test.cc"],
deps = [
":cpu_backend_gemm",
"@com_google_googletest//:gtest",
# ruy's reference path provides the reference implementation
# that this test compares against.
"//tensorflow/lite/experimental/ruy",
],
)
cc_library( cc_library(
name = "cpu_backend_support", name = "cpu_backend_support",
srcs = [ srcs = [

View File

@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/cpu_backend_context.h"
#include "public/gemmlowp.h" #include "public/gemmlowp.h"
#include "tensorflow/lite/experimental/ruy/context.h"
namespace tflite { namespace tflite {
CpuBackendContext::CpuBackendContext() CpuBackendContext::CpuBackendContext()
: gemmlowp_context_(new gemmlowp::GemmContext) {} : ruy_context_(new ruy::Context),
gemmlowp_context_(new gemmlowp::GemmContext) {}
CpuBackendContext::~CpuBackendContext() {} CpuBackendContext::~CpuBackendContext() {}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "public/gemmlowp.h" #include "public/gemmlowp.h"
#include "tensorflow/lite/experimental/ruy/context.h"
namespace tflite { namespace tflite {
@ -27,6 +28,8 @@ class CpuBackendContext final {
CpuBackendContext(); CpuBackendContext();
~CpuBackendContext(); ~CpuBackendContext();
ruy::Context* ruy_context() const { return ruy_context_.get(); }
gemmlowp::GemmContext* gemmlowp_context() const { gemmlowp::GemmContext* gemmlowp_context() const {
return gemmlowp_context_.get(); return gemmlowp_context_.get();
} }
@ -34,7 +37,13 @@ class CpuBackendContext final {
void set_max_num_threads(int max_num_threads); void set_max_num_threads(int max_num_threads);
private: private:
// gemmlowp context used to implement this CpuBackendContext. // To enable a smooth transition from the current direct usage
// of the underlying gemmlowp context to going through abstractions
// (see :cpu_backend_gemm), for now a CpuBackendContext always
// stores both a gemmlowp context and a ruy context.
// TODO(b/131416458): Once call sites all go through abstractions,
// elide what can be elided based on TFLITE_WITH_RUY.
const std::unique_ptr<ruy::Context> ruy_context_;
const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_; const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
CpuBackendContext(const CpuBackendContext&) = delete; CpuBackendContext(const CpuBackendContext&) = delete;

View File

@ -0,0 +1,187 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
#include <cstdint>
#include <limits>
#include <type_traits>
#include "tensorflow/lite/experimental/ruy/path.h"
#include "tensorflow/lite/experimental/ruy/ruy.h"
#include "tensorflow/lite/kernels/cpu_backend_context.h"
namespace tflite {
namespace cpu_backend_gemm {
// Matrix storage order: column-major or row-major.
enum class Order { kColMajor, kRowMajor };
// MatrixParams encapsulates the parameters that Gemm needs about each
// matrix, besides the buffer data pointer.
// Compare to ruy::Matrix, which also encapsulates the data pointer.
// Rationale for leaving the data pointer out of here: doing so
// requires complicated const-correctness mechanics. See
// ruy::ConstCheckingPtr.
template <typename Scalar>
struct MatrixParams {
// Storage layout order. For now we only do plain linear non-strided
// layout. It would be easy to support a stride if needed.
Order order = Order::kColMajor;
// Number of rows of the matrix.
int rows = 0;
// Number of columns of the matrix.
int cols = 0;
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
// When Scalar is floating-point, this must be 0.
Scalar zero_point = 0;
};
// Additional parameters that Gemm needs, beyond what falls into
// the MatrixParams that it takes. Compare to ruy::Spec.
//
// Decoupling AccumScalar from DstScalar (rather than deducing it from that)
// is useful future-proofing. Think of a float16 path using float32 accum.
template <typename AccumScalar, typename DstScalar>
struct GemmParams {
// Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
// of the multiplier by which accumulators are multiplied before being casted
// to the destination type.
AccumScalar multiplier_fixedpoint = 0;
// Only for non-floating-point cases. The exponent part of the aforementioned
// multiplier.
int multiplier_exponent = 0;
// Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
// point to a buffer of as many values as there are rows in the destination
// matrix. Each row of the destination matrix will use the corresponding
// buffer element instead of multiplier_fixedpoint.
const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
// Per-channel variant of multiplier_exponent. If not nullptr, this must
// point to a buffer of as many values as there are rows in the destination
// matrix. Each row of the destination matrix will use the corresponding
// buffer element instead of multiplier_exponent.
//
// Either none or both of multiplier_exponent_perchannel and
// multiplier_fixedpoint_perchannel must be nullptr.
const int* multiplier_exponent_perchannel = nullptr;
// The bias vector data, if not null.
const AccumScalar* bias = nullptr;
// min clamp bound of destination values.
DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
// max clamp bound of destination values.
DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
};
/* Convenience typedefs */
template <typename DstScalar>
using QuantizedGemmParams = GemmParams<std::int32_t, DstScalar>;
using FloatGemmParams = GemmParams<float, float>;
/* Generic implementation using ruy */
namespace detail {
template <typename Scalar, typename DataPointer>
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
ruy::Matrix<Scalar>* dst) {
dst->layout.rows = params.rows;
dst->layout.cols = params.cols;
if (params.order == Order::kColMajor) {
dst->layout.order = ruy::Order::kColMajor;
dst->layout.stride = params.rows;
} else {
dst->layout.order = ruy::Order::kRowMajor;
dst->layout.stride = params.cols;
}
// Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
// It does care whether we assign to it a Scalar* or a const Scalar*.
dst->data = data_ptr;
dst->zero_point = params.zero_point;
}
template <typename GemmParamsType, typename RuySpecType>
void MakeRuySpec(const GemmParamsType& params, RuySpecType* ruy_spec) {
if (std::is_floating_point<typename RuySpecType::AccumScalar>::value) {
TF_LITE_ASSERT(!params.multiplier_fixedpoint);
TF_LITE_ASSERT(!params.multiplier_exponent);
TF_LITE_ASSERT(!params.multiplier_fixedpoint_perchannel);
TF_LITE_ASSERT(!params.multiplier_exponent_perchannel);
} else {
TF_LITE_ASSERT((params.multiplier_fixedpoint == 0) !=
(params.multiplier_fixedpoint_perchannel == nullptr));
}
ruy_spec->multiplier_fixedpoint = params.multiplier_fixedpoint;
ruy_spec->multiplier_exponent = params.multiplier_exponent;
ruy_spec->multiplier_fixedpoint_perchannel =
params.multiplier_fixedpoint_perchannel;
ruy_spec->multiplier_exponent_perchannel =
params.multiplier_exponent_perchannel;
ruy_spec->bias = params.bias;
ruy_spec->clamp_min = params.clamp_min;
ruy_spec->clamp_max = params.clamp_max;
}
} // namespace detail
// Non-ruy implementation will be partial specializations of this template.
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
struct GemmImpl {
static void Run(const MatrixParams<LhsScalar>& lhs_params,
const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params,
const RhsScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params,
DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar>& params,
CpuBackendContext* context) {
ruy::Matrix<LhsScalar> ruy_lhs;
ruy::Matrix<RhsScalar> ruy_rhs;
ruy::Matrix<DstScalar> ruy_dst;
detail::MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs);
detail::MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs);
detail::MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
ruy::BasicSpec<AccumScalar, DstScalar> ruy_spec;
detail::MakeRuySpec(params, &ruy_spec);
ruy::Mul<ruy::kAllPaths>(ruy_lhs, ruy_rhs, ruy_spec, context->ruy_context(),
&ruy_dst);
}
};
/* Public entry point */
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar>& params,
CpuBackendContext* context) {
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Run(
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params,
context);
}
} // namespace cpu_backend_gemm
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_

View File

@ -0,0 +1,603 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/cpu_backend_gemm.h"
#include <algorithm>
#include <cstdarg>
#include <limits>
#include <random>
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
#include <gtest/gtest.h>
#include "tensorflow/lite/experimental/ruy/ruy.h"
namespace tflite {
namespace {
using cpu_backend_gemm::Gemm;
using cpu_backend_gemm::GemmParams;
using cpu_backend_gemm::MatrixParams;
template <typename Scalar>
std::string ToString(const std::vector<Scalar>& vector) {
std::stringstream s;
if (vector.empty()) {
s << "{}";
} else {
s << "{ " << static_cast<double>(vector[0]);
for (int i = 1; i < vector.size(); i++) {
s << ", " << static_cast<double>(vector[i]);
}
s << "}";
}
return s.str();
}
template <typename Scalar>
void MakeDeterministicPseudoRandomVector(int size,
std::vector<Scalar>* vector) {
// Intentionally create a new local random_engine in each invocation,
// so pseudorandom values don't depend on invocation order.
// Otherwise, test results would be affecting by e.g. filtering.
std::default_random_engine random_engine;
(void)random_engine();
// Do not use std::uniform*_distribution: the values that it
// generates are implementation-defined.
const double random_min = static_cast<double>(random_engine.min());
const double random_max = static_cast<double>(random_engine.max());
const double result_min =
std::is_floating_point<Scalar>::value
? -1.0
: std::max(-256., static_cast<double>(
std::numeric_limits<Scalar>::lowest()));
const double result_max =
std::is_floating_point<Scalar>::value
? 1.0
: std::min(256.,
static_cast<double>(std::numeric_limits<Scalar>::max()));
const double random_scale =
(result_max - result_min) / (random_max - random_min);
vector->resize(size);
for (int i = 0; i < size; i++) {
double val = random_scale * (random_engine() - random_min);
val = std::max(val,
static_cast<double>(std::numeric_limits<Scalar>::lowest()));
val =
std::min(val, static_cast<double>(std::numeric_limits<Scalar>::max()));
(*vector)[i] = static_cast<Scalar>(val);
}
}
template <typename Scalar>
void MakeVectorFilledWithConsecutiveInts(int size,
std::vector<Scalar>* vector) {
vector->resize(size);
EXPECT_LE(size, std::numeric_limits<Scalar>::max());
for (int i = 0; i < size; i++) {
(*vector)[i] = static_cast<Scalar>(i + 1);
}
}
template <typename Scalar>
Scalar Median(const std::vector<Scalar>& vector) {
EXPECT_GT(vector.size(), 0);
std::vector<Scalar> vector_copy = vector;
std::sort(std::begin(vector_copy), std::end(vector_copy));
return vector_copy[vector_copy.size() / 2];
}
template <typename Scalar>
double MedianAbs(const std::vector<Scalar>& vector) {
EXPECT_GT(vector.size(), 0);
std::vector<double> vector_abs;
vector_abs.resize(vector.size());
for (int i = 0; i < vector.size(); i++) {
vector_abs[i] = std::abs(static_cast<double>(vector[i]));
}
std::sort(std::begin(vector_abs), std::end(vector_abs));
return vector_abs[vector_abs.size() / 2];
}
template <typename Scalar>
void Clamp(const std::vector<Scalar>& src, Scalar clamp_min, Scalar clamp_max,
std::vector<Scalar>* dst) {
dst->resize(src.size());
for (int i = 0; i < src.size(); i++) {
(*dst)[i] = std::max(std::min(src[i], clamp_max), clamp_min);
}
}
template <typename AccumScalar, typename DstScalar>
void Clamp(const GemmParams<AccumScalar, DstScalar>& src, DstScalar clamp_min,
DstScalar clamp_max, GemmParams<AccumScalar, DstScalar>* dst) {
*dst = src;
dst->clamp_min = clamp_min;
dst->clamp_max = clamp_max;
}
struct ErrorStats {
int size;
double scale_factor;
double max_abs_diff;
double mean_abs_diff;
double abs_mean_diff;
};
template <typename Scalar>
void ComputeErrorStats(const std::vector<Scalar>& actual,
const std::vector<Scalar>& expected,
ErrorStats* error_stats) {
double max_abs_diff = 0;
double sum_abs_diff = 0;
double sum_diff = 0;
double max_abs_expected = 0;
EXPECT_EQ(actual.size(), expected.size());
for (int i = 0; i < actual.size(); i++) {
double actual_val = static_cast<double>(actual[i]);
double expected_val = static_cast<double>(expected[i]);
double diff = actual_val - expected_val;
max_abs_expected = std::max(max_abs_expected, std::abs(expected_val));
sum_diff += diff;
sum_abs_diff += std::abs(diff);
max_abs_diff = std::max(max_abs_diff, std::abs(diff));
}
error_stats->scale_factor = max_abs_expected;
error_stats->max_abs_diff = max_abs_diff;
error_stats->mean_abs_diff = sum_abs_diff / actual.size();
error_stats->abs_mean_diff = std::abs(sum_diff / actual.size());
error_stats->size = actual.size();
}
template <typename AccumScalar, typename DstScalar>
bool CheckErrorStats(const ErrorStats& error_stats, int accumulation_depth) {
double tolerated_relative_max_abs_diff = 0;
double tolerated_relative_mean_abs_diff = 0;
double tolerated_relative_abs_mean_diff = 0;
double inverse_size = 1. / error_stats.size;
if (std::is_floating_point<AccumScalar>::value) {
// Somewhat naive requirement: the worst case should be epsilons
// adding up towards the same direction, on values of same magnitude.
tolerated_relative_max_abs_diff =
accumulation_depth * std::numeric_limits<DstScalar>::epsilon();
// Naive interpretation of the Central Limit Theorem is the rationale
// for the sqrt here. We haven't even worked out the correct scale factor,
// or how applicable that theorem is here (the random variables being added
// might not be mutually independent).
tolerated_relative_mean_abs_diff =
std::sqrt(static_cast<double>(accumulation_depth)) *
std::numeric_limits<DstScalar>::epsilon();
// Unbiasing requirement: we require the bias, abs_mean_diff, to be much
// smaller than the mean_abs_diff, except when there are very few values.
tolerated_relative_abs_mean_diff =
tolerated_relative_mean_abs_diff * std::sqrt(inverse_size);
} else {
// In quantized arithmetic, tolerate minor rounding differences, resulting
// in off-by-one errors (tolerated_relative_max_abs_diff = 1), as long
// as they are rare (tolerated_relative_mean_abs_diff) and unbiased
// (tolerated_relative_abs_mean_diff).
tolerated_relative_max_abs_diff = 1;
// Naively require mean_abs_diff and abs_mean_diff to converge to zero
// as size gets large. We don't know at all how quick that convergence
// should be: this is just based on trial-and-error and striking a
// compromise between something that works and something that's simple
// enough code that doesn't feel too ad-hoc. As above in the float path,
// abs_mean_diff is subject to a stricter requirement as it is a bias.
tolerated_relative_mean_abs_diff = std::sqrt(inverse_size);
tolerated_relative_abs_mean_diff = inverse_size;
}
double tolerated_max_abs_diff =
tolerated_relative_max_abs_diff * error_stats.scale_factor;
double tolerated_mean_abs_diff =
tolerated_relative_mean_abs_diff * error_stats.scale_factor;
double tolerated_abs_mean_diff =
tolerated_relative_abs_mean_diff * error_stats.scale_factor;
EXPECT_LE(error_stats.max_abs_diff, tolerated_max_abs_diff);
EXPECT_LE(error_stats.mean_abs_diff, tolerated_mean_abs_diff);
EXPECT_LE(error_stats.abs_mean_diff, tolerated_abs_mean_diff);
return error_stats.max_abs_diff <= tolerated_max_abs_diff &&
error_stats.mean_abs_diff <= tolerated_mean_abs_diff &&
error_stats.abs_mean_diff <= tolerated_abs_mean_diff;
}
template <typename AccumScalar, typename DstScalar>
void CheckErrorForAccumulation(int accumulation_depth,
const std::vector<DstScalar>& actual,
const std::vector<DstScalar>& expected) {
ErrorStats error_stats;
ComputeErrorStats(actual, expected, &error_stats);
bool success =
CheckErrorStats<AccumScalar, DstScalar>(error_stats, accumulation_depth);
EXPECT_TRUE(success) << "Actual vector\n"
<< ToString(actual) << "\ndiffers from expected vector\n"
<< ToString(expected) << "\n";
}
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
void PerformGemmThenCompareResultsThenAgainWithClamping(
const MatrixParams<LhsScalar>& lhs_params,
const std::vector<LhsScalar>& lhs_data,
const MatrixParams<RhsScalar>& rhs_params,
const std::vector<RhsScalar>& rhs_data,
const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
const GemmParams<AccumScalar, DstScalar>& params,
const std::vector<DstScalar>& expected,
CpuBackendContext* cpu_backend_context) {
const int accumulation_depth = lhs_params.cols;
Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
dst_data->data(), params, cpu_backend_context);
CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
expected);
DstScalar expected_median = Median(expected);
std::vector<DstScalar> expected_with_clamp;
GemmParams<AccumScalar, DstScalar> params_with_clamp;
DstScalar clamp_min, clamp_max;
clamp_min = std::numeric_limits<DstScalar>::lowest();
clamp_max = expected_median;
Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
Clamp(params, clamp_min, clamp_max, &params_with_clamp);
Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
dst_data->data(), params_with_clamp, cpu_backend_context);
CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
expected_with_clamp);
clamp_min = expected_median;
clamp_max = std::numeric_limits<DstScalar>::max();
Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
Clamp(params, clamp_min, clamp_max, &params_with_clamp);
Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
dst_data->data(), params_with_clamp, cpu_backend_context);
CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
expected_with_clamp);
}
// When generating testcases for a quantized GEMM, it's not trivial to
// pick multiplier exponents: a too low value will result in too many zeros,
// a too high value will result in too many large clamped values, in both
// cases testing coverage is harmed. Therefore to ensure good testing coverage
// we must find a multiplier exponent that's just right. It would be possible
// to do so by analysis of the random distribution of values in the result
// matrix. That however would require some mathematical work that we haven't
// done so far. Until that is done, the best that we can do is to search for
// a good exponent value by trial-and-error. This is expensive, as each try
// requires computing a whole GEMM. This is thus probably a major contribution
// to the overall latency of this tesat. To partially mitigate that,
// we use a bisection to reduce the required number of tries.
//
// This function is recursive. The bisect_min and bisect_max arguments
// are the current bisection bounds. It performs a Gemm with the mid-point,
// named bisect_mid, as the multiplier exponent. Based on whether the values
// in the resulting matrix are rather too low or too large in absolute
// value, it then recurses into the corresponding half of the bisection range.
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
int BisectReasonableMultiplierExponent(
int bisect_min, int bisect_max, const MatrixParams<LhsScalar>& lhs_params,
const std::vector<LhsScalar>& lhs_data,
const MatrixParams<RhsScalar>& rhs_params,
const std::vector<RhsScalar>& rhs_data,
const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
const GemmParams<AccumScalar, DstScalar>& params,
CpuBackendContext* cpu_backend_context) {
if (bisect_min == bisect_max) {
return bisect_min;
}
// Compute the midpoint as the floor of the average of bisect_min and
// bisect_max. As C++ integer division is rounding towards zero and our values
// may be of any sign, it is not trivial to implement this using only integer
// arithmetic.
int bisect_mid =
static_cast<int>(std::floor(0.5 * (bisect_min + bisect_max)));
GemmParams<AccumScalar, DstScalar> params_copy(params);
params_copy.multiplier_exponent = bisect_mid;
double clamp_abs = std::max(std::abs(static_cast<double>(params.clamp_min)),
std::abs(static_cast<double>(params.clamp_max)));
Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
dst_data->data(), params_copy, cpu_backend_context);
double median_abs = MedianAbs(*dst_data);
if (median_abs < 0.25 * clamp_abs) {
return BisectReasonableMultiplierExponent(
bisect_mid + 1, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params_copy, cpu_backend_context);
} else {
return BisectReasonableMultiplierExponent(
bisect_min, bisect_mid, lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, dst_data, params_copy, cpu_backend_context);
}
}
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
void ReferenceGemm(const MatrixParams<LhsScalar>& lhs_params,
const LhsScalar* lhs_data,
const MatrixParams<RhsScalar>& rhs_params,
const RhsScalar* rhs_data,
const MatrixParams<DstScalar>& dst_params,
DstScalar* dst_data,
const GemmParams<AccumScalar, DstScalar>& params,
CpuBackendContext* context) {
ruy::Matrix<LhsScalar> ruy_lhs;
ruy::Matrix<RhsScalar> ruy_rhs;
ruy::Matrix<DstScalar> ruy_dst;
cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs);
cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs);
cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
ruy::BasicSpec<AccumScalar, DstScalar> ruy_spec;
cpu_backend_gemm::detail::MakeRuySpec(params, &ruy_spec);
ruy::Mul<ruy::Path::kReference>(ruy_lhs, ruy_rhs, ruy_spec,
context->ruy_context(), &ruy_dst);
}
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
void TestSomeGemm(int rows, int depth, int cols,
const std::vector<DstScalar>& golden) {
CpuBackendContext cpu_backend_context;
std::default_random_engine random_engine;
cpu_backend_context.set_max_num_threads(1 + (random_engine() % 8));
const bool use_golden = !golden.empty();
std::vector<LhsScalar> lhs_data;
std::vector<RhsScalar> rhs_data;
std::vector<AccumScalar> bias_data;
std::vector<DstScalar> dst_data;
if (use_golden) {
MakeVectorFilledWithConsecutiveInts(rows * depth, &lhs_data);
MakeVectorFilledWithConsecutiveInts(depth * cols, &rhs_data);
MakeVectorFilledWithConsecutiveInts(rows, &bias_data);
} else {
MakeDeterministicPseudoRandomVector(rows * depth, &lhs_data);
MakeDeterministicPseudoRandomVector(depth * cols, &rhs_data);
MakeDeterministicPseudoRandomVector(rows, &bias_data);
}
MakeDeterministicPseudoRandomVector(rows * cols, &dst_data);
MatrixParams<LhsScalar> lhs_params;
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.rows = rows;
lhs_params.cols = depth;
if (!use_golden && !std::is_floating_point<LhsScalar>::value) {
lhs_params.zero_point = random_engine() % 8;
}
MatrixParams<RhsScalar> rhs_params;
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
rhs_params.rows = depth;
rhs_params.cols = cols;
if (!use_golden && !std::is_floating_point<RhsScalar>::value) {
rhs_params.zero_point = random_engine() % 8;
}
MatrixParams<DstScalar> dst_params;
dst_params.order = cpu_backend_gemm::Order::kColMajor;
dst_params.rows = rows;
dst_params.cols = cols;
if (!use_golden && !std::is_floating_point<DstScalar>::value) {
dst_params.zero_point = random_engine() % 8;
}
GemmParams<AccumScalar, DstScalar> params;
params.bias = bias_data.data();
if (!std::is_floating_point<AccumScalar>::value) {
// some large int32 value. Not being a multiple of a large
// power of two helps testing rounding behavior.
params.multiplier_fixedpoint = 1234567890;
// Now find a suitable value for multiplier_exponent.
// It needs to be low enough for a substantial amount of dst values
// to avoid getting clamped.
int bisect_min = -8 * static_cast<int>(sizeof(AccumScalar));
// We don't increase test coverage by using positive multipliers,
// and using very large positive multipliers may at the moment
// result in overflow in some paths.
// TODO(benoitjacob): fix that.
int bisect_max = 0;
params.multiplier_exponent = BisectReasonableMultiplierExponent(
bisect_min, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
dst_params, &dst_data, params, &cpu_backend_context);
}
std::vector<DstScalar> expected;
if (use_golden) {
EXPECT_EQ(golden.size(), dst_data.size());
expected = golden;
} else {
expected.resize(dst_data.size());
ReferenceGemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(),
dst_params, expected.data(), params, &cpu_backend_context);
}
PerformGemmThenCompareResultsThenAgainWithClamping(
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data, params,
expected, &cpu_backend_context);
if (!std::is_floating_point<AccumScalar>::value) {
// Try with per-channel quantized multipliers. Just a naive check
// duplicating the same multiplier --- would already catch most bugs.
std::vector<AccumScalar> multiplier_fixedpoint_perchannel(
rows, params.multiplier_fixedpoint);
std::vector<int> multiplier_exponent_perchannel(rows,
params.multiplier_exponent);
GemmParams<AccumScalar, DstScalar> params_perchannel = params;
params_perchannel.multiplier_fixedpoint = 0;
params_perchannel.multiplier_exponent = 0;
params_perchannel.multiplier_fixedpoint_perchannel =
multiplier_fixedpoint_perchannel.data();
params_perchannel.multiplier_exponent_perchannel =
multiplier_exponent_perchannel.data();
PerformGemmThenCompareResultsThenAgainWithClamping(
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data,
params_perchannel, expected, &cpu_backend_context);
}
}
TEST(CpuBackendGemmSimpleTestAgainstGolden, Float) {
TestSomeGemm<float, float, float, float>(2, 3, 4,
{15, 34, 33, 79, 51, 124, 69, 169});
}
TEST(CpuBackendGemmSimpleTestAgainstGolden, Uint8) {
TestSomeGemm<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>(
5, 2, 3, {3, 7, 11, 16, 20, 7, 16, 24, 33, 41, 10, 24, 37, 50, 63});
}
TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8) {
TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int8_t>(
2, 6, 3, {13, 32, 31, 81, 50, 127});
}
TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8Int16) {
TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int16_t>(
3, 5, 4, {32, 76, 120, 75, 191, 306, 118, 306, 493, 162, 421, 680});
}
template <typename tLhsScalar, typename tRhsScalar, typename tAccumScalar,
typename tDstScalar>
struct TypesTuple {
using LhsScalar = tLhsScalar;
using RhsScalar = tRhsScalar;
using AccumScalar = tAccumScalar;
using DstScalar = tDstScalar;
};
template <typename TypesTupleType>
void TestRandomGemms(const std::vector<std::tuple<int, int, int>>& shapes) {
using LhsScalar = typename TypesTupleType::LhsScalar;
using RhsScalar = typename TypesTupleType::RhsScalar;
using AccumScalar = typename TypesTupleType::AccumScalar;
using DstScalar = typename TypesTupleType::DstScalar;
for (const auto& shape : shapes) {
int rows = std::get<0>(shape);
int depth = std::get<1>(shape);
int cols = std::get<2>(shape);
TestSomeGemm<LhsScalar, RhsScalar, AccumScalar, DstScalar>(rows, depth,
cols, {});
}
}
template <typename TypesTupleType>
class CpuBackendGemmTest : public testing::Test {};
TYPED_TEST_SUITE_P(CpuBackendGemmTest);
typedef ::testing::Types<
TypesTuple<float, float, float, float>,
TypesTuple<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>,
TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int8_t>,
TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int16_t>>
CpuBackendGemmTestInstantiations;
TYPED_TEST_SUITE(CpuBackendGemmTest, CpuBackendGemmTestInstantiations);
TYPED_TEST(CpuBackendGemmTest, Square) {
std::vector<std::tuple<int, int, int>> shapes;
for (int size = 1; size < 50; size++) {
shapes.push_back(std::make_tuple(size, size, size));
}
TestRandomGemms<TypeParam>(shapes);
}
TYPED_TEST(CpuBackendGemmTest, SquarePowerOfTwo) {
std::vector<std::tuple<int, int, int>> shapes;
for (int size = 64; size <= 128; size++) {
shapes.push_back(std::make_tuple(size, size, size));
}
TestRandomGemms<TypeParam>(shapes);
}
TYPED_TEST(CpuBackendGemmTest, MatrixTimesVector) {
std::vector<std::tuple<int, int, int>> shapes;
for (int size = 1; size < 200; size++) {
shapes.push_back(std::make_tuple(size, size, 1));
}
TestRandomGemms<TypeParam>(shapes);
}
TYPED_TEST(CpuBackendGemmTest, VectorTimesMatrix) {
std::vector<std::tuple<int, int, int>> shapes;
for (int size = 1; size < 200; size++) {
shapes.push_back(std::make_tuple(1, size, size));
}
TestRandomGemms<TypeParam>(shapes);
}
TYPED_TEST(CpuBackendGemmTest, MatrixTimesNarrow) {
std::vector<std::tuple<int, int, int>> shapes;
for (int size = 1; size < 100; size++) {
shapes.push_back(std::make_tuple(size, size, 2));
shapes.push_back(std::make_tuple(size, size, 3));
shapes.push_back(std::make_tuple(size, size, 4));
shapes.push_back(std::make_tuple(size, size, 8));
}
TestRandomGemms<TypeParam>(shapes);
}
TYPED_TEST(CpuBackendGemmTest, Rectangular) {
std::vector<std::tuple<int, int, int>> shapes;
for (int size = 1; size < 50; size++) {
shapes.push_back(std::make_tuple(size, size + 5, size + 1));
shapes.push_back(std::make_tuple(size + 10, size + 2, size));
}
TestRandomGemms<TypeParam>(shapes);
}
TYPED_TEST(CpuBackendGemmTest, HighlyRectangular) {
std::vector<std::tuple<int, int, int>> shapes;
for (int size = 1; size <= 10000; size *= 10) {
shapes.push_back(std::make_tuple(size, 10, 10));
shapes.push_back(std::make_tuple(10, size, 10));
shapes.push_back(std::make_tuple(10, 10, size));
}
TestRandomGemms<TypeParam>(shapes);
}
TYPED_TEST(CpuBackendGemmTest, InnerProduct) {
std::vector<std::tuple<int, int, int>> shapes;
for (int size = 1; size < 200; size++) {
shapes.push_back(std::make_tuple(1, size, 1));
}
TestRandomGemms<TypeParam>(shapes);
}
TYPED_TEST(CpuBackendGemmTest, OuterProduct) {
std::vector<std::tuple<int, int, int>> shapes;
for (int size = 1; size < 200; size++) {
shapes.push_back(std::make_tuple(size, 1, size));
}
TestRandomGemms<TypeParam>(shapes);
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}