Fix msvc execution of TFLite C/Kernel tests

PiperOrigin-RevId: 307488714
Change-Id: Ia68e3479e5c8a8f6ac2222264638d2980e92127c
This commit is contained in:
Jared Duke 2020-04-20 15:27:38 -07:00 committed by TensorFlower Gardener
parent de6c0ec676
commit 5a674e06a9
18 changed files with 107 additions and 88 deletions

View File

@ -2065,7 +2065,13 @@ cc_library(
"//tensorflow/core/platform/default:logging.h", "//tensorflow/core/platform/default:logging.h",
], ],
copts = tf_copts(), copts = tf_copts(),
linkopts = ["-ldl"], linkopts = select({
"//tensorflow:freebsd": [],
"//tensorflow:windows": [],
"//conditions:default": [
"-ldl",
],
}),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":platform_base", ":platform_base",

View File

@ -87,6 +87,7 @@ cc_test(
name = "c_api_test", name = "c_api_test",
size = "small", size = "small",
srcs = ["c_api_test.cc"], srcs = ["c_api_test.cc"],
copts = tflite_copts(),
data = [ data = [
"//tensorflow/lite:testdata/add.bin", "//tensorflow/lite:testdata/add.bin",
"//tensorflow/lite:testdata/add_quantized.bin", "//tensorflow/lite:testdata/add_quantized.bin",
@ -103,6 +104,7 @@ cc_test(
name = "c_api_experimental_test", name = "c_api_experimental_test",
size = "small", size = "small",
srcs = ["c_api_experimental_test.cc"], srcs = ["c_api_experimental_test.cc"],
copts = tflite_copts(),
data = ["//tensorflow/lite:testdata/add.bin"], data = ["//tensorflow/lite:testdata/add.bin"],
deps = [ deps = [
":c_api", ":c_api",

View File

@ -25,11 +25,10 @@ namespace {
TfLiteRegistration* GetDummyRegistration() { TfLiteRegistration* GetDummyRegistration() {
static TfLiteRegistration registration = { static TfLiteRegistration registration = {
.init = nullptr, /*init=*/nullptr,
.free = nullptr, /*free=*/nullptr,
.prepare = nullptr, /*prepare=*/nullptr,
.invoke = [](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }, /*invoke=*/[](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }};
};
return &registration; return &registration;
} }

View File

@ -26,8 +26,8 @@ namespace {
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
enum class TestType { enum class TestType {
CONST = 0, kConst = 0,
DYNAMIC = 1, kDynamic = 1,
}; };
template <typename InputType> template <typename InputType>
@ -36,7 +36,7 @@ class ExpandDimsOpModel : public SingleOpModel {
ExpandDimsOpModel(int axis, std::initializer_list<int> input_shape, ExpandDimsOpModel(int axis, std::initializer_list<int> input_shape,
std::initializer_list<InputType> input_data, std::initializer_list<InputType> input_data,
TestType input_tensor_types) { TestType input_tensor_types) {
if (input_tensor_types == TestType::DYNAMIC) { if (input_tensor_types == TestType::kDynamic) {
input_ = AddInput(GetTensorType<InputType>()); input_ = AddInput(GetTensorType<InputType>());
axis_ = AddInput(TensorType_INT32); axis_ = AddInput(TensorType_INT32);
} else { } else {
@ -50,7 +50,7 @@ class ExpandDimsOpModel : public SingleOpModel {
BuildInterpreter({input_shape, {1}}); BuildInterpreter({input_shape, {1}});
if (input_tensor_types == TestType::DYNAMIC) { if (input_tensor_types == TestType::kDynamic) {
PopulateTensor<InputType>(input_, input_data); PopulateTensor<InputType>(input_, input_data);
PopulateTensor<int32_t>(axis_, {axis}); PopulateTensor<int32_t>(axis_, {axis});
} }
@ -69,18 +69,18 @@ class ExpandDimsOpModel : public SingleOpModel {
template <typename T> template <typename T>
class ExpandDimsOpTest : public ::testing::Test { class ExpandDimsOpTest : public ::testing::Test {
public: public:
static std::vector<TestType> _range_; static std::vector<TestType> range_;
}; };
template <> template <>
std::vector<TestType> ExpandDimsOpTest<TestType>::_range_{TestType::CONST, std::vector<TestType> ExpandDimsOpTest<TestType>::range_{TestType::kConst,
TestType::DYNAMIC}; TestType::kDynamic};
using DataTypes = ::testing::Types<float, int8_t, int16_t, int32_t>; using DataTypes = ::testing::Types<float, int8_t, int16_t, int32_t>;
TYPED_TEST_SUITE(ExpandDimsOpTest, DataTypes); TYPED_TEST_SUITE(ExpandDimsOpTest, DataTypes);
TYPED_TEST(ExpandDimsOpTest, PositiveAxis) { TYPED_TEST(ExpandDimsOpTest, PositiveAxis) {
for (TestType test_type : ExpandDimsOpTest<TestType>::_range_) { for (TestType test_type : ExpandDimsOpTest<TestType>::range_) {
std::initializer_list<TypeParam> values = {-1, 1, -2, 2}; std::initializer_list<TypeParam> values = {-1, 1, -2, 2};
ExpandDimsOpModel<TypeParam> axis_0(0, {2, 2}, values, test_type); ExpandDimsOpModel<TypeParam> axis_0(0, {2, 2}, values, test_type);
@ -101,7 +101,7 @@ TYPED_TEST(ExpandDimsOpTest, PositiveAxis) {
} }
TYPED_TEST(ExpandDimsOpTest, NegativeAxis) { TYPED_TEST(ExpandDimsOpTest, NegativeAxis) {
for (TestType test_type : ExpandDimsOpTest<TestType>::_range_) { for (TestType test_type : ExpandDimsOpTest<TestType>::range_) {
std::initializer_list<TypeParam> values = {-1, 1, -2, 2}; std::initializer_list<TypeParam> values = {-1, 1, -2, 2};
ExpandDimsOpModel<TypeParam> m(-1, {2, 2}, values, test_type); ExpandDimsOpModel<TypeParam> m(-1, {2, 2}, values, test_type);
@ -115,7 +115,7 @@ TEST(ExpandDimsOpTest, StrTensor) {
std::initializer_list<std::string> values = {"abc", "de", "fghi"}; std::initializer_list<std::string> values = {"abc", "de", "fghi"};
// this test will fail on TestType::CONST // this test will fail on TestType::CONST
ExpandDimsOpModel<std::string> m(0, {3}, values, TestType::DYNAMIC); ExpandDimsOpModel<std::string> m(0, {3}, values, TestType::kDynamic);
m.Invoke(); m.Invoke();
EXPECT_THAT(m.GetValues(), ElementsAreArray(values)); EXPECT_THAT(m.GetValues(), ElementsAreArray(values));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));

View File

@ -713,11 +713,13 @@ void SimpleTestQuantizedInt16OutputCase(
/*activation_func=*/ActivationFunctionType_NONE, weights_format); /*activation_func=*/ActivationFunctionType_NONE, weights_format);
std::mt19937 random_engine; std::mt19937 random_engine;
std::uniform_int_distribution<uint8_t> weights_dist; // Some compilers don't support uint8_t for uniform_distribution.
std::uniform_int_distribution<uint32_t> weights_dist(
0, std::numeric_limits<uint8_t>::max());
std::vector<float> weights_data(input_depth * output_depth); std::vector<float> weights_data(input_depth * output_depth);
for (auto& w : weights_data) { for (auto& w : weights_data) {
uint8_t q = weights_dist(random_engine); uint8_t q = static_cast<uint8_t>(weights_dist(random_engine));
w = (q - kWeightsZeroPoint) * kWeightsScale; w = (q - kWeightsZeroPoint) * kWeightsScale;
} }
@ -739,10 +741,12 @@ void SimpleTestQuantizedInt16OutputCase(
LOG(FATAL) << "Unhandled weights format"; LOG(FATAL) << "Unhandled weights format";
} }
std::uniform_int_distribution<uint8_t> input_dist; // Some compilers don't support uint8_t for uniform_distribution.
std::uniform_int_distribution<uint32_t> input_dist(
0, std::numeric_limits<uint8_t>::max());
std::vector<float> input_data(input_depth * batches); std::vector<float> input_data(input_depth * batches);
for (auto& i : input_data) { for (auto& i : input_data) {
uint8_t q = input_dist(random_engine); uint8_t q = static_cast<uint8_t>(input_dist(random_engine));
i = (q - kInputZeroPoint) * kInputScale; i = (q - kInputZeroPoint) * kInputScale;
} }

View File

@ -105,6 +105,7 @@ float ExponentialRandomPositiveFloat(float percentile, float percentile_val,
void FillRandom(std::vector<float>* vec, float min, float max) { void FillRandom(std::vector<float>* vec, float min, float max) {
std::uniform_real_distribution<float> dist(min, max); std::uniform_real_distribution<float> dist(min, max);
// TODO(b/154540105): use std::ref to avoid copying the random engine.
auto gen = std::bind(dist, RandomEngine()); auto gen = std::bind(dist, RandomEngine());
std::generate(std::begin(*vec), std::end(*vec), gen); std::generate(std::begin(*vec), std::end(*vec), gen);
} }

View File

@ -59,12 +59,22 @@ float ExponentialRandomPositiveFloat(float percentile, float percentile_val,
// Fills a vector with random floats between |min| and |max|. // Fills a vector with random floats between |min| and |max|.
void FillRandom(std::vector<float>* vec, float min, float max); void FillRandom(std::vector<float>* vec, float min, float max);
template <typename T>
void FillRandom(typename std::vector<T>::iterator begin_it,
typename std::vector<T>::iterator end_it, T min, T max) {
// Workaround for compilers that don't support (u)int8_t uniform_distribution.
typedef typename std::conditional<sizeof(T) >= sizeof(int16_t), T,
std::int16_t>::type rand_type;
std::uniform_int_distribution<rand_type> dist(min, max);
// TODO(b/154540105): use std::ref to avoid copying the random engine.
auto gen = std::bind(dist, RandomEngine());
std::generate(begin_it, end_it, [&gen] { return static_cast<T>(gen()); });
}
// Fills a vector with random numbers between |min| and |max|. // Fills a vector with random numbers between |min| and |max|.
template <typename T> template <typename T>
void FillRandom(std::vector<T>* vec, T min, T max) { void FillRandom(std::vector<T>* vec, T min, T max) {
std::uniform_int_distribution<T> dist(min, max); return FillRandom(std::begin(*vec), std::end(*vec), min, max);
auto gen = std::bind(dist, RandomEngine());
std::generate(std::begin(*vec), std::end(*vec), gen);
} }
// Fills a vector with random numbers. // Fills a vector with random numbers.
@ -73,14 +83,6 @@ void FillRandom(std::vector<T>* vec) {
FillRandom(vec, std::numeric_limits<T>::min(), std::numeric_limits<T>::max()); FillRandom(vec, std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
} }
template <typename T>
void FillRandom(typename std::vector<T>::iterator begin_it,
typename std::vector<T>::iterator end_it, T min, T max) {
std::uniform_int_distribution<T> dist(min, max);
auto gen = std::bind(dist, RandomEngine());
std::generate(begin_it, end_it, gen);
}
// Fill with a "skyscraper" pattern, in which there is a central section (across // Fill with a "skyscraper" pattern, in which there is a central section (across
// the depth) with higher values than the surround. // the depth) with higher values than the surround.
template <typename T> template <typename T>

View File

@ -25,8 +25,8 @@ using ::testing::ElementsAreArray;
using uint8 = std::uint8_t; using uint8 = std::uint8_t;
enum class TestType { enum class TestType {
CONST = 0, kConst = 0,
DYNAMIC = 1, kDynamic = 1,
}; };
class ResizeBilinearOpModel : public SingleOpModel { class ResizeBilinearOpModel : public SingleOpModel {
@ -35,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel {
std::initializer_list<int> size_data, std::initializer_list<int> size_data,
TestType test_type, TestType test_type,
bool half_pixel_centers = false) { bool half_pixel_centers = false) {
bool const_size = (test_type == TestType::CONST); bool const_size = (test_type == TestType::kConst);
input_ = AddInput(input); input_ = AddInput(input);
if (const_size) { if (const_size) {
@ -332,7 +332,7 @@ TEST_P(ResizeBilinearOpTest, ThreeDimensionalResizeInt8) {
} }
INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpTest, ResizeBilinearOpTest, INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpTest, ResizeBilinearOpTest,
testing::Values(TestType::CONST, TestType::DYNAMIC)); testing::Values(TestType::kConst, TestType::kDynamic));
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -25,8 +25,8 @@ using ::testing::ElementsAreArray;
using uint8 = std::uint8_t; using uint8 = std::uint8_t;
enum class TestType { enum class TestType {
CONST = 0, kConst = 0,
DYNAMIC = 1, kDynamic = 1,
}; };
class ResizeNearestNeighborOpModel : public SingleOpModel { class ResizeNearestNeighborOpModel : public SingleOpModel {
@ -34,7 +34,7 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
explicit ResizeNearestNeighborOpModel(const TensorData& input, explicit ResizeNearestNeighborOpModel(const TensorData& input,
std::initializer_list<int> size_data, std::initializer_list<int> size_data,
TestType test_type) { TestType test_type) {
bool const_size = (test_type == TestType::CONST); bool const_size = (test_type == TestType::kConst);
input_ = AddInput(input); input_ = AddInput(input);
if (const_size) { if (const_size) {
@ -264,7 +264,7 @@ TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) {
} }
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTest, INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTest,
ResizeNearestNeighborOpTest, ResizeNearestNeighborOpTest,
testing::Values(TestType::CONST, TestType::DYNAMIC)); testing::Values(TestType::kConst, TestType::kDynamic));
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -24,8 +24,8 @@ namespace {
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
enum class TestType { enum class TestType {
CONST = 0, kConst = 0,
DYNAMIC = 1, kDynamic = 1,
}; };
template <typename input_type, typename index_type> template <typename input_type, typename index_type>
@ -39,7 +39,7 @@ class SliceOpModel : public SingleOpModel {
TensorType tensor_index_type, TensorType tensor_input_type, TensorType tensor_index_type, TensorType tensor_input_type,
TestType input_tensor_types) { TestType input_tensor_types) {
input_ = AddInput(tensor_input_type); input_ = AddInput(tensor_input_type);
if (input_tensor_types == TestType::DYNAMIC) { if (input_tensor_types == TestType::kDynamic) {
begin_ = AddInput(tensor_index_type); begin_ = AddInput(tensor_index_type);
size_ = AddInput(tensor_index_type); size_ = AddInput(tensor_index_type);
} else { } else {
@ -52,7 +52,7 @@ class SliceOpModel : public SingleOpModel {
CreateSliceOptions(builder_).Union()); CreateSliceOptions(builder_).Union());
BuildInterpreter({input_shape, begin_shape, size_shape}); BuildInterpreter({input_shape, begin_shape, size_shape});
if (input_tensor_types == TestType::DYNAMIC) { if (input_tensor_types == TestType::kDynamic) {
PopulateTensor<index_type>(begin_, begin_data); PopulateTensor<index_type>(begin_, begin_data);
PopulateTensor<index_type>(size_, size_data); PopulateTensor<index_type>(size_, size_data);
} }
@ -239,7 +239,8 @@ TEST_P(SliceOpTest, SliceString) {
} }
INSTANTIATE_TEST_SUITE_P(SliceOpTest, SliceOpTest, INSTANTIATE_TEST_SUITE_P(SliceOpTest, SliceOpTest,
::testing::Values(TestType::CONST, TestType::DYNAMIC)); ::testing::Values(TestType::kConst,
TestType::kDynamic));
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -26,8 +26,8 @@ using ::testing::ElementsAreArray;
constexpr int kAxisIsATensor = -1000; constexpr int kAxisIsATensor = -1000;
enum class TestType { enum class TestType {
CONST = 0, kConst = 0,
DYNAMIC = 1, kDynamic = 1,
}; };
class SplitOpModel : public SingleOpModel { class SplitOpModel : public SingleOpModel {
@ -83,7 +83,7 @@ void Check(TestType test_type, int axis, int num_splits,
<< " and num_splits=" << num_splits; << " and num_splits=" << num_splits;
return ss.str(); return ss.str();
}; };
if (test_type == TestType::DYNAMIC) { if (test_type == TestType::kDynamic) {
SplitOpModel m({type, input_shape}, num_splits); SplitOpModel m({type, input_shape}, num_splits);
m.SetInput(input_data); m.SetInput(input_data);
m.SetAxis(axis); m.SetAxis(axis);
@ -110,18 +110,18 @@ void Check(TestType test_type, int axis, int num_splits,
template <typename T> template <typename T>
class SplitOpTest : public ::testing::Test { class SplitOpTest : public ::testing::Test {
public: public:
static std::vector<TestType> _range_; static std::vector<TestType> range_;
}; };
template <> template <>
std::vector<TestType> SplitOpTest<TestType>::_range_{TestType::CONST, std::vector<TestType> SplitOpTest<TestType>::range_{TestType::kConst,
TestType::DYNAMIC}; TestType::kDynamic};
using DataTypes = ::testing::Types<float, int8_t, int16_t>; using DataTypes = ::testing::Types<float, int8_t, int16_t>;
TYPED_TEST_SUITE(SplitOpTest, DataTypes); TYPED_TEST_SUITE(SplitOpTest, DataTypes);
TYPED_TEST(SplitOpTest, FourDimensional) { TYPED_TEST(SplitOpTest, FourDimensional) {
for (TestType test_type : SplitOpTest<TestType>::_range_) { for (TestType test_type : SplitOpTest<TestType>::range_) {
Check<TypeParam>(/*axis_as_tensor*/ test_type, Check<TypeParam>(/*axis_as_tensor*/ test_type,
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, /*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
@ -158,7 +158,7 @@ TYPED_TEST(SplitOpTest, FourDimensional) {
} }
TYPED_TEST(SplitOpTest, FourDimensionalInt8) { TYPED_TEST(SplitOpTest, FourDimensionalInt8) {
for (TestType test_type : SplitOpTest<TestType>::_range_) { for (TestType test_type : SplitOpTest<TestType>::range_) {
Check<TypeParam>(/*axis_as_tensor*/ test_type, Check<TypeParam>(/*axis_as_tensor*/ test_type,
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, /*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
@ -195,7 +195,7 @@ TYPED_TEST(SplitOpTest, FourDimensionalInt8) {
} }
TYPED_TEST(SplitOpTest, FourDimensionalInt32) { TYPED_TEST(SplitOpTest, FourDimensionalInt32) {
for (TestType test_type : SplitOpTest<TestType>::_range_) { for (TestType test_type : SplitOpTest<TestType>::range_) {
Check<TypeParam>(/*axis_as_tensor*/ test_type, Check<TypeParam>(/*axis_as_tensor*/ test_type,
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, /*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
@ -232,7 +232,7 @@ TYPED_TEST(SplitOpTest, FourDimensionalInt32) {
} }
TYPED_TEST(SplitOpTest, OneDimensional) { TYPED_TEST(SplitOpTest, OneDimensional) {
for (TestType test_type : SplitOpTest<TestType>::_range_) { for (TestType test_type : SplitOpTest<TestType>::range_) {
Check<TypeParam>( Check<TypeParam>(
/*axis_as_tensor*/ test_type, /*axis_as_tensor*/ test_type,
/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8}, /*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
@ -241,7 +241,7 @@ TYPED_TEST(SplitOpTest, OneDimensional) {
} }
TYPED_TEST(SplitOpTest, NegativeAxis) { TYPED_TEST(SplitOpTest, NegativeAxis) {
for (TestType test_type : SplitOpTest<TestType>::_range_) { for (TestType test_type : SplitOpTest<TestType>::range_) {
Check<TypeParam>(/*axis_as_tensor*/ test_type, Check<TypeParam>(/*axis_as_tensor*/ test_type,
/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, /*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},

View File

@ -26,8 +26,8 @@ namespace {
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
enum class TestType { enum class TestType {
CONST = 0, kConst = 0,
DYNAMIC = 1, kDynamic = 1,
}; };
template <typename InputType> template <typename InputType>
@ -36,7 +36,7 @@ class TopKV2OpModel : public SingleOpModel {
TopKV2OpModel(int top_k, std::initializer_list<int> input_shape, TopKV2OpModel(int top_k, std::initializer_list<int> input_shape,
std::initializer_list<InputType> input_data, std::initializer_list<InputType> input_data,
TestType input_tensor_types) { TestType input_tensor_types) {
if (input_tensor_types == TestType::DYNAMIC) { if (input_tensor_types == TestType::kDynamic) {
input_ = AddInput(GetTensorType<InputType>()); input_ = AddInput(GetTensorType<InputType>());
top_k_ = AddInput(TensorType_INT32); top_k_ = AddInput(TensorType_INT32);
} else { } else {
@ -49,7 +49,7 @@ class TopKV2OpModel : public SingleOpModel {
SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0); SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0);
BuildInterpreter({input_shape, {1}}); BuildInterpreter({input_shape, {1}});
if (input_tensor_types == TestType::DYNAMIC) { if (input_tensor_types == TestType::kDynamic) {
PopulateTensor<InputType>(input_, input_data); PopulateTensor<InputType>(input_, input_data);
PopulateTensor<int32_t>(top_k_, {top_k}); PopulateTensor<int32_t>(top_k_, {top_k});
} }
@ -119,7 +119,8 @@ TEST_P(TopKV2OpTest, TypeInt32) {
} }
INSTANTIATE_TEST_SUITE_P(TopKV2OpTest, TopKV2OpTest, INSTANTIATE_TEST_SUITE_P(TopKV2OpTest, TopKV2OpTest,
::testing::Values(TestType::CONST, TestType::DYNAMIC)); ::testing::Values(TestType::kConst,
TestType::kDynamic));
// Check that uint8_t works. // Check that uint8_t works.
TEST_P(TopKV2OpTest, TypeUint8) { TEST_P(TopKV2OpTest, TypeUint8) {

View File

@ -37,8 +37,8 @@ namespace {
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
enum class TestType { enum class TestType {
CONST = 0, kConst = 0,
DYNAMIC = 1, kDynamic = 1,
}; };
template <typename InputType> template <typename InputType>
@ -54,7 +54,7 @@ class BaseTransposeConvOpModel : public SingleOpModel {
// Just to be confusing, transpose_conv has an _input_ named "output_shape" // Just to be confusing, transpose_conv has an _input_ named "output_shape"
// that sets the shape of the output tensor of the op :). It must always be // that sets the shape of the output tensor of the op :). It must always be
// an int32 1D four element tensor. // an int32 1D four element tensor.
if (test_type == TestType::DYNAMIC) { if (test_type == TestType::kDynamic) {
output_shape_ = AddInput({TensorType_INT32, {4}}); output_shape_ = AddInput({TensorType_INT32, {4}});
filter_ = AddInput(filter); filter_ = AddInput(filter);
} else { } else {
@ -74,7 +74,7 @@ class BaseTransposeConvOpModel : public SingleOpModel {
BuildInterpreter( BuildInterpreter(
{GetShape(output_shape_), GetShape(filter_), GetShape(input_)}); {GetShape(output_shape_), GetShape(filter_), GetShape(input_)});
if (test_type == TestType::DYNAMIC) { if (test_type == TestType::kDynamic) {
PopulateTensor<int32_t>(output_shape_, output_shape_data); PopulateTensor<int32_t>(output_shape_, output_shape_data);
PopulateTensor<InputType>(filter_, filter_data); PopulateTensor<InputType>(filter_, filter_data);
} }
@ -445,7 +445,7 @@ INSTANTIATE_TEST_SUITE_P(
TransposeConvOpTest, TransposeConvOpTest, TransposeConvOpTest, TransposeConvOpTest,
::testing::Combine( ::testing::Combine(
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)), ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)),
::testing::Values(TestType::CONST, TestType::DYNAMIC))); ::testing::Values(TestType::kConst, TestType::kDynamic)));
} // namespace } // namespace
} // namespace tflite } // namespace tflite

View File

@ -161,7 +161,7 @@ Flag CreateFlag(const char* name, BenchmarkParams* params,
const std::string& usage) { const std::string& usage) {
return Flag( return Flag(
name, [params, name](const T& val) { params->Set<T>(name, val); }, name, [params, name](const T& val) { params->Set<T>(name, val); },
params->Get<T>(name), usage, Flag::OPTIONAL); params->Get<T>(name), usage, Flag::kOptional);
} }
// Benchmarks a model. // Benchmarks a model.

View File

@ -58,7 +58,7 @@ class DelegateProvider {
const std::string& usage) const { const std::string& usage) const {
return Flag( return Flag(
name, [params, name](const T& val) { params->Set<T>(name, val); }, name, [params, name](const T& val) { params->Set<T>(name, val); },
default_params_.Get<T>(name), usage, Flag::OPTIONAL); default_params_.Get<T>(name), usage, Flag::kOptional);
} }
BenchmarkParams default_params_; BenchmarkParams default_params_;
}; };

View File

@ -142,7 +142,7 @@ Flag::Flag(const char* name,
flag_type_(flag_type) {} flag_type_(flag_type) {}
bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const { bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
return ParseFlag(arg, name_, flag_type_ == POSITIONAL, value_hook_, return ParseFlag(arg, name_, flag_type_ == kPositional, value_hook_,
value_parsing_ok); value_parsing_ok);
} }
@ -195,7 +195,7 @@ std::string Flag::GetTypeName() const {
result = false; result = false;
} }
continue; continue;
} else if (flag.flag_type_ == Flag::REQUIRED) { } else if (flag.flag_type_ == Flag::kRequired) {
TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_; TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
// If the required flag isn't found, we immediately stop the whole flag // If the required flag isn't found, we immediately stop the whole flag
// parsing. // parsing.
@ -205,7 +205,7 @@ std::string Flag::GetTypeName() const {
} }
// Parses positional flags. // Parses positional flags.
if (flag.flag_type_ == Flag::POSITIONAL) { if (flag.flag_type_ == Flag::kPositional) {
if (++positional_count >= *argc) { if (++positional_count >= *argc) {
TFLITE_LOG(ERROR) << "Too few command line arguments."; TFLITE_LOG(ERROR) << "Too few command line arguments.";
return false; return false;
@ -245,7 +245,7 @@ std::string Flag::GetTypeName() const {
// The flag isn't found, do some bookkeeping work. // The flag isn't found, do some bookkeeping work.
processed_flags[flag.name_] = -1; processed_flags[flag.name_] = -1;
if (flag.flag_type_ == Flag::REQUIRED) { if (flag.flag_type_ == Flag::kRequired) {
TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_; TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
result = false; result = false;
// If the required flag isn't found, we immediately stop the whole flag // If the required flag isn't found, we immediately stop the whole flag
@ -280,7 +280,7 @@ std::string Flag::GetTypeName() const {
// Prints usage for positional flag. // Prints usage for positional flag.
for (int i = 0; i < sorted_idx.size(); ++i) { for (int i = 0; i < sorted_idx.size(); ++i) {
const Flag& flag = flag_list[sorted_idx[i]]; const Flag& flag = flag_list[sorted_idx[i]];
if (flag.flag_type_ == Flag::POSITIONAL) { if (flag.flag_type_ == Flag::kPositional) {
positional_count++; positional_count++;
usage_text << " <" << flag.name_ << ">"; usage_text << " <" << flag.name_ << ">";
} else { } else {
@ -295,7 +295,7 @@ std::string Flag::GetTypeName() const {
std::vector<std::string> name_column(flag_list.size()); std::vector<std::string> name_column(flag_list.size());
for (int i = 0; i < sorted_idx.size(); ++i) { for (int i = 0; i < sorted_idx.size(); ++i) {
const Flag& flag = flag_list[sorted_idx[i]]; const Flag& flag = flag_list[sorted_idx[i]];
if (flag.flag_type_ != Flag::POSITIONAL) { if (flag.flag_type_ != Flag::kPositional) {
name_column[i] += "--"; name_column[i] += "--";
name_column[i] += flag.name_; name_column[i] += flag.name_;
name_column[i] += "="; name_column[i] += "=";
@ -320,7 +320,8 @@ std::string Flag::GetTypeName() const {
usage_text << "\t"; usage_text << "\t";
usage_text << std::left << std::setw(max_name_width) << name_column[i]; usage_text << std::left << std::setw(max_name_width) << name_column[i];
usage_text << "\t" << type_name << "\t"; usage_text << "\t" << type_name << "\t";
usage_text << (flag.flag_type_ != Flag::OPTIONAL ? "required" : "optional"); usage_text << (flag.flag_type_ != Flag::kOptional ? "required"
: "optional");
usage_text << "\t" << flag.usage_text_ << "\n"; usage_text << "\t" << flag.usage_text_ << "\n";
} }
return usage_text.str(); return usage_text.str();

View File

@ -65,16 +65,16 @@ namespace tflite {
class Flag { class Flag {
public: public:
enum FlagType { enum FlagType {
POSITIONAL = 0, kPositional = 0,
REQUIRED, kRequired,
OPTIONAL, kOptional,
}; };
// The order of the positional flags is the same as they are added. // The order of the positional flags is the same as they are added.
// Positional flags are supposed to be required. // Positional flags are supposed to be required.
template <typename T> template <typename T>
static Flag CreateFlag(const char* name, T* val, const char* usage, static Flag CreateFlag(const char* name, T* val, const char* usage,
FlagType flag_type = OPTIONAL) { FlagType flag_type = kOptional) {
return Flag( return Flag(
name, [val](const T& v) { *val = v; }, *val, usage, flag_type); name, [val](const T& v) { *val = v; }, *val, usage, flag_type);
} }

View File

@ -55,8 +55,10 @@ TEST(CommandLineFlagsTest, BasicUsage) {
Flag::CreateFlag("some_numeric_bool", &some_numeric_bool, Flag::CreateFlag("some_numeric_bool", &some_numeric_bool,
"some numeric bool"), "some numeric bool"),
Flag::CreateFlag("some_int1", &some_int1, "some int"), Flag::CreateFlag("some_int1", &some_int1, "some int"),
Flag::CreateFlag("some_int2", &some_int2, "some int", Flag::REQUIRED), Flag::CreateFlag("some_int2", &some_int2, "some int",
Flag::CreateFlag("float_1", &float_1, "some float", Flag::POSITIONAL), Flag::kRequired),
Flag::CreateFlag("float_1", &float_1, "some float",
Flag::kPositional),
}); });
EXPECT_TRUE(parsed_ok); EXPECT_TRUE(parsed_ok);
@ -131,7 +133,7 @@ TEST(CommandLineFlagsTest, RequiredFlagNotFound) {
const char* argv_strings[] = {"program_name", "--flag=12"}; const char* argv_strings[] = {"program_name", "--flag=12"};
bool parsed_ok = Flags::Parse( bool parsed_ok = Flags::Parse(
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)}); {Flag::CreateFlag("some_flag", &some_float, "", Flag::kRequired)});
EXPECT_FALSE(parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_NEAR(-23.23f, some_float, 1e-5f);
@ -144,7 +146,7 @@ TEST(CommandLineFlagsTest, NoArguments) {
const char* argv_strings[] = {"program_name"}; const char* argv_strings[] = {"program_name"};
bool parsed_ok = Flags::Parse( bool parsed_ok = Flags::Parse(
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)}); {Flag::CreateFlag("some_flag", &some_float, "", Flag::kRequired)});
EXPECT_FALSE(parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_NEAR(-23.23f, some_float, 1e-5f);
@ -157,7 +159,7 @@ TEST(CommandLineFlagsTest, NotEnoughArguments) {
const char* argv_strings[] = {"program_name"}; const char* argv_strings[] = {"program_name"};
bool parsed_ok = Flags::Parse( bool parsed_ok = Flags::Parse(
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)}); {Flag::CreateFlag("some_flag", &some_float, "", Flag::kPositional)});
EXPECT_FALSE(parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_NEAR(-23.23f, some_float, 1e-5f);
@ -170,7 +172,7 @@ TEST(CommandLineFlagsTest, PositionalFlagFailed) {
const char* argv_strings[] = {"program_name", "string"}; const char* argv_strings[] = {"program_name", "string"};
bool parsed_ok = Flags::Parse( bool parsed_ok = Flags::Parse(
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)}); {Flag::CreateFlag("some_flag", &some_float, "", Flag::kPositional)});
EXPECT_FALSE(parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f); EXPECT_NEAR(-23.23f, some_float, 1e-5f);
@ -213,9 +215,9 @@ TEST(CommandLineFlagsTest, UsageString) {
{Flag::CreateFlag("some_int", &some_int, "some int"), {Flag::CreateFlag("some_int", &some_int, "some int"),
Flag::CreateFlag("some_int64", &some_int64, "some int64"), Flag::CreateFlag("some_int64", &some_int64, "some int64"),
Flag::CreateFlag("some_switch", &some_switch, "some switch"), Flag::CreateFlag("some_switch", &some_switch, "some switch"),
Flag::CreateFlag("some_name", &some_name, "some name", Flag::REQUIRED), Flag::CreateFlag("some_name", &some_name, "some name", Flag::kRequired),
Flag::CreateFlag("some_int2", &some_int2, "some int", Flag::CreateFlag("some_int2", &some_int2, "some int",
Flag::POSITIONAL)}); Flag::kPositional)});
// Match the usage message, being sloppy about whitespace. // Match the usage message, being sloppy about whitespace.
const char* expected_usage = const char* expected_usage =
" usage: some_tool_name <some_int2> <flags>\n" " usage: some_tool_name <some_int2> <flags>\n"
@ -307,8 +309,8 @@ TEST(CommandLineFlagsTest, DuplicateFlagsNotFound) {
const char* argv_strings[] = {"program_name", "--some_float=1.0"}; const char* argv_strings[] = {"program_name", "--some_float=1.0"};
bool parsed_ok = Flags::Parse( bool parsed_ok = Flags::Parse(
&argc, reinterpret_cast<const char**>(argv_strings), &argc, reinterpret_cast<const char**>(argv_strings),
{Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::OPTIONAL), {Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::kOptional),
Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::REQUIRED)}); Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::kRequired)});
EXPECT_FALSE(parsed_ok); EXPECT_FALSE(parsed_ok);
EXPECT_EQ(-23, some_int1); EXPECT_EQ(-23, some_int1);