STT-tensorflow/tensorflow/lite/testing/tflite_model_test.bzl
Jared Duke 298ec44da3 Add a tflite_model_test build rule
This test runs an automated diff comparison of TF vs TFLite for a given
source model. It can also be used to run comparisons on-device with
delegates.

Also fix the tf_driver/tflite_diff tool to allow execution on mobile devices.

PiperOrigin-RevId: 284293992
Change-Id: Ia64927b4d76a195924e5dc2f16b7f4aa53481c0e
2019-12-06 17:35:14 -08:00

153 lines
5.3 KiB
Python

"""Definition for tflite_model_test rule that runs a TF Lite model accuracy test.
This rule generates targets to run a diff-based model accuracy test against
synthetic, random inputs. Future work will allow injection of "golden" inputs,
as well as more robust execution on mobile devices.
Example usage:
tflite_model_test(
name = "simple_diff_test",
tensorflow_model_file = "//tensorflow/lite:testdata/multi_add.pb",
input_layer = "a,b,c,d",
input_layer_shape = "1,8,8,3:1,8,8,3:1,8,8,3:1,8,8,3",
input_layer_type = "float,float,float,float",
output_layer = "x,y",
)
"""
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
def tflite_model_test(
name,
tensorflow_model_file,
input_layer,
input_layer_type,
input_layer_shape,
output_layer,
inference_type = "float",
extra_conversion_flags = [],
num_runs = 20,
tags = [],
size = "large"):
"""Create test targets for validating TFLite model execution relative to TF.
Args:
name: Generated test target name. Note that multiple targets may be
created if `delegates` are provided.
tensorflow_model_file: The binary GraphDef proto to run the benchmark on.
input_layer: A list of input tensors to use in the test.
input_layer_shape: The shape of the input layer in csv format.
input_layer_type: The data type of the input layer(s) (int, float, etc).
output_layer: The layer that output should be read from.
inference_type: The data type for inference and output.
extra_conversion_flags: Extra flags to append to those used for converting
models to the tflite format.
num_runs: Number of synthetic test cases to run.
tags: Extra tags to apply to the test targets.
size: The test size to use.
"""
conversion_flags = [
"--input_shapes=%s" % input_layer_shape,
"--input_arrays=%s" % input_layer,
"--output_arrays=%s" % output_layer,
] + extra_conversion_flags
tflite_model_file = make_tflite_files(
target_name = "tflite_" + name + "_model",
model_file = tensorflow_model_file,
conversion_flags = conversion_flags,
inference_type = inference_type,
)
diff_args = [
# TODO(b/134772701): Find a better way to extract the absolute path from
# a target without relying on $(location), which doesn't work with some
# mobile test variants. For now we use $(location), but something like
# the following is what we want for mobile tests:
# "--tensorflow_model=%s" % tensorflow_model_file.replace("//", "").replace(":", "/"),
# "--tflite_model=%s" % tflite_model_file.replace("//", "").replace(":", "/"),
"--tensorflow_model=$(location %s)" % tensorflow_model_file,
"--tflite_model=$(location %s)" % tflite_model_file,
"--input_layer=%s" % input_layer,
"--input_layer_type=%s" % input_layer_type,
"--input_layer_shape=%s" % input_layer_shape,
"--output_layer=%s" % output_layer,
"--num_runs_per_pass=%s" % num_runs,
]
tf_cc_test(
name = name,
size = size,
srcs = ["//tensorflow/lite/testing:tflite_diff_example_test.cc"],
args = diff_args,
data = [
tensorflow_model_file,
tflite_model_file,
],
tags = tags,
deps = [
"//tensorflow/lite/testing:init_tensorflow",
"//tensorflow/lite/testing:tflite_diff_flags",
"//tensorflow/lite/testing:tflite_diff_util",
],
)
def make_tflite_files(
target_name,
model_file,
conversion_flags,
inference_type):
"""Uses TFLite to convert and input proto to tflite flatbuffer format.
Args:
target_name: Generated target name.
model_file: the path to the input file.
conversion_flags: parameters to pass to tflite for conversion.
inference_type: The data type for inference and output.
Returns:
The name of the generated file.
"""
flags = [] + conversion_flags
if inference_type == "float":
flags += [
"--inference_type=FLOAT",
"--inference_input_type=FLOAT",
]
elif inference_type == "quantized":
flags += [
"--inference_type=QUANTIZED_UINT8",
"--inference_input_type=QUANTIZED_UINT8",
]
else:
fail("Invalid inference type (%s). Expected 'float' or 'quantized'" % inference_type)
srcs = [model_file]
# Convert from Tensorflow graphdef to tflite model.
output_file = target_name + ".fb"
tool = "//tensorflow/lite/python:tflite_convert"
cmd = ("$(location %s) " +
" --graph_def_file=$(location %s)" +
" --output_file=$(location %s)" +
" --input_format=TENSORFLOW_GRAPHDEF" +
" --output_format=TFLITE " +
" ".join(flags)
.replace("std_value", "std_dev_value")
.replace("quantize_weights=true", "quantize_weights"))
native.genrule(
name = target_name,
srcs = srcs,
tags = ["manual"],
outs = [
output_file,
],
cmd = cmd % (tool, model_file, output_file),
tools = [tool],
visibility = ["//tensorflow/lite/testing:__subpackages__"],
)
return "//%s:%s" % (native.package_name(), output_file)