Add a tflite_model_test build rule
This test runs an automated diff comparison of TF vs TFLite for a given source model. It can also be used to run comparisons on-device with delegates. Also fix the tf_driver/tflite_diff tool to allow execution on mobile devices. PiperOrigin-RevId: 284293992 Change-Id: Ia64927b4d76a195924e5dc2f16b7f4aa53481c0e
This commit is contained in:
parent
92f61576fa
commit
298ec44da3
@ -6,6 +6,7 @@ load(
|
||||
"merged_test_models",
|
||||
)
|
||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||
load("//tensorflow/lite/testing:tflite_model_test.bzl", "tflite_model_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
@ -21,7 +22,10 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["generated_examples_zip_test.cc"])
|
||||
exports_files([
|
||||
"generated_examples_zip_test.cc",
|
||||
"tflite_diff_example_test.cc",
|
||||
])
|
||||
|
||||
[gen_zip_test(
|
||||
name = "zip_test_%s" % test_name,
|
||||
@ -309,13 +313,22 @@ cc_library(
|
||||
":join",
|
||||
":split",
|
||||
":test_runner",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/lite:string_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
"//tensorflow/lite:string_util",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
"//tensorflow:ios": [
|
||||
"//tensorflow/core:ios_tensorflow_lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -342,9 +355,18 @@ cc_library(
|
||||
":join",
|
||||
":split",
|
||||
":tf_driver",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/lite:string",
|
||||
],
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
"//tensorflow:ios": [
|
||||
"//tensorflow/core:ios_tensorflow_lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -404,6 +426,7 @@ cc_library(
|
||||
":split",
|
||||
":tflite_diff_util",
|
||||
":tflite_driver",
|
||||
"@com_google_absl//absl/strings",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -453,6 +476,20 @@ tf_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
tflite_model_test(
|
||||
name = "tflite_model_example_test",
|
||||
input_layer = "a,b,c,d",
|
||||
input_layer_shape = "1,8,8,3:1,8,8,3:1,8,8,3:1,8,8,3",
|
||||
input_layer_type = "float,float,float,float",
|
||||
output_layer = "x,y",
|
||||
tags = [
|
||||
"no_cuda_on_cpu_tap",
|
||||
"no_oss", # needs test data
|
||||
"tflite_not_portable", # TODO(b/134772701): Enable after making this a proper GTest.
|
||||
],
|
||||
tensorflow_model_file = "//tensorflow/lite:testdata/multi_add.pb",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "string_util_lib",
|
||||
srcs = ["string_util.cc"],
|
||||
|
@ -13,34 +13,31 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "tensorflow/lite/testing/generate_testspec.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/lite/testing/join.h"
|
||||
#include "tensorflow/lite/testing/split.h"
|
||||
#include "tensorflow/lite/testing/tf_driver.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void GenerateCsv(const std::vector<int>& shape, float min, float max,
|
||||
string* out) {
|
||||
auto random_float = [](float min, float max) {
|
||||
static unsigned int seed;
|
||||
return min + (max - min) * static_cast<float>(rand_r(&seed)) / RAND_MAX;
|
||||
};
|
||||
|
||||
std::function<T(int)> random_t = [&](int) {
|
||||
return static_cast<T>(random_float(min, max));
|
||||
};
|
||||
std::vector<T> data = GenerateRandomTensor(shape, random_t);
|
||||
template <typename T, typename RandomEngine, typename RandomDistribution>
|
||||
void GenerateCsv(const std::vector<int>& shape, RandomEngine* engine,
|
||||
RandomDistribution distribution, string* out) {
|
||||
std::vector<T> data =
|
||||
GenerateRandomTensor<T>(shape, [&]() { return distribution(*engine); });
|
||||
*out = Join(data.data(), data.size(), ",");
|
||||
}
|
||||
|
||||
template <typename RandomEngine>
|
||||
std::vector<string> GenerateInputValues(
|
||||
const std::vector<string>& input_layer,
|
||||
RandomEngine* engine, const std::vector<string>& input_layer,
|
||||
const std::vector<string>& input_layer_type,
|
||||
const std::vector<string>& input_layer_shape) {
|
||||
std::vector<string> input_values;
|
||||
@ -52,19 +49,29 @@ std::vector<string> GenerateInputValues(
|
||||
|
||||
switch (type) {
|
||||
case tensorflow::DT_FLOAT:
|
||||
GenerateCsv<float>(shape, -0.5, 0.5, &input_values[i]);
|
||||
GenerateCsv<float>(shape, engine,
|
||||
std::uniform_real_distribution<float>(-0.5, 0.5),
|
||||
&input_values[i]);
|
||||
break;
|
||||
case tensorflow::DT_UINT8:
|
||||
GenerateCsv<uint8_t>(shape, 0, 255, &input_values[i]);
|
||||
GenerateCsv<uint8_t>(shape, engine,
|
||||
std::uniform_int_distribution<uint8_t>(0, 255),
|
||||
&input_values[i]);
|
||||
break;
|
||||
case tensorflow::DT_INT32:
|
||||
GenerateCsv<int32_t>(shape, -100, 100, &input_values[i]);
|
||||
GenerateCsv<int32_t>(shape, engine,
|
||||
std::uniform_int_distribution<int32_t>(-100, 100),
|
||||
&input_values[i]);
|
||||
break;
|
||||
case tensorflow::DT_INT64:
|
||||
GenerateCsv<int64_t>(shape, -100, 100, &input_values[i]);
|
||||
GenerateCsv<int64_t>(shape, engine,
|
||||
std::uniform_int_distribution<int64_t>(-100, 100),
|
||||
&input_values[i]);
|
||||
break;
|
||||
case tensorflow::DT_BOOL:
|
||||
GenerateCsv<int>(shape, 0.01, 1.99, &input_values[i]);
|
||||
GenerateCsv<int>(shape, engine,
|
||||
std::uniform_int_distribution<int>(0, 1),
|
||||
&input_values[i]);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type %d (%s) when generating testspec.\n",
|
||||
@ -76,6 +83,8 @@ std::vector<string> GenerateInputValues(
|
||||
return input_values;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool GenerateTestSpecFromTensorflowModel(
|
||||
std::iostream& stream, const string& tensorflow_model_path,
|
||||
const string& tflite_model_path, int num_invocations,
|
||||
@ -109,11 +118,12 @@ bool GenerateTestSpecFromTensorflowModel(
|
||||
stream << "}\n";
|
||||
|
||||
// Generate inputs.
|
||||
std::mt19937 random_engine;
|
||||
for (int i = 0; i < num_invocations; ++i) {
|
||||
// Note that the input values are random, so each invocation will have a
|
||||
// different set.
|
||||
std::vector<string> input_values =
|
||||
GenerateInputValues(input_layer, input_layer_type, input_layer_shape);
|
||||
std::vector<string> input_values = GenerateInputValues(
|
||||
&random_engine, input_layer, input_layer_type, input_layer_shape);
|
||||
if (input_values.empty()) {
|
||||
std::cerr << "Unable to generate input values for the TensorFlow model. "
|
||||
"Make sure the correct values are defined for "
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_TESTING_GENERATE_TESTSPEC_H_
|
||||
#define TENSORFLOW_LITE_TESTING_GENERATE_TESTSPEC_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
@ -46,19 +47,16 @@ bool GenerateTestSpecFromTensorflowModel(
|
||||
const std::vector<string>& output_layer);
|
||||
|
||||
// Generates random values that are filled into the tensor.
|
||||
// random_func returns the generated random element at given index.
|
||||
template <typename T>
|
||||
template <typename T, typename RandomFunction>
|
||||
std::vector<T> GenerateRandomTensor(const std::vector<int>& shape,
|
||||
const std::function<T(int)>& random_func) {
|
||||
RandomFunction random_func) {
|
||||
int64_t num_elements = 1;
|
||||
for (const int dim : shape) {
|
||||
num_elements *= dim;
|
||||
}
|
||||
|
||||
std::vector<T> result(num_elements);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
result[i] = random_func(i);
|
||||
}
|
||||
std::generate_n(result.data(), num_elements, random_func);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/testing/generate_testspec.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -22,16 +24,16 @@ namespace testing {
|
||||
namespace {
|
||||
|
||||
TEST(GenerateRandomTensor, FloatValue) {
|
||||
static unsigned int seed = 0;
|
||||
std::function<float(int)> float_rand = [](int idx) {
|
||||
return static_cast<float>(rand_r(&seed)) / RAND_MAX - 0.5f;
|
||||
std::mt19937 random_engine;
|
||||
auto random_func = [&]() {
|
||||
return std::uniform_real_distribution<float>(-0.5, 0.5)(random_engine);
|
||||
};
|
||||
|
||||
std::set<float> values;
|
||||
float sum_x_square = 0.0f;
|
||||
float sum_x = 0.0f;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
const auto& data = GenerateRandomTensor<float>({1, 3, 4}, float_rand);
|
||||
const auto& data = GenerateRandomTensor<float>({1, 3, 4}, random_func);
|
||||
for (float value : data) {
|
||||
values.insert(value);
|
||||
sum_x_square += value * value;
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/lite/testing/split.h"
|
||||
#include "tensorflow/lite/testing/tflite_diff_util.h"
|
||||
@ -76,11 +77,11 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
|
||||
|
||||
TfLiteDriver::DelegateType delegate = TfLiteDriver::DelegateType::kNone;
|
||||
if (!values.delegate_name.empty()) {
|
||||
if (delegate_name == "NNAPI") {
|
||||
if (absl::EqualsIgnoreCase(values.delegate_name, "nnapi")) {
|
||||
delegate = TfLiteDriver::DelegateType::kNnapi;
|
||||
} else if (values.delegate_name == "GPU") {
|
||||
} else if (absl::EqualsIgnoreCase(values.delegate_name, "gpu")) {
|
||||
delegate = TfLiteDriver::DelegateType::kGpu;
|
||||
} else if (values.delegate_name == "FLEX") {
|
||||
} else if (absl::EqualsIgnoreCase(values.delegate_name, "flex")) {
|
||||
delegate = TfLiteDriver::DelegateType::kFlex;
|
||||
} else {
|
||||
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
|
||||
|
152
tensorflow/lite/testing/tflite_model_test.bzl
Normal file
152
tensorflow/lite/testing/tflite_model_test.bzl
Normal file
@ -0,0 +1,152 @@
|
||||
"""Definition for tflite_model_test rule that runs a TF Lite model accuracy test.
|
||||
|
||||
This rule generates targets to run a diff-based model accuracy test against
|
||||
synthetic, random inputs. Future work will allow injection of "golden" inputs,
|
||||
as well as more robust execution on mobile devices.
|
||||
|
||||
Example usage:
|
||||
|
||||
tflite_model_test(
|
||||
name = "simple_diff_test",
|
||||
tensorflow_model_file = "//tensorflow/lite:testdata/multi_add.pb",
|
||||
input_layer = "a,b,c,d",
|
||||
input_layer_shape = "1,8,8,3:1,8,8,3:1,8,8,3:1,8,8,3",
|
||||
input_layer_type = "float,float,float,float",
|
||||
output_layer = "x,y",
|
||||
)
|
||||
"""
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
def tflite_model_test(
|
||||
name,
|
||||
tensorflow_model_file,
|
||||
input_layer,
|
||||
input_layer_type,
|
||||
input_layer_shape,
|
||||
output_layer,
|
||||
inference_type = "float",
|
||||
extra_conversion_flags = [],
|
||||
num_runs = 20,
|
||||
tags = [],
|
||||
size = "large"):
|
||||
"""Create test targets for validating TFLite model execution relative to TF.
|
||||
|
||||
Args:
|
||||
name: Generated test target name. Note that multiple targets may be
|
||||
created if `delegates` are provided.
|
||||
tensorflow_model_file: The binary GraphDef proto to run the benchmark on.
|
||||
input_layer: A list of input tensors to use in the test.
|
||||
input_layer_shape: The shape of the input layer in csv format.
|
||||
input_layer_type: The data type of the input layer(s) (int, float, etc).
|
||||
output_layer: The layer that output should be read from.
|
||||
inference_type: The data type for inference and output.
|
||||
extra_conversion_flags: Extra flags to append to those used for converting
|
||||
models to the tflite format.
|
||||
num_runs: Number of synthetic test cases to run.
|
||||
tags: Extra tags to apply to the test targets.
|
||||
size: The test size to use.
|
||||
"""
|
||||
|
||||
conversion_flags = [
|
||||
"--input_shapes=%s" % input_layer_shape,
|
||||
"--input_arrays=%s" % input_layer,
|
||||
"--output_arrays=%s" % output_layer,
|
||||
] + extra_conversion_flags
|
||||
|
||||
tflite_model_file = make_tflite_files(
|
||||
target_name = "tflite_" + name + "_model",
|
||||
model_file = tensorflow_model_file,
|
||||
conversion_flags = conversion_flags,
|
||||
inference_type = inference_type,
|
||||
)
|
||||
|
||||
diff_args = [
|
||||
# TODO(b/134772701): Find a better way to extract the absolute path from
|
||||
# a target without relying on $(location), which doesn't work with some
|
||||
# mobile test variants. For now we use $(location), but something like
|
||||
# the following is what we want for mobile tests:
|
||||
# "--tensorflow_model=%s" % tensorflow_model_file.replace("//", "").replace(":", "/"),
|
||||
# "--tflite_model=%s" % tflite_model_file.replace("//", "").replace(":", "/"),
|
||||
"--tensorflow_model=$(location %s)" % tensorflow_model_file,
|
||||
"--tflite_model=$(location %s)" % tflite_model_file,
|
||||
"--input_layer=%s" % input_layer,
|
||||
"--input_layer_type=%s" % input_layer_type,
|
||||
"--input_layer_shape=%s" % input_layer_shape,
|
||||
"--output_layer=%s" % output_layer,
|
||||
"--num_runs_per_pass=%s" % num_runs,
|
||||
]
|
||||
|
||||
tf_cc_test(
|
||||
name = name,
|
||||
size = size,
|
||||
srcs = ["//tensorflow/lite/testing:tflite_diff_example_test.cc"],
|
||||
args = diff_args,
|
||||
data = [
|
||||
tensorflow_model_file,
|
||||
tflite_model_file,
|
||||
],
|
||||
tags = tags,
|
||||
deps = [
|
||||
"//tensorflow/lite/testing:init_tensorflow",
|
||||
"//tensorflow/lite/testing:tflite_diff_flags",
|
||||
"//tensorflow/lite/testing:tflite_diff_util",
|
||||
],
|
||||
)
|
||||
|
||||
def make_tflite_files(
|
||||
target_name,
|
||||
model_file,
|
||||
conversion_flags,
|
||||
inference_type):
|
||||
"""Uses TFLite to convert and input proto to tflite flatbuffer format.
|
||||
|
||||
Args:
|
||||
target_name: Generated target name.
|
||||
model_file: the path to the input file.
|
||||
conversion_flags: parameters to pass to tflite for conversion.
|
||||
inference_type: The data type for inference and output.
|
||||
Returns:
|
||||
The name of the generated file.
|
||||
"""
|
||||
flags = [] + conversion_flags
|
||||
if inference_type == "float":
|
||||
flags += [
|
||||
"--inference_type=FLOAT",
|
||||
"--inference_input_type=FLOAT",
|
||||
]
|
||||
elif inference_type == "quantized":
|
||||
flags += [
|
||||
"--inference_type=QUANTIZED_UINT8",
|
||||
"--inference_input_type=QUANTIZED_UINT8",
|
||||
]
|
||||
else:
|
||||
fail("Invalid inference type (%s). Expected 'float' or 'quantized'" % inference_type)
|
||||
|
||||
srcs = [model_file]
|
||||
|
||||
# Convert from Tensorflow graphdef to tflite model.
|
||||
output_file = target_name + ".fb"
|
||||
|
||||
tool = "//tensorflow/lite/python:tflite_convert"
|
||||
cmd = ("$(location %s) " +
|
||||
" --graph_def_file=$(location %s)" +
|
||||
" --output_file=$(location %s)" +
|
||||
" --input_format=TENSORFLOW_GRAPHDEF" +
|
||||
" --output_format=TFLITE " +
|
||||
" ".join(flags)
|
||||
.replace("std_value", "std_dev_value")
|
||||
.replace("quantize_weights=true", "quantize_weights"))
|
||||
|
||||
native.genrule(
|
||||
name = target_name,
|
||||
srcs = srcs,
|
||||
tags = ["manual"],
|
||||
outs = [
|
||||
output_file,
|
||||
],
|
||||
cmd = cmd % (tool, model_file, output_file),
|
||||
tools = [tool],
|
||||
visibility = ["//tensorflow/lite/testing:__subpackages__"],
|
||||
)
|
||||
return "//%s:%s" % (native.package_name(), output_file)
|
Loading…
Reference in New Issue
Block a user