Op level diff test.

PiperOrigin-RevId: 233089842
This commit is contained in:
Yunlu Li 2019-02-08 10:59:29 -08:00 committed by TensorFlower Gardener
parent d815b49f4c
commit d5fb65ff33
18 changed files with 947 additions and 1 deletions

View File

@ -11,6 +11,8 @@ load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
exports_files(glob([
"testdata/*.bin",
"testdata/*.pb",
"testdata/*.tflite",
"testdata/*.csv",
"models/testdata/*",
]))

View File

@ -12,6 +12,7 @@ cc_library(
"c_api_internal.h",
],
visibility = [
"//learning/brain/mobile/kernel_test:__subpackages__",
"//tensorflow/lite:__subpackages__",
],
)

View File

@ -0,0 +1 @@
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

View File

@ -159,6 +159,7 @@ cc_library(
srcs = ["tflite_driver.cc"],
hdrs = ["tflite_driver.h"],
deps = [
":join",
":split",
":test_runner",
"//tensorflow/lite:builtin_op_data",

View File

@ -0,0 +1,124 @@
package(default_visibility = [
"//visibility:public",
])
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_binary",
"tf_cc_test",
)
cc_library(
name = "util",
hdrs = ["util.h"],
deps = [
":input_generator",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/testing:split",
"//tensorflow/lite/testing:tflite_driver",
] + select({
"//conditions:default": [
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
}),
)
tf_cc_test(
name = "util_test",
size = "small",
srcs = ["util_test.cc"],
data = [
"//tensorflow/lite:testdata/add.bin",
"//tensorflow/lite:testdata/test_input.csv",
],
tags = [
"no_oss",
],
deps = [
":util",
"//tensorflow/lite/testing:tflite_driver",
"@com_google_googletest//:gtest_main",
],
)
tf_cc_binary(
name = "tflite_kernel_runner",
srcs = ["tflite_kernel_runner.cc"],
deps = [
":util",
],
)
tf_cc_binary(
name = "generate_diff_report",
srcs = ["generate_diff_report.cc"],
deps = [
":diff_analyzer",
"//tensorflow/core:framework_internal",
],
)
cc_library(
name = "input_generator",
srcs = ["input_generator.cc"],
hdrs = ["input_generator.h"],
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite:string",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/testing:join",
"//tensorflow/lite/testing:split",
],
)
tf_cc_test(
name = "input_generator_test",
size = "small",
srcs = ["input_generator_test.cc"],
data = [
"//tensorflow/lite:testdata/multi_add.bin",
"//tensorflow/lite:testdata/test_input.csv",
],
tags = [
"no_oss",
],
deps = [
":input_generator",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "diff_analyzer",
srcs = ["diff_analyzer.cc"],
hdrs = ["diff_analyzer.h"],
deps = [
"//tensorflow/lite:string",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/testing:split",
],
)
tf_cc_test(
name = "diff_analyzer_test",
size = "small",
srcs = ["diff_analyzer_test.cc"],
data = [
"//tensorflow/lite:testdata/test_input.csv",
],
tags = [
"no_oss",
],
deps = [
":diff_analyzer",
"//tensorflow/core:lib",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -0,0 +1,115 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/kernel_test/diff_analyzer.h"
#include <cmath>
#include <fstream>
#include "tensorflow/lite/testing/split.h"
namespace tflite {
namespace testing {
namespace {
float CalculateNormalizedMaxDiff(const std::vector<float>& base,
const std::vector<float>& test) {
float diff = 0;
// For numerical stability in case the tensor is all 0.
float base_max = 1e-6;
for (int i = 0; i < base.size(); i++) {
diff = std::max(diff, std::abs(base[i] - test[i]));
base_max = std::max(base_max, base[i]);
}
return diff / base_max;
}
float CalculateNormalizedL2Norm(const std::vector<float>& base,
const std::vector<float>& test) {
float l2_error = 0;
// For numerical stability in case the tensor is all 0.
float base_max = 1e-6;
for (int i = 0; i < base.size(); i++) {
float diff = base[i] - test[i];
l2_error += diff * diff;
base_max = std::max(base_max, base[i]);
}
l2_error /= base.size();
return std::sqrt(l2_error) / base_max;
}
TfLiteStatus Populate(const string& filename,
std::vector<std::vector<float>>* tensors) {
if (filename.empty()) {
fprintf(stderr, "Empty input file name.");
return kTfLiteError;
}
std::ifstream file(filename);
string content;
while (std::getline(file, content, '\n')) {
tensors->push_back(Split<float>(content, ","));
}
file.close();
return kTfLiteOk;
}
} // namespace
TfLiteStatus DiffAnalyzer::ReadFiles(const string& base, const string& test) {
TF_LITE_ENSURE_STATUS(Populate(base, &base_tensors_));
TF_LITE_ENSURE_STATUS(Populate(test, &test_tensors_));
if (base_tensors_.size() != test_tensors_.size()) {
fprintf(stderr, "Golden and test tensor dimensions don't match.");
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus DiffAnalyzer::WriteReport(const string& filename) {
if (filename.empty()) {
fprintf(stderr, "Empty output file name.");
return kTfLiteError;
}
std::ofstream output_file;
output_file.open(filename, std::fstream::out | std::fstream::trunc);
if (!output_file) {
fprintf(stderr, "Failed to open output file %s.", filename.c_str());
return kTfLiteError;
}
output_file << "Normalized L2 Error"
<< ","
<< "Normalized Max Diff"
<< "\n";
for (int i = 0; i < base_tensors_.size(); i++) {
float l2_error =
CalculateNormalizedL2Norm(base_tensors_[i], test_tensors_[i]);
float max_diff =
CalculateNormalizedMaxDiff(base_tensors_[i], test_tensors_[i]);
output_file << l2_error << "," << max_diff << "\n";
}
output_file.close();
return kTfLiteOk;
}
} // namespace testing
} // namespace tflite

View File

@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TESTING_KERNEL_TEST_DIFF_ANALYZER_H_
#define TENSORFLOW_LITE_TESTING_KERNEL_TEST_DIFF_ANALYZER_H_
#include <vector>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/string.h"
namespace tflite {
namespace testing {
// Reads the baseline and test files with output tensor values, and calculates
// the diff metrics.
class DiffAnalyzer {
public:
DiffAnalyzer() = default;
TfLiteStatus ReadFiles(const string& base, const string& test);
TfLiteStatus WriteReport(const string& filename);
private:
std::vector<std::vector<float>> base_tensors_;
std::vector<std::vector<float>> test_tensors_;
};
} // namespace testing
} // namespace tflite
#endif // TENSORFLOW_LITE_TESTING_KERNEL_TEST_DIFF_ANALYZER_H_

View File

@ -0,0 +1,47 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/kernel_test/diff_analyzer.h"
#include <fstream>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/lib/io/path.h"
namespace tflite {
namespace testing {
namespace {
TEST(DiffAnalyzerTest, ZeroDiff) {
DiffAnalyzer diff_analyzer;
string filename = "third_party/tensorflow/lite/testdata/test_input.csv";
ASSERT_EQ(diff_analyzer.ReadFiles(filename, filename), kTfLiteOk);
string output_file =
tensorflow::io::JoinPath(FLAGS_test_tmpdir + "diff_report.csv");
ASSERT_EQ(diff_analyzer.WriteReport(output_file), kTfLiteOk);
std::string content;
std::ifstream file(output_file);
std::getline(file, content);
std::getline(file, content);
ASSERT_EQ(content, "0,0");
}
} // namespace
} // namespace testing
} // namespace tflite

View File

@ -0,0 +1,34 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/testing/kernel_test/diff_analyzer.h"
int main(int argc, char** argv) {
string base, test, output;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("base", &base, "Path to the base serialized tensor."),
tensorflow::Flag("test", &test, "Path to the test serialized tensor."),
tensorflow::Flag("output", &output, "Path to the output file."),
};
tensorflow::Flags::Parse(&argc, argv, flag_list);
tflite::testing::DiffAnalyzer diff_analyzer;
diff_analyzer.ReadFiles(base, test);
diff_analyzer.WriteReport(output);
return 0;
}

View File

@ -0,0 +1,208 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/kernel_test/input_generator.h"
#include <fstream>
#include <limits>
#include <random>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/testing/join.h"
#include "tensorflow/lite/testing/split.h"
namespace tflite {
namespace testing {
namespace {
template <typename T>
std::vector<T> GenerateRandomTensor(TfLiteIntArray* dims,
const std::function<T(int)>& random_func) {
int64_t num_elements = 1;
for (int i = 0; i < dims->size; i++) {
num_elements *= dims->data[i];
}
std::vector<T> result(num_elements);
for (int i = 0; i < num_elements; i++) {
result[i] = random_func(i);
}
return result;
}
template <typename T>
std::vector<T> GenerateUniform(TfLiteIntArray* dims, float min, float max) {
auto random_float = [](float min, float max) {
// TODO(yunluli): Change seed for each invocation if needed.
static unsigned int seed;
return min + (max - min) * static_cast<float>(rand_r(&seed)) / RAND_MAX;
};
std::function<T(int)> random_t = [&](int) {
return static_cast<T>(random_float(min, max));
};
std::vector<T> data = GenerateRandomTensor(dims, random_t);
return data;
}
template <typename T>
std::vector<T> GenerateGaussian(TfLiteIntArray* dims, float min, float max) {
auto random_float = [](float min, float max) {
static std::default_random_engine generator;
// We generate a float number within [0, 1) following a mormal distribution
// with mean = 0.5 and stddev = 1/3, and use it to scale the final random
// number into the desired range.
static std::normal_distribution<double> distribution(0.5, 1.0 / 3);
auto rand_n = distribution(generator);
while (rand_n < 0 || rand_n >= 1) {
rand_n = distribution(generator);
}
return min + (max - min) * static_cast<float>(rand_n);
};
std::function<T(int)> random_t = [&](int) {
return static_cast<T>(random_float(min, max));
};
std::vector<T> data = GenerateRandomTensor(dims, random_t);
return data;
}
} // namespace
TfLiteStatus InputGenerator::LoadModel(const string& model_dir) {
model_ = FlatBufferModel::BuildFromFile(model_dir.c_str());
if (!model_) {
fprintf(stderr, "Cannot load model %s", model_dir.c_str());
return kTfLiteError;
}
::tflite::ops::builtin::BuiltinOpResolver builtin_ops;
InterpreterBuilder(*model_, builtin_ops)(&interpreter_);
if (!interpreter_) {
fprintf(stderr, "Failed to build interpreter.");
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus InputGenerator::ReadInputsFromFile(const string& filename) {
if (filename.empty()) {
fprintf(stderr, "Empty input file name.");
return kTfLiteError;
}
std::ifstream input_file(filename);
string input;
while (std::getline(input_file, input, '\n')) {
inputs_.push_back(input);
}
input_file.close();
return kTfLiteOk;
}
TfLiteStatus InputGenerator::WriteInputsToFile(const string& filename) {
if (filename.empty()) {
fprintf(stderr, "Empty input file name.");
return kTfLiteError;
}
std::ofstream output_file;
output_file.open(filename, std::fstream::out | std::fstream::trunc);
if (!output_file) {
fprintf(stderr, "Failed to open output file %s.", filename.c_str());
return kTfLiteError;
}
for (const auto& input : inputs_) {
output_file << input << "\n";
}
output_file.close();
return kTfLiteOk;
}
// TODO(yunluli): Support more tensor types when needed.
TfLiteStatus InputGenerator::GenerateInput(const string& distribution) {
auto input_tensor_ids = interpreter_->inputs();
for (auto id : input_tensor_ids) {
auto* tensor = interpreter_->tensor(id);
if (distribution == "UNIFORM") {
switch (tensor->type) {
case kTfLiteInt8: {
auto data = GenerateUniform<int8_t>(
tensor->dims, std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max());
inputs_.push_back(Join(data.data(), data.size(), ","));
break;
}
case kTfLiteUInt8: {
auto data = GenerateUniform<uint8_t>(
tensor->dims, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
inputs_.push_back(Join(data.data(), data.size(), ","));
break;
}
case kTfLiteFloat32: {
auto data = GenerateUniform<float>(tensor->dims, -1, 1);
inputs_.push_back(JoinDefault(data.data(), data.size(), ","));
break;
}
default:
fprintf(stderr, "Unsupported input tensor type %s.",
TfLiteTypeGetName(tensor->type));
break;
}
} else if (distribution == "GAUSSIAN") {
switch (tensor->type) {
case kTfLiteInt8: {
auto data = GenerateGaussian<int8_t>(
tensor->dims, std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max());
inputs_.push_back(Join(data.data(), data.size(), ","));
break;
}
case kTfLiteUInt8: {
auto data = GenerateGaussian<uint8_t>(
tensor->dims, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
inputs_.push_back(Join(data.data(), data.size(), ","));
break;
}
case kTfLiteFloat32: {
auto data = GenerateGaussian<float>(tensor->dims, -1, 1);
inputs_.push_back(JoinDefault(data.data(), data.size(), ","));
break;
}
default:
fprintf(stderr, "Unsupported input tensor type %s.",
TfLiteTypeGetName(tensor->type));
break;
}
} else {
fprintf(stderr, "Unsupported distribution %s.", distribution.c_str());
return kTfLiteError;
}
}
return kTfLiteOk;
}
std::vector<string> InputGenerator::GetInputs() { return inputs_; }
} // namespace testing
} // namespace tflite

View File

@ -0,0 +1,50 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TESTING_KERNEL_TEST_INPUT_GENERATOR_H_
#define TENSORFLOW_LITE_TESTING_KERNEL_TEST_INPUT_GENERATOR_H_
#include <memory>
#include <vector>
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string.h"
namespace tflite {
namespace testing {
// Generate random input, or read input from a file for kernel diff test.
// Needs to load the tflite graph to get information like tensor shape and
// data type.
class InputGenerator {
public:
InputGenerator() = default;
TfLiteStatus LoadModel(const string& model_dir);
TfLiteStatus ReadInputsFromFile(const string& filename);
TfLiteStatus GenerateInput(const string& distribution);
std::vector<string> GetInputs();
TfLiteStatus WriteInputsToFile(const string& filename);
private:
std::unique_ptr<FlatBufferModel> model_;
std::unique_ptr<Interpreter> interpreter_;
std::vector<string> inputs_;
};
} // namespace testing
} // namespace tflite
#endif // TENSORFLOW_LITE_TESTING_KERNEL_TEST_INPUT_GENERATOR_H_

View File

@ -0,0 +1,81 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/kernel_test/input_generator.h"
#include <fstream>
#include <map>
#include <gmock/gmock.h>
#include "testing/base/public/googletest.h"
#include <gtest/gtest.h>
namespace tflite {
namespace testing {
namespace {
TEST(InputGeneratorTest, LoadModel) {
InputGenerator input_generator;
ASSERT_EQ(input_generator.LoadModel(
"third_party/tensorflow/lite/testdata/multi_add.bin"),
kTfLiteOk);
}
TEST(InputGeneratorTest, ReadWriteSimpleFile) {
InputGenerator input_generator;
ASSERT_EQ(input_generator.ReadInputsFromFile(
"third_party/tensorflow/lite/testdata/test_input.csv"),
kTfLiteOk);
std::vector<string> inputs;
std::string content = "1";
for (int i = 0; i < 1 * 8 * 8 * 3 - 1; i++) {
content.append(",1");
}
inputs.push_back(content);
ASSERT_EQ(input_generator.GetInputs(), inputs);
auto output_filename = FLAGS_test_tmpdir + "/out.csv";
ASSERT_EQ(input_generator.WriteInputsToFile(output_filename), kTfLiteOk);
std::ifstream in(output_filename);
std::string out;
std::getline(in, out, '\n');
ASSERT_EQ(out, content);
}
TEST(InputGeneratorTest, GenerateUniformInput) {
InputGenerator input_generator;
ASSERT_EQ(input_generator.LoadModel(
"third_party/tensorflow/lite/testdata/multi_add.bin"),
kTfLiteOk);
input_generator.GenerateInput("UNIFORM");
auto inputs = input_generator.GetInputs();
ASSERT_EQ(inputs.size(), 4);
}
TEST(InputGeneratorTest, GenerateGaussianInput) {
InputGenerator input_generator;
ASSERT_EQ(input_generator.LoadModel(
"third_party/tensorflow/lite/testdata/multi_add.bin"),
kTfLiteOk);
input_generator.GenerateInput("GAUSSIAN");
auto inputs = input_generator.GetInputs();
ASSERT_EQ(inputs.size(), 4);
}
} // namespace
} // namespace testing
} // namespace tflite

View File

@ -0,0 +1,32 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/kernel_test/util.h"
int main(int argc, char** argv) {
tflite::testing::kernel_test::TestOptions options =
tflite::testing::kernel_test::ParseTfliteKernelTestFlags(&argc, argv);
const bool run_reference_kernel = options.kernel_type == "REFERENCE";
const bool use_nnapi = options.kernel_type == "NNAPI";
auto runner = absl::make_unique<tflite::testing::TfLiteDriver>(
use_nnapi, "", run_reference_kernel);
if (tflite::testing::kernel_test::RunKernelTest(options, runner.get()) ==
kTfLiteOk) {
return 0;
}
return -1;
}

View File

@ -0,0 +1,122 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TESTING_KERNEL_TEST_UTIL_H_
#define TENSORFLOW_LITE_TESTING_KERNEL_TEST_UTIL_H_
#include <fstream>
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/testing/kernel_test/input_generator.h"
#include "tensorflow/lite/testing/split.h"
#include "tensorflow/lite/testing/tflite_driver.h"
namespace tflite {
namespace testing {
namespace kernel_test {
struct TestOptions {
// Path of tensorflow lite model.
string tflite_model;
// Path of the input file. If empty, generate at runtime.
string read_input_from_file;
// Path to dump the input file.
string dump_input_to_file;
// Path to dump the output.
string dump_output_to_file;
// Input distribution.
string input_distribution;
// Kernel type.
string kernel_type;
};
TestOptions ParseTfliteKernelTestFlags(int* argc, char** argv) {
TestOptions options;
std::vector<tensorflow::Flag> flags = {
tensorflow::Flag("tflite_model", &options.tflite_model,
"Path of tensorflow lite model."),
tensorflow::Flag("read_input_from_file", &options.read_input_from_file,
"File to read input data from. If empty, generates "
"input at runtime."),
tensorflow::Flag("dump_input_to_file", &options.dump_input_to_file,
"File to dump randomly generated input."),
tensorflow::Flag("dump_output_to_file", &options.dump_output_to_file,
"File to dump output."),
tensorflow::Flag("input_distribution", &options.input_distribution,
"Input distribution. Default: Gaussian."),
tensorflow::Flag("kernel_type", &options.kernel_type, "Kernel type."),
};
tensorflow::Flags::Parse(argc, argv, flags);
return options;
}
TfLiteStatus RunKernelTest(const kernel_test::TestOptions& options,
TestRunner* runner) {
InputGenerator input_generator;
if (options.read_input_from_file.empty()) {
TF_LITE_ENSURE_STATUS(input_generator.LoadModel(options.tflite_model));
TF_LITE_ENSURE_STATUS(
input_generator.GenerateInput(options.input_distribution));
} else {
TF_LITE_ENSURE_STATUS(
input_generator.ReadInputsFromFile(options.read_input_from_file));
}
runner->LoadModel(options.tflite_model);
runner->AllocateTensors();
if (!runner->IsValid()) return kTfLiteError;
auto input_tensor_ids = runner->GetInputs();
auto inputs = input_generator.GetInputs();
if (inputs.size() != input_tensor_ids.size()) {
fprintf(stderr,
"Number of input tensors generated doesn't match what the model "
"asks for.");
}
for (int i = 0; i < inputs.size(); i++) {
runner->SetInput(input_tensor_ids[i], inputs[i]);
}
runner->Invoke();
if (!options.dump_input_to_file.empty()) {
TF_LITE_ENSURE_STATUS(
input_generator.WriteInputsToFile(options.dump_input_to_file));
}
if (!options.dump_output_to_file.empty()) {
std::ofstream output_file;
output_file.open(options.dump_output_to_file,
std::fstream::out | std::fstream::trunc);
if (!output_file) {
return kTfLiteError;
}
for (auto id : runner->GetOutputs()) {
output_file << runner->ReadOutput(id) << "\n";
}
output_file.close();
}
return kTfLiteOk;
}
} // namespace kernel_test
} // namespace testing
} // namespace tflite
#endif // TENSORFLOW_LITE_TESTING_KERNEL_TEST_UTIL_H_

View File

@ -0,0 +1,52 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/kernel_test/util.h"
#include <fstream>
#include <memory>
#include <gmock/gmock.h>
#include "testing/base/public/googletest.h"
#include <gtest/gtest.h>
#include "tensorflow/lite/testing/tflite_driver.h"
namespace tflite {
namespace testing {
namespace kernel_test {
namespace {
TEST(UtilTest, SimpleE2ETest) {
TestOptions options;
options.tflite_model = "third_party/tensorflow/lite/testdata/add.bin";
options.read_input_from_file =
"third_party/tensorflow/lite/testdata/test_input.csv";
options.dump_output_to_file = FLAGS_test_tmpdir + "/test_out.csv";
options.kernel_type = "REFERENCE";
std::unique_ptr<TestRunner> runner(new TfLiteDriver(false, "", true));
RunKernelTest(options, runner.get());
std::string expected = "3";
for (int i = 0; i < 1 * 8 * 8 * 3 - 1; i++) {
expected.append(",3");
}
std::string content;
std::ifstream file(options.dump_output_to_file);
std::getline(file, content);
EXPECT_EQ(content, expected);
}
} // namespace
} // namespace kernel_test
} // namespace testing
} // namespace tflite

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/register_ref.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/testing/join.h"
#include "tensorflow/lite/testing/split.h"
namespace tflite {
@ -383,5 +384,34 @@ void TfLiteDriver::ResetLSTMStateTensors() {
interpreter_->ResetVariableTensors();
}
string TfLiteDriver::ReadOutput(int id) {
auto* tensor = interpreter_->tensor(id);
int num_elements = 1;
for (int i = 0; i < tensor->dims->size; ++i) {
num_elements *= tensor->dims->data[i];
}
switch (tensor->type) {
case kTfLiteFloat32:
return JoinDefault(tensor->data.f, num_elements, ",");
case kTfLiteInt32:
return JoinDefault(tensor->data.i32, num_elements, ",");
case kTfLiteInt64:
return JoinDefault(tensor->data.i64, num_elements, ",");
case kTfLiteUInt8:
return Join(tensor->data.uint8, num_elements, ",");
case kTfLiteInt8:
return JoinDefault(tensor->data.int8, num_elements, ",");
case kTfLiteBool:
return JoinDefault(tensor->data.b, num_elements, ",");
default:
Invalidate(absl::StrCat("Unsupported tensor type ",
TfLiteTypeGetName(tensor->type),
" in TfLiteDriver::ReadOutput"));
return "";
}
}
} // namespace testing
} // namespace tflite

View File

@ -49,7 +49,7 @@ class TfLiteDriver : public TestRunner {
void SetExpectation(int id, const string& csv_values) override;
void Invoke() override;
bool CheckResults() override;
string ReadOutput(int id) override { return "no-op"; }
string ReadOutput(int id) override;
private:
void DeallocateStringTensor(TfLiteTensor* t) {

View File

@ -54,6 +54,8 @@ TEST(TfliteDriverTest, SimpleTest) {
ASSERT_TRUE(runner->IsValid());
ASSERT_TRUE(runner->CheckResults());
EXPECT_EQ(runner->ReadOutput(5), "0.101,0.202,0.303,0.404");
EXPECT_EQ(runner->ReadOutput(6), "0.011,0.022,0.033,0.044");
}
TEST(TfliteDriverTest, SingleAddOpTest) {
@ -88,6 +90,8 @@ TEST(TfliteDriverTest, SingleAddOpTest) {
ASSERT_TRUE(runner->IsValid());
ASSERT_TRUE(runner->CheckResults());
EXPECT_EQ(runner->ReadOutput(5), "0.101,0.202,0.303,0.404");
EXPECT_EQ(runner->ReadOutput(6), "0.011,0.022,0.033,0.044");
}
} // namespace