Add a tflite_model_test build rule

This test runs an automated diff comparison of TF vs TFLite for a given
source model. It can also be used to run comparisons on-device with
delegates.

Also fix the tf_driver/tflite_diff tool to allow execution on mobile devices.

PiperOrigin-RevId: 284293992
Change-Id: Ia64927b4d76a195924e5dc2f16b7f4aa53481c0e
This commit is contained in:
Jared Duke 2019-12-06 17:30:23 -08:00 committed by TensorFlower Gardener
parent 92f61576fa
commit 298ec44da3
6 changed files with 245 additions and 45 deletions

View File

@ -6,6 +6,7 @@ load(
"merged_test_models",
)
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
load("//tensorflow/lite/testing:tflite_model_test.bzl", "tflite_model_test")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load(
"//tensorflow:tensorflow.bzl",
@ -21,7 +22,10 @@ package(
licenses = ["notice"], # Apache 2.0
)
exports_files(["generated_examples_zip_test.cc"])
exports_files([
"generated_examples_zip_test.cc",
"tflite_diff_example_test.cc",
])
[gen_zip_test(
name = "zip_test_%s" % test_name,
@ -309,13 +313,22 @@ cc_library(
":join",
":split",
":test_runner",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
"//tensorflow/lite:string_util",
"@com_google_absl//absl/strings",
],
"//tensorflow/lite:string_util",
] + select({
"//conditions:default": [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
],
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
"//tensorflow:ios": [
"//tensorflow/core:ios_tensorflow_lib",
],
}),
)
tf_cc_test(
@ -342,9 +355,18 @@ cc_library(
":join",
":split",
":tf_driver",
"//tensorflow/core:framework",
"//tensorflow/lite:string",
],
] + select({
"//conditions:default": [
"//tensorflow/core:framework",
],
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
],
"//tensorflow:ios": [
"//tensorflow/core:ios_tensorflow_lib",
],
}),
)
tf_cc_test(
@ -404,6 +426,7 @@ cc_library(
":split",
":tflite_diff_util",
":tflite_driver",
"@com_google_absl//absl/strings",
] + select({
"//conditions:default": [
"//tensorflow/core:framework_internal",
@ -453,6 +476,20 @@ tf_cc_binary(
],
)
tflite_model_test(
name = "tflite_model_example_test",
input_layer = "a,b,c,d",
input_layer_shape = "1,8,8,3:1,8,8,3:1,8,8,3:1,8,8,3",
input_layer_type = "float,float,float,float",
output_layer = "x,y",
tags = [
"no_cuda_on_cpu_tap",
"no_oss", # needs test data
"tflite_not_portable", # TODO(b/134772701): Enable after making this a proper GTest.
],
tensorflow_model_file = "//tensorflow/lite:testdata/multi_add.pb",
)
cc_library(
name = "string_util_lib",
srcs = ["string_util.cc"],

View File

@ -13,34 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <iostream>
#include "tensorflow/lite/testing/generate_testspec.h"
#include <iostream>
#include <random>
#include "tensorflow/core/framework/types.h"
#include "tensorflow/lite/testing/join.h"
#include "tensorflow/lite/testing/split.h"
#include "tensorflow/lite/testing/tf_driver.h"
#include "tensorflow/core/framework/types.h"
namespace tflite {
namespace testing {
namespace {
template <typename T>
void GenerateCsv(const std::vector<int>& shape, float min, float max,
string* out) {
auto random_float = [](float min, float max) {
static unsigned int seed;
return min + (max - min) * static_cast<float>(rand_r(&seed)) / RAND_MAX;
};
std::function<T(int)> random_t = [&](int) {
return static_cast<T>(random_float(min, max));
};
std::vector<T> data = GenerateRandomTensor(shape, random_t);
template <typename T, typename RandomEngine, typename RandomDistribution>
void GenerateCsv(const std::vector<int>& shape, RandomEngine* engine,
RandomDistribution distribution, string* out) {
std::vector<T> data =
GenerateRandomTensor<T>(shape, [&]() { return distribution(*engine); });
*out = Join(data.data(), data.size(), ",");
}
template <typename RandomEngine>
std::vector<string> GenerateInputValues(
const std::vector<string>& input_layer,
RandomEngine* engine, const std::vector<string>& input_layer,
const std::vector<string>& input_layer_type,
const std::vector<string>& input_layer_shape) {
std::vector<string> input_values;
@ -52,19 +49,29 @@ std::vector<string> GenerateInputValues(
switch (type) {
case tensorflow::DT_FLOAT:
GenerateCsv<float>(shape, -0.5, 0.5, &input_values[i]);
GenerateCsv<float>(shape, engine,
std::uniform_real_distribution<float>(-0.5, 0.5),
&input_values[i]);
break;
case tensorflow::DT_UINT8:
GenerateCsv<uint8_t>(shape, 0, 255, &input_values[i]);
GenerateCsv<uint8_t>(shape, engine,
std::uniform_int_distribution<uint8_t>(0, 255),
&input_values[i]);
break;
case tensorflow::DT_INT32:
GenerateCsv<int32_t>(shape, -100, 100, &input_values[i]);
GenerateCsv<int32_t>(shape, engine,
std::uniform_int_distribution<int32_t>(-100, 100),
&input_values[i]);
break;
case tensorflow::DT_INT64:
GenerateCsv<int64_t>(shape, -100, 100, &input_values[i]);
GenerateCsv<int64_t>(shape, engine,
std::uniform_int_distribution<int64_t>(-100, 100),
&input_values[i]);
break;
case tensorflow::DT_BOOL:
GenerateCsv<int>(shape, 0.01, 1.99, &input_values[i]);
GenerateCsv<int>(shape, engine,
std::uniform_int_distribution<int>(0, 1),
&input_values[i]);
break;
default:
fprintf(stderr, "Unsupported type %d (%s) when generating testspec.\n",
@ -76,6 +83,8 @@ std::vector<string> GenerateInputValues(
return input_values;
}
} // namespace
bool GenerateTestSpecFromTensorflowModel(
std::iostream& stream, const string& tensorflow_model_path,
const string& tflite_model_path, int num_invocations,
@ -109,11 +118,12 @@ bool GenerateTestSpecFromTensorflowModel(
stream << "}\n";
// Generate inputs.
std::mt19937 random_engine;
for (int i = 0; i < num_invocations; ++i) {
// Note that the input values are random, so each invocation will have a
// different set.
std::vector<string> input_values =
GenerateInputValues(input_layer, input_layer_type, input_layer_shape);
std::vector<string> input_values = GenerateInputValues(
&random_engine, input_layer, input_layer_type, input_layer_shape);
if (input_values.empty()) {
std::cerr << "Unable to generate input values for the TensorFlow model. "
"Make sure the correct values are defined for "

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_TESTING_GENERATE_TESTSPEC_H_
#define TENSORFLOW_LITE_TESTING_GENERATE_TESTSPEC_H_
#include <algorithm>
#include <functional>
#include <iostream>
#include <vector>
@ -46,19 +47,16 @@ bool GenerateTestSpecFromTensorflowModel(
const std::vector<string>& output_layer);
// Generates random values that are filled into the tensor.
// random_func returns the generated random element at given index.
template <typename T>
template <typename T, typename RandomFunction>
std::vector<T> GenerateRandomTensor(const std::vector<int>& shape,
const std::function<T(int)>& random_func) {
RandomFunction random_func) {
int64_t num_elements = 1;
for (const int dim : shape) {
num_elements *= dim;
}
std::vector<T> result(num_elements);
for (int i = 0; i < num_elements; i++) {
result[i] = random_func(i);
}
std::generate_n(result.data(), num_elements, random_func);
return result;
}

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/testing/generate_testspec.h"
#include <random>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@ -22,16 +24,16 @@ namespace testing {
namespace {
TEST(GenerateRandomTensor, FloatValue) {
static unsigned int seed = 0;
std::function<float(int)> float_rand = [](int idx) {
return static_cast<float>(rand_r(&seed)) / RAND_MAX - 0.5f;
std::mt19937 random_engine;
auto random_func = [&]() {
return std::uniform_real_distribution<float>(-0.5, 0.5)(random_engine);
};
std::set<float> values;
float sum_x_square = 0.0f;
float sum_x = 0.0f;
for (int i = 0; i < 100; i++) {
const auto& data = GenerateRandomTensor<float>({1, 3, 4}, float_rand);
const auto& data = GenerateRandomTensor<float>({1, 3, 4}, random_func);
for (float value : data) {
values.insert(value);
sum_x_square += value * value;

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <cstring>
#include "absl/strings/match.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/lite/testing/split.h"
#include "tensorflow/lite/testing/tflite_diff_util.h"
@ -76,11 +77,11 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
TfLiteDriver::DelegateType delegate = TfLiteDriver::DelegateType::kNone;
if (!values.delegate_name.empty()) {
if (delegate_name == "NNAPI") {
if (absl::EqualsIgnoreCase(values.delegate_name, "nnapi")) {
delegate = TfLiteDriver::DelegateType::kNnapi;
} else if (values.delegate_name == "GPU") {
} else if (absl::EqualsIgnoreCase(values.delegate_name, "gpu")) {
delegate = TfLiteDriver::DelegateType::kGpu;
} else if (values.delegate_name == "FLEX") {
} else if (absl::EqualsIgnoreCase(values.delegate_name, "flex")) {
delegate = TfLiteDriver::DelegateType::kFlex;
} else {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());

View File

@ -0,0 +1,152 @@
"""Definition for tflite_model_test rule that runs a TF Lite model accuracy test.
This rule generates targets to run a diff-based model accuracy test against
synthetic, random inputs. Future work will allow injection of "golden" inputs,
as well as more robust execution on mobile devices.
Example usage:
tflite_model_test(
name = "simple_diff_test",
tensorflow_model_file = "//tensorflow/lite:testdata/multi_add.pb",
input_layer = "a,b,c,d",
input_layer_shape = "1,8,8,3:1,8,8,3:1,8,8,3:1,8,8,3",
input_layer_type = "float,float,float,float",
output_layer = "x,y",
)
"""
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
def tflite_model_test(
name,
tensorflow_model_file,
input_layer,
input_layer_type,
input_layer_shape,
output_layer,
inference_type = "float",
extra_conversion_flags = [],
num_runs = 20,
tags = [],
size = "large"):
"""Create test targets for validating TFLite model execution relative to TF.
Args:
name: Generated test target name. Note that multiple targets may be
created if `delegates` are provided.
tensorflow_model_file: The binary GraphDef proto to run the benchmark on.
input_layer: A list of input tensors to use in the test.
input_layer_shape: The shape of the input layer in csv format.
input_layer_type: The data type of the input layer(s) (int, float, etc).
output_layer: The layer that output should be read from.
inference_type: The data type for inference and output.
extra_conversion_flags: Extra flags to append to those used for converting
models to the tflite format.
num_runs: Number of synthetic test cases to run.
tags: Extra tags to apply to the test targets.
size: The test size to use.
"""
conversion_flags = [
"--input_shapes=%s" % input_layer_shape,
"--input_arrays=%s" % input_layer,
"--output_arrays=%s" % output_layer,
] + extra_conversion_flags
tflite_model_file = make_tflite_files(
target_name = "tflite_" + name + "_model",
model_file = tensorflow_model_file,
conversion_flags = conversion_flags,
inference_type = inference_type,
)
diff_args = [
# TODO(b/134772701): Find a better way to extract the absolute path from
# a target without relying on $(location), which doesn't work with some
# mobile test variants. For now we use $(location), but something like
# the following is what we want for mobile tests:
# "--tensorflow_model=%s" % tensorflow_model_file.replace("//", "").replace(":", "/"),
# "--tflite_model=%s" % tflite_model_file.replace("//", "").replace(":", "/"),
"--tensorflow_model=$(location %s)" % tensorflow_model_file,
"--tflite_model=$(location %s)" % tflite_model_file,
"--input_layer=%s" % input_layer,
"--input_layer_type=%s" % input_layer_type,
"--input_layer_shape=%s" % input_layer_shape,
"--output_layer=%s" % output_layer,
"--num_runs_per_pass=%s" % num_runs,
]
tf_cc_test(
name = name,
size = size,
srcs = ["//tensorflow/lite/testing:tflite_diff_example_test.cc"],
args = diff_args,
data = [
tensorflow_model_file,
tflite_model_file,
],
tags = tags,
deps = [
"//tensorflow/lite/testing:init_tensorflow",
"//tensorflow/lite/testing:tflite_diff_flags",
"//tensorflow/lite/testing:tflite_diff_util",
],
)
def make_tflite_files(
target_name,
model_file,
conversion_flags,
inference_type):
"""Uses TFLite to convert and input proto to tflite flatbuffer format.
Args:
target_name: Generated target name.
model_file: the path to the input file.
conversion_flags: parameters to pass to tflite for conversion.
inference_type: The data type for inference and output.
Returns:
The name of the generated file.
"""
flags = [] + conversion_flags
if inference_type == "float":
flags += [
"--inference_type=FLOAT",
"--inference_input_type=FLOAT",
]
elif inference_type == "quantized":
flags += [
"--inference_type=QUANTIZED_UINT8",
"--inference_input_type=QUANTIZED_UINT8",
]
else:
fail("Invalid inference type (%s). Expected 'float' or 'quantized'" % inference_type)
srcs = [model_file]
# Convert from Tensorflow graphdef to tflite model.
output_file = target_name + ".fb"
tool = "//tensorflow/lite/python:tflite_convert"
cmd = ("$(location %s) " +
" --graph_def_file=$(location %s)" +
" --output_file=$(location %s)" +
" --input_format=TENSORFLOW_GRAPHDEF" +
" --output_format=TFLITE " +
" ".join(flags)
.replace("std_value", "std_dev_value")
.replace("quantize_weights=true", "quantize_weights"))
native.genrule(
name = target_name,
srcs = srcs,
tags = ["manual"],
outs = [
output_file,
],
cmd = cmd % (tool, model_file, output_file),
tools = [tool],
visibility = ["//tensorflow/lite/testing:__subpackages__"],
)
return "//%s:%s" % (native.package_name(), output_file)