Add cacheable
flag to Ruy Matrix so that caller "opts in" to cache behavior on a per-call basis
PiperOrigin-RevId: 286241942 Change-Id: Ie1320c17f6a50468a03dad2664a1c8645e09f3ce
This commit is contained in:
parent
ddb75d9d1e
commit
112e2b3fc7
@ -382,13 +382,15 @@ struct CompileTimeEnabledReferenceMul</*ReferenceMulIsEnabled=*/false> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void HandlePrepackedCaching(TrMulParams* params, Context* context) {
|
inline void HandlePrepackedCaching(TrMulParams* params,
|
||||||
|
const SidePair<bool>& cacheable,
|
||||||
|
Context* context) {
|
||||||
if (context->cache_policy == CachePolicy::kNoCache) {
|
if (context->cache_policy == CachePolicy::kNoCache) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (context->cache_policy == CachePolicy::kCacheLHSOnGemV) {
|
if (context->cache_policy == CachePolicy::kCacheLHSOnGemV) {
|
||||||
if (params->dst.layout.cols != 1) {
|
if (!cacheable[Side::kLhs] || params->dst.layout.cols != 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
PrepackedCache* prepacked_cache = context->GetPrepackedCache();
|
PrepackedCache* prepacked_cache = context->GetPrepackedCache();
|
||||||
@ -465,7 +467,8 @@ void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
|
|||||||
TrMulParams params;
|
TrMulParams params;
|
||||||
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
|
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst,
|
||||||
the_path, ¶ms);
|
the_path, ¶ms);
|
||||||
HandlePrepackedCaching(¶ms, context);
|
SidePair<bool> cacheable(lhs.cacheable, rhs.cacheable);
|
||||||
|
HandlePrepackedCaching(¶ms, cacheable, context);
|
||||||
TrMul(¶ms, context);
|
TrMul(¶ms, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,6 +108,7 @@ template <typename Scalar>
|
|||||||
struct Matrix final {
|
struct Matrix final {
|
||||||
Matrix& operator=(const Matrix& other) {
|
Matrix& operator=(const Matrix& other) {
|
||||||
data = other.data;
|
data = other.data;
|
||||||
|
cacheable = other.cacheable;
|
||||||
layout = other.layout;
|
layout = other.layout;
|
||||||
zero_point = other.zero_point;
|
zero_point = other.zero_point;
|
||||||
return *this;
|
return *this;
|
||||||
@ -120,6 +121,10 @@ struct Matrix final {
|
|||||||
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
|
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
|
||||||
// When Scalar is floating-point, this must be 0.
|
// When Scalar is floating-point, this must be 0.
|
||||||
Scalar zero_point = 0;
|
Scalar zero_point = 0;
|
||||||
|
// Clients of Ruy must set this flag to enable any caching behavior. Doesn't
|
||||||
|
// impact numerical results, but caching can impact observable metrics like
|
||||||
|
// latency, memory usage, power, etc.
|
||||||
|
bool cacheable = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
|
inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) {
|
||||||
|
@ -25,7 +25,6 @@ namespace ruy {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(PrepackedCacheTest, TestCacheEjection) {
|
TEST(PrepackedCacheTest, TestCacheEjection) {
|
||||||
ruy::Context* context = new ruy::Context();
|
|
||||||
// Create the cache.
|
// Create the cache.
|
||||||
PrepackedCache prepacked_cache(32);
|
PrepackedCache prepacked_cache(32);
|
||||||
// Allocate the prepacked matrix.
|
// Allocate the prepacked matrix.
|
||||||
@ -54,11 +53,9 @@ TEST(PrepackedCacheTest, TestCacheEjection) {
|
|||||||
// The cache size was exceeded by inserting mat2. Ensure that mat1 was
|
// The cache size was exceeded by inserting mat2. Ensure that mat1 was
|
||||||
// ejected.
|
// ejected.
|
||||||
EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
|
EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
|
||||||
delete context;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PrepackedCacheTest, TestCacheBasic) {
|
TEST(PrepackedCacheTest, TestCacheBasic) {
|
||||||
ruy::Context* context = new ruy::Context();
|
|
||||||
// Create the cache.
|
// Create the cache.
|
||||||
PrepackedCache prepacked_cache(48);
|
PrepackedCache prepacked_cache(48);
|
||||||
// Allocate the prepacked matrix.
|
// Allocate the prepacked matrix.
|
||||||
@ -83,11 +80,9 @@ TEST(PrepackedCacheTest, TestCacheBasic) {
|
|||||||
// The cache size was not exceeded by inserting mat2. Ensure that mat1 was not
|
// The cache size was not exceeded by inserting mat2. Ensure that mat1 was not
|
||||||
// ejected.
|
// ejected.
|
||||||
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
|
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
|
||||||
delete context;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PrepackedCacheTest, TestCacheEjection2) {
|
TEST(PrepackedCacheTest, TestCacheEjection2) {
|
||||||
ruy::Context* context = new ruy::Context();
|
|
||||||
// Create the cache.
|
// Create the cache.
|
||||||
PrepackedCache prepacked_cache(73);
|
PrepackedCache prepacked_cache(73);
|
||||||
// Allocate the prepacked matrix 1.
|
// Allocate the prepacked matrix 1.
|
||||||
@ -137,7 +132,39 @@ TEST(PrepackedCacheTest, TestCacheEjection2) {
|
|||||||
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend());
|
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend());
|
||||||
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
|
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend());
|
||||||
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend());
|
EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend());
|
||||||
delete context;
|
}
|
||||||
|
|
||||||
|
TEST(PrepackedCacheTest, TestCacheOnCacheable) {
|
||||||
|
// Create context and set the cache policy
|
||||||
|
ruy::Context context;
|
||||||
|
context.cache_policy = ruy::kCacheLHSOnGemV;
|
||||||
|
PrepackedCache* cache = context.GetPrepackedCache();
|
||||||
|
EXPECT_EQ(cache->TotalSize(), 0);
|
||||||
|
|
||||||
|
const float lhs_data[] = {1, 2, 3, 4};
|
||||||
|
const float rhs_data[] = {1, 2};
|
||||||
|
float dst_data[4];
|
||||||
|
|
||||||
|
ruy::Matrix<float> lhs;
|
||||||
|
ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout);
|
||||||
|
lhs.data = lhs_data;
|
||||||
|
ruy::Matrix<float> rhs;
|
||||||
|
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout);
|
||||||
|
rhs.data = rhs_data;
|
||||||
|
ruy::Matrix<float> dst;
|
||||||
|
ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout);
|
||||||
|
dst.data = dst_data;
|
||||||
|
|
||||||
|
ruy::BasicSpec<float, float> spec;
|
||||||
|
// Perform the multiplication and confirm no caching occured.
|
||||||
|
ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
|
||||||
|
EXPECT_EQ(cache->TotalSize(), 0);
|
||||||
|
|
||||||
|
// Set cacheable for the LHS, repeat the multiplication, and see
|
||||||
|
// that caching did occur.
|
||||||
|
lhs.cacheable = true;
|
||||||
|
ruy::Mul<ruy::kAllPaths>(lhs, rhs, spec, &context, &dst);
|
||||||
|
EXPECT_NE(cache->TotalSize(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user