1. Simply used ',' instead of '#' as the separator between input layer name and its value range. This is to avoid cases where '#' exists in the input layer name.

For example: if the model has two input tensors named as "input1", "input2" and the "input1" requires the value to be initialized between 0 and 5 (inclusive), we could specify "--input_layer=input1,input2 --input_layer_shape=1,128:1,64 --input_layer_value_range=input2,0,5:input1,-12,15" to achieve this.

2. Added a test to check the parsing such input-related parameters.

PiperOrigin-RevId: 274694297
This commit is contained in:
Chao Mei 2019-10-14 17:09:30 -07:00 committed by TensorFlower Gardener
parent 025e871a4a
commit 1079ed61d8
3 changed files with 38 additions and 18 deletions

View File

@ -110,6 +110,8 @@ cc_library(
":benchmark_model_lib", ":benchmark_model_lib",
":benchmark_utils", ":benchmark_utils",
":logging", ":logging",
"@com_google_absl//absl/strings",
"@gemmlowp",
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite:string_util", "//tensorflow/lite:string_util",
"//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:builtin_ops",
@ -117,8 +119,12 @@ cc_library(
"//tensorflow/lite/profiling:profile_summarizer", "//tensorflow/lite/profiling:profile_summarizer",
"//tensorflow/lite/profiling:profiler", "//tensorflow/lite/profiling:profiler",
"//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation:utils",
"@gemmlowp", ] + select({
"//tensorflow:android": [
"//tensorflow/lite/delegates/gpu:delegate",
], ],
"//conditions:default": [],
}),
) )
cc_library( cc_library(

View File

@ -85,6 +85,21 @@ TEST(BenchmarkTest, DoesntCrash) {
benchmark.Run(); benchmark.Run();
} }
TEST(BenchmarkTest, DoesntCrashWithExplicitInput) {
ASSERT_THAT(g_model_path, testing::NotNull());
// Note: the following input-related params are *specific* to model
// 'g_model_path' which is specified as 'lite:testdata/multi_add.bin for the
// test.
BenchmarkParams params = CreateParams();
params.Set<std::string>("input_layer", "a,b,c,d");
params.Set<std::string>("input_layer_shape",
"1,8,8,3:1,8,8,3:1,8,8,3:1,8,8,3");
params.Set<std::string>("input_layer_value_range", "d,1,10:b,0,100");
BenchmarkTfLiteModel benchmark(std::move(params));
benchmark.Run();
}
class MaxDurationWorksTestListener : public BenchmarkListener { class MaxDurationWorksTestListener : public BenchmarkListener {
void OnBenchmarkEnd(const BenchmarkResults& results) override { void OnBenchmarkEnd(const BenchmarkResults& results) override {
const int64_t num_actul_runs = results.inference_time_us().count(); const int64_t num_actul_runs = results.inference_time_us().count();

View File

@ -23,6 +23,8 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "absl/strings/numbers.h"
#if defined(__ANDROID__) #if defined(__ANDROID__)
#include "tensorflow/lite/delegates/gpu/delegate.h" #include "tensorflow/lite/delegates/gpu/delegate.h"
#include "tensorflow/lite/nnapi/nnapi_util.h" #include "tensorflow/lite/nnapi/nnapi_util.h"
@ -191,8 +193,8 @@ TfLiteStatus PopulateInputLayerInfo(
std::vector<std::string> value_ranges = Split(value_ranges_string, ':'); std::vector<std::string> value_ranges = Split(value_ranges_string, ':');
std::vector<int> tmp_range; std::vector<int> tmp_range;
for (const auto val : value_ranges) { for (const auto val : value_ranges) {
std::vector<std::string> name_range = Split(val, '#'); std::vector<std::string> name_range = Split(val, ',');
if (name_range.size() != 2) { if (name_range.size() != 3) {
TFLITE_LOG(FATAL) << "Wrong input value range item specified: " << val; TFLITE_LOG(FATAL) << "Wrong input value range item specified: " << val;
} }
@ -210,20 +212,18 @@ TfLiteStatus PopulateInputLayerInfo(
<< ") in --input_layer as " << names_string; << ") in --input_layer as " << names_string;
// Parse the range value. // Parse the range value.
const std::string& input_range_str = name_range[1]; int low, high;
tmp_range.clear(); bool has_low = absl::SimpleAtoi(name_range[1], &low);
TFLITE_BENCHMARK_CHECK( bool has_high = absl::SimpleAtoi(name_range[2], &high);
util::SplitAndParse(input_range_str, ',', &tmp_range)) if (!has_low || !has_high || low > high) {
<< "Incorrect input value range string specified: " << input_range_str;
if (tmp_range.size() != 2 && tmp_range[0] > tmp_range[1]) {
TFLITE_LOG(FATAL) TFLITE_LOG(FATAL)
<< "Wrong low and high value of the input value range specified: " << "Wrong low and high value of the input value range specified: "
<< input_range_str; << val;
} }
info->at(layer_info_idx).has_value_range = true; info->at(layer_info_idx).has_value_range = true;
info->at(layer_info_idx).low = tmp_range[0]; info->at(layer_info_idx).low = low;
info->at(layer_info_idx).high = tmp_range[1]; info->at(layer_info_idx).high = high;
} }
return kTfLiteOk; return kTfLiteOk;
@ -299,11 +299,10 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
CreateFlag<std::string>("input_layer_shape", &params_, "input layer shape"), CreateFlag<std::string>("input_layer_shape", &params_, "input layer shape"),
CreateFlag<std::string>( CreateFlag<std::string>(
"input_layer_value_range", &params_, "input_layer_value_range", &params_,
"A map-like string representing value range for integer input layers. " "A map-like string representing value range for *integer* input "
"Each item is separated by ':', and the item value is a pair of input " "layers. Each item is separated by ':', and the item value consists of "
"layer name and integer-only range values (both low and high are " "input layer name and integer-only range values (both low and high are "
"inclusive), the name and the range is separated by '#', the low/high " "inclusive) separated by ',', e.g. input1,1,2:input2,0,254"),
"are separated by ',' e.g. input1#1,2:input2#0,254"),
CreateFlag<bool>("use_nnapi", &params_, "use nnapi delegate api"), CreateFlag<bool>("use_nnapi", &params_, "use nnapi delegate api"),
CreateFlag<std::string>( CreateFlag<std::string>(
"nnapi_execution_preference", &params_, "nnapi_execution_preference", &params_,