Add low-level pre-packing API in ruy_advanced.h

Also, extend benchmark/test infrastructure to use it.

This also changes TestSet::results to be a vector of unique_ptr's. We
shouldn't be copying StorageMatrix and especially the new Allocator which
is in TestResult to hold prepacked matrices.

This also introduces ruy_advanced.h, which exposes Ruy's "advanced" API. Currently, the only "advanced" interface is the low-level pre-packing interface.

PiperOrigin-RevId: 247669742
This commit is contained in:
Sean Silva 2019-05-10 13:51:38 -07:00 committed by TensorFlower Gardener
parent e617fa4f6d
commit 5a75d8bd0f
10 changed files with 611 additions and 104 deletions

View File

@ -265,11 +265,13 @@ cc_library(
srcs = [
"dispatch.h",
"impl.h",
"prepack.h",
],
hdrs = [
"matrix.h",
"path.h",
"ruy.h",
"ruy_advanced.h",
],
visibility = ruy_visibility(),
deps = [
@ -290,7 +292,7 @@ cc_library(
],
)
# Just a usage example.
# Usage examples.
cc_binary(
name = "example",
srcs = ["example.cc"],
@ -299,6 +301,15 @@ cc_binary(
],
)
# Usage examples of the advanced API.
cc_binary(
name = "example_advanced",
srcs = ["example_advanced.cc"],
deps = [
":ruy",
],
)
# Small library to query PMU counters, for benchmark only
cc_library(
name = "pmu",

View File

@ -36,7 +36,8 @@ struct BenchmarkShape {
};
template <typename TestSetType>
std::vector<TestResult<DstScalar>> BenchmarkRCC(const BenchmarkShape& shape) {
std::vector<std::unique_ptr<TestResult<DstScalar>>> BenchmarkRCC(
const BenchmarkShape& shape) {
TestSetType test_set;
test_set.rows = shape.rows;
test_set.depth = shape.depth;
@ -52,8 +53,10 @@ std::vector<TestResult<DstScalar>> BenchmarkRCC(const BenchmarkShape& shape) {
test_set.rhs_zero_point = SymmetricZeroPoint<RhsScalar>() + asymmetry_rhs;
test_set.use_specified_zero_points = true;
test_set.perchannel = GetBoolEnvVarOrFalse("PERCHANNEL");
test_set.benchmark_prepack_lhs = GetBoolEnvVarOrFalse("PREPACK_LHS");
test_set.benchmark_prepack_rhs = GetBoolEnvVarOrFalse("PREPACK_RHS");
test_set.Run();
return test_set.results;
return std::move(test_set.results);
}
void Benchmark() {
@ -108,7 +111,7 @@ void Benchmark() {
if (benchmark_cubic) {
printf("size");
for (const auto& result : results) {
printf(",%s", PathName(result).c_str());
printf(",%s", PathName(*result).c_str());
}
printf("\n");
} else {
@ -119,27 +122,28 @@ void Benchmark() {
if (benchmark_cubic) {
printf("%d", shape.rows);
for (const auto& result : results) {
printf(",%.4g",
2.0e-9 * shape.rows * shape.cols * shape.depth / result.latency);
printf(",%.4g", 2.0e-9 * shape.rows * shape.cols * shape.depth /
result->latency);
if (getenv("RUY_BENCHMARK_PMU")) {
printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", result.l1_refill_rate,
result.l2_refill_rate, result.l3_refill_rate,
result.mispred_rate, result.frontend_stall_rate,
result.backend_stall_rate);
printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", result->l1_refill_rate,
result->l2_refill_rate, result->l3_refill_rate,
result->mispred_rate, result->frontend_stall_rate,
result->backend_stall_rate);
}
}
printf("\n");
fflush(stdout);
} else {
for (const auto& result : results) {
printf("%s,%dx%dx%d,%.4g", PathName(result).c_str(), shape.rows,
shape.depth, shape.cols,
2.0e-9 * shape.rows * shape.cols * shape.depth / result.latency);
printf(
"%s,%dx%dx%d,%.4g", PathName(*result).c_str(), shape.rows,
shape.depth, shape.cols,
2.0e-9 * shape.rows * shape.cols * shape.depth / result->latency);
if (getenv("RUY_BENCHMARK_PMU")) {
printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", result.l1_refill_rate,
result.l2_refill_rate, result.l3_refill_rate,
result.mispred_rate, result.frontend_stall_rate,
result.backend_stall_rate);
printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", result->l1_refill_rate,
result->l2_refill_rate, result->l3_refill_rate,
result->mispred_rate, result->frontend_stall_rate,
result->backend_stall_rate);
}
printf("\n");
}

View File

@ -0,0 +1,80 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iostream>
#include "tensorflow/lite/experimental/ruy/ruy_advanced.h"
// Simple allocator for allocating pre-packed matrices.
class SimpleAllocator {
public:
void* AllocateBytes(std::size_t num_bytes) {
char* p = new char[num_bytes];
buffers_.emplace_back(p);
return static_cast<void*>(p);
}
private:
std::vector<std::unique_ptr<char[]>> buffers_;
};
void ExamplePrepack(ruy::Context* context) {
const float lhs_data[] = {1, 2, 3, 4};
const float rhs_data[] = {1, 2, 3, 4};
float dst_data[4];
// Set up the matrix layouts and spec.
ruy::Matrix<float> lhs;
ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
ruy::Matrix<float> rhs;
ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout);
ruy::Matrix<float> dst;
ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout);
ruy::BasicSpec<float, float> spec;
SimpleAllocator allocator;
auto alloc_fn = [&allocator](std::size_t num_bytes) -> void* {
return allocator.AllocateBytes(num_bytes);
};
// In this example, we pre-pack only the RHS, but either will work.
// Note that we only need to set the data pointer for the matrix we are
// pre-packing.
ruy::PrepackedMatrix prepacked_rhs;
rhs.data = rhs_data;
ruy::PrePackForMul<ruy::kAllPaths>(lhs, rhs, spec, context, &dst,
/*prepacked_lhs=*/nullptr, &prepacked_rhs,
alloc_fn);
// No data will be read from the RHS input matrix when using a pre-packed RHS.
rhs.data = nullptr;
lhs.data = lhs_data;
dst.data = dst_data;
ruy::MulWithPrepacked<ruy::kAllPaths>(lhs, rhs, spec, context, &dst,
/*prepacked_lhs=*/nullptr,
&prepacked_rhs);
rhs.data = rhs_data;
// Print out the results.
std::cout << "Example Mul with pre-packing RHS, float:\n";
std::cout << "LHS:\n" << lhs;
std::cout << "RHS:\n" << rhs;
std::cout << "Result:\n" << dst << "\n";
}
int main() {
ruy::Context context;
ExamplePrepack(&context);
}

View File

@ -59,6 +59,8 @@ struct TrMulParams {
DMatrix dst;
PMatrix packed_lhs;
PMatrix packed_rhs;
bool lhs_is_prepacked = false;
bool rhs_is_prepacked = false;
// Type-erased Spec.
void* spec = nullptr;
@ -239,14 +241,23 @@ inline void TrMul(TrMulParams* params, Context* context) {
const auto loop_structure = GetLoopStructure(thread_count, rows, cols, depth);
const Tuning tuning = GetTuning(context);
Allocator* allocator = context->GetMainAllocator();
AllocatePMatrix(allocator, &packed_lhs);
AllocatePMatrix(allocator, &packed_rhs);
if (!params->lhs_is_prepacked) {
AllocatePMatrix(allocator, &packed_lhs);
}
if (!params->rhs_is_prepacked) {
AllocatePMatrix(allocator, &packed_rhs);
}
if (loop_structure == LoopStructure::kSimple) {
gemmlowp::ScopedProfilingLabel label_simple("TrMulImpl, simple loop");
params->LhsRunPack(tuning, 0, rows_rounded_up);
params->RhsRunPack(tuning, 0, cols_rounded_up);
if (!params->lhs_is_prepacked) {
params->LhsRunPack(tuning, 0, rows_rounded_up);
}
if (!params->rhs_is_prepacked) {
params->RhsRunPack(tuning, 0, cols_rounded_up);
}
params->RunKernel(tuning, 0, 0, rows_rounded_up, cols_rounded_up);
allocator->FreeAll();
@ -277,21 +288,29 @@ inline void TrMul(TrMulParams* params, Context* context) {
}
// Allocate memory.
std::atomic<bool>* lhs_packed;
allocator->Allocate(num_blocks_of_rows, &lhs_packed);
std::atomic<bool>* rhs_packed;
allocator->Allocate(num_blocks_of_cols, &rhs_packed);
std::atomic<bool>* lhs_packed = nullptr;
if (!params->lhs_is_prepacked) {
allocator->Allocate(num_blocks_of_rows, &lhs_packed);
}
std::atomic<bool>* rhs_packed = nullptr;
if (!params->rhs_is_prepacked) {
allocator->Allocate(num_blocks_of_cols, &rhs_packed);
}
std::atomic<std::uint32_t>* atomic_n;
allocator->Allocate(1, &atomic_n);
TrMulTask* tasks;
allocator->Allocate(thread_count, &tasks);
// Initialize allocated data.
for (int i = 0; i < num_blocks_of_rows; i++) {
lhs_packed[i].store(false, std::memory_order_release);
if (lhs_packed != nullptr) {
for (int i = 0; i < num_blocks_of_rows; i++) {
lhs_packed[i].store(false, std::memory_order_release);
}
}
for (int i = 0; i < num_blocks_of_cols; i++) {
rhs_packed[i].store(false, std::memory_order_release);
if (rhs_packed != nullptr) {
for (int i = 0; i < num_blocks_of_cols; i++) {
rhs_packed[i].store(false, std::memory_order_release);
}
}
atomic_n->store(thread_count);

View File

@ -21,8 +21,9 @@ limitations under the License.
// TODO(silvasean): Put parts of this architecture description somewhere more
// prominent.
//
// The 4 different matrix types are:
// - Matrix<T>: This is a user-facing type on Ruy's external API boundary.
// The 4 main matrix types are:
// - Matrix<T>: This is a user-facing type on Ruy's external API boundary. It is
// also used internally.
// - DMatrix: This is a type-erased version of Matrix<T>. "D" = "dynamic".
// - PMatrix: This represents a packed matrix, which requires tracking kernel
// layout and row/column sums for quantization. It is type-erased.
@ -71,10 +72,20 @@ limitations under the License.
//
// To present another structured view of our various matrix types, here's a
// table:
// User matrices Packed matrices
// Plain matrices Packed matrices
// +----------------------------------
// Templated | Matrix<T> PackedMatrix<T>
// Type-erased | DMatrix PMatrix
//
//
// There is 1 additional matrix type not mentioned above, due to its low
// importance:
// - PrepackedMatrix: This is a user-facing version of PMatrix. It has the bare
// minimum of fields needed for representing the raw data and sums buffers of a
// packed matrix for the "advanced" explicit pre-packing API. This type plays no
// role in Ruy's internals and can generally by ignored. The only reason it
// exists is so that PMatrix is not exposed to users -- we prefer to keep the
// internal matrix types hidden, even from "advanced" users.
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_MATRIX_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_MATRIX_H_
#include <cstddef>
#include <cstdint>
#include <type_traits>
@ -52,8 +53,18 @@ class ConstCheckingPtr final {
using element_type = T;
// Convenience methods. Most `set` calls go through these.
void operator=(T* ptr) { set(ptr); }
void operator=(const T* ptr) { set(ptr); }
ConstCheckingPtr& operator=(T* ptr) {
set(ptr);
return *this;
}
ConstCheckingPtr& operator=(const T* ptr) {
set(ptr);
return *this;
}
ConstCheckingPtr& operator=(std::nullptr_t) {
set(static_cast<T*>(nullptr));
return *this;
}
// Core accessors. These encapsulate the main logic:
// - for `set`, the constness of the argument determines whether internal
@ -117,6 +128,15 @@ inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
layout->stride = order == Order::kColMajor ? rows : cols;
}
// Opaque data structure representing a pre-packed matrix, as obtained from
// Ruy's advanced API.
struct PrepackedMatrix {
void* data = nullptr;
std::size_t data_size = 0;
void* sums = nullptr;
std::size_t sums_size = 0;
};
template <typename StreamType, typename Scalar>
StreamType& operator<<(StreamType& stream, const Matrix<Scalar>& mat) {
for (int row = 0; row < mat.layout.rows; row++) {

View File

@ -13,6 +13,73 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// # What is "packing"?
//
// Before feeding data to the gemm kernels (the parts of Ruy that do lots
// of multiply-add operations), Ruy first performs a data transformation (which
// we call "packing") on the input matrices. This transformation has two main
// goals:
// - rearrange data into blocks that are a convenient size/layout for the gemm
// kernels to consume. This helps make the memory access pattern of the gemm
// kernel simpler and more contiguous, and puts the data in a layout most
// convenient for specific arithmetic instructions in the gemm kernel.
// - compute row/column sums needed for handling quantization with non-symmetric
// zero points.
//
// # Simplified algorithmic analysis of packing
//
// Packing is a relatively simple transformation which does a small constant
// amount of work on each element of an input matrix, and hence for an NxM
// matrix performs O(N*M) work. If N and M are of the same order, then this is
// O(N^2) work.
//
// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations.
// Note that if N, K, and M are all the same order, then the number of
// multiply-accumulate operations is O(N^3).
//
// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the
// case of all dimensions being roughly the same order.
//
// # Packing cost can be significant
//
// When matrix * matrix multiplications begin to look more like matrix * vector
// multiplications, packing cost can become significant. We sometimes call these
// cases "gemv-like".
//
// Continuing the algorithmic analysis above, if we consider a case where an
// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the
// situation is different. In this case, the multiply-accumulate work is only
// quadratic, so the quadratic cost of packing can be come significant.
//
// Another way to say this is that the cost of packing an input matrix (either
// the LHS or RHS) is amortized across the non-depth dimension of the opposite
// input matrix. Thus, when the LHS has very few rows or the RHS has very few
// columns, the cost of packing the opposite input matrix can become
// significant.
//
// As a rough rule of thumb, the cost of packing starts to become significant
// when either N or M is below 32 (and other dimensions are hundreds), with very
// significant packing costs at 8 or below. This varies by data type, Path, and
// tuning, so these numbers are only rough guides.
//
// One practical use case that is affected by this is inference of
// fully connected neural network layers with a low batch size. The weight
// matrix (which is a constant for inference) is the one affected by significant
// packing cost.
//
// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack
// input matrices that are affected by significant packing costs.
//
// # Implementation notes
//
// Ruy's packing routines always operate on a range of columns and can be
// applied to either the LHS or RHS. This is possible because Ruy internally
// implements a TrMul, so the accumulation along depth is done along columns of
// both the LHS and RHS (whereas for a normal Mul the accumulation along depth
// for the LHS is along rows). As another example, we are always computing
// column sums for quantization (and never row sums, since the LHS is
// transposed).
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_

View File

@ -0,0 +1,107 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Implementation of low-level pre-packing API.
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_
#include <functional>
#include "tensorflow/lite/experimental/ruy/context.h"
#include "tensorflow/lite/experimental/ruy/dispatch.h"
#include "tensorflow/lite/experimental/ruy/matrix.h"
#include "tensorflow/lite/experimental/ruy/path.h"
#include "tensorflow/lite/experimental/ruy/spec.h"
namespace ruy {
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename Spec>
void PrePackForMulInternal(const Matrix<LhsScalar>& lhs,
const Matrix<RhsScalar>& rhs, const Spec& spec,
Context* context, Matrix<DstScalar>* dst,
PrepackedMatrix* prepacked_lhs,
PrepackedMatrix* prepacked_rhs,
std::function<void*(std::size_t)> alloc_fn) {
gemmlowp::ScopedProfilingLabel label("PrePackForMul");
Path the_path = context->GetPathToTake<CompiledPaths>();
RUY_CHECK(the_path != Path::kReference);
constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference;
Matrix<LhsScalar> transposed_lhs(lhs);
Transpose(&transposed_lhs);
TrMulParams params;
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
the_path, &params);
Tuning tuning = GetTuning(context);
if (prepacked_lhs) {
prepacked_lhs->data_size = DataSize(params.packed_lhs);
prepacked_lhs->sums_size = SumsSize(params.packed_lhs);
prepacked_lhs->data = alloc_fn(prepacked_lhs->data_size);
prepacked_lhs->sums = alloc_fn(prepacked_lhs->sums_size);
params.packed_lhs.data = prepacked_lhs->data;
params.packed_lhs.sums = prepacked_lhs->sums;
params.LhsRunPack(tuning, 0, params.packed_lhs.layout.cols);
}
if (prepacked_rhs) {
prepacked_rhs->data_size = DataSize(params.packed_rhs);
prepacked_rhs->sums_size = SumsSize(params.packed_rhs);
prepacked_rhs->data = alloc_fn(prepacked_rhs->data_size);
prepacked_rhs->sums = alloc_fn(prepacked_rhs->sums_size);
params.packed_rhs.data = prepacked_rhs->data;
params.packed_rhs.sums = prepacked_rhs->sums;
params.RhsRunPack(tuning, 0, params.packed_rhs.layout.cols);
}
}
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename Spec>
void MulWithPrepackedInternal(const Matrix<LhsScalar>& lhs,
const Matrix<RhsScalar>& rhs, const Spec& spec,
Context* context, Matrix<DstScalar>* dst,
PrepackedMatrix* prepacked_lhs,
PrepackedMatrix* prepacked_rhs) {
gemmlowp::ScopedProfilingLabel label("MulWithPrepacked");
EnforceLayoutSupport<Spec>(lhs.layout, rhs.layout, dst->layout);
EnforceZeroPointSupport<Spec>(lhs.zero_point, rhs.zero_point,
dst->zero_point);
Path the_path = context->GetPathToTake<CompiledPaths>();
RUY_CHECK(the_path != Path::kReference);
constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference;
Matrix<LhsScalar> transposed_lhs(lhs);
Transpose(&transposed_lhs);
TrMulParams params;
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
the_path, &params);
if (prepacked_lhs) {
params.packed_lhs.data = prepacked_lhs->data;
params.packed_lhs.sums = prepacked_lhs->sums;
params.lhs_is_prepacked = true;
}
if (prepacked_rhs) {
params.packed_rhs.data = prepacked_rhs->data;
params.packed_rhs.sums = prepacked_rhs->sums;
params.rhs_is_prepacked = true;
}
TrMul(&params, context);
}
} // namespace ruy
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_

View File

@ -0,0 +1,60 @@
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_
#include "tensorflow/lite/experimental/ruy/prepack.h"
namespace ruy {
// Low-level, explicit pre-packing API.
//
// The cost of packing an input matrix (either the LHS or RHS) is amortized
// across the non-depth dimension of the opposite input matrix. Thus, when the
// LHS has very few rows or the RHS has very few columns, the cost of packing
// the opposite input matrix can become significant. See pack.h for further
// information on packing.
//
// This file provides an API allowing a user to explicitly pack a matrix and
// reuse the pre-packed matrix, avoiding that cost.
//
// See example_prepack.cc for example usage.
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename Spec>
void PrePackForMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
const Spec& spec, Context* context, Matrix<DstScalar>* dst,
PrepackedMatrix* prepacked_lhs,
PrepackedMatrix* prepacked_rhs,
std::function<void*(std::size_t)> alloc_fn) {
PrePackForMulInternal<CompiledPaths>(lhs, rhs, spec, context, dst,
prepacked_lhs, prepacked_rhs, alloc_fn);
}
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename Spec>
void MulWithPrepacked(const Matrix<LhsScalar>& lhs,
const Matrix<RhsScalar>& rhs, const Spec& spec,
Context* context, Matrix<DstScalar>* dst,
PrepackedMatrix* prepacked_lhs,
PrepackedMatrix* prepacked_rhs) {
MulWithPrepackedInternal<CompiledPaths>(lhs, rhs, spec, context, dst,
prepacked_lhs, prepacked_rhs);
}
} // namespace ruy
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_

View File

@ -31,6 +31,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/lite/experimental/ruy/pmu.h"
#include "tensorflow/lite/experimental/ruy/ruy.h"
#include "tensorflow/lite/experimental/ruy/ruy_advanced.h"
#include "tensorflow/lite/experimental/ruy/time.h"
#ifdef RUY_TEST_EXTERNAL_PATHS
@ -38,8 +39,8 @@ limitations under the License.
#define EIGEN_USE_CUSTOM_THREAD_POOL
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "third_party/lapack/blas.h"
#include "public/gemmlowp.h"
#include "third_party/lapack/blas.h"
#endif
#ifdef GEMMLOWP_PROFILING
@ -318,6 +319,9 @@ void MakeLayout(int rows, int cols, Order order, LayoutStyle layout_style,
template <typename Scalar>
struct StorageMatrix {
StorageMatrix() = default;
StorageMatrix(const StorageMatrix&) = delete;
void operator=(const StorageMatrix&) = delete;
std::vector<Scalar> data;
Matrix<Scalar> matrix;
};
@ -350,6 +354,8 @@ void MakeRandom(int rows, int cols, Order order, Scalar zero_point,
template <typename Scalar>
struct TestResult {
void operator=(const TestResult&) = delete;
void operator=(const TestResult&&) = delete;
StorageMatrix<Scalar> storage_matrix;
Path path = Path::kNone;
Tuning tuning = Tuning::kAuto;
@ -361,6 +367,14 @@ struct TestResult {
float mispred_rate;
float frontend_stall_rate;
float backend_stall_rate;
// Per-path data for pre-packing.
// This is not used by external paths or by Path::kReference.
Allocator allocator;
PrepackedMatrix prepacked_lhs;
PrepackedMatrix prepacked_rhs;
bool use_prepacked_lhs = false;
bool use_prepacked_rhs = false;
};
template <typename Scalar>
@ -389,6 +403,7 @@ struct TestSet final {
using AccumScalar = typename SpecType::AccumScalar;
using DstScalar = typename SpecType::DstScalar;
using Spec = SpecType;
using TestResultType = TestResult<DstScalar>;
void Run() {
MakeZeroPoints();
@ -396,6 +411,7 @@ struct TestSet final {
MakeSpec();
MakeOtherParams();
MakeResultPaths();
MakePrepackedMatrices();
Eval();
Verify();
}
@ -405,13 +421,16 @@ struct TestSet final {
void MakeLhsRhs();
void MakeSpec();
void MakeResultPaths();
void MakePrepackedMatrices();
void MakeOtherParams();
void EvalAndVerify();
void Eval();
void Verify();
void EvalResult(TestResult<DstScalar>* result);
void Benchmark(TestResult<DstScalar>* result);
void EvalResult(TestResultType* result);
void EvalRuy(TestResultType* result);
void DoMul(TestResultType* result);
void Benchmark(TestResultType* result);
void VerifyTestResults() const;
void VerifyNonTrivial() const;
@ -423,6 +442,7 @@ struct TestSet final {
kHasSpec,
kHasOtherParams,
kHasResultPaths,
kHasPrepackedMatrices,
kEvaluated,
kFinal
};
@ -455,7 +475,7 @@ struct TestSet final {
StorageMatrix<RhsScalar> rhs;
Spec spec;
std::vector<AccumScalar> bias_data;
std::vector<TestResult<DstScalar>> results;
std::vector<std::unique_ptr<TestResultType>> results;
std::vector<Path> paths;
std::vector<ExternalPath> external_paths;
@ -463,6 +483,8 @@ struct TestSet final {
bool benchmark = false;
bool perchannel = false;
int max_num_threads = 0;
bool benchmark_prepack_lhs = false;
bool benchmark_prepack_rhs = false;
};
Context& GlobalContext() {
@ -479,13 +501,40 @@ Context& GlobalContext() {
#endif
#endif // defined(__has_feature)
template <typename LhsScalar, typename RhsScalar, typename DstScalar,
typename Spec>
void EvalRuy(Path path, Tuning tuning, const Matrix<LhsScalar>& lhs,
const Matrix<RhsScalar>& rhs, const Spec& spec,
Matrix<DstScalar>* dst, ExpectedOutcome expected_outcome,
bool benchmark, int max_num_threads) {
GlobalContext().explicit_tuning = tuning;
template <typename LhsScalar, typename RhsScalar, typename SpecType>
void TestSet<LhsScalar, RhsScalar, SpecType>::DoMul(TestResultType* result) {
Context* context = &GlobalContext();
if (!result->use_prepacked_lhs && !result->use_prepacked_rhs) {
Mul<kAllPaths>(lhs.matrix, rhs.matrix, spec, context,
&result->storage_matrix.matrix);
return;
}
// If we prepacked an input matrix, null out its data pointer to check
// that we don't access any data through it.
Matrix<LhsScalar> null_data_lhs = lhs.matrix;
Matrix<RhsScalar> null_data_rhs = rhs.matrix;
if (result->use_prepacked_lhs) {
null_data_lhs.data = nullptr;
}
if (result->use_prepacked_rhs) {
null_data_rhs.data = nullptr;
}
// Do the multiplication with pre-packed matrices.
PrepackedMatrix* prepacked_lhs_ptr =
result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr;
PrepackedMatrix* prepacked_rhs_ptr =
result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr;
MulWithPrepacked<kAllPaths>(null_data_lhs, null_data_rhs, spec, context,
&result->storage_matrix.matrix, prepacked_lhs_ptr,
prepacked_rhs_ptr);
}
template <typename LhsScalar, typename RhsScalar, typename SpecType>
void TestSet<LhsScalar, RhsScalar, SpecType>::EvalRuy(TestResultType* result) {
GlobalContext().explicit_tuning = result->tuning;
if (max_num_threads) {
GlobalContext().max_num_threads = max_num_threads;
} else if (benchmark) {
@ -493,15 +542,15 @@ void EvalRuy(Path path, Tuning tuning, const Matrix<LhsScalar>& lhs,
} else {
GlobalContext().max_num_threads = 1 + global_random_engine()() % 8;
}
GlobalContext().SetRuntimeEnabledPaths(path);
GlobalContext().SetRuntimeEnabledPaths(result->path);
if (expected_outcome == ExpectedOutcome::kSuccess) {
Mul<kAllPaths>(lhs, rhs, spec, &GlobalContext(), dst);
RUY_CHECK(GlobalContext().last_taken_path == path);
DoMul(result);
RUY_CHECK(GlobalContext().last_taken_path == result->path);
} else if (expected_outcome == ExpectedOutcome::kDeath) {
// TODO(benoitjacob) TSan and ASan seem to be breaking ASSERT_DEATH.
// Report a bug?
#if (!defined NDEBUG) && (!defined RUY_ASAN) && (!defined RUY_TSAN)
ASSERT_DEATH(Mul<kAllPaths>(lhs, rhs, spec, &GlobalContext(), dst), "");
ASSERT_DEATH(DoMul(result), "");
#endif
} else {
RUY_CHECK(false);
@ -1194,9 +1243,9 @@ struct ErrorAnalysis {
template <typename TestSetType>
void AnalyzeTestError(const TestSetType& test_set, int first_bad_result_index,
ErrorAnalysis* error_analysis) {
const auto& good_matrix = test_set.results[0].storage_matrix.matrix;
const auto& good_matrix = test_set.results[0]->storage_matrix.matrix;
const auto& bad_matrix =
test_set.results[first_bad_result_index].storage_matrix.matrix;
test_set.results[first_bad_result_index]->storage_matrix.matrix;
GetMatrixStats(good_matrix, &error_analysis->stats_good);
GetMatrixStats(bad_matrix, &error_analysis->stats_bad);
bool found_first_error = false;
@ -1498,6 +1547,55 @@ std::vector<Tuning> EnumerateTuningsForPath(Path path, bool benchmark) {
return {Tuning::kAuto};
}
template <typename LhsScalar, typename RhsScalar, typename SpecType>
void TestSet<LhsScalar, RhsScalar, SpecType>::MakePrepackedMatrices() {
RUY_CHECK(life_stage == LifeStage::kHasResultPaths);
// Prepacked matrices are Path-dependent, so create them for each test result.
for (auto& result : results) {
// If this result uses an external path, then skip this entirely.
if (result->path == Path::kNone) {
continue;
}
// Pre-packing doesn't make sense for Path::kReference.
// TODO(silvasean): Make Path::kReference an ExternalPath?
if (result->path == Path::kReference) {
continue;
}
// Determine whether we should create/use prepacked matrices.
if (benchmark) {
// For benchmarking, do as requested.
result->use_prepacked_lhs = benchmark_prepack_lhs;
result->use_prepacked_rhs = benchmark_prepack_rhs;
} else {
// When testing, randomly pre-pack sometimes. But don't do it too often.
result->use_prepacked_lhs = (global_random_engine()() & 7) == 0;
result->use_prepacked_rhs = (global_random_engine()() & 7) == 0;
}
// Create the pre-packed matrices.
PrepackedMatrix* prepacked_lhs_ptr =
result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr;
PrepackedMatrix* prepacked_rhs_ptr =
result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr;
auto alloc_fn = [&result](std::size_t num_bytes) {
return result->allocator.AllocateBytes(num_bytes);
};
// Use a dst with a null data pointer to check that the pre-packing
// invocation doesn't write into it.
Matrix<DstScalar> null_data_dst = result->storage_matrix.matrix;
null_data_dst.data = nullptr;
GlobalContext().SetRuntimeEnabledPaths(result->path);
PrePackForMul<kAllPaths>(lhs.matrix, rhs.matrix, spec, &GlobalContext(),
&null_data_dst, prepacked_lhs_ptr,
prepacked_rhs_ptr, alloc_fn);
RUY_CHECK(GlobalContext().last_taken_path == result->path);
}
life_stage = LifeStage::kHasPrepackedMatrices;
}
template <typename LhsScalar, typename RhsScalar, typename SpecType>
void TestSet<LhsScalar, RhsScalar, SpecType>::MakeResultPaths() {
RUY_CHECK(life_stage == LifeStage::kHasOtherParams);
@ -1515,6 +1613,7 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::MakeResultPaths() {
// to allow specifying e.g. ffff to mean 'all paths' regardless of whether all
// those bits exist as actual paths.
paths_bitfield = paths_bitfield & kAllPaths;
RUY_CHECK(paths_bitfield != Path::kNone);
paths = PathsBitfieldAsVector(paths_bitfield);
#ifdef RUY_TEST_EXTERNAL_PATHS
@ -1554,8 +1653,8 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::MakeResultPaths() {
for (Path path : paths) {
for (Tuning tuning : EnumerateTuningsForPath(path, benchmark)) {
results.emplace_back();
TestResult<DstScalar>& result = results.back();
results.emplace_back(new TestResultType);
TestResultType& result = *results.back();
result.path = path;
result.tuning = tuning;
MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style,
@ -1564,8 +1663,8 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::MakeResultPaths() {
}
for (ExternalPath external_path : external_paths) {
results.emplace_back();
TestResult<DstScalar>& result = results.back();
results.emplace_back(new TestResultType);
TestResultType& result = *results.back();
result.external_path = external_path;
MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style,
RandomRange::kGeneral, &result.storage_matrix);
@ -1580,9 +1679,7 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::EvalResult(
RUY_CHECK(result->path != Path::kNone ||
result->external_path != ExternalPath::kNone);
if (result->path != Path::kNone) {
EvalRuy(result->path, result->tuning, lhs.matrix, rhs.matrix, spec,
&result->storage_matrix.matrix, expected_outcome, benchmark,
max_num_threads);
EvalRuy(result);
} else {
#ifdef RUY_TEST_EXTERNAL_PATHS
using TestSetType = TestSet<LhsScalar, RhsScalar, SpecType>;
@ -1645,16 +1742,35 @@ int StorageSize(const Matrix<Scalar>& matrix) {
return sizeof(Scalar) * FlatSize(matrix.layout);
}
template <typename Scalar>
void MakeColdData(int num_copies, const Matrix<Scalar>& matrix,
std::vector<Scalar>* cold_data) {
const int flatsize = FlatSize(matrix.layout);
cold_data->resize(num_copies * flatsize);
for (int i = 0; i < num_copies; i++) {
memcpy(cold_data->data() + i * flatsize, matrix.data.get(),
sizeof(Scalar) * flatsize);
// Helper that replicates a buffer and gives out pointers to the replicas.
// This is useful when one wants to traverse data so that it is cold in cache.
// By having a sufficiently large value of num_repeats, one can ensure that the
// working set covered by the replicas is greater than the cache size.
template <typename T>
class RepeatedBuffer {
public:
RepeatedBuffer() = default;
void Init(const T* elems, std::size_t num_elems, int num_repeats) {
buffers_.clear();
allocator_.FreeAll();
for (int i = 0; i < num_repeats; i++) {
T* p;
allocator_.Allocate(num_elems, &p);
memcpy(p, elems, num_elems * sizeof(T));
buffers_.push_back(p);
}
}
}
T* Next() {
T* ret = buffers_[current_];
current_ = (current_ + 1) % buffers_.size();
return ret;
}
private:
Allocator allocator_;
std::vector<T*> buffers_;
int current_ = 0;
};
template <typename LhsScalar, typename RhsScalar, typename SpecType>
void TestSet<LhsScalar, RhsScalar, SpecType>::Benchmark(
@ -1662,14 +1778,20 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::Benchmark(
using DstScalar = typename SpecType::DstScalar;
const bool cold = getenv("RUY_BENCHMARK_COLD");
const LhsScalar* orig_lhs_data = nullptr;
const RhsScalar* orig_rhs_data = nullptr;
DstScalar* orig_dst_data = nullptr;
std::vector<LhsScalar> cold_lhs_data;
std::vector<RhsScalar> cold_rhs_data;
std::vector<DstScalar> cold_dst_data;
LhsScalar* orig_lhs_data = lhs.matrix.data.get();
RhsScalar* orig_rhs_data = rhs.matrix.data.get();
DstScalar* orig_dst_data = result->storage_matrix.matrix.data.get();
void* orig_prepacked_lhs_data = result->prepacked_lhs.data;
void* orig_prepacked_rhs_data = result->prepacked_rhs.data;
int num_matmul_sets = 0;
RepeatedBuffer<LhsScalar> cold_lhs;
RepeatedBuffer<RhsScalar> cold_rhs;
RepeatedBuffer<DstScalar> cold_dst;
RepeatedBuffer<char> cold_prepacked_lhs;
RepeatedBuffer<char> cold_prepacked_rhs;
if (cold) {
const int kWorkingSetSize = 100 << 20;
const int each_matmul_set_size = StorageSize(lhs.matrix) +
@ -1678,14 +1800,21 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::Benchmark(
num_matmul_sets =
(kWorkingSetSize + each_matmul_set_size - 1) / each_matmul_set_size;
MakeColdData(num_matmul_sets, lhs.matrix, &cold_lhs_data);
MakeColdData(num_matmul_sets, rhs.matrix, &cold_rhs_data);
MakeColdData(num_matmul_sets, result->storage_matrix.matrix,
&cold_dst_data);
orig_lhs_data = lhs.matrix.data.get();
orig_rhs_data = rhs.matrix.data.get();
orig_dst_data = result->storage_matrix.matrix.data.get();
cold_lhs.Init(lhs.matrix.data.get(), FlatSize(lhs.matrix.layout),
num_matmul_sets);
cold_rhs.Init(rhs.matrix.data.get(), FlatSize(rhs.matrix.layout),
num_matmul_sets);
cold_dst.Init(result->storage_matrix.matrix.data.get(),
FlatSize(result->storage_matrix.matrix.layout),
num_matmul_sets);
if (benchmark_prepack_lhs) {
cold_prepacked_lhs.Init(static_cast<char*>(result->prepacked_lhs.data),
result->prepacked_lhs.data_size, num_matmul_sets);
}
if (benchmark_prepack_rhs) {
cold_prepacked_rhs.Init(static_cast<char*>(result->prepacked_rhs.data),
result->prepacked_rhs.data_size, num_matmul_sets);
}
}
int kRepeats = 4;
const double kBenchmarkMinSecs = 0.5;
@ -1704,7 +1833,6 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::Benchmark(
#endif
double latency = std::numeric_limits<double>::infinity();
int data_index = 0;
const bool record_pmu = getenv("RUY_BENCHMARK_PMU");
for (int repeat = 0; repeat < kRepeats; repeat++) {
PmuEvents pmu_events;
@ -1718,16 +1846,14 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::Benchmark(
while (ToSeconds(t - time_start) < kBenchmarkMinSecs) {
for (int i = 0; i < iters_at_a_time; i++) {
if (cold) {
lhs.matrix.data =
cold_lhs_data.data() + data_index * FlatSize(lhs.matrix.layout);
rhs.matrix.data =
cold_rhs_data.data() + data_index * FlatSize(rhs.matrix.layout);
result->storage_matrix.matrix.data =
cold_dst_data.data() +
data_index * FlatSize(result->storage_matrix.matrix.layout);
data_index++;
if (data_index == num_matmul_sets) {
data_index = 0;
lhs.matrix.data = cold_lhs.Next();
rhs.matrix.data = cold_rhs.Next();
result->storage_matrix.matrix.data = cold_dst.Next();
if (benchmark_prepack_lhs) {
result->prepacked_lhs.data = cold_prepacked_lhs.Next();
}
if (benchmark_prepack_rhs) {
result->prepacked_rhs.data = cold_prepacked_rhs.Next();
}
}
EvalResult(result);
@ -1763,19 +1889,21 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::Benchmark(
lhs.matrix.data = orig_lhs_data;
rhs.matrix.data = orig_rhs_data;
memcpy(orig_dst_data, result->storage_matrix.matrix.data.get(),
sizeof(DstScalar) * FlatSize(result->storage_matrix.matrix.layout));
StorageSize(result->storage_matrix.matrix));
result->storage_matrix.matrix.data = orig_dst_data;
result->prepacked_lhs.data = orig_prepacked_lhs_data;
result->prepacked_rhs.data = orig_prepacked_rhs_data;
}
}
template <typename LhsScalar, typename RhsScalar, typename SpecType>
void TestSet<LhsScalar, RhsScalar, SpecType>::Eval() {
RUY_CHECK(life_stage == LifeStage::kHasResultPaths);
RUY_CHECK(life_stage == LifeStage::kHasPrepackedMatrices);
for (auto& result : results) {
if (benchmark) {
Benchmark(&result);
Benchmark(result.get());
} else {
EvalResult(&result);
EvalResult(result.get());
}
}
life_stage = LifeStage::kEvaluated;
@ -1803,16 +1931,16 @@ template <typename LhsScalar, typename RhsScalar, typename SpecType>
void TestSet<LhsScalar, RhsScalar, SpecType>::VerifyTestResults() const {
const int depth = lhs.matrix.layout.cols;
for (int i = 0; i < results.size() - 1; i++) {
if (!Agree(results[i], results[i + 1], depth)) {
if (!Agree(*results[i], *results[i + 1], depth)) {
std::string paths_in_agreement;
paths_in_agreement.append(PathName(results[0]));
paths_in_agreement.append(PathName(*results[0]));
for (int j = 1; j <= i; j++) {
paths_in_agreement.append(", ");
paths_in_agreement.append(PathName(results[j]));
paths_in_agreement.append(PathName(*results[j]));
}
ErrorAnalysis error_analysis;
AnalyzeTestError(*this, i + 1, &error_analysis);
std::cerr << "Error: path (" << PathName(results[i + 1])
std::cerr << "Error: path (" << PathName(*results[i + 1])
<< ") disagrees with the other paths (" << paths_in_agreement
<< "), which agree with each other." << std::endl;
std::cerr << "Shape: rows = " << rows << ", cols = " << cols
@ -1841,12 +1969,12 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::VerifyTestResults() const {
std::cerr << "Bad value : " << error_analysis.first_error_bad_value
<< std::endl;
std::cerr << "Region of Good result matrix around first error:\n\n"
<< DumpRegion(results[0].storage_matrix.matrix,
<< DumpRegion(results[0]->storage_matrix.matrix,
error_analysis.row_of_first_error,
error_analysis.col_of_first_error)
<< std::endl;
std::cerr << "Region of Bad result matrix around first error:\n\n"
<< DumpRegion(results[i + 1].storage_matrix.matrix,
<< DumpRegion(results[i + 1]->storage_matrix.matrix,
error_analysis.row_of_first_error,
error_analysis.col_of_first_error)
<< std::endl;
@ -1860,12 +1988,12 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::VerifyNonTrivial() const {
if (getenv("QUICK_BENCHMARK")) {
return;
}
if (results.front().path != Path::kReference) {
if (results.front()->path != Path::kReference) {
return;
}
Context context;
context.SetRuntimeEnabledPaths(Path::kReference);
const auto& dst_storage = results.front().storage_matrix;
const auto& dst_storage = results.front()->storage_matrix;
const Matrix<DstScalar>& dst = dst_storage.matrix;
Matrix<DstScalar> unclamped_dst;
unclamped_dst.layout = dst.layout;