Add cacheable
flag to Ruy Matrix so that caller "opts in" to cache behavior on a per-call basis
PiperOrigin-RevId: 286241942 Change-Id: Ie1320c17f6a50468a03dad2664a1c8645e09f3ce
This commit is contained in:
parent
ddb75d9d1e
commit
112e2b3fc7
@ -382,13 +382,15 @@ struct CompileTimeEnabledReferenceMul</*ReferenceMulIsEnabled=*/false> {
|
||||
}
|
||||
};
|
||||
|
||||
inline void HandlePrepackedCaching(TrMulParams* params, Context* context) {
|
||||
inline void HandlePrepackedCaching(TrMulParams* params,
|
||||
const SidePair<bool>& cacheable,
|
||||
Context* context) {
|
||||
if (context->cache_policy == CachePolicy::kNoCache) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (context->cache_policy == CachePolicy::kCacheLHSOnGemV) {
|
||||
if (params->dst.layout.cols != 1) {
|
||||
if (!cacheable[Side::kLhs] || params->dst.layout.cols != 1) {
|
||||
return;
|
||||
}
|
||||
PrepackedCache* prepacked_cache = context->GetPrepackedCache();
|
||||
@ -465,7 +467,8 @@ void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
|
||||
TrMulParams params;
|
||||
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
|
||||
the_path, ¶ms);
|
||||
HandlePrepackedCaching(¶ms, context);
|
||||
SidePair<bool> cacheable(lhs.cacheable, rhs.cacheable);
|
||||
HandlePrepackedCaching(¶ms, cacheable, context);
|
||||
TrMul(¶ms, context);
|
||||
}
|
||||
|
||||
|
@ -108,6 +108,7 @@ template <typename Scalar>
|
||||
struct Matrix final {
|
||||
Matrix& operator=(const Matrix& other) {
|
||||
data = other.data;
|
||||
cacheable = other.cacheable;
|
||||
layout = other.layout;
|
||||
zero_point = other.zero_point;
|
||||
return *this;
|
||||
@ -120,6 +121,10 @@ struct Matrix final {
|
||||
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
|
||||
// When Scalar is floating-point, this must be 0.
|
||||
Scalar zero_point = 0;
|
||||
// Clients of Ruy must set this flag to enable any caching behavior. Doesn't
|
||||
// impact numerical results, but caching can impact observable metrics like
|
||||
// latency, memory usage, power, etc.
|
||||
bool cacheable = false;
|
||||
};
|
||||
|
||||
inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
|
||||
|
@ -25,7 +25,6 @@ namespace ruy {
|
||||
namespace {
|
||||
|
||||
TEST(PrepackedCacheTest, TestCacheEjection) {
|
||||
ruy::Context* context = new ruy::Context();
|
||||
// Create the cache.
|
||||
PrepackedCache prepacked_cache(32);
|
||||
// Allocate the prepacked matrix.
|
||||
@ -54,11 +53,9 @@ TEST(PrepackedCacheTest, TestCacheEjection) {
|
||||
// The cache size was exceeded by inserting mat2. Ensure that mat1 was
|
||||
// ejected.
|
||||
EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
|
||||
delete context;
|
||||
}
|
||||
|
||||
TEST(PrepackedCacheTest, TestCacheBasic) {
|
||||
ruy::Context* context = new ruy::Context();
|
||||
// Create the cache.
|
||||
PrepackedCache prepacked_cache(48);
|
||||
// Allocate the prepacked matrix.
|
||||
@ -83,11 +80,9 @@ TEST(PrepackedCacheTest, TestCacheBasic) {
|
||||
// The cache size was not exceeded by inserting mat2. Ensure that mat1 was not
|
||||
// ejected.
|
||||
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
|
||||
delete context;
|
||||
}
|
||||
|
||||
TEST(PrepackedCacheTest, TestCacheEjection2) {
|
||||
ruy::Context* context = new ruy::Context();
|
||||
// Create the cache.
|
||||
PrepackedCache prepacked_cache(73);
|
||||
// Allocate the prepacked matrix 1.
|
||||
@ -137,7 +132,39 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
|
||||
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend());
|
||||
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
|
||||
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend());
|
||||
delete context;
|
||||
}
|
||||
|
||||
TEST(PrepackedCacheTest, TestCacheOnCacheable) {
|
||||
// Create context and set the cache policy
|
||||
ruy::Context context;
|
||||
context.cache_policy = ruy::kCacheLHSOnGemV;
|
||||
PrepackedCache* cache = context.GetPrepackedCache();
|
||||
EXPECT_EQ(cache->TotalSize(), 0);
|
||||
|
||||
const float lhs_data[] = {1, 2, 3, 4};
|
||||
const float rhs_data[] = {1, 2};
|
||||
float dst_data[4];
|
||||
|
||||
ruy::Matrix<float> lhs;
|
||||
ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
|
||||
lhs.data = lhs_data;
|
||||
ruy::Matrix<float> rhs;
|
||||
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout);
|
||||
rhs.data = rhs_data;
|
||||
ruy::Matrix<float> dst;
|
||||
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout);
|
||||
dst.data = dst_data;
|
||||
|
||||
ruy::BasicSpec<float, float> spec;
|
||||
// Perform the multiplication and confirm no caching occured.
|
||||
ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
|
||||
EXPECT_EQ(cache->TotalSize(), 0);
|
||||
|
||||
// Set cacheable for the LHS, repeat the multiplication, and see
|
||||
// that caching did occur.
|
||||
lhs.cacheable = true;
|
||||
ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
|
||||
EXPECT_NE(cache->TotalSize(), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user