Introduce a :cpu_backend_gemm library allowing to perform all
types of GEMM that TFLite needs, given a CpuBackendContext. Ruy provides a generic implementation. A subsequent CL will introduce a config_setting, `with_ruy`, false by default. When with_ruy is false, the generic ruy-based implementation will be overridden by a gemmlowp/eigen implementation. The rationale for having ruy as the generic implementation is that the current pre-ruy landscape is that some paths have spotty support, with the new int8/perchannel stuff only supported in gemmlowp on NEON, forcing call sites across TFLite to have complicated #if logic. Thanks to ruy providing a generic implementation in all cases, such call sites won't have to use #if's to avoid build issues. Call sites may still want to avoid a code size penalty from generic paths which they don't call, but they will be able to get that by letting the compiler elide statically unused branches. PiperOrigin-RevId: 245451095
This commit is contained in:
parent
ba4db912ec
commit
71ad025fc4
@ -256,7 +256,11 @@ cc_library(
|
|||||||
"dispatch.h",
|
"dispatch.h",
|
||||||
"impl.h",
|
"impl.h",
|
||||||
],
|
],
|
||||||
hdrs = ["ruy.h"],
|
hdrs = [
|
||||||
|
"matrix.h",
|
||||||
|
"path.h",
|
||||||
|
"ruy.h",
|
||||||
|
],
|
||||||
visibility = ruy_visibility(),
|
visibility = ruy_visibility(),
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
@ -265,10 +269,8 @@ cc_library(
|
|||||||
":common",
|
":common",
|
||||||
":context",
|
":context",
|
||||||
":kernel",
|
":kernel",
|
||||||
":matrix",
|
|
||||||
":opt_set",
|
":opt_set",
|
||||||
":pack",
|
":pack",
|
||||||
":path",
|
|
||||||
":size_util",
|
":size_util",
|
||||||
":spec",
|
":spec",
|
||||||
":thread_pool",
|
":thread_pool",
|
||||||
|
@ -3,4 +3,6 @@ Control of ruy visibility
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def ruy_visibility():
|
def ruy_visibility():
|
||||||
return []
|
return [
|
||||||
|
"//tensorflow/lite/kernels:__subpackages__",
|
||||||
|
]
|
||||||
|
@ -90,10 +90,40 @@ cc_library(
|
|||||||
copts = tflite_copts(),
|
copts = tflite_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
":op_macros",
|
":op_macros",
|
||||||
|
# For now this unconditionally depends on both ruy and gemmlowp.
|
||||||
|
# See the comment inside class CpuBackendContext on the
|
||||||
|
# gemmlowp_context_ and ruy_context_ members.
|
||||||
|
"//tensorflow/lite/experimental/ruy:context",
|
||||||
"@gemmlowp",
|
"@gemmlowp",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "cpu_backend_gemm",
|
||||||
|
hdrs = [
|
||||||
|
"cpu_backend_gemm.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":op_macros",
|
||||||
|
":cpu_backend_context",
|
||||||
|
# Depend on ruy regardless of `with_ruy`. See the comment in
|
||||||
|
# cpu_backend_gemm.h about why ruy is the generic path.
|
||||||
|
"//tensorflow/lite/experimental/ruy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "cpu_backend_gemm_test",
|
||||||
|
srcs = ["cpu_backend_gemm_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":cpu_backend_gemm",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
# ruy's reference path provides the reference implementation
|
||||||
|
# that this test compares against.
|
||||||
|
"//tensorflow/lite/experimental/ruy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cpu_backend_support",
|
name = "cpu_backend_support",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
|
|
||||||
#include "public/gemmlowp.h"
|
#include "public/gemmlowp.h"
|
||||||
|
#include "tensorflow/lite/experimental/ruy/context.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
CpuBackendContext::CpuBackendContext()
|
CpuBackendContext::CpuBackendContext()
|
||||||
: gemmlowp_context_(new gemmlowp::GemmContext) {}
|
: ruy_context_(new ruy::Context),
|
||||||
|
gemmlowp_context_(new gemmlowp::GemmContext) {}
|
||||||
|
|
||||||
CpuBackendContext::~CpuBackendContext() {}
|
CpuBackendContext::~CpuBackendContext() {}
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "public/gemmlowp.h"
|
#include "public/gemmlowp.h"
|
||||||
|
#include "tensorflow/lite/experimental/ruy/context.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
@ -27,6 +28,8 @@ class CpuBackendContext final {
|
|||||||
CpuBackendContext();
|
CpuBackendContext();
|
||||||
~CpuBackendContext();
|
~CpuBackendContext();
|
||||||
|
|
||||||
|
ruy::Context* ruy_context() const { return ruy_context_.get(); }
|
||||||
|
|
||||||
gemmlowp::GemmContext* gemmlowp_context() const {
|
gemmlowp::GemmContext* gemmlowp_context() const {
|
||||||
return gemmlowp_context_.get();
|
return gemmlowp_context_.get();
|
||||||
}
|
}
|
||||||
@ -34,7 +37,13 @@ class CpuBackendContext final {
|
|||||||
void set_max_num_threads(int max_num_threads);
|
void set_max_num_threads(int max_num_threads);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// gemmlowp context used to implement this CpuBackendContext.
|
// To enable a smooth transition from the current direct usage
|
||||||
|
// of the underlying gemmlowp context to going through abstractions
|
||||||
|
// (see :cpu_backend_gemm), for now a CpuBackendContext always
|
||||||
|
// stores both a gemmlowp context and a ruy context.
|
||||||
|
// TODO(b/131416458): Once call sites all go through abstractions,
|
||||||
|
// elide what can be elided based on TFLITE_WITH_RUY.
|
||||||
|
const std::unique_ptr<ruy::Context> ruy_context_;
|
||||||
const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
|
const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
|
||||||
|
|
||||||
CpuBackendContext(const CpuBackendContext&) = delete;
|
CpuBackendContext(const CpuBackendContext&) = delete;
|
||||||
|
187
tensorflow/lite/kernels/cpu_backend_gemm.h
Normal file
187
tensorflow/lite/kernels/cpu_backend_gemm.h
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
|
||||||
|
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <limits>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/experimental/ruy/path.h"
|
||||||
|
#include "tensorflow/lite/experimental/ruy/ruy.h"
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
namespace cpu_backend_gemm {
|
||||||
|
|
||||||
|
// Matrix storage order: column-major or row-major.
|
||||||
|
enum class Order { kColMajor, kRowMajor };
|
||||||
|
|
||||||
|
// MatrixParams encapsulates the parameters that Gemm needs about each
|
||||||
|
// matrix, besides the buffer data pointer.
|
||||||
|
// Compare to ruy::Matrix, which also encapsulates the data pointer.
|
||||||
|
// Rationale for leaving the data pointer out of here: doing so
|
||||||
|
// requires complicated const-correctness mechanics. See
|
||||||
|
// ruy::ConstCheckingPtr.
|
||||||
|
template <typename Scalar>
|
||||||
|
struct MatrixParams {
|
||||||
|
// Storage layout order. For now we only do plain linear non-strided
|
||||||
|
// layout. It would be easy to support a stride if needed.
|
||||||
|
Order order = Order::kColMajor;
|
||||||
|
// Number of rows of the matrix.
|
||||||
|
int rows = 0;
|
||||||
|
// Number of columns of the matrix.
|
||||||
|
int cols = 0;
|
||||||
|
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
|
||||||
|
// When Scalar is floating-point, this must be 0.
|
||||||
|
Scalar zero_point = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Additional parameters that Gemm needs, beyond what falls into
|
||||||
|
// the MatrixParams that it takes. Compare to ruy::Spec.
|
||||||
|
//
|
||||||
|
// Decoupling AccumScalar from DstScalar (rather than deducing it from that)
|
||||||
|
// is useful future-proofing. Think of a float16 path using float32 accum.
|
||||||
|
template <typename AccumScalar, typename DstScalar>
|
||||||
|
struct GemmParams {
|
||||||
|
// Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
|
||||||
|
// of the multiplier by which accumulators are multiplied before being casted
|
||||||
|
// to the destination type.
|
||||||
|
AccumScalar multiplier_fixedpoint = 0;
|
||||||
|
// Only for non-floating-point cases. The exponent part of the aforementioned
|
||||||
|
// multiplier.
|
||||||
|
int multiplier_exponent = 0;
|
||||||
|
// Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
|
||||||
|
// point to a buffer of as many values as there are rows in the destination
|
||||||
|
// matrix. Each row of the destination matrix will use the corresponding
|
||||||
|
// buffer element instead of multiplier_fixedpoint.
|
||||||
|
const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
|
||||||
|
// Per-channel variant of multiplier_exponent. If not nullptr, this must
|
||||||
|
// point to a buffer of as many values as there are rows in the destination
|
||||||
|
// matrix. Each row of the destination matrix will use the corresponding
|
||||||
|
// buffer element instead of multiplier_exponent.
|
||||||
|
//
|
||||||
|
// Either none or both of multiplier_exponent_perchannel and
|
||||||
|
// multiplier_fixedpoint_perchannel must be nullptr.
|
||||||
|
const int* multiplier_exponent_perchannel = nullptr;
|
||||||
|
// The bias vector data, if not null.
|
||||||
|
const AccumScalar* bias = nullptr;
|
||||||
|
// min clamp bound of destination values.
|
||||||
|
DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
|
||||||
|
// max clamp bound of destination values.
|
||||||
|
DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Convenience typedefs */
|
||||||
|
|
||||||
|
template <typename DstScalar>
|
||||||
|
using QuantizedGemmParams = GemmParams<std::int32_t, DstScalar>;
|
||||||
|
|
||||||
|
using FloatGemmParams = GemmParams<float, float>;
|
||||||
|
|
||||||
|
/* Generic implementation using ruy */
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
template <typename Scalar, typename DataPointer>
|
||||||
|
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
||||||
|
ruy::Matrix<Scalar>* dst) {
|
||||||
|
dst->layout.rows = params.rows;
|
||||||
|
dst->layout.cols = params.cols;
|
||||||
|
if (params.order == Order::kColMajor) {
|
||||||
|
dst->layout.order = ruy::Order::kColMajor;
|
||||||
|
dst->layout.stride = params.rows;
|
||||||
|
} else {
|
||||||
|
dst->layout.order = ruy::Order::kRowMajor;
|
||||||
|
dst->layout.stride = params.cols;
|
||||||
|
}
|
||||||
|
// Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
|
||||||
|
// It does care whether we assign to it a Scalar* or a const Scalar*.
|
||||||
|
dst->data = data_ptr;
|
||||||
|
dst->zero_point = params.zero_point;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename GemmParamsType, typename RuySpecType>
|
||||||
|
void MakeRuySpec(const GemmParamsType& params, RuySpecType* ruy_spec) {
|
||||||
|
if (std::is_floating_point<typename RuySpecType::AccumScalar>::value) {
|
||||||
|
TF_LITE_ASSERT(!params.multiplier_fixedpoint);
|
||||||
|
TF_LITE_ASSERT(!params.multiplier_exponent);
|
||||||
|
TF_LITE_ASSERT(!params.multiplier_fixedpoint_perchannel);
|
||||||
|
TF_LITE_ASSERT(!params.multiplier_exponent_perchannel);
|
||||||
|
} else {
|
||||||
|
TF_LITE_ASSERT((params.multiplier_fixedpoint == 0) !=
|
||||||
|
(params.multiplier_fixedpoint_perchannel == nullptr));
|
||||||
|
}
|
||||||
|
ruy_spec->multiplier_fixedpoint = params.multiplier_fixedpoint;
|
||||||
|
ruy_spec->multiplier_exponent = params.multiplier_exponent;
|
||||||
|
ruy_spec->multiplier_fixedpoint_perchannel =
|
||||||
|
params.multiplier_fixedpoint_perchannel;
|
||||||
|
ruy_spec->multiplier_exponent_perchannel =
|
||||||
|
params.multiplier_exponent_perchannel;
|
||||||
|
ruy_spec->bias = params.bias;
|
||||||
|
ruy_spec->clamp_min = params.clamp_min;
|
||||||
|
ruy_spec->clamp_max = params.clamp_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
// Non-ruy implementation will be partial specializations of this template.
|
||||||
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
|
typename DstScalar>
|
||||||
|
struct GemmImpl {
|
||||||
|
static void Run(const MatrixParams<LhsScalar>& lhs_params,
|
||||||
|
const LhsScalar* lhs_data,
|
||||||
|
const MatrixParams<RhsScalar>& rhs_params,
|
||||||
|
const RhsScalar* rhs_data,
|
||||||
|
const MatrixParams<DstScalar>& dst_params,
|
||||||
|
DstScalar* dst_data,
|
||||||
|
const GemmParams<AccumScalar, DstScalar>& params,
|
||||||
|
CpuBackendContext* context) {
|
||||||
|
ruy::Matrix<LhsScalar> ruy_lhs;
|
||||||
|
ruy::Matrix<RhsScalar> ruy_rhs;
|
||||||
|
ruy::Matrix<DstScalar> ruy_dst;
|
||||||
|
detail::MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs);
|
||||||
|
detail::MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs);
|
||||||
|
detail::MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
|
||||||
|
|
||||||
|
ruy::BasicSpec<AccumScalar, DstScalar> ruy_spec;
|
||||||
|
detail::MakeRuySpec(params, &ruy_spec);
|
||||||
|
|
||||||
|
ruy::Mul<ruy::kAllPaths>(ruy_lhs, ruy_rhs, ruy_spec, context->ruy_context(),
|
||||||
|
&ruy_dst);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Public entry point */
|
||||||
|
|
||||||
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
|
typename DstScalar>
|
||||||
|
void Gemm(const MatrixParams<LhsScalar>& lhs_params, const LhsScalar* lhs_data,
|
||||||
|
const MatrixParams<RhsScalar>& rhs_params, const RhsScalar* rhs_data,
|
||||||
|
const MatrixParams<DstScalar>& dst_params, DstScalar* dst_data,
|
||||||
|
const GemmParams<AccumScalar, DstScalar>& params,
|
||||||
|
CpuBackendContext* context) {
|
||||||
|
GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar>::Run(
|
||||||
|
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params,
|
||||||
|
context);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cpu_backend_gemm
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_H_
|
603
tensorflow/lite/kernels/cpu_backend_gemm_test.cc
Normal file
603
tensorflow/lite/kernels/cpu_backend_gemm_test.cc
Normal file
@ -0,0 +1,603 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_gemm.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdarg>
|
||||||
|
#include <limits>
|
||||||
|
#include <random>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/experimental/ruy/ruy.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using cpu_backend_gemm::Gemm;
|
||||||
|
using cpu_backend_gemm::GemmParams;
|
||||||
|
using cpu_backend_gemm::MatrixParams;
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
std::string ToString(const std::vector<Scalar>& vector) {
|
||||||
|
std::stringstream s;
|
||||||
|
if (vector.empty()) {
|
||||||
|
s << "{}";
|
||||||
|
} else {
|
||||||
|
s << "{ " << static_cast<double>(vector[0]);
|
||||||
|
for (int i = 1; i < vector.size(); i++) {
|
||||||
|
s << ", " << static_cast<double>(vector[i]);
|
||||||
|
}
|
||||||
|
s << "}";
|
||||||
|
}
|
||||||
|
return s.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
void MakeDeterministicPseudoRandomVector(int size,
|
||||||
|
std::vector<Scalar>* vector) {
|
||||||
|
// Intentionally create a new local random_engine in each invocation,
|
||||||
|
// so pseudorandom values don't depend on invocation order.
|
||||||
|
// Otherwise, test results would be affecting by e.g. filtering.
|
||||||
|
std::default_random_engine random_engine;
|
||||||
|
(void)random_engine();
|
||||||
|
// Do not use std::uniform*_distribution: the values that it
|
||||||
|
// generates are implementation-defined.
|
||||||
|
const double random_min = static_cast<double>(random_engine.min());
|
||||||
|
const double random_max = static_cast<double>(random_engine.max());
|
||||||
|
const double result_min =
|
||||||
|
std::is_floating_point<Scalar>::value
|
||||||
|
? -1.0
|
||||||
|
: std::max(-256., static_cast<double>(
|
||||||
|
std::numeric_limits<Scalar>::lowest()));
|
||||||
|
const double result_max =
|
||||||
|
std::is_floating_point<Scalar>::value
|
||||||
|
? 1.0
|
||||||
|
: std::min(256.,
|
||||||
|
static_cast<double>(std::numeric_limits<Scalar>::max()));
|
||||||
|
const double random_scale =
|
||||||
|
(result_max - result_min) / (random_max - random_min);
|
||||||
|
|
||||||
|
vector->resize(size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
double val = random_scale * (random_engine() - random_min);
|
||||||
|
val = std::max(val,
|
||||||
|
static_cast<double>(std::numeric_limits<Scalar>::lowest()));
|
||||||
|
val =
|
||||||
|
std::min(val, static_cast<double>(std::numeric_limits<Scalar>::max()));
|
||||||
|
(*vector)[i] = static_cast<Scalar>(val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
void MakeVectorFilledWithConsecutiveInts(int size,
|
||||||
|
std::vector<Scalar>* vector) {
|
||||||
|
vector->resize(size);
|
||||||
|
EXPECT_LE(size, std::numeric_limits<Scalar>::max());
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
(*vector)[i] = static_cast<Scalar>(i + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
Scalar Median(const std::vector<Scalar>& vector) {
|
||||||
|
EXPECT_GT(vector.size(), 0);
|
||||||
|
std::vector<Scalar> vector_copy = vector;
|
||||||
|
std::sort(std::begin(vector_copy), std::end(vector_copy));
|
||||||
|
return vector_copy[vector_copy.size() / 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
double MedianAbs(const std::vector<Scalar>& vector) {
|
||||||
|
EXPECT_GT(vector.size(), 0);
|
||||||
|
std::vector<double> vector_abs;
|
||||||
|
vector_abs.resize(vector.size());
|
||||||
|
for (int i = 0; i < vector.size(); i++) {
|
||||||
|
vector_abs[i] = std::abs(static_cast<double>(vector[i]));
|
||||||
|
}
|
||||||
|
std::sort(std::begin(vector_abs), std::end(vector_abs));
|
||||||
|
return vector_abs[vector_abs.size() / 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
void Clamp(const std::vector<Scalar>& src, Scalar clamp_min, Scalar clamp_max,
|
||||||
|
std::vector<Scalar>* dst) {
|
||||||
|
dst->resize(src.size());
|
||||||
|
for (int i = 0; i < src.size(); i++) {
|
||||||
|
(*dst)[i] = std::max(std::min(src[i], clamp_max), clamp_min);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename AccumScalar, typename DstScalar>
|
||||||
|
void Clamp(const GemmParams<AccumScalar, DstScalar>& src, DstScalar clamp_min,
|
||||||
|
DstScalar clamp_max, GemmParams<AccumScalar, DstScalar>* dst) {
|
||||||
|
*dst = src;
|
||||||
|
dst->clamp_min = clamp_min;
|
||||||
|
dst->clamp_max = clamp_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ErrorStats {
|
||||||
|
int size;
|
||||||
|
double scale_factor;
|
||||||
|
double max_abs_diff;
|
||||||
|
double mean_abs_diff;
|
||||||
|
double abs_mean_diff;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
void ComputeErrorStats(const std::vector<Scalar>& actual,
|
||||||
|
const std::vector<Scalar>& expected,
|
||||||
|
ErrorStats* error_stats) {
|
||||||
|
double max_abs_diff = 0;
|
||||||
|
double sum_abs_diff = 0;
|
||||||
|
double sum_diff = 0;
|
||||||
|
double max_abs_expected = 0;
|
||||||
|
EXPECT_EQ(actual.size(), expected.size());
|
||||||
|
for (int i = 0; i < actual.size(); i++) {
|
||||||
|
double actual_val = static_cast<double>(actual[i]);
|
||||||
|
double expected_val = static_cast<double>(expected[i]);
|
||||||
|
double diff = actual_val - expected_val;
|
||||||
|
max_abs_expected = std::max(max_abs_expected, std::abs(expected_val));
|
||||||
|
sum_diff += diff;
|
||||||
|
sum_abs_diff += std::abs(diff);
|
||||||
|
max_abs_diff = std::max(max_abs_diff, std::abs(diff));
|
||||||
|
}
|
||||||
|
error_stats->scale_factor = max_abs_expected;
|
||||||
|
error_stats->max_abs_diff = max_abs_diff;
|
||||||
|
error_stats->mean_abs_diff = sum_abs_diff / actual.size();
|
||||||
|
error_stats->abs_mean_diff = std::abs(sum_diff / actual.size());
|
||||||
|
error_stats->size = actual.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename AccumScalar, typename DstScalar>
|
||||||
|
bool CheckErrorStats(const ErrorStats& error_stats, int accumulation_depth) {
|
||||||
|
double tolerated_relative_max_abs_diff = 0;
|
||||||
|
double tolerated_relative_mean_abs_diff = 0;
|
||||||
|
double tolerated_relative_abs_mean_diff = 0;
|
||||||
|
|
||||||
|
double inverse_size = 1. / error_stats.size;
|
||||||
|
|
||||||
|
if (std::is_floating_point<AccumScalar>::value) {
|
||||||
|
// Somewhat naive requirement: the worst case should be epsilons
|
||||||
|
// adding up towards the same direction, on values of same magnitude.
|
||||||
|
tolerated_relative_max_abs_diff =
|
||||||
|
accumulation_depth * std::numeric_limits<DstScalar>::epsilon();
|
||||||
|
// Naive interpretation of the Central Limit Theorem is the rationale
|
||||||
|
// for the sqrt here. We haven't even worked out the correct scale factor,
|
||||||
|
// or how applicable that theorem is here (the random variables being added
|
||||||
|
// might not be mutually independent).
|
||||||
|
tolerated_relative_mean_abs_diff =
|
||||||
|
std::sqrt(static_cast<double>(accumulation_depth)) *
|
||||||
|
std::numeric_limits<DstScalar>::epsilon();
|
||||||
|
// Unbiasing requirement: we require the bias, abs_mean_diff, to be much
|
||||||
|
// smaller than the mean_abs_diff, except when there are very few values.
|
||||||
|
tolerated_relative_abs_mean_diff =
|
||||||
|
tolerated_relative_mean_abs_diff * std::sqrt(inverse_size);
|
||||||
|
} else {
|
||||||
|
// In quantized arithmetic, tolerate minor rounding differences, resulting
|
||||||
|
// in off-by-one errors (tolerated_relative_max_abs_diff = 1), as long
|
||||||
|
// as they are rare (tolerated_relative_mean_abs_diff) and unbiased
|
||||||
|
// (tolerated_relative_abs_mean_diff).
|
||||||
|
tolerated_relative_max_abs_diff = 1;
|
||||||
|
// Naively require mean_abs_diff and abs_mean_diff to converge to zero
|
||||||
|
// as size gets large. We don't know at all how quick that convergence
|
||||||
|
// should be: this is just based on trial-and-error and striking a
|
||||||
|
// compromise between something that works and something that's simple
|
||||||
|
// enough code that doesn't feel too ad-hoc. As above in the float path,
|
||||||
|
// abs_mean_diff is subject to a stricter requirement as it is a bias.
|
||||||
|
tolerated_relative_mean_abs_diff = std::sqrt(inverse_size);
|
||||||
|
tolerated_relative_abs_mean_diff = inverse_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
double tolerated_max_abs_diff =
|
||||||
|
tolerated_relative_max_abs_diff * error_stats.scale_factor;
|
||||||
|
double tolerated_mean_abs_diff =
|
||||||
|
tolerated_relative_mean_abs_diff * error_stats.scale_factor;
|
||||||
|
double tolerated_abs_mean_diff =
|
||||||
|
tolerated_relative_abs_mean_diff * error_stats.scale_factor;
|
||||||
|
|
||||||
|
EXPECT_LE(error_stats.max_abs_diff, tolerated_max_abs_diff);
|
||||||
|
EXPECT_LE(error_stats.mean_abs_diff, tolerated_mean_abs_diff);
|
||||||
|
EXPECT_LE(error_stats.abs_mean_diff, tolerated_abs_mean_diff);
|
||||||
|
|
||||||
|
return error_stats.max_abs_diff <= tolerated_max_abs_diff &&
|
||||||
|
error_stats.mean_abs_diff <= tolerated_mean_abs_diff &&
|
||||||
|
error_stats.abs_mean_diff <= tolerated_abs_mean_diff;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename AccumScalar, typename DstScalar>
|
||||||
|
void CheckErrorForAccumulation(int accumulation_depth,
|
||||||
|
const std::vector<DstScalar>& actual,
|
||||||
|
const std::vector<DstScalar>& expected) {
|
||||||
|
ErrorStats error_stats;
|
||||||
|
ComputeErrorStats(actual, expected, &error_stats);
|
||||||
|
bool success =
|
||||||
|
CheckErrorStats<AccumScalar, DstScalar>(error_stats, accumulation_depth);
|
||||||
|
EXPECT_TRUE(success) << "Actual vector\n"
|
||||||
|
<< ToString(actual) << "\ndiffers from expected vector\n"
|
||||||
|
<< ToString(expected) << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
|
typename DstScalar>
|
||||||
|
void PerformGemmThenCompareResultsThenAgainWithClamping(
|
||||||
|
const MatrixParams<LhsScalar>& lhs_params,
|
||||||
|
const std::vector<LhsScalar>& lhs_data,
|
||||||
|
const MatrixParams<RhsScalar>& rhs_params,
|
||||||
|
const std::vector<RhsScalar>& rhs_data,
|
||||||
|
const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
|
||||||
|
const GemmParams<AccumScalar, DstScalar>& params,
|
||||||
|
const std::vector<DstScalar>& expected,
|
||||||
|
CpuBackendContext* cpu_backend_context) {
|
||||||
|
const int accumulation_depth = lhs_params.cols;
|
||||||
|
Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
|
||||||
|
dst_data->data(), params, cpu_backend_context);
|
||||||
|
CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
|
||||||
|
expected);
|
||||||
|
DstScalar expected_median = Median(expected);
|
||||||
|
std::vector<DstScalar> expected_with_clamp;
|
||||||
|
GemmParams<AccumScalar, DstScalar> params_with_clamp;
|
||||||
|
DstScalar clamp_min, clamp_max;
|
||||||
|
|
||||||
|
clamp_min = std::numeric_limits<DstScalar>::lowest();
|
||||||
|
clamp_max = expected_median;
|
||||||
|
Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
|
||||||
|
Clamp(params, clamp_min, clamp_max, ¶ms_with_clamp);
|
||||||
|
Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
|
||||||
|
dst_data->data(), params_with_clamp, cpu_backend_context);
|
||||||
|
CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
|
||||||
|
expected_with_clamp);
|
||||||
|
|
||||||
|
clamp_min = expected_median;
|
||||||
|
clamp_max = std::numeric_limits<DstScalar>::max();
|
||||||
|
Clamp(expected, clamp_min, clamp_max, &expected_with_clamp);
|
||||||
|
Clamp(params, clamp_min, clamp_max, ¶ms_with_clamp);
|
||||||
|
Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
|
||||||
|
dst_data->data(), params_with_clamp, cpu_backend_context);
|
||||||
|
CheckErrorForAccumulation<AccumScalar>(accumulation_depth, *dst_data,
|
||||||
|
expected_with_clamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// When generating testcases for a quantized GEMM, it's not trivial to
|
||||||
|
// pick multiplier exponents: a too low value will result in too many zeros,
|
||||||
|
// a too high value will result in too many large clamped values, in both
|
||||||
|
// cases testing coverage is harmed. Therefore to ensure good testing coverage
|
||||||
|
// we must find a multiplier exponent that's just right. It would be possible
|
||||||
|
// to do so by analysis of the random distribution of values in the result
|
||||||
|
// matrix. That however would require some mathematical work that we haven't
|
||||||
|
// done so far. Until that is done, the best that we can do is to search for
|
||||||
|
// a good exponent value by trial-and-error. This is expensive, as each try
|
||||||
|
// requires computing a whole GEMM. This is thus probably a major contribution
|
||||||
|
// to the overall latency of this tesat. To partially mitigate that,
|
||||||
|
// we use a bisection to reduce the required number of tries.
|
||||||
|
//
|
||||||
|
// This function is recursive. The bisect_min and bisect_max arguments
|
||||||
|
// are the current bisection bounds. It performs a Gemm with the mid-point,
|
||||||
|
// named bisect_mid, as the multiplier exponent. Based on whether the values
|
||||||
|
// in the resulting matrix are rather too low or too large in absolute
|
||||||
|
// value, it then recurses into the corresponding half of the bisection range.
|
||||||
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
|
typename DstScalar>
|
||||||
|
int BisectReasonableMultiplierExponent(
|
||||||
|
int bisect_min, int bisect_max, const MatrixParams<LhsScalar>& lhs_params,
|
||||||
|
const std::vector<LhsScalar>& lhs_data,
|
||||||
|
const MatrixParams<RhsScalar>& rhs_params,
|
||||||
|
const std::vector<RhsScalar>& rhs_data,
|
||||||
|
const MatrixParams<DstScalar>& dst_params, std::vector<DstScalar>* dst_data,
|
||||||
|
const GemmParams<AccumScalar, DstScalar>& params,
|
||||||
|
CpuBackendContext* cpu_backend_context) {
|
||||||
|
if (bisect_min == bisect_max) {
|
||||||
|
return bisect_min;
|
||||||
|
}
|
||||||
|
// Compute the midpoint as the floor of the average of bisect_min and
|
||||||
|
// bisect_max. As C++ integer division is rounding towards zero and our values
|
||||||
|
// may be of any sign, it is not trivial to implement this using only integer
|
||||||
|
// arithmetic.
|
||||||
|
int bisect_mid =
|
||||||
|
static_cast<int>(std::floor(0.5 * (bisect_min + bisect_max)));
|
||||||
|
GemmParams<AccumScalar, DstScalar> params_copy(params);
|
||||||
|
params_copy.multiplier_exponent = bisect_mid;
|
||||||
|
double clamp_abs = std::max(std::abs(static_cast<double>(params.clamp_min)),
|
||||||
|
std::abs(static_cast<double>(params.clamp_max)));
|
||||||
|
Gemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(), dst_params,
|
||||||
|
dst_data->data(), params_copy, cpu_backend_context);
|
||||||
|
double median_abs = MedianAbs(*dst_data);
|
||||||
|
if (median_abs < 0.25 * clamp_abs) {
|
||||||
|
return BisectReasonableMultiplierExponent(
|
||||||
|
bisect_mid + 1, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
|
||||||
|
dst_params, dst_data, params_copy, cpu_backend_context);
|
||||||
|
} else {
|
||||||
|
return BisectReasonableMultiplierExponent(
|
||||||
|
bisect_min, bisect_mid, lhs_params, lhs_data, rhs_params, rhs_data,
|
||||||
|
dst_params, dst_data, params_copy, cpu_backend_context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
|
typename DstScalar>
|
||||||
|
void ReferenceGemm(const MatrixParams<LhsScalar>& lhs_params,
|
||||||
|
const LhsScalar* lhs_data,
|
||||||
|
const MatrixParams<RhsScalar>& rhs_params,
|
||||||
|
const RhsScalar* rhs_data,
|
||||||
|
const MatrixParams<DstScalar>& dst_params,
|
||||||
|
DstScalar* dst_data,
|
||||||
|
const GemmParams<AccumScalar, DstScalar>& params,
|
||||||
|
CpuBackendContext* context) {
|
||||||
|
ruy::Matrix<LhsScalar> ruy_lhs;
|
||||||
|
ruy::Matrix<RhsScalar> ruy_rhs;
|
||||||
|
ruy::Matrix<DstScalar> ruy_dst;
|
||||||
|
cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &ruy_lhs);
|
||||||
|
cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &ruy_rhs);
|
||||||
|
cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &ruy_dst);
|
||||||
|
|
||||||
|
ruy::BasicSpec<AccumScalar, DstScalar> ruy_spec;
|
||||||
|
cpu_backend_gemm::detail::MakeRuySpec(params, &ruy_spec);
|
||||||
|
|
||||||
|
ruy::Mul<ruy::Path::kReference>(ruy_lhs, ruy_rhs, ruy_spec,
|
||||||
|
context->ruy_context(), &ruy_dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
|
||||||
|
typename DstScalar>
|
||||||
|
void TestSomeGemm(int rows, int depth, int cols,
|
||||||
|
const std::vector<DstScalar>& golden) {
|
||||||
|
CpuBackendContext cpu_backend_context;
|
||||||
|
std::default_random_engine random_engine;
|
||||||
|
cpu_backend_context.set_max_num_threads(1 + (random_engine() % 8));
|
||||||
|
|
||||||
|
const bool use_golden = !golden.empty();
|
||||||
|
|
||||||
|
std::vector<LhsScalar> lhs_data;
|
||||||
|
std::vector<RhsScalar> rhs_data;
|
||||||
|
std::vector<AccumScalar> bias_data;
|
||||||
|
std::vector<DstScalar> dst_data;
|
||||||
|
if (use_golden) {
|
||||||
|
MakeVectorFilledWithConsecutiveInts(rows * depth, &lhs_data);
|
||||||
|
MakeVectorFilledWithConsecutiveInts(depth * cols, &rhs_data);
|
||||||
|
MakeVectorFilledWithConsecutiveInts(rows, &bias_data);
|
||||||
|
} else {
|
||||||
|
MakeDeterministicPseudoRandomVector(rows * depth, &lhs_data);
|
||||||
|
MakeDeterministicPseudoRandomVector(depth * cols, &rhs_data);
|
||||||
|
MakeDeterministicPseudoRandomVector(rows, &bias_data);
|
||||||
|
}
|
||||||
|
MakeDeterministicPseudoRandomVector(rows * cols, &dst_data);
|
||||||
|
|
||||||
|
MatrixParams<LhsScalar> lhs_params;
|
||||||
|
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||||
|
lhs_params.rows = rows;
|
||||||
|
lhs_params.cols = depth;
|
||||||
|
if (!use_golden && !std::is_floating_point<LhsScalar>::value) {
|
||||||
|
lhs_params.zero_point = random_engine() % 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
MatrixParams<RhsScalar> rhs_params;
|
||||||
|
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
|
rhs_params.rows = depth;
|
||||||
|
rhs_params.cols = cols;
|
||||||
|
if (!use_golden && !std::is_floating_point<RhsScalar>::value) {
|
||||||
|
rhs_params.zero_point = random_engine() % 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
MatrixParams<DstScalar> dst_params;
|
||||||
|
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
|
dst_params.rows = rows;
|
||||||
|
dst_params.cols = cols;
|
||||||
|
if (!use_golden && !std::is_floating_point<DstScalar>::value) {
|
||||||
|
dst_params.zero_point = random_engine() % 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
GemmParams<AccumScalar, DstScalar> params;
|
||||||
|
params.bias = bias_data.data();
|
||||||
|
if (!std::is_floating_point<AccumScalar>::value) {
|
||||||
|
// some large int32 value. Not being a multiple of a large
|
||||||
|
// power of two helps testing rounding behavior.
|
||||||
|
params.multiplier_fixedpoint = 1234567890;
|
||||||
|
// Now find a suitable value for multiplier_exponent.
|
||||||
|
// It needs to be low enough for a substantial amount of dst values
|
||||||
|
// to avoid getting clamped.
|
||||||
|
int bisect_min = -8 * static_cast<int>(sizeof(AccumScalar));
|
||||||
|
// We don't increase test coverage by using positive multipliers,
|
||||||
|
// and using very large positive multipliers may at the moment
|
||||||
|
// result in overflow in some paths.
|
||||||
|
// TODO(benoitjacob): fix that.
|
||||||
|
int bisect_max = 0;
|
||||||
|
params.multiplier_exponent = BisectReasonableMultiplierExponent(
|
||||||
|
bisect_min, bisect_max, lhs_params, lhs_data, rhs_params, rhs_data,
|
||||||
|
dst_params, &dst_data, params, &cpu_backend_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<DstScalar> expected;
|
||||||
|
if (use_golden) {
|
||||||
|
EXPECT_EQ(golden.size(), dst_data.size());
|
||||||
|
expected = golden;
|
||||||
|
} else {
|
||||||
|
expected.resize(dst_data.size());
|
||||||
|
ReferenceGemm(lhs_params, lhs_data.data(), rhs_params, rhs_data.data(),
|
||||||
|
dst_params, expected.data(), params, &cpu_backend_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
PerformGemmThenCompareResultsThenAgainWithClamping(
|
||||||
|
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data, params,
|
||||||
|
expected, &cpu_backend_context);
|
||||||
|
|
||||||
|
if (!std::is_floating_point<AccumScalar>::value) {
|
||||||
|
// Try with per-channel quantized multipliers. Just a naive check
|
||||||
|
// duplicating the same multiplier --- would already catch most bugs.
|
||||||
|
std::vector<AccumScalar> multiplier_fixedpoint_perchannel(
|
||||||
|
rows, params.multiplier_fixedpoint);
|
||||||
|
std::vector<int> multiplier_exponent_perchannel(rows,
|
||||||
|
params.multiplier_exponent);
|
||||||
|
GemmParams<AccumScalar, DstScalar> params_perchannel = params;
|
||||||
|
params_perchannel.multiplier_fixedpoint = 0;
|
||||||
|
params_perchannel.multiplier_exponent = 0;
|
||||||
|
params_perchannel.multiplier_fixedpoint_perchannel =
|
||||||
|
multiplier_fixedpoint_perchannel.data();
|
||||||
|
params_perchannel.multiplier_exponent_perchannel =
|
||||||
|
multiplier_exponent_perchannel.data();
|
||||||
|
PerformGemmThenCompareResultsThenAgainWithClamping(
|
||||||
|
lhs_params, lhs_data, rhs_params, rhs_data, dst_params, &dst_data,
|
||||||
|
params_perchannel, expected, &cpu_backend_context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CpuBackendGemmSimpleTestAgainstGolden, Float) {
|
||||||
|
TestSomeGemm<float, float, float, float>(2, 3, 4,
|
||||||
|
{15, 34, 33, 79, 51, 124, 69, 169});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CpuBackendGemmSimpleTestAgainstGolden, Uint8) {
|
||||||
|
TestSomeGemm<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>(
|
||||||
|
5, 2, 3, {3, 7, 11, 16, 20, 7, 16, 24, 33, 41, 10, 24, 37, 50, 63});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8) {
|
||||||
|
TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int8_t>(
|
||||||
|
2, 6, 3, {13, 32, 31, 81, 50, 127});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CpuBackendGemmSimpleTestAgainstGolden, Int8Int16) {
|
||||||
|
TestSomeGemm<std::int8_t, std::int8_t, std::int32_t, std::int16_t>(
|
||||||
|
3, 5, 4, {32, 76, 120, 75, 191, 306, 118, 306, 493, 162, 421, 680});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename tLhsScalar, typename tRhsScalar, typename tAccumScalar,
|
||||||
|
typename tDstScalar>
|
||||||
|
struct TypesTuple {
|
||||||
|
using LhsScalar = tLhsScalar;
|
||||||
|
using RhsScalar = tRhsScalar;
|
||||||
|
using AccumScalar = tAccumScalar;
|
||||||
|
using DstScalar = tDstScalar;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename TypesTupleType>
|
||||||
|
void TestRandomGemms(const std::vector<std::tuple<int, int, int>>& shapes) {
|
||||||
|
using LhsScalar = typename TypesTupleType::LhsScalar;
|
||||||
|
using RhsScalar = typename TypesTupleType::RhsScalar;
|
||||||
|
using AccumScalar = typename TypesTupleType::AccumScalar;
|
||||||
|
using DstScalar = typename TypesTupleType::DstScalar;
|
||||||
|
for (const auto& shape : shapes) {
|
||||||
|
int rows = std::get<0>(shape);
|
||||||
|
int depth = std::get<1>(shape);
|
||||||
|
int cols = std::get<2>(shape);
|
||||||
|
TestSomeGemm<LhsScalar, RhsScalar, AccumScalar, DstScalar>(rows, depth,
|
||||||
|
cols, {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TypesTupleType>
|
||||||
|
class CpuBackendGemmTest : public testing::Test {};
|
||||||
|
|
||||||
|
TYPED_TEST_SUITE_P(CpuBackendGemmTest);
|
||||||
|
|
||||||
|
typedef ::testing::Types<
|
||||||
|
TypesTuple<float, float, float, float>,
|
||||||
|
TypesTuple<std::uint8_t, std::uint8_t, std::int32_t, std::uint8_t>,
|
||||||
|
TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int8_t>,
|
||||||
|
TypesTuple<std::int8_t, std::int8_t, std::int32_t, std::int16_t>>
|
||||||
|
CpuBackendGemmTestInstantiations;
|
||||||
|
|
||||||
|
TYPED_TEST_SUITE(CpuBackendGemmTest, CpuBackendGemmTestInstantiations);
|
||||||
|
|
||||||
|
TYPED_TEST(CpuBackendGemmTest, Square) {
|
||||||
|
std::vector<std::tuple<int, int, int>> shapes;
|
||||||
|
for (int size = 1; size < 50; size++) {
|
||||||
|
shapes.push_back(std::make_tuple(size, size, size));
|
||||||
|
}
|
||||||
|
TestRandomGemms<TypeParam>(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(CpuBackendGemmTest, SquarePowerOfTwo) {
|
||||||
|
std::vector<std::tuple<int, int, int>> shapes;
|
||||||
|
for (int size = 64; size <= 128; size++) {
|
||||||
|
shapes.push_back(std::make_tuple(size, size, size));
|
||||||
|
}
|
||||||
|
TestRandomGemms<TypeParam>(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(CpuBackendGemmTest, MatrixTimesVector) {
|
||||||
|
std::vector<std::tuple<int, int, int>> shapes;
|
||||||
|
for (int size = 1; size < 200; size++) {
|
||||||
|
shapes.push_back(std::make_tuple(size, size, 1));
|
||||||
|
}
|
||||||
|
TestRandomGemms<TypeParam>(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(CpuBackendGemmTest, VectorTimesMatrix) {
|
||||||
|
std::vector<std::tuple<int, int, int>> shapes;
|
||||||
|
for (int size = 1; size < 200; size++) {
|
||||||
|
shapes.push_back(std::make_tuple(1, size, size));
|
||||||
|
}
|
||||||
|
TestRandomGemms<TypeParam>(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(CpuBackendGemmTest, MatrixTimesNarrow) {
|
||||||
|
std::vector<std::tuple<int, int, int>> shapes;
|
||||||
|
for (int size = 1; size < 100; size++) {
|
||||||
|
shapes.push_back(std::make_tuple(size, size, 2));
|
||||||
|
shapes.push_back(std::make_tuple(size, size, 3));
|
||||||
|
shapes.push_back(std::make_tuple(size, size, 4));
|
||||||
|
shapes.push_back(std::make_tuple(size, size, 8));
|
||||||
|
}
|
||||||
|
TestRandomGemms<TypeParam>(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(CpuBackendGemmTest, Rectangular) {
|
||||||
|
std::vector<std::tuple<int, int, int>> shapes;
|
||||||
|
for (int size = 1; size < 50; size++) {
|
||||||
|
shapes.push_back(std::make_tuple(size, size + 5, size + 1));
|
||||||
|
shapes.push_back(std::make_tuple(size + 10, size + 2, size));
|
||||||
|
}
|
||||||
|
TestRandomGemms<TypeParam>(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(CpuBackendGemmTest, HighlyRectangular) {
|
||||||
|
std::vector<std::tuple<int, int, int>> shapes;
|
||||||
|
for (int size = 1; size <= 10000; size *= 10) {
|
||||||
|
shapes.push_back(std::make_tuple(size, 10, 10));
|
||||||
|
shapes.push_back(std::make_tuple(10, size, 10));
|
||||||
|
shapes.push_back(std::make_tuple(10, 10, size));
|
||||||
|
}
|
||||||
|
TestRandomGemms<TypeParam>(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(CpuBackendGemmTest, InnerProduct) {
|
||||||
|
std::vector<std::tuple<int, int, int>> shapes;
|
||||||
|
for (int size = 1; size < 200; size++) {
|
||||||
|
shapes.push_back(std::make_tuple(1, size, 1));
|
||||||
|
}
|
||||||
|
TestRandomGemms<TypeParam>(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(CpuBackendGemmTest, OuterProduct) {
|
||||||
|
std::vector<std::tuple<int, int, int>> shapes;
|
||||||
|
for (int size = 1; size < 200; size++) {
|
||||||
|
shapes.push_back(std::make_tuple(size, 1, size));
|
||||||
|
}
|
||||||
|
TestRandomGemms<TypeParam>(shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user