Add tests for user's defined op in Flex delegate

PiperOrigin-RevId: 344928044
Change-Id: I80f50c1db4b01403670d7ac75c7cc282da74c4d9
This commit is contained in:
Thai Nguyen 2020-11-30 18:19:31 -08:00 committed by TensorFlower Gardener
parent 3666f6bc53
commit b9344903bd
6 changed files with 134 additions and 6 deletions

View File

@ -25,12 +25,14 @@ load("//tensorflow/lite:special_rules.bzl", "flex_portable_tensorflow_deps")
def generate_flex_kernel_header(
name,
models,
testonly = 0,
additional_deps = []):
"""A rule to generate a header file listing only used operators.
Args:
name: Name of the generated library.
models: TFLite models to interpret.
testonly: Should be marked as true if additional_deps is testonly.
additional_deps: Dependencies for additional TF ops.
Returns:
@ -55,6 +57,7 @@ def generate_flex_kernel_header(
deps = [
clean_dep("//tensorflow/lite/tools:list_flex_ops_main_lib"),
] + additional_deps,
testonly = testonly,
)
list_ops_tool = ":%s_list_flex_ops_main" % name
native.genrule(
@ -65,6 +68,7 @@ def generate_flex_kernel_header(
message = "Listing flex ops from %s..." % ",".join(models),
cmd = ("$(location " + list_ops_tool + ")" +
model_file_args + " > \"$@\""),
testonly = testonly,
)
# Generate the kernel registration header file from list of flex ops.
@ -86,6 +90,7 @@ def tflite_flex_cc_library(
name,
models = [],
additional_deps = [],
testonly = 0,
visibility = ["//visibility:public"]):
"""A rule to generate a flex delegate with only ops to run listed models.
@ -95,6 +100,7 @@ def tflite_flex_cc_library(
to support these models. If empty, the library will include all Tensorflow
ops and kernels.
additional_deps: Dependencies for additional TF ops.
testonly: Mark this library as testonly if true.
visibility: visibility of the generated rules.
"""
portable_tensorflow_lib = clean_dep("//tensorflow/core:portable_tensorflow_lib")
@ -103,6 +109,7 @@ def tflite_flex_cc_library(
name = "%s_tf_op_headers" % name,
models = models,
additional_deps = additional_deps,
testonly = testonly,
)
# Define a custom tensorflow_lib with selective registration.
@ -138,6 +145,7 @@ def tflite_flex_cc_library(
clean_dep("//tensorflow/lite/delegates/flex:portable_images_lib"),
],
alwayslink = 1,
testonly = testonly,
)
portable_tensorflow_lib = ":%s_tensorflow_lib" % name
@ -164,6 +172,7 @@ def tflite_flex_cc_library(
clean_dep("//tensorflow/lite/c:common"),
],
}) + additional_deps,
testonly = testonly,
alwayslink = 1,
)
@ -171,6 +180,7 @@ def tflite_flex_jni_library(
name,
models = [],
additional_deps = [],
testonly = 0,
visibility = ["//visibility:private"]):
"""A rule to generate a jni library listing only used operators.
@ -183,6 +193,7 @@ def tflite_flex_jni_library(
to support these models. If empty, the library will include all Tensorflow
ops and kernels.
additional_deps: Dependencies for additional TF ops.
testonly: Mark this library as testonly if true.
visibility: visibility of the generated rules.
"""
@ -192,6 +203,7 @@ def tflite_flex_jni_library(
name = "%s_flex_delegate" % name,
models = models,
additional_deps = additional_deps,
testonly = testonly,
visibility = visibility,
)
@ -204,6 +216,7 @@ def tflite_flex_jni_library(
clean_dep("//tensorflow/lite/delegates/flex/java/src/main/native:flex_delegate_jni.cc"),
],
copts = tflite_copts(),
testonly = testonly,
visibility = visibility,
deps = [
":%s_flex_delegate" % name,
@ -224,6 +237,7 @@ def tflite_flex_jni_library(
tflite_jni_binary(
name = "libtensorflowlite_flex_jni.so",
linkopts = tflite_jni_linkopts(),
testonly = testonly,
deps = [
":%s_flex_native" % name,
],
@ -234,6 +248,7 @@ def tflite_flex_android_library(
models = [],
additional_deps = [],
custom_package = "org.tensorflow.lite.flex",
testonly = 0,
visibility = ["//visibility:private"]):
"""A rule to generate an android library based on the selective-built jni library.
@ -244,18 +259,21 @@ def tflite_flex_android_library(
Tensorflow ops and kernels.
additional_deps: Dependencies for additional TF ops.
custom_package: Java package for which java sources will be generated.
testonly: Mark this library as testonly if true.
visibility: visibility of the generated rules.
"""
tflite_flex_jni_library(
name = name,
models = models,
additional_deps = additional_deps,
testonly = testonly,
visibility = visibility,
)
native.cc_library(
name = "%s_native" % name,
srcs = ["libtensorflowlite_flex_jni.so"],
testonly = testonly,
visibility = visibility,
)
@ -265,6 +283,7 @@ def tflite_flex_android_library(
manifest = clean_dep("//tensorflow/lite/java:AndroidManifest.xml"),
proguard_specs = [clean_dep("//tensorflow/lite/java:proguard.flags")],
custom_package = custom_package,
testonly = testonly,
deps = [
":%s_native" % name,
clean_dep("//tensorflow/lite/java:tensorflowlite_java"),

View File

@ -13,8 +13,11 @@ package(
tflite_flex_jni_library(
name = "test",
testonly = 1,
additional_deps = ["//tensorflow/lite/python/testdata:double_op_and_kernels"],
models = [
"//tensorflow/lite:testdata/multi_add_flex.bin",
"//tensorflow/lite:testdata/double_flex.bin",
],
)
@ -59,3 +62,32 @@ java_test(
"@junit",
],
)
java_test(
name = "SelectiveBuiltInterpreterFlexWithCustomOpsTest",
size = "small",
srcs = [
"//tensorflow/lite/java:portable_flex_with_custom_ops_tests",
"//tensorflow/lite/java:portable_test_utils",
],
data = [
"//tensorflow/lite:testdata/double_flex.bin",
],
javacopts = JAVACOPTS,
tags = [
"no_cuda_on_cpu_tap", # CUDA + flex is not officially supported.
"no_gpu", # GPU + flex is not officially supported.
"no_oss", # Currently requires --config=monolithic, b/118895218.
# TODO(b/121204962): Re-enable test after fixing memory leaks.
"noasan",
"notsan", # TODO(b/158651814) Re-enable after fixing racing condition.
],
test_class = "org.tensorflow.lite.InterpreterFlexWithCustomOpsTest",
visibility = ["//visibility:private"],
deps = [
":test_tensorflowlitelib_flex",
"//tensorflow/lite/java:tensorflowlitelib",
"@com_google_truth",
"@junit",
],
)

View File

@ -391,6 +391,16 @@ filegroup(
visibility = ["//visibility:public"],
)
# portable_flex_with_custom_ops_tests includes files for testing Flex delegate
# with models containing user's defined ops.
filegroup(
name = "portable_flex_with_custom_ops_tests",
srcs = [
"src/test/java/org/tensorflow/lite/InterpreterFlexWithCustomOpsTest.java",
],
visibility = ["//visibility:public"],
)
# portable_test_utils include utilities for loading files and processing images.
filegroup(
name = "portable_test_utils",

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite;
import static com.google.common.truth.Truth.assertThat;
import java.nio.ByteBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.lite.flex.FlexDelegate;
/**
* Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that
* have user's defined TensorFlow ops.
*/
@RunWith(JUnit4.class)
public final class InterpreterFlexWithCustomOpsTest {
private static final ByteBuffer DOUBLE_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer("tensorflow/lite/testdata/double_flex.bin");
/** Smoke test validating that flex model with a user's defined TF op. */
@Test
public void testFlexModelWithUsersDefinedOp() throws Exception {
try (Interpreter interpreter = new Interpreter(DOUBLE_MODEL_BUFFER)) {
int[] oneD = {1, 2, 3, 4};
int[][] twoD = {oneD};
int[][] parsedOutputs = new int[1][4];
interpreter.run(twoD, parsedOutputs);
int[] outputOneD = parsedOutputs[0];
int[] expected = {2, 4, 6, 8};
assertThat(outputOneD).isEqualTo(expected);
}
}
static {
FlexDelegate.initTensorFlowForTesting();
}
}

View File

@ -1,6 +1,11 @@
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite", "tflite_copts")
load("//tensorflow:tensorflow.bzl", "pybind_extension", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
load(
"//tensorflow:tensorflow.bzl",
"tf_custom_op_library",
"tf_gen_op_wrapper_py",
"tf_opts_nortti_if_android",
)
package(
default_visibility = ["//tensorflow:internal"],
@ -91,10 +96,19 @@ cc_library(
name = "double_op_and_kernels",
testonly = 1,
srcs = ["double_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
copts = tflite_copts() + tf_opts_nortti_if_android(),
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:ios": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
}),
alwayslink = 1,
)

BIN
tensorflow/lite/testdata/double_flex.bin vendored Normal file

Binary file not shown.