Add tests for user's defined op in Flex delegate
PiperOrigin-RevId: 344928044 Change-Id: I80f50c1db4b01403670d7ac75c7cc282da74c4d9
This commit is contained in:
parent
3666f6bc53
commit
b9344903bd
@ -25,12 +25,14 @@ load("//tensorflow/lite:special_rules.bzl", "flex_portable_tensorflow_deps")
|
||||
def generate_flex_kernel_header(
|
||||
name,
|
||||
models,
|
||||
testonly = 0,
|
||||
additional_deps = []):
|
||||
"""A rule to generate a header file listing only used operators.
|
||||
|
||||
Args:
|
||||
name: Name of the generated library.
|
||||
models: TFLite models to interpret.
|
||||
testonly: Should be marked as true if additional_deps is testonly.
|
||||
additional_deps: Dependencies for additional TF ops.
|
||||
|
||||
Returns:
|
||||
@ -55,6 +57,7 @@ def generate_flex_kernel_header(
|
||||
deps = [
|
||||
clean_dep("//tensorflow/lite/tools:list_flex_ops_main_lib"),
|
||||
] + additional_deps,
|
||||
testonly = testonly,
|
||||
)
|
||||
list_ops_tool = ":%s_list_flex_ops_main" % name
|
||||
native.genrule(
|
||||
@ -65,6 +68,7 @@ def generate_flex_kernel_header(
|
||||
message = "Listing flex ops from %s..." % ",".join(models),
|
||||
cmd = ("$(location " + list_ops_tool + ")" +
|
||||
model_file_args + " > \"$@\""),
|
||||
testonly = testonly,
|
||||
)
|
||||
|
||||
# Generate the kernel registration header file from list of flex ops.
|
||||
@ -86,6 +90,7 @@ def tflite_flex_cc_library(
|
||||
name,
|
||||
models = [],
|
||||
additional_deps = [],
|
||||
testonly = 0,
|
||||
visibility = ["//visibility:public"]):
|
||||
"""A rule to generate a flex delegate with only ops to run listed models.
|
||||
|
||||
@ -95,6 +100,7 @@ def tflite_flex_cc_library(
|
||||
to support these models. If empty, the library will include all Tensorflow
|
||||
ops and kernels.
|
||||
additional_deps: Dependencies for additional TF ops.
|
||||
testonly: Mark this library as testonly if true.
|
||||
visibility: visibility of the generated rules.
|
||||
"""
|
||||
portable_tensorflow_lib = clean_dep("//tensorflow/core:portable_tensorflow_lib")
|
||||
@ -103,6 +109,7 @@ def tflite_flex_cc_library(
|
||||
name = "%s_tf_op_headers" % name,
|
||||
models = models,
|
||||
additional_deps = additional_deps,
|
||||
testonly = testonly,
|
||||
)
|
||||
|
||||
# Define a custom tensorflow_lib with selective registration.
|
||||
@ -138,6 +145,7 @@ def tflite_flex_cc_library(
|
||||
clean_dep("//tensorflow/lite/delegates/flex:portable_images_lib"),
|
||||
],
|
||||
alwayslink = 1,
|
||||
testonly = testonly,
|
||||
)
|
||||
portable_tensorflow_lib = ":%s_tensorflow_lib" % name
|
||||
|
||||
@ -164,6 +172,7 @@ def tflite_flex_cc_library(
|
||||
clean_dep("//tensorflow/lite/c:common"),
|
||||
],
|
||||
}) + additional_deps,
|
||||
testonly = testonly,
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
@ -171,6 +180,7 @@ def tflite_flex_jni_library(
|
||||
name,
|
||||
models = [],
|
||||
additional_deps = [],
|
||||
testonly = 0,
|
||||
visibility = ["//visibility:private"]):
|
||||
"""A rule to generate a jni library listing only used operators.
|
||||
|
||||
@ -183,6 +193,7 @@ def tflite_flex_jni_library(
|
||||
to support these models. If empty, the library will include all Tensorflow
|
||||
ops and kernels.
|
||||
additional_deps: Dependencies for additional TF ops.
|
||||
testonly: Mark this library as testonly if true.
|
||||
visibility: visibility of the generated rules.
|
||||
"""
|
||||
|
||||
@ -192,6 +203,7 @@ def tflite_flex_jni_library(
|
||||
name = "%s_flex_delegate" % name,
|
||||
models = models,
|
||||
additional_deps = additional_deps,
|
||||
testonly = testonly,
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
@ -204,6 +216,7 @@ def tflite_flex_jni_library(
|
||||
clean_dep("//tensorflow/lite/delegates/flex/java/src/main/native:flex_delegate_jni.cc"),
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
testonly = testonly,
|
||||
visibility = visibility,
|
||||
deps = [
|
||||
":%s_flex_delegate" % name,
|
||||
@ -224,6 +237,7 @@ def tflite_flex_jni_library(
|
||||
tflite_jni_binary(
|
||||
name = "libtensorflowlite_flex_jni.so",
|
||||
linkopts = tflite_jni_linkopts(),
|
||||
testonly = testonly,
|
||||
deps = [
|
||||
":%s_flex_native" % name,
|
||||
],
|
||||
@ -234,6 +248,7 @@ def tflite_flex_android_library(
|
||||
models = [],
|
||||
additional_deps = [],
|
||||
custom_package = "org.tensorflow.lite.flex",
|
||||
testonly = 0,
|
||||
visibility = ["//visibility:private"]):
|
||||
"""A rule to generate an android library based on the selective-built jni library.
|
||||
|
||||
@ -244,18 +259,21 @@ def tflite_flex_android_library(
|
||||
Tensorflow ops and kernels.
|
||||
additional_deps: Dependencies for additional TF ops.
|
||||
custom_package: Java package for which java sources will be generated.
|
||||
testonly: Mark this library as testonly if true.
|
||||
visibility: visibility of the generated rules.
|
||||
"""
|
||||
tflite_flex_jni_library(
|
||||
name = name,
|
||||
models = models,
|
||||
additional_deps = additional_deps,
|
||||
testonly = testonly,
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
native.cc_library(
|
||||
name = "%s_native" % name,
|
||||
srcs = ["libtensorflowlite_flex_jni.so"],
|
||||
testonly = testonly,
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
@ -265,6 +283,7 @@ def tflite_flex_android_library(
|
||||
manifest = clean_dep("//tensorflow/lite/java:AndroidManifest.xml"),
|
||||
proguard_specs = [clean_dep("//tensorflow/lite/java:proguard.flags")],
|
||||
custom_package = custom_package,
|
||||
testonly = testonly,
|
||||
deps = [
|
||||
":%s_native" % name,
|
||||
clean_dep("//tensorflow/lite/java:tensorflowlite_java"),
|
||||
|
@ -13,8 +13,11 @@ package(
|
||||
|
||||
tflite_flex_jni_library(
|
||||
name = "test",
|
||||
testonly = 1,
|
||||
additional_deps = ["//tensorflow/lite/python/testdata:double_op_and_kernels"],
|
||||
models = [
|
||||
"//tensorflow/lite:testdata/multi_add_flex.bin",
|
||||
"//tensorflow/lite:testdata/double_flex.bin",
|
||||
],
|
||||
)
|
||||
|
||||
@ -59,3 +62,32 @@ java_test(
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "SelectiveBuiltInterpreterFlexWithCustomOpsTest",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"//tensorflow/lite/java:portable_flex_with_custom_ops_tests",
|
||||
"//tensorflow/lite/java:portable_test_utils",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/double_flex.bin",
|
||||
],
|
||||
javacopts = JAVACOPTS,
|
||||
tags = [
|
||||
"no_cuda_on_cpu_tap", # CUDA + flex is not officially supported.
|
||||
"no_gpu", # GPU + flex is not officially supported.
|
||||
"no_oss", # Currently requires --config=monolithic, b/118895218.
|
||||
# TODO(b/121204962): Re-enable test after fixing memory leaks.
|
||||
"noasan",
|
||||
"notsan", # TODO(b/158651814) Re-enable after fixing racing condition.
|
||||
],
|
||||
test_class = "org.tensorflow.lite.InterpreterFlexWithCustomOpsTest",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":test_tensorflowlitelib_flex",
|
||||
"//tensorflow/lite/java:tensorflowlitelib",
|
||||
"@com_google_truth",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
@ -391,6 +391,16 @@ filegroup(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# portable_flex_with_custom_ops_tests includes files for testing Flex delegate
|
||||
# with models containing user's defined ops.
|
||||
filegroup(
|
||||
name = "portable_flex_with_custom_ops_tests",
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/InterpreterFlexWithCustomOpsTest.java",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# portable_test_utils include utilities for loading files and processing images.
|
||||
filegroup(
|
||||
name = "portable_test_utils",
|
||||
|
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.tensorflow.lite.flex.FlexDelegate;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that
|
||||
* have user's defined TensorFlow ops.
|
||||
*/
|
||||
@RunWith(JUnit4.class)
|
||||
public final class InterpreterFlexWithCustomOpsTest {
|
||||
|
||||
private static final ByteBuffer DOUBLE_MODEL_BUFFER =
|
||||
TestUtils.getTestFileAsBuffer("tensorflow/lite/testdata/double_flex.bin");
|
||||
|
||||
/** Smoke test validating that flex model with a user's defined TF op. */
|
||||
@Test
|
||||
public void testFlexModelWithUsersDefinedOp() throws Exception {
|
||||
try (Interpreter interpreter = new Interpreter(DOUBLE_MODEL_BUFFER)) {
|
||||
int[] oneD = {1, 2, 3, 4};
|
||||
int[][] twoD = {oneD};
|
||||
int[][] parsedOutputs = new int[1][4];
|
||||
interpreter.run(twoD, parsedOutputs);
|
||||
int[] outputOneD = parsedOutputs[0];
|
||||
int[] expected = {2, 4, 6, 8};
|
||||
assertThat(outputOneD).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
|
||||
static {
|
||||
FlexDelegate.initTensorFlowForTesting();
|
||||
}
|
||||
}
|
26
tensorflow/lite/python/testdata/BUILD
vendored
26
tensorflow/lite/python/testdata/BUILD
vendored
@ -1,6 +1,11 @@
|
||||
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite")
|
||||
load("//tensorflow/lite:build_def.bzl", "tf_to_tflite", "tflite_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension", "tf_custom_op_py_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_opts_nortti_if_android",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
@ -91,10 +96,19 @@ cc_library(
|
||||
name = "double_op_and_kernels",
|
||||
testonly = 1,
|
||||
srcs = ["double_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
copts = tflite_copts() + tf_opts_nortti_if_android(),
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//tensorflow:ios": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
|
BIN
tensorflow/lite/testdata/double_flex.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/double_flex.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user