From 112e2b3fc70b46571d1eb83d16cd71d9b1c015f8 Mon Sep 17 00:00:00 2001 From: "T.J. Alumbaugh" Date: Wed, 18 Dec 2019 12:21:42 -0800 Subject: [PATCH] Add `cacheable` flag to Ruy Matrix so that caller "opts in" to cache behavior on a per-call basis PiperOrigin-RevId: 286241942 Change-Id: Ie1320c17f6a50468a03dad2664a1c8645e09f3ce --- tensorflow/lite/experimental/ruy/dispatch.h | 9 +++-- tensorflow/lite/experimental/ruy/matrix.h | 5 +++ .../experimental/ruy/prepacked_cache_test.cc | 39 ++++++++++++++++--- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/experimental/ruy/dispatch.h b/tensorflow/lite/experimental/ruy/dispatch.h index 0aaaccafb2e..de5f3c3e9b4 100644 --- a/tensorflow/lite/experimental/ruy/dispatch.h +++ b/tensorflow/lite/experimental/ruy/dispatch.h @@ -382,13 +382,15 @@ struct CompileTimeEnabledReferenceMul { } }; -inline void HandlePrepackedCaching(TrMulParams* params, Context* context) { +inline void HandlePrepackedCaching(TrMulParams* params, + const SidePair& cacheable, + Context* context) { if (context->cache_policy == CachePolicy::kNoCache) { return; } if (context->cache_policy == CachePolicy::kCacheLHSOnGemV) { - if (params->dst.layout.cols != 1) { + if (!cacheable[Side::kLhs] || params->dst.layout.cols != 1) { return; } PrepackedCache* prepacked_cache = context->GetPrepackedCache(); @@ -465,7 +467,8 @@ void DispatchMul(const Matrix& lhs, const Matrix& rhs, TrMulParams params; CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, the_path, ¶ms); - HandlePrepackedCaching(¶ms, context); + SidePair cacheable(lhs.cacheable, rhs.cacheable); + HandlePrepackedCaching(¶ms, cacheable, context); TrMul(¶ms, context); } diff --git a/tensorflow/lite/experimental/ruy/matrix.h b/tensorflow/lite/experimental/ruy/matrix.h index bd11248c8c1..978714c353e 100644 --- a/tensorflow/lite/experimental/ruy/matrix.h +++ b/tensorflow/lite/experimental/ruy/matrix.h @@ -108,6 +108,7 @@ template struct Matrix final { Matrix& operator=(const Matrix& other) { data = other.data; + cacheable = other.cacheable; layout = other.layout; zero_point = other.zero_point; return *this; @@ -120,6 +121,10 @@ struct Matrix final { // The zero_point, i.e. which Scalar value is to be interpreted as zero. // When Scalar is floating-point, this must be 0. Scalar zero_point = 0; + // Clients of Ruy must set this flag to enable any caching behavior. Doesn't + // impact numerical results, but caching can impact observable metrics like + // latency, memory usage, power, etc. + bool cacheable = false; }; inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { diff --git a/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc b/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc index efb6f2b358c..e4b1379b43a 100644 --- a/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc +++ b/tensorflow/lite/experimental/ruy/prepacked_cache_test.cc @@ -25,7 +25,6 @@ namespace ruy { namespace { TEST(PrepackedCacheTest, TestCacheEjection) { - ruy::Context* context = new ruy::Context(); // Create the cache. PrepackedCache prepacked_cache(32); // Allocate the prepacked matrix. @@ -54,11 +53,9 @@ TEST(PrepackedCacheTest, TestCacheEjection) { // The cache size was exceeded by inserting mat2. Ensure that mat1 was // ejected. EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - delete context; } TEST(PrepackedCacheTest, TestCacheBasic) { - ruy::Context* context = new ruy::Context(); // Create the cache. PrepackedCache prepacked_cache(48); // Allocate the prepacked matrix. @@ -83,11 +80,9 @@ TEST(PrepackedCacheTest, TestCacheBasic) { // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not // ejected. EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - delete context; } TEST(PrepackedCacheTest, TestCacheEjection2) { - ruy::Context* context = new ruy::Context(); // Create the cache. PrepackedCache prepacked_cache(73); // Allocate the prepacked matrix 1. @@ -137,7 +132,39 @@ TEST(PrepackedCacheTest, TestCacheEjection2) { EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend()); - delete context; +} + +TEST(PrepackedCacheTest, TestCacheOnCacheable) { + // Create context and set the cache policy + ruy::Context context; + context.cache_policy = ruy::kCacheLHSOnGemV; + PrepackedCache* cache = context.GetPrepackedCache(); + EXPECT_EQ(cache->TotalSize(), 0); + + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2}; + float dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + // Perform the multiplication and confirm no caching occured. + ruy::Mul(lhs, rhs, spec, &context, &dst); + EXPECT_EQ(cache->TotalSize(), 0); + + // Set cacheable for the LHS, repeat the multiplication, and see + // that caching did occur. + lhs.cacheable = true; + ruy::Mul(lhs, rhs, spec, &context, &dst); + EXPECT_NE(cache->TotalSize(), 0); } } // namespace