Add a unit test covering GetBlockByIndex. This is where traversal orders are implemented. A mistake there would not be caught in matrix multiplication tests as it would be a performance-only bug (or even a memory-locality-only bug not necessarily affecting latencies).
PiperOrigin-RevId: 292199511 Change-Id: Ie1c4e35ffc75a9d22395b7867ac82439821f3ee6
This commit is contained in:
parent
5177a9ff11
commit
0322cb8e1d
@ -207,6 +207,7 @@ cc_test(
|
|||||||
":block_map",
|
":block_map",
|
||||||
":cpu_cache_size",
|
":cpu_cache_size",
|
||||||
":path",
|
":path",
|
||||||
|
":side_pair",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -17,11 +17,14 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <cstdlib>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include "tensorflow/lite/experimental/ruy/cpu_cache_size.h"
|
#include "tensorflow/lite/experimental/ruy/cpu_cache_size.h"
|
||||||
#include "tensorflow/lite/experimental/ruy/path.h"
|
#include "tensorflow/lite/experimental/ruy/path.h"
|
||||||
|
#include "tensorflow/lite/experimental/ruy/side_pair.h"
|
||||||
|
|
||||||
namespace ruy {
|
namespace ruy {
|
||||||
namespace {
|
namespace {
|
||||||
@ -142,6 +145,115 @@ TEST(BlockMapTest, MakeBlockMapTuningTestRectangular) {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
int L1Distance(const SidePair<int>& a, const SidePair<int>& b) {
|
||||||
|
return std::abs(a[Side::kLhs] - b[Side::kLhs]) +
|
||||||
|
std::abs(a[Side::kRhs] - b[Side::kRhs]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetBlockByIndexSquareTest(int num_blocks_base_log2,
|
||||||
|
BlockMapTraversalOrder traversal_order) {
|
||||||
|
// Arbitrary, does not affect this test. 3 is just a typical value.
|
||||||
|
constexpr int kKernelSizeLog2 = 3;
|
||||||
|
|
||||||
|
const int size_log2 = num_blocks_base_log2 + kKernelSizeLog2;
|
||||||
|
BlockMap block_map;
|
||||||
|
block_map.thread_count = 1;
|
||||||
|
block_map.traversal_order = traversal_order;
|
||||||
|
block_map.num_blocks_base_log2 = num_blocks_base_log2;
|
||||||
|
for (Side side : {Side::kLhs, Side::kRhs}) {
|
||||||
|
block_map.dims[side] = 1 << size_log2;
|
||||||
|
block_map.rectangularness_log2[side] = 0;
|
||||||
|
block_map.kernel_dims[side] = 1 << kKernelSizeLog2;
|
||||||
|
block_map.small_block_dims[side] = block_map.kernel_dims[side];
|
||||||
|
block_map.large_blocks[side] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int num_blocks_per_side = 1 << num_blocks_base_log2;
|
||||||
|
const int num_blocks = num_blocks_per_side * num_blocks_per_side;
|
||||||
|
EXPECT_EQ(num_blocks, NumBlocks(block_map));
|
||||||
|
|
||||||
|
// Perform a full traversal of all blocks, as if computing a whole matrix
|
||||||
|
// multiplication.
|
||||||
|
//
|
||||||
|
// Used to record how many times each block was hit by the traversal.
|
||||||
|
std::vector<int> block_hit_counts(num_blocks);
|
||||||
|
// Here we guard an assumption that all traversal orders start at (0, 0).
|
||||||
|
SidePair<int> previous_block_coords(0, 0);
|
||||||
|
// Sum of L1 norm of the coordinate change at every step of the traversal.
|
||||||
|
std::int64_t total_l1_distance = 0;
|
||||||
|
// Number of jumps i.e. traversal steps with a L1 norm greater than 1.
|
||||||
|
int discontinuity_count = 0;
|
||||||
|
for (int block_index = 0; block_index < num_blocks; block_index++) {
|
||||||
|
SidePair<int> block_coords;
|
||||||
|
GetBlockByIndex(block_map, block_index, &block_coords);
|
||||||
|
++block_hit_counts[block_coords[Side::kLhs] +
|
||||||
|
num_blocks_per_side * block_coords[Side::kRhs]];
|
||||||
|
int distance = L1Distance(block_coords, previous_block_coords);
|
||||||
|
total_l1_distance += distance;
|
||||||
|
discontinuity_count += (distance > 1);
|
||||||
|
previous_block_coords = block_coords;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that each block was traversed exactly once.
|
||||||
|
for (int l = 0; l < num_blocks_per_side; l++) {
|
||||||
|
for (int r = 0; r < num_blocks_per_side; r++) {
|
||||||
|
EXPECT_EQ(block_hit_counts[l + num_blocks_per_side * r], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the discontinuity_count and total_l1_distance are as expected
|
||||||
|
// for the given traversal_order.
|
||||||
|
switch (traversal_order) {
|
||||||
|
case BlockMapTraversalOrder::kFractalHilbert:
|
||||||
|
// No discontinuity at all with this space-filling continuous curve!
|
||||||
|
EXPECT_EQ(discontinuity_count, 0);
|
||||||
|
// Therefore, total_l1_distance has to be the number of blocks minus one.
|
||||||
|
EXPECT_EQ(total_l1_distance, num_blocks - 1);
|
||||||
|
break;
|
||||||
|
case BlockMapTraversalOrder::kLinear:
|
||||||
|
EXPECT_EQ(discontinuity_count, num_blocks_per_side - 1);
|
||||||
|
EXPECT_EQ(total_l1_distance,
|
||||||
|
2 * num_blocks_per_side * (num_blocks_per_side - 1));
|
||||||
|
break;
|
||||||
|
case BlockMapTraversalOrder::kFractalZ:
|
||||||
|
EXPECT_EQ(discontinuity_count, num_blocks > 1 ? (num_blocks / 2 - 1) : 0);
|
||||||
|
EXPECT_EQ(total_l1_distance,
|
||||||
|
2 * num_blocks_per_side * (num_blocks_per_side - 1));
|
||||||
|
break;
|
||||||
|
case BlockMapTraversalOrder::kFractalU: {
|
||||||
|
if (num_blocks_base_log2 == 0) {
|
||||||
|
EXPECT_EQ(discontinuity_count, 0);
|
||||||
|
EXPECT_EQ(total_l1_distance, 0);
|
||||||
|
} else {
|
||||||
|
int expected_discontinuity_count = 0;
|
||||||
|
int expected_total_l1_distance = 3;
|
||||||
|
for (int i = 2; i <= num_blocks_base_log2; i++) {
|
||||||
|
expected_discontinuity_count = 4 * expected_discontinuity_count + 2;
|
||||||
|
expected_total_l1_distance =
|
||||||
|
4 * expected_total_l1_distance + (1 << (i + 1)) - 1;
|
||||||
|
}
|
||||||
|
EXPECT_EQ(discontinuity_count, expected_discontinuity_count);
|
||||||
|
EXPECT_EQ(total_l1_distance, expected_total_l1_distance);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(BlockMapTest, GetBlockByIndexSquare) {
|
||||||
|
for (int num_blocks_base_log2 = 0; num_blocks_base_log2 <= 10;
|
||||||
|
num_blocks_base_log2++) {
|
||||||
|
for (BlockMapTraversalOrder traversal_order :
|
||||||
|
{BlockMapTraversalOrder::kLinear, BlockMapTraversalOrder::kFractalZ,
|
||||||
|
BlockMapTraversalOrder::kFractalU,
|
||||||
|
BlockMapTraversalOrder::kFractalHilbert}) {
|
||||||
|
GetBlockByIndexSquareTest(num_blocks_base_log2, traversal_order);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace ruy
|
} // namespace ruy
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user