Add cacheable flag to Ruy Matrix so that caller "opts in" to cache behavior on a per-call basis

PiperOrigin-RevId: 286241942
Change-Id: Ie1320c17f6a50468a03dad2664a1c8645e09f3ce
This commit is contained in:
T.J. Alumbaugh 2019-12-18 12:21:42 -08:00 committed by TensorFlower Gardener
parent ddb75d9d1e
commit 112e2b3fc7
3 changed files with 44 additions and 9 deletions

View File

@ -382,13 +382,15 @@ struct CompileTimeEnabledReferenceMul</*ReferenceMulIsEnabled=*/false> {
} }
}; };
inline void HandlePrepackedCaching(TrMulParams* params, Context* context) { inline void HandlePrepackedCaching(TrMulParams* params,
const SidePair<bool>& cacheable,
Context* context) {
if (context->cache_policy == CachePolicy::kNoCache) { if (context->cache_policy == CachePolicy::kNoCache) {
return; return;
} }
if (context->cache_policy == CachePolicy::kCacheLHSOnGemV) { if (context->cache_policy == CachePolicy::kCacheLHSOnGemV) {
if (params->dst.layout.cols != 1) { if (!cacheable[Side::kLhs] || params->dst.layout.cols != 1) {
return; return;
} }
PrepackedCache* prepacked_cache = context->GetPrepackedCache(); PrepackedCache* prepacked_cache = context->GetPrepackedCache();
@ -465,7 +467,8 @@ void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
TrMulParams params; TrMulParams params;
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst, CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
the_path, &params); the_path, &params);
HandlePrepackedCaching(&params, context); SidePair<bool> cacheable(lhs.cacheable, rhs.cacheable);
HandlePrepackedCaching(&params, cacheable, context);
TrMul(&params, context); TrMul(&params, context);
} }

View File

@ -108,6 +108,7 @@ template <typename Scalar>
struct Matrix final { struct Matrix final {
Matrix& operator=(const Matrix& other) { Matrix& operator=(const Matrix& other) {
data = other.data; data = other.data;
cacheable = other.cacheable;
layout = other.layout; layout = other.layout;
zero_point = other.zero_point; zero_point = other.zero_point;
return *this; return *this;
@ -120,6 +121,10 @@ struct Matrix final {
// The zero_point, i.e. which Scalar value is to be interpreted as zero. // The zero_point, i.e. which Scalar value is to be interpreted as zero.
// When Scalar is floating-point, this must be 0. // When Scalar is floating-point, this must be 0.
Scalar zero_point = 0; Scalar zero_point = 0;
// Clients of Ruy must set this flag to enable any caching behavior. Doesn't
// impact numerical results, but caching can impact observable metrics like
// latency, memory usage, power, etc.
bool cacheable = false;
}; };
inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {

View File

@ -25,7 +25,6 @@ namespace ruy {
namespace { namespace {
TEST(PrepackedCacheTest, TestCacheEjection) { TEST(PrepackedCacheTest, TestCacheEjection) {
ruy::Context* context = new ruy::Context();
// Create the cache. // Create the cache.
PrepackedCache prepacked_cache(32); PrepackedCache prepacked_cache(32);
// Allocate the prepacked matrix. // Allocate the prepacked matrix.
@ -54,11 +53,9 @@ TEST(PrepackedCacheTest, TestCacheEjection) {
// The cache size was exceeded by inserting mat2. Ensure that mat1 was // The cache size was exceeded by inserting mat2. Ensure that mat1 was
// ejected. // ejected.
EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
delete context;
} }
TEST(PrepackedCacheTest, TestCacheBasic) { TEST(PrepackedCacheTest, TestCacheBasic) {
ruy::Context* context = new ruy::Context();
// Create the cache. // Create the cache.
PrepackedCache prepacked_cache(48); PrepackedCache prepacked_cache(48);
// Allocate the prepacked matrix. // Allocate the prepacked matrix.
@ -83,11 +80,9 @@ TEST(PrepackedCacheTest, TestCacheBasic) {
// The cache size was not exceeded by inserting mat2. Ensure that mat1 was not // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not
// ejected. // ejected.
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
delete context;
} }
TEST(PrepackedCacheTest, TestCacheEjection2) { TEST(PrepackedCacheTest, TestCacheEjection2) {
ruy::Context* context = new ruy::Context();
// Create the cache. // Create the cache.
PrepackedCache prepacked_cache(73); PrepackedCache prepacked_cache(73);
// Allocate the prepacked matrix 1. // Allocate the prepacked matrix 1.
@ -137,7 +132,39 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend());
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend()); EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend());
delete context; }
TEST(PrepackedCacheTest, TestCacheOnCacheable) {
// Create context and set the cache policy
ruy::Context context;
context.cache_policy = ruy::kCacheLHSOnGemV;
PrepackedCache* cache = context.GetPrepackedCache();
EXPECT_EQ(cache->TotalSize(), 0);
const float lhs_data[] = {1, 2, 3, 4};
const float rhs_data[] = {1, 2};
float dst_data[4];
ruy::Matrix<float> lhs;
ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
lhs.data = lhs_data;
ruy::Matrix<float> rhs;
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout);
rhs.data = rhs_data;
ruy::Matrix<float> dst;
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout);
dst.data = dst_data;
ruy::BasicSpec<float, float> spec;
// Perform the multiplication and confirm no caching occured.
ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
EXPECT_EQ(cache->TotalSize(), 0);
// Set cacheable for the LHS, repeat the multiplication, and see
// that caching did occur.
lhs.cacheable = true;
ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
EXPECT_NE(cache->TotalSize(), 0);
} }
} // namespace } // namespace