Add check to ensure SVDF input size is a multiple of two. This is required for optimized kernels for some platforms with alignment constraints.
PiperOrigin-RevId: 305312483 Change-Id: Idab2dd7fe7dbc3abbd5e24c3e016f4e40ee3c00c
This commit is contained in:
parent
2617230e61
commit
157209a692
@ -25,7 +25,8 @@ namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
static float svdf_input[] = {
|
||||
// naming as follows: svdf_<tensor name>_<input size>x<batch size>x<batch count>
|
||||
static float svdf_input_3x2x10[] = {
|
||||
0.12609188, -0.46347019, -0.89598465,
|
||||
0.35867718, 0.36897406, 0.73463392,
|
||||
|
||||
@ -57,7 +58,49 @@ static float svdf_input[] = {
|
||||
-0.6230064, 0.29819036, 1.06939757,
|
||||
};
|
||||
|
||||
static float svdf_golden_output_rank_1[] = {
|
||||
static float svdf_input_2x2x10[] = {
|
||||
0.12609188, -0.46347019, 0.35867718, 0.36897406,
|
||||
|
||||
0.14278367, -1.64410412, -0.57290924, 0.12729003,
|
||||
|
||||
0.49837467, 0.19278903, 0.17660543, 0.52949083,
|
||||
|
||||
-0.11186574, 0.13164264, -0.72674477, -0.5683046,
|
||||
|
||||
-0.68892461, 0.37783599, -0.63690937, 0.44483393,
|
||||
|
||||
-0.81299269, -0.86831826, -0.95760226, 1.82078898,
|
||||
|
||||
-1.45006323, -0.82251364, -1.65087092, -1.89238167,
|
||||
|
||||
0.03966608, -0.24936394, 2.06740379, -1.51439476,
|
||||
|
||||
0.11771342, -0.23761693, 0.31088525, -1.55601168,
|
||||
|
||||
-0.89477462, 1.67204106, -0.6230064, 0.29819036,
|
||||
};
|
||||
|
||||
static float svdf_golden_output_2x2x30_rank_1[] = {
|
||||
-0.044205, -0.013757, 0.050369, -0.018447, 0.073010, 0.025142, -0.021154,
|
||||
0.013551, -0.209613, -0.062421, 0.150209, -0.108334, 0.028256, -0.006950,
|
||||
-0.030885, 0.009603, -0.076800, -0.037075, -0.087198, -0.155183, 0.091069,
|
||||
0.098446, -0.016083, 0.106475, -0.082123, -0.162238, -0.084434, -0.141074,
|
||||
-0.029340, -0.090685, 0.053302, -0.030604, -0.201440, 0.088424, 0.139877,
|
||||
0.012416, -0.113212, 0.103893, -0.100842, 0.122780, -0.166632, -0.116705,
|
||||
0.175298, -0.047163, 0.313077, -0.166485, -0.285860, 0.129069, -0.625911,
|
||||
0.046134, 0.138081, -0.129581, -0.521455, -0.061579, 0.230289, 0.114963,
|
||||
-0.216693, -0.161643, -0.179177, -0.052599, -0.213239, 0.029502, 0.260858,
|
||||
0.275045, -0.213689, -0.323608, -0.285635, -0.317687, -0.324092, -0.317972,
|
||||
-0.208450, -0.462504, -0.255126, -0.218576, -0.041528, 0.179421, -0.440583,
|
||||
0.072127, -0.284136, 0.241570, -0.582490, 0.253004, 0.156972, 0.132266,
|
||||
-0.175340, -0.269495, -0.005782, -0.125683, -0.461215, 0.257511, 0.340125,
|
||||
0.140569, -0.866940, -0.075565, 0.484422, 0.018665, 0.059312, -0.006378,
|
||||
-0.465532, 0.291374, -0.182749, 0.232608, 0.479811, 0.541274, 0.286369,
|
||||
-0.188810, -0.011561, 0.022947, 0.451862, 0.214710, -0.367849, -0.722380,
|
||||
-0.072298, -0.270524, -0.083401, -0.038342, -0.035884, -0.565247, -0.427794,
|
||||
0.015071};
|
||||
|
||||
static float svdf_golden_output_3x2x10_rank_1[] = {
|
||||
0.014899, -0.0517661, -0.143725, -0.00271883,
|
||||
-0.03004015, 0.09565311, 0.1587342, 0.00784263,
|
||||
|
||||
@ -89,7 +132,7 @@ static float svdf_golden_output_rank_1[] = {
|
||||
0.17012937, -0.34447709, 0.38505614, -0.28158101,
|
||||
};
|
||||
|
||||
static float svdf_golden_output_rank_2[] = {
|
||||
static float svdf_golden_output_3x2x10_rank_2[] = {
|
||||
-0.09623547, -0.10193135, 0.11083051, -0.0347917,
|
||||
0.1141196, 0.12965347, -0.12652366, 0.01007236,
|
||||
|
||||
@ -160,7 +203,12 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
TfLiteStatus prepare_status = registration->prepare(&context, &node);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, prepare_status);
|
||||
// Abort early to make it clear prepare failed.
|
||||
if (prepare_status != kTfLiteOk) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
|
||||
@ -171,14 +219,18 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
|
||||
float* input_batch_end = input_batch_start + input_size * batch_size;
|
||||
|
||||
PopulateFloatTensor(&tensors[0], input_batch_start, input_batch_end);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
TfLiteStatus status = registration->invoke(&context, &node);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, status);
|
||||
|
||||
int output_idx = 0;
|
||||
int golden_idx = i * batch_size * num_units;
|
||||
for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx],
|
||||
tolerance);
|
||||
output_idx++;
|
||||
// Only validate outputs when invoke has succeeded.
|
||||
if (status == kTfLiteOk) {
|
||||
int output_idx = 0;
|
||||
int golden_idx = i * batch_size * num_units;
|
||||
for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
|
||||
TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx],
|
||||
tolerance);
|
||||
output_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -226,7 +278,12 @@ void ValidateIntegerSVDFGoldens(const int batch_size, const int num_units,
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
TfLiteStatus prepare_status = registration->prepare(&context, &node);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, prepare_status);
|
||||
// Abort early to make it clear prepare failed.
|
||||
if (prepare_status != kTfLiteOk) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
|
||||
@ -398,7 +455,7 @@ inline void TestIntegerSVDF(
|
||||
|
||||
TF_LITE_MICRO_TESTS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_TEST(BlackBoxTestRank1) {
|
||||
TF_LITE_MICRO_TEST(SvdfFloatInputSize3Rank1ShouldMatchGolden) {
|
||||
constexpr int batch_size = 2;
|
||||
constexpr int num_units = 4;
|
||||
constexpr int input_size = 3;
|
||||
@ -440,12 +497,12 @@ TF_LITE_MICRO_TEST(BlackBoxTestRank1) {
|
||||
tflite::testing::TestSVDF(
|
||||
batch_size, num_units, input_size, memory_size, rank, input_data,
|
||||
weights_feature_data, weights_time_data, activation_state_data,
|
||||
scratch_data, output_data, tflite::testing::svdf_input,
|
||||
sizeof(tflite::testing::svdf_input),
|
||||
tflite::testing::svdf_golden_output_rank_1);
|
||||
scratch_data, output_data, tflite::testing::svdf_input_3x2x10,
|
||||
sizeof(tflite::testing::svdf_input_3x2x10),
|
||||
tflite::testing::svdf_golden_output_3x2x10_rank_1);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(BlackBoxTestRank2) {
|
||||
TF_LITE_MICRO_TEST(SvdfFloatInputSize3Rank2ShouldMatchGolden) {
|
||||
constexpr int batch_size = 2;
|
||||
constexpr int num_units = 4;
|
||||
constexpr int input_size = 3;
|
||||
@ -500,15 +557,75 @@ TF_LITE_MICRO_TEST(BlackBoxTestRank2) {
|
||||
tflite::testing::TestSVDF(
|
||||
batch_size, num_units, input_size, memory_size, rank, input_data,
|
||||
weights_feature_data, weights_time_data, activation_state_data,
|
||||
scratch_data, output_data, tflite::testing::svdf_input,
|
||||
sizeof(tflite::testing::svdf_input),
|
||||
tflite::testing::svdf_golden_output_rank_2);
|
||||
scratch_data, output_data, tflite::testing::svdf_input_3x2x10,
|
||||
sizeof(tflite::testing::svdf_input_3x2x10),
|
||||
tflite::testing::svdf_golden_output_3x2x10_rank_2);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(BlackBoxTestIntegerRank1) {
|
||||
TF_LITE_MICRO_TEST(SvdfFloatInputSize2Rank1ShouldMatchGolden) {
|
||||
constexpr int batch_size = 2;
|
||||
constexpr int num_units = 4;
|
||||
constexpr int input_size = 3;
|
||||
constexpr int input_size = 2;
|
||||
constexpr int memory_size = 10;
|
||||
constexpr int rank = 2;
|
||||
constexpr int num_filters = num_units * rank;
|
||||
|
||||
float weights_feature_data[] = {
|
||||
-0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199,
|
||||
0.15785322, 0.27901134, 0.3905206, 0.21931258, -0.36137494,
|
||||
-0.10640851, 0.31053296, -0.36118156, -0.0976817, -0.36916667,
|
||||
0.22197971, 0.15294972, 0.38031587, 0.27557442, 0.39635518,
|
||||
-0.21580373, -0.06634006, -0.02702999, 0.27072677};
|
||||
|
||||
float weights_time_data[] = {
|
||||
-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
|
||||
0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
|
||||
|
||||
0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
|
||||
-0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
|
||||
|
||||
-0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
|
||||
0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
|
||||
|
||||
-0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
|
||||
-0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
|
||||
|
||||
-0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
|
||||
0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
|
||||
|
||||
-0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
|
||||
0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
|
||||
|
||||
-0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
|
||||
-0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
|
||||
|
||||
0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
|
||||
0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763};
|
||||
|
||||
const int input_size_dims_count = batch_size * input_size;
|
||||
float input_data[input_size_dims_count];
|
||||
|
||||
const int activation_state_dims_count =
|
||||
batch_size * memory_size * num_filters;
|
||||
float activation_state_data[activation_state_dims_count];
|
||||
const int scratch_dims_count = batch_size * num_filters;
|
||||
float scratch_data[scratch_dims_count];
|
||||
|
||||
const int output_dims_count = batch_size * num_units;
|
||||
float output_data[output_dims_count];
|
||||
|
||||
tflite::testing::TestSVDF(
|
||||
batch_size, num_units, input_size, memory_size, rank, input_data,
|
||||
weights_feature_data, weights_time_data, activation_state_data,
|
||||
scratch_data, output_data, tflite::testing::svdf_input_2x2x10,
|
||||
sizeof(tflite::testing::svdf_input_2x2x10),
|
||||
tflite::testing::svdf_golden_output_2x2x30_rank_1);
|
||||
}
|
||||
|
||||
TF_LITE_MICRO_TEST(SvdfIntegerInputSize2Rank1ShouldMatchGolden) {
|
||||
constexpr int batch_size = 2;
|
||||
constexpr int num_units = 4;
|
||||
constexpr int input_size = 2;
|
||||
constexpr int memory_size = 10;
|
||||
constexpr int rank = 1;
|
||||
constexpr int num_filters = num_units * rank;
|
||||
@ -537,13 +654,17 @@ TF_LITE_MICRO_TEST(BlackBoxTestIntegerRank1) {
|
||||
};
|
||||
|
||||
int8_t expected_output[] = {
|
||||
-9, 24, 31, 1, -10, 10, -3, 0, 2, 4, -44, -7, -10, 32,
|
||||
52, 1, 12, -17, 9, -8, 7, 16, -11, -8, -26, 29, 28, 16,
|
||||
-23, 26, 30, -6, -8, -25, -86, -5, -44, 59, 81, 15, 62, -16,
|
||||
-37, 3, 27, 14, 34, -10, 1, 24, -25, 23, 31, 61, 67, 11,
|
||||
-64, -65, -128, -25, -53, 59, 127, 20, 20, -29, -20, -15, -28, 0,
|
||||
8, -27, 54, 61, -67, 38, 38, 64, 115, 0, -44, -75, -128, -20,
|
||||
-19, 93, 101, 35, -5, -56, 30, -18, -40, -9, -8, -31,
|
||||
-9, 9, 18, -2, -6, 8, 13, -2, 2, -16, 2, 5, 2, -7,
|
||||
0, 3, 7, 0, 5, 7, -11, 18, 30, 0, -9, -24, 14, -12,
|
||||
-1, 1, -20, 2, -19, -20, 20, -13, -1, -10, 50, 4, 26, 32,
|
||||
2, -12, -12, 11, -10, -29, 50, -61, 4, 15, 19, -39, 13, 19,
|
||||
-56, 49, 12, 13, 29, -3, -4, -22, -76, -29, -14, 38, -30, -30,
|
||||
27, 0, 39, 16, 49, -14, -18, 28, -35, 11, 45, 0, -13, -61,
|
||||
34, -80, 37, 26, 15, -23, 12, 15, 18, 83, -28, -21, -27, -48,
|
||||
17, 2, -113, -52, 9, 48, -4, -1, 15, -7, 39, 16, 49, -14,
|
||||
-18, 28, -35, 11, 45, 0, -13, -61, 34, -80, 37, 26, 15, -23,
|
||||
12, 15, 18, 83, -28, -21, -27, -48, 17, 2, -113, -52, 9, 48,
|
||||
-4, -1, 15, -7,
|
||||
};
|
||||
|
||||
const int input_size_dims_count = batch_size * input_size;
|
||||
|
@ -385,6 +385,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int rank = params->rank;
|
||||
const int input_size = input->dims->data[1];
|
||||
const int batch_size = input->dims->data[0];
|
||||
// Ensure the input size is a multiple of two. This is necessary since
|
||||
// optimized kernels access the memory in chunks of two, and all accesses
|
||||
// must be aligned to 16 bits.
|
||||
// TODO(b/153202598): Remove when padding is allowed in TFLite tensors.
|
||||
TF_LITE_ENSURE_EQ(context, input_size % 2, 0);
|
||||
|
||||
const int num_filters = weights_feature->dims->data[0];
|
||||
TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
|
||||
const int num_units = num_filters / rank;
|
||||
|
@ -282,6 +282,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int rank = params->rank;
|
||||
const int input_size = input->dims->data[1];
|
||||
const int batch_size = input->dims->data[0];
|
||||
// Ensure the input size is a multiple of two. This is necessary since
|
||||
// optimized kernels access the memory in chunks of two, and all accesses
|
||||
// must be aligned to 16 bits.
|
||||
// TODO(b/153202598): Remove when padding is allowed in TFLite tensors.
|
||||
TF_LITE_ENSURE_EQ(context, input_size % 2, 0);
|
||||
|
||||
const int num_filters = weights_feature->dims->data[0];
|
||||
TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
|
||||
const int num_units = num_filters / rank;
|
||||
|
Loading…
Reference in New Issue
Block a user