Refactor Elementwise Test.

PiperOrigin-RevId: 318354380
Change-Id: I4129008725bc447e304028cb337d08ce1e12d6b7
This commit is contained in:
Nat Jeffries 2020-06-25 14:57:27 -07:00 committed by TensorFlower Gardener
parent 71172f7322
commit e25272e743

View File

@ -23,13 +23,12 @@ namespace tflite {
namespace testing {
void TestElementwiseFloat(tflite::BuiltinOperator op,
std::initializer_list<int> input_dims_data,
std::initializer_list<float> input_data,
std::initializer_list<int> output_dims_data,
std::initializer_list<float> expected_output_data,
const int* input_dims_data, const float* input_data,
const int* output_dims_data,
const float* expected_output_data,
float* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
const int output_dims_count = ElementCount(*output_dims);
constexpr int input_size = 1;
@ -54,9 +53,9 @@ void TestElementwiseFloat(tflite::BuiltinOperator op,
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {1, 0};
static int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
static int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
@ -77,19 +76,15 @@ void TestElementwiseFloat(tflite::BuiltinOperator op,
registration->free(&context, user_data);
}
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
1e-5f);
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], 1e-5f);
}
}
void TestElementwiseBool(tflite::BuiltinOperator op,
std::initializer_list<int> input_dims_data,
std::initializer_list<bool> input_data,
std::initializer_list<int> output_dims_data,
std::initializer_list<bool> expected_output_data,
bool* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
void TestElementwiseBool(tflite::BuiltinOperator op, const int* input_dims_data,
const bool* input_data, const int* output_dims_data,
const bool* expected_output_data, bool* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
const int output_dims_count = ElementCount(*output_dims);
constexpr int input_size = 1;
@ -115,9 +110,9 @@ void TestElementwiseBool(tflite::BuiltinOperator op,
user_data = registration->init(&context, nullptr, 0);
}
int inputs_array_data[] = {1, 0};
const int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
const int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
@ -138,7 +133,7 @@ void TestElementwiseBool(tflite::BuiltinOperator op,
registration->free(&context, user_data);
}
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
}
}
@ -149,110 +144,83 @@ TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(Abs) {
constexpr int output_dims_count = 4;
const int shape[] = {2, 2, 2};
const float input[] = {0.01, -0.01, 10, -10};
const float golden[] = {0.01, 0.01, 10, 10};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(
tflite::BuiltinOperator_ABS, // ABS operator
{2, 2, 2}, // Input shape
{0.01f, -0.01f, 10.0f, -10.0f}, // Input values
{2, 2, 2}, // Output shape
{0.01f, 0.01f, 10.0f, 10.0f}, // Output values
output_data);
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_ABS, shape,
input, shape, golden, output_data);
}
TF_LITE_MICRO_TEST(Sin) {
constexpr int output_dims_count = 4;
const int shape[] = {2, 2, 2};
const float input[] = {0, 3.1415926, -3.1415926, 1};
const float golden[] = {0, 0, 0, 0.84147};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(
tflite::BuiltinOperator_SIN, // SIN operator
{2, 2, 2}, // Input shape
{0.0f, 3.1415926f, -3.1415926f, 1.0f}, // Input values
{2, 2, 2}, // Output shape
{0.0f, 0.0f, 0.0f, 0.84147f}, // Output values
output_data);
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SIN, shape,
input, shape, golden, output_data);
}
TF_LITE_MICRO_TEST(Cos) {
constexpr int output_dims_count = 4;
const int shape[] = {2, 2, 2};
const float input[] = {0, 3.1415926, -3.1415926, 1};
const float golden[] = {1, -1, -1, 0.54030};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(
tflite::BuiltinOperator_COS, // COS operator
{2, 2, 2}, // Input shape
{0.0f, 3.1415926f, -3.1415926f, 1.0f}, // Input values
{2, 2, 2}, // Output shape
{1.0f, -1.0f, -1.0f, 0.54030f}, // Output values
output_data);
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_COS, shape,
input, shape, golden, output_data);
}
TF_LITE_MICRO_TEST(Log) {
constexpr int output_dims_count = 4;
const int shape[] = {2, 2, 2};
const float input[] = {1, 2.7182818, 0.5, 2};
const float golden[] = {0, 1, -0.6931472, 0.6931472};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(
tflite::BuiltinOperator_LOG, // LOG operator
{2, 2, 2}, // Input shape
{1.0f, 2.7182818f, 0.5f, 2.0f}, // Input values
{2, 2, 2}, // Output shape
{0.0f, 1.0f, -0.6931472f, 0.6931472f}, // Output values
output_data);
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_LOG, shape,
input, shape, golden, output_data);
}
TF_LITE_MICRO_TEST(Sqrt) {
constexpr int output_dims_count = 4;
const int shape[] = {2, 2, 2};
const float input[] = {0, 1, 2, 4};
const float golden[] = {0, 1, 1.41421, 2};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(
tflite::BuiltinOperator_SQRT, // SQRT operator
{2, 2, 2}, // Input shape
{0.0f, 1.0f, 2.0f, 4.0f}, // Input values
{2, 2, 2}, // Output shape
{0.0f, 1.0f, 1.41421f, 2.0f}, // Output values
output_data);
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SQRT, shape,
input, shape, golden, output_data);
}
TF_LITE_MICRO_TEST(Rsqrt) {
constexpr int output_dims_count = 4;
const int shape[] = {2, 2, 2};
const float input[] = {1, 2, 4, 9};
const float golden[] = {1, 0.7071, 0.5, 0.33333};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(
tflite::BuiltinOperator_RSQRT, // RSQRT operator
{2, 2, 2}, // Input shape
{1.0f, 2.0f, 4.0f, 9.0f}, // Input values
{2, 2, 2}, // Output shape
{1.0f, 0.7071f, 0.5f, 0.33333f}, // Output values
output_data);
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_RSQRT, shape,
input, shape, golden, output_data);
}
TF_LITE_MICRO_TEST(Square) {
constexpr int output_dims_count = 4;
const int shape[] = {2, 2, 2};
const float input[] = {1, 2, 0.5, -3.0};
const float golden[] = {1, 4.0, 0.25, 9.0};
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(
tflite::BuiltinOperator_SQUARE, // SQARE operator
{2, 2, 2}, // Input shape
{1.0f, 2.0f, 0.5f, -3.0f}, // Input values
{2, 2, 2}, // Output shape
{1.0f, 4.0f, 0.25f, 9.0f}, // Output values
output_data);
tflite::testing::TestElementwiseFloat(tflite::BuiltinOperator_SQUARE, shape,
input, shape, golden, output_data);
}
TF_LITE_MICRO_TEST(LogicalNot) {
constexpr int output_dims_count = 4;
const int shape[] = {2, 2, 2};
const bool input[] = {true, false, false, true};
const bool golden[] = {false, true, true, false};
bool output_data[output_dims_count];
tflite::testing::TestElementwiseBool(
tflite::BuiltinOperator_LOGICAL_NOT, // Logical NOT operator
{2, 2, 2}, // Input shape
{true, false, false, true}, // Input values
{2, 2, 2}, // Output shape
{false, true, true, false}, // Output values
output_data);
}
TF_LITE_MICRO_TEST(TANH) {
constexpr int output_dims_count = 4;
float output_data[output_dims_count];
tflite::testing::TestElementwiseFloat(
tflite::BuiltinOperator_TANH, // TANH operator
{2, 2, 2}, // Input shape
{0.0f, 50.0f, 0.5f, -50.0f}, // Input values
{2, 2, 2}, // Output shape
{0.0f, 1.0f, 0.462117f, -1.0f}, // Output values
output_data);
tflite::testing::TestElementwiseBool(tflite::BuiltinOperator_LOGICAL_NOT,
shape, input, shape, golden,
output_data);
}
TF_LITE_MICRO_TESTS_END