Fix msvc execution of TFLite C/Kernel tests
PiperOrigin-RevId: 307488714 Change-Id: Ia68e3479e5c8a8f6ac2222264638d2980e92127c
This commit is contained in:
parent
de6c0ec676
commit
5a674e06a9
@ -2065,7 +2065,13 @@ cc_library(
|
||||
"//tensorflow/core/platform/default:logging.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
linkopts = ["-ldl"],
|
||||
linkopts = select({
|
||||
"//tensorflow:freebsd": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//conditions:default": [
|
||||
"-ldl",
|
||||
],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":platform_base",
|
||||
|
@ -87,6 +87,7 @@ cc_test(
|
||||
name = "c_api_test",
|
||||
size = "small",
|
||||
srcs = ["c_api_test.cc"],
|
||||
copts = tflite_copts(),
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/add.bin",
|
||||
"//tensorflow/lite:testdata/add_quantized.bin",
|
||||
@ -103,6 +104,7 @@ cc_test(
|
||||
name = "c_api_experimental_test",
|
||||
size = "small",
|
||||
srcs = ["c_api_experimental_test.cc"],
|
||||
copts = tflite_copts(),
|
||||
data = ["//tensorflow/lite:testdata/add.bin"],
|
||||
deps = [
|
||||
":c_api",
|
||||
|
@ -25,11 +25,10 @@ namespace {
|
||||
|
||||
TfLiteRegistration* GetDummyRegistration() {
|
||||
static TfLiteRegistration registration = {
|
||||
.init = nullptr,
|
||||
.free = nullptr,
|
||||
.prepare = nullptr,
|
||||
.invoke = [](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; },
|
||||
};
|
||||
/*init=*/nullptr,
|
||||
/*free=*/nullptr,
|
||||
/*prepare=*/nullptr,
|
||||
/*invoke=*/[](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }};
|
||||
return ®istration;
|
||||
}
|
||||
|
||||
|
@ -26,8 +26,8 @@ namespace {
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
enum class TestType {
|
||||
CONST = 0,
|
||||
DYNAMIC = 1,
|
||||
kConst = 0,
|
||||
kDynamic = 1,
|
||||
};
|
||||
|
||||
template <typename InputType>
|
||||
@ -36,7 +36,7 @@ class ExpandDimsOpModel : public SingleOpModel {
|
||||
ExpandDimsOpModel(int axis, std::initializer_list<int> input_shape,
|
||||
std::initializer_list<InputType> input_data,
|
||||
TestType input_tensor_types) {
|
||||
if (input_tensor_types == TestType::DYNAMIC) {
|
||||
if (input_tensor_types == TestType::kDynamic) {
|
||||
input_ = AddInput(GetTensorType<InputType>());
|
||||
axis_ = AddInput(TensorType_INT32);
|
||||
} else {
|
||||
@ -50,7 +50,7 @@ class ExpandDimsOpModel : public SingleOpModel {
|
||||
|
||||
BuildInterpreter({input_shape, {1}});
|
||||
|
||||
if (input_tensor_types == TestType::DYNAMIC) {
|
||||
if (input_tensor_types == TestType::kDynamic) {
|
||||
PopulateTensor<InputType>(input_, input_data);
|
||||
PopulateTensor<int32_t>(axis_, {axis});
|
||||
}
|
||||
@ -69,18 +69,18 @@ class ExpandDimsOpModel : public SingleOpModel {
|
||||
template <typename T>
|
||||
class ExpandDimsOpTest : public ::testing::Test {
|
||||
public:
|
||||
static std::vector<TestType> _range_;
|
||||
static std::vector<TestType> range_;
|
||||
};
|
||||
|
||||
template <>
|
||||
std::vector<TestType> ExpandDimsOpTest<TestType>::_range_{TestType::CONST,
|
||||
TestType::DYNAMIC};
|
||||
std::vector<TestType> ExpandDimsOpTest<TestType>::range_{TestType::kConst,
|
||||
TestType::kDynamic};
|
||||
|
||||
using DataTypes = ::testing::Types<float, int8_t, int16_t, int32_t>;
|
||||
TYPED_TEST_SUITE(ExpandDimsOpTest, DataTypes);
|
||||
|
||||
TYPED_TEST(ExpandDimsOpTest, PositiveAxis) {
|
||||
for (TestType test_type : ExpandDimsOpTest<TestType>::_range_) {
|
||||
for (TestType test_type : ExpandDimsOpTest<TestType>::range_) {
|
||||
std::initializer_list<TypeParam> values = {-1, 1, -2, 2};
|
||||
|
||||
ExpandDimsOpModel<TypeParam> axis_0(0, {2, 2}, values, test_type);
|
||||
@ -101,7 +101,7 @@ TYPED_TEST(ExpandDimsOpTest, PositiveAxis) {
|
||||
}
|
||||
|
||||
TYPED_TEST(ExpandDimsOpTest, NegativeAxis) {
|
||||
for (TestType test_type : ExpandDimsOpTest<TestType>::_range_) {
|
||||
for (TestType test_type : ExpandDimsOpTest<TestType>::range_) {
|
||||
std::initializer_list<TypeParam> values = {-1, 1, -2, 2};
|
||||
|
||||
ExpandDimsOpModel<TypeParam> m(-1, {2, 2}, values, test_type);
|
||||
@ -115,7 +115,7 @@ TEST(ExpandDimsOpTest, StrTensor) {
|
||||
std::initializer_list<std::string> values = {"abc", "de", "fghi"};
|
||||
|
||||
// this test will fail on TestType::CONST
|
||||
ExpandDimsOpModel<std::string> m(0, {3}, values, TestType::DYNAMIC);
|
||||
ExpandDimsOpModel<std::string> m(0, {3}, values, TestType::kDynamic);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetValues(), ElementsAreArray(values));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
|
||||
|
@ -713,11 +713,13 @@ void SimpleTestQuantizedInt16OutputCase(
|
||||
/*activation_func=*/ActivationFunctionType_NONE, weights_format);
|
||||
|
||||
std::mt19937 random_engine;
|
||||
std::uniform_int_distribution<uint8_t> weights_dist;
|
||||
// Some compilers don't support uint8_t for uniform_distribution.
|
||||
std::uniform_int_distribution<uint32_t> weights_dist(
|
||||
0, std::numeric_limits<uint8_t>::max());
|
||||
|
||||
std::vector<float> weights_data(input_depth * output_depth);
|
||||
for (auto& w : weights_data) {
|
||||
uint8_t q = weights_dist(random_engine);
|
||||
uint8_t q = static_cast<uint8_t>(weights_dist(random_engine));
|
||||
w = (q - kWeightsZeroPoint) * kWeightsScale;
|
||||
}
|
||||
|
||||
@ -739,10 +741,12 @@ void SimpleTestQuantizedInt16OutputCase(
|
||||
LOG(FATAL) << "Unhandled weights format";
|
||||
}
|
||||
|
||||
std::uniform_int_distribution<uint8_t> input_dist;
|
||||
// Some compilers don't support uint8_t for uniform_distribution.
|
||||
std::uniform_int_distribution<uint32_t> input_dist(
|
||||
0, std::numeric_limits<uint8_t>::max());
|
||||
std::vector<float> input_data(input_depth * batches);
|
||||
for (auto& i : input_data) {
|
||||
uint8_t q = input_dist(random_engine);
|
||||
uint8_t q = static_cast<uint8_t>(input_dist(random_engine));
|
||||
i = (q - kInputZeroPoint) * kInputScale;
|
||||
}
|
||||
|
||||
|
@ -105,6 +105,7 @@ float ExponentialRandomPositiveFloat(float percentile, float percentile_val,
|
||||
|
||||
void FillRandom(std::vector<float>* vec, float min, float max) {
|
||||
std::uniform_real_distribution<float> dist(min, max);
|
||||
// TODO(b/154540105): use std::ref to avoid copying the random engine.
|
||||
auto gen = std::bind(dist, RandomEngine());
|
||||
std::generate(std::begin(*vec), std::end(*vec), gen);
|
||||
}
|
||||
|
@ -59,12 +59,22 @@ float ExponentialRandomPositiveFloat(float percentile, float percentile_val,
|
||||
// Fills a vector with random floats between |min| and |max|.
|
||||
void FillRandom(std::vector<float>* vec, float min, float max);
|
||||
|
||||
template <typename T>
|
||||
void FillRandom(typename std::vector<T>::iterator begin_it,
|
||||
typename std::vector<T>::iterator end_it, T min, T max) {
|
||||
// Workaround for compilers that don't support (u)int8_t uniform_distribution.
|
||||
typedef typename std::conditional<sizeof(T) >= sizeof(int16_t), T,
|
||||
std::int16_t>::type rand_type;
|
||||
std::uniform_int_distribution<rand_type> dist(min, max);
|
||||
// TODO(b/154540105): use std::ref to avoid copying the random engine.
|
||||
auto gen = std::bind(dist, RandomEngine());
|
||||
std::generate(begin_it, end_it, [&gen] { return static_cast<T>(gen()); });
|
||||
}
|
||||
|
||||
// Fills a vector with random numbers between |min| and |max|.
|
||||
template <typename T>
|
||||
void FillRandom(std::vector<T>* vec, T min, T max) {
|
||||
std::uniform_int_distribution<T> dist(min, max);
|
||||
auto gen = std::bind(dist, RandomEngine());
|
||||
std::generate(std::begin(*vec), std::end(*vec), gen);
|
||||
return FillRandom(std::begin(*vec), std::end(*vec), min, max);
|
||||
}
|
||||
|
||||
// Fills a vector with random numbers.
|
||||
@ -73,14 +83,6 @@ void FillRandom(std::vector<T>* vec) {
|
||||
FillRandom(vec, std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FillRandom(typename std::vector<T>::iterator begin_it,
|
||||
typename std::vector<T>::iterator end_it, T min, T max) {
|
||||
std::uniform_int_distribution<T> dist(min, max);
|
||||
auto gen = std::bind(dist, RandomEngine());
|
||||
std::generate(begin_it, end_it, gen);
|
||||
}
|
||||
|
||||
// Fill with a "skyscraper" pattern, in which there is a central section (across
|
||||
// the depth) with higher values than the surround.
|
||||
template <typename T>
|
||||
|
@ -25,8 +25,8 @@ using ::testing::ElementsAreArray;
|
||||
using uint8 = std::uint8_t;
|
||||
|
||||
enum class TestType {
|
||||
CONST = 0,
|
||||
DYNAMIC = 1,
|
||||
kConst = 0,
|
||||
kDynamic = 1,
|
||||
};
|
||||
|
||||
class ResizeBilinearOpModel : public SingleOpModel {
|
||||
@ -35,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
||||
std::initializer_list<int> size_data,
|
||||
TestType test_type,
|
||||
bool half_pixel_centers = false) {
|
||||
bool const_size = (test_type == TestType::CONST);
|
||||
bool const_size = (test_type == TestType::kConst);
|
||||
|
||||
input_ = AddInput(input);
|
||||
if (const_size) {
|
||||
@ -332,7 +332,7 @@ TEST_P(ResizeBilinearOpTest, ThreeDimensionalResizeInt8) {
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ResizeBilinearOpTest, ResizeBilinearOpTest,
|
||||
testing::Values(TestType::CONST, TestType::DYNAMIC));
|
||||
testing::Values(TestType::kConst, TestType::kDynamic));
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -25,8 +25,8 @@ using ::testing::ElementsAreArray;
|
||||
using uint8 = std::uint8_t;
|
||||
|
||||
enum class TestType {
|
||||
CONST = 0,
|
||||
DYNAMIC = 1,
|
||||
kConst = 0,
|
||||
kDynamic = 1,
|
||||
};
|
||||
|
||||
class ResizeNearestNeighborOpModel : public SingleOpModel {
|
||||
@ -34,7 +34,7 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
|
||||
explicit ResizeNearestNeighborOpModel(const TensorData& input,
|
||||
std::initializer_list<int> size_data,
|
||||
TestType test_type) {
|
||||
bool const_size = (test_type == TestType::CONST);
|
||||
bool const_size = (test_type == TestType::kConst);
|
||||
|
||||
input_ = AddInput(input);
|
||||
if (const_size) {
|
||||
@ -264,7 +264,7 @@ TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) {
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTest,
|
||||
ResizeNearestNeighborOpTest,
|
||||
testing::Values(TestType::CONST, TestType::DYNAMIC));
|
||||
testing::Values(TestType::kConst, TestType::kDynamic));
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -24,8 +24,8 @@ namespace {
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
enum class TestType {
|
||||
CONST = 0,
|
||||
DYNAMIC = 1,
|
||||
kConst = 0,
|
||||
kDynamic = 1,
|
||||
};
|
||||
|
||||
template <typename input_type, typename index_type>
|
||||
@ -39,7 +39,7 @@ class SliceOpModel : public SingleOpModel {
|
||||
TensorType tensor_index_type, TensorType tensor_input_type,
|
||||
TestType input_tensor_types) {
|
||||
input_ = AddInput(tensor_input_type);
|
||||
if (input_tensor_types == TestType::DYNAMIC) {
|
||||
if (input_tensor_types == TestType::kDynamic) {
|
||||
begin_ = AddInput(tensor_index_type);
|
||||
size_ = AddInput(tensor_index_type);
|
||||
} else {
|
||||
@ -52,7 +52,7 @@ class SliceOpModel : public SingleOpModel {
|
||||
CreateSliceOptions(builder_).Union());
|
||||
BuildInterpreter({input_shape, begin_shape, size_shape});
|
||||
|
||||
if (input_tensor_types == TestType::DYNAMIC) {
|
||||
if (input_tensor_types == TestType::kDynamic) {
|
||||
PopulateTensor<index_type>(begin_, begin_data);
|
||||
PopulateTensor<index_type>(size_, size_data);
|
||||
}
|
||||
@ -239,7 +239,8 @@ TEST_P(SliceOpTest, SliceString) {
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(SliceOpTest, SliceOpTest,
|
||||
::testing::Values(TestType::CONST, TestType::DYNAMIC));
|
||||
::testing::Values(TestType::kConst,
|
||||
TestType::kDynamic));
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -26,8 +26,8 @@ using ::testing::ElementsAreArray;
|
||||
constexpr int kAxisIsATensor = -1000;
|
||||
|
||||
enum class TestType {
|
||||
CONST = 0,
|
||||
DYNAMIC = 1,
|
||||
kConst = 0,
|
||||
kDynamic = 1,
|
||||
};
|
||||
|
||||
class SplitOpModel : public SingleOpModel {
|
||||
@ -83,7 +83,7 @@ void Check(TestType test_type, int axis, int num_splits,
|
||||
<< " and num_splits=" << num_splits;
|
||||
return ss.str();
|
||||
};
|
||||
if (test_type == TestType::DYNAMIC) {
|
||||
if (test_type == TestType::kDynamic) {
|
||||
SplitOpModel m({type, input_shape}, num_splits);
|
||||
m.SetInput(input_data);
|
||||
m.SetAxis(axis);
|
||||
@ -110,18 +110,18 @@ void Check(TestType test_type, int axis, int num_splits,
|
||||
template <typename T>
|
||||
class SplitOpTest : public ::testing::Test {
|
||||
public:
|
||||
static std::vector<TestType> _range_;
|
||||
static std::vector<TestType> range_;
|
||||
};
|
||||
|
||||
template <>
|
||||
std::vector<TestType> SplitOpTest<TestType>::_range_{TestType::CONST,
|
||||
TestType::DYNAMIC};
|
||||
std::vector<TestType> SplitOpTest<TestType>::range_{TestType::kConst,
|
||||
TestType::kDynamic};
|
||||
|
||||
using DataTypes = ::testing::Types<float, int8_t, int16_t>;
|
||||
TYPED_TEST_SUITE(SplitOpTest, DataTypes);
|
||||
|
||||
TYPED_TEST(SplitOpTest, FourDimensional) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::_range_) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::range_) {
|
||||
Check<TypeParam>(/*axis_as_tensor*/ test_type,
|
||||
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
@ -158,7 +158,7 @@ TYPED_TEST(SplitOpTest, FourDimensional) {
|
||||
}
|
||||
|
||||
TYPED_TEST(SplitOpTest, FourDimensionalInt8) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::_range_) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::range_) {
|
||||
Check<TypeParam>(/*axis_as_tensor*/ test_type,
|
||||
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
@ -195,7 +195,7 @@ TYPED_TEST(SplitOpTest, FourDimensionalInt8) {
|
||||
}
|
||||
|
||||
TYPED_TEST(SplitOpTest, FourDimensionalInt32) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::_range_) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::range_) {
|
||||
Check<TypeParam>(/*axis_as_tensor*/ test_type,
|
||||
/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
@ -232,7 +232,7 @@ TYPED_TEST(SplitOpTest, FourDimensionalInt32) {
|
||||
}
|
||||
|
||||
TYPED_TEST(SplitOpTest, OneDimensional) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::_range_) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::range_) {
|
||||
Check<TypeParam>(
|
||||
/*axis_as_tensor*/ test_type,
|
||||
/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
|
||||
@ -241,7 +241,7 @@ TYPED_TEST(SplitOpTest, OneDimensional) {
|
||||
}
|
||||
|
||||
TYPED_TEST(SplitOpTest, NegativeAxis) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::_range_) {
|
||||
for (TestType test_type : SplitOpTest<TestType>::range_) {
|
||||
Check<TypeParam>(/*axis_as_tensor*/ test_type,
|
||||
/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
|
@ -26,8 +26,8 @@ namespace {
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
enum class TestType {
|
||||
CONST = 0,
|
||||
DYNAMIC = 1,
|
||||
kConst = 0,
|
||||
kDynamic = 1,
|
||||
};
|
||||
|
||||
template <typename InputType>
|
||||
@ -36,7 +36,7 @@ class TopKV2OpModel : public SingleOpModel {
|
||||
TopKV2OpModel(int top_k, std::initializer_list<int> input_shape,
|
||||
std::initializer_list<InputType> input_data,
|
||||
TestType input_tensor_types) {
|
||||
if (input_tensor_types == TestType::DYNAMIC) {
|
||||
if (input_tensor_types == TestType::kDynamic) {
|
||||
input_ = AddInput(GetTensorType<InputType>());
|
||||
top_k_ = AddInput(TensorType_INT32);
|
||||
} else {
|
||||
@ -49,7 +49,7 @@ class TopKV2OpModel : public SingleOpModel {
|
||||
SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0);
|
||||
BuildInterpreter({input_shape, {1}});
|
||||
|
||||
if (input_tensor_types == TestType::DYNAMIC) {
|
||||
if (input_tensor_types == TestType::kDynamic) {
|
||||
PopulateTensor<InputType>(input_, input_data);
|
||||
PopulateTensor<int32_t>(top_k_, {top_k});
|
||||
}
|
||||
@ -119,7 +119,8 @@ TEST_P(TopKV2OpTest, TypeInt32) {
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TopKV2OpTest, TopKV2OpTest,
|
||||
::testing::Values(TestType::CONST, TestType::DYNAMIC));
|
||||
::testing::Values(TestType::kConst,
|
||||
TestType::kDynamic));
|
||||
|
||||
// Check that uint8_t works.
|
||||
TEST_P(TopKV2OpTest, TypeUint8) {
|
||||
|
@ -37,8 +37,8 @@ namespace {
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
enum class TestType {
|
||||
CONST = 0,
|
||||
DYNAMIC = 1,
|
||||
kConst = 0,
|
||||
kDynamic = 1,
|
||||
};
|
||||
|
||||
template <typename InputType>
|
||||
@ -54,7 +54,7 @@ class BaseTransposeConvOpModel : public SingleOpModel {
|
||||
// Just to be confusing, transpose_conv has an _input_ named "output_shape"
|
||||
// that sets the shape of the output tensor of the op :). It must always be
|
||||
// an int32 1D four element tensor.
|
||||
if (test_type == TestType::DYNAMIC) {
|
||||
if (test_type == TestType::kDynamic) {
|
||||
output_shape_ = AddInput({TensorType_INT32, {4}});
|
||||
filter_ = AddInput(filter);
|
||||
} else {
|
||||
@ -74,7 +74,7 @@ class BaseTransposeConvOpModel : public SingleOpModel {
|
||||
BuildInterpreter(
|
||||
{GetShape(output_shape_), GetShape(filter_), GetShape(input_)});
|
||||
|
||||
if (test_type == TestType::DYNAMIC) {
|
||||
if (test_type == TestType::kDynamic) {
|
||||
PopulateTensor<int32_t>(output_shape_, output_shape_data);
|
||||
PopulateTensor<InputType>(filter_, filter_data);
|
||||
}
|
||||
@ -445,7 +445,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
TransposeConvOpTest, TransposeConvOpTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)),
|
||||
::testing::Values(TestType::CONST, TestType::DYNAMIC)));
|
||||
::testing::Values(TestType::kConst, TestType::kDynamic)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -161,7 +161,7 @@ Flag CreateFlag(const char* name, BenchmarkParams* params,
|
||||
const std::string& usage) {
|
||||
return Flag(
|
||||
name, [params, name](const T& val) { params->Set<T>(name, val); },
|
||||
params->Get<T>(name), usage, Flag::OPTIONAL);
|
||||
params->Get<T>(name), usage, Flag::kOptional);
|
||||
}
|
||||
|
||||
// Benchmarks a model.
|
||||
|
@ -58,7 +58,7 @@ class DelegateProvider {
|
||||
const std::string& usage) const {
|
||||
return Flag(
|
||||
name, [params, name](const T& val) { params->Set<T>(name, val); },
|
||||
default_params_.Get<T>(name), usage, Flag::OPTIONAL);
|
||||
default_params_.Get<T>(name), usage, Flag::kOptional);
|
||||
}
|
||||
BenchmarkParams default_params_;
|
||||
};
|
||||
|
@ -142,7 +142,7 @@ Flag::Flag(const char* name,
|
||||
flag_type_(flag_type) {}
|
||||
|
||||
bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
|
||||
return ParseFlag(arg, name_, flag_type_ == POSITIONAL, value_hook_,
|
||||
return ParseFlag(arg, name_, flag_type_ == kPositional, value_hook_,
|
||||
value_parsing_ok);
|
||||
}
|
||||
|
||||
@ -195,7 +195,7 @@ std::string Flag::GetTypeName() const {
|
||||
result = false;
|
||||
}
|
||||
continue;
|
||||
} else if (flag.flag_type_ == Flag::REQUIRED) {
|
||||
} else if (flag.flag_type_ == Flag::kRequired) {
|
||||
TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
|
||||
// If the required flag isn't found, we immediately stop the whole flag
|
||||
// parsing.
|
||||
@ -205,7 +205,7 @@ std::string Flag::GetTypeName() const {
|
||||
}
|
||||
|
||||
// Parses positional flags.
|
||||
if (flag.flag_type_ == Flag::POSITIONAL) {
|
||||
if (flag.flag_type_ == Flag::kPositional) {
|
||||
if (++positional_count >= *argc) {
|
||||
TFLITE_LOG(ERROR) << "Too few command line arguments.";
|
||||
return false;
|
||||
@ -245,7 +245,7 @@ std::string Flag::GetTypeName() const {
|
||||
|
||||
// The flag isn't found, do some bookkeeping work.
|
||||
processed_flags[flag.name_] = -1;
|
||||
if (flag.flag_type_ == Flag::REQUIRED) {
|
||||
if (flag.flag_type_ == Flag::kRequired) {
|
||||
TFLITE_LOG(ERROR) << "Required flag not provided: " << flag.name_;
|
||||
result = false;
|
||||
// If the required flag isn't found, we immediately stop the whole flag
|
||||
@ -280,7 +280,7 @@ std::string Flag::GetTypeName() const {
|
||||
// Prints usage for positional flag.
|
||||
for (int i = 0; i < sorted_idx.size(); ++i) {
|
||||
const Flag& flag = flag_list[sorted_idx[i]];
|
||||
if (flag.flag_type_ == Flag::POSITIONAL) {
|
||||
if (flag.flag_type_ == Flag::kPositional) {
|
||||
positional_count++;
|
||||
usage_text << " <" << flag.name_ << ">";
|
||||
} else {
|
||||
@ -295,7 +295,7 @@ std::string Flag::GetTypeName() const {
|
||||
std::vector<std::string> name_column(flag_list.size());
|
||||
for (int i = 0; i < sorted_idx.size(); ++i) {
|
||||
const Flag& flag = flag_list[sorted_idx[i]];
|
||||
if (flag.flag_type_ != Flag::POSITIONAL) {
|
||||
if (flag.flag_type_ != Flag::kPositional) {
|
||||
name_column[i] += "--";
|
||||
name_column[i] += flag.name_;
|
||||
name_column[i] += "=";
|
||||
@ -320,7 +320,8 @@ std::string Flag::GetTypeName() const {
|
||||
usage_text << "\t";
|
||||
usage_text << std::left << std::setw(max_name_width) << name_column[i];
|
||||
usage_text << "\t" << type_name << "\t";
|
||||
usage_text << (flag.flag_type_ != Flag::OPTIONAL ? "required" : "optional");
|
||||
usage_text << (flag.flag_type_ != Flag::kOptional ? "required"
|
||||
: "optional");
|
||||
usage_text << "\t" << flag.usage_text_ << "\n";
|
||||
}
|
||||
return usage_text.str();
|
||||
|
@ -65,16 +65,16 @@ namespace tflite {
|
||||
class Flag {
|
||||
public:
|
||||
enum FlagType {
|
||||
POSITIONAL = 0,
|
||||
REQUIRED,
|
||||
OPTIONAL,
|
||||
kPositional = 0,
|
||||
kRequired,
|
||||
kOptional,
|
||||
};
|
||||
|
||||
// The order of the positional flags is the same as they are added.
|
||||
// Positional flags are supposed to be required.
|
||||
template <typename T>
|
||||
static Flag CreateFlag(const char* name, T* val, const char* usage,
|
||||
FlagType flag_type = OPTIONAL) {
|
||||
FlagType flag_type = kOptional) {
|
||||
return Flag(
|
||||
name, [val](const T& v) { *val = v; }, *val, usage, flag_type);
|
||||
}
|
||||
|
@ -55,8 +55,10 @@ TEST(CommandLineFlagsTest, BasicUsage) {
|
||||
Flag::CreateFlag("some_numeric_bool", &some_numeric_bool,
|
||||
"some numeric bool"),
|
||||
Flag::CreateFlag("some_int1", &some_int1, "some int"),
|
||||
Flag::CreateFlag("some_int2", &some_int2, "some int", Flag::REQUIRED),
|
||||
Flag::CreateFlag("float_1", &float_1, "some float", Flag::POSITIONAL),
|
||||
Flag::CreateFlag("some_int2", &some_int2, "some int",
|
||||
Flag::kRequired),
|
||||
Flag::CreateFlag("float_1", &float_1, "some float",
|
||||
Flag::kPositional),
|
||||
});
|
||||
|
||||
EXPECT_TRUE(parsed_ok);
|
||||
@ -131,7 +133,7 @@ TEST(CommandLineFlagsTest, RequiredFlagNotFound) {
|
||||
const char* argv_strings[] = {"program_name", "--flag=12"};
|
||||
bool parsed_ok = Flags::Parse(
|
||||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)});
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::kRequired)});
|
||||
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||
@ -144,7 +146,7 @@ TEST(CommandLineFlagsTest, NoArguments) {
|
||||
const char* argv_strings[] = {"program_name"};
|
||||
bool parsed_ok = Flags::Parse(
|
||||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::REQUIRED)});
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::kRequired)});
|
||||
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||
@ -157,7 +159,7 @@ TEST(CommandLineFlagsTest, NotEnoughArguments) {
|
||||
const char* argv_strings[] = {"program_name"};
|
||||
bool parsed_ok = Flags::Parse(
|
||||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)});
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::kPositional)});
|
||||
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||
@ -170,7 +172,7 @@ TEST(CommandLineFlagsTest, PositionalFlagFailed) {
|
||||
const char* argv_strings[] = {"program_name", "string"};
|
||||
bool parsed_ok = Flags::Parse(
|
||||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::POSITIONAL)});
|
||||
{Flag::CreateFlag("some_flag", &some_float, "", Flag::kPositional)});
|
||||
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||
@ -213,9 +215,9 @@ TEST(CommandLineFlagsTest, UsageString) {
|
||||
{Flag::CreateFlag("some_int", &some_int, "some int"),
|
||||
Flag::CreateFlag("some_int64", &some_int64, "some int64"),
|
||||
Flag::CreateFlag("some_switch", &some_switch, "some switch"),
|
||||
Flag::CreateFlag("some_name", &some_name, "some name", Flag::REQUIRED),
|
||||
Flag::CreateFlag("some_name", &some_name, "some name", Flag::kRequired),
|
||||
Flag::CreateFlag("some_int2", &some_int2, "some int",
|
||||
Flag::POSITIONAL)});
|
||||
Flag::kPositional)});
|
||||
// Match the usage message, being sloppy about whitespace.
|
||||
const char* expected_usage =
|
||||
" usage: some_tool_name <some_int2> <flags>\n"
|
||||
@ -307,8 +309,8 @@ TEST(CommandLineFlagsTest, DuplicateFlagsNotFound) {
|
||||
const char* argv_strings[] = {"program_name", "--some_float=1.0"};
|
||||
bool parsed_ok = Flags::Parse(
|
||||
&argc, reinterpret_cast<const char**>(argv_strings),
|
||||
{Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::OPTIONAL),
|
||||
Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::REQUIRED)});
|
||||
{Flag::CreateFlag("some_int", &some_int1, "some int1", Flag::kOptional),
|
||||
Flag::CreateFlag("some_int", &some_int2, "some int2", Flag::kRequired)});
|
||||
|
||||
EXPECT_FALSE(parsed_ok);
|
||||
EXPECT_EQ(-23, some_int1);
|
||||
|
Loading…
Reference in New Issue
Block a user