TFLite selective registration: Add android target.

This includes build rule for
- libtensorflowlite_jni.so
- Android target
- AAR target

PiperOrigin-RevId: 322952184
Change-Id: If0fd97d9f50867dbce45a304758e3b9e9ac3eb0a
This commit is contained in:
Tiezhen WANG 2020-07-24 00:59:02 -07:00 committed by TensorFlower Gardener
parent a8b690362f
commit b0cec6bd77
4 changed files with 165 additions and 1 deletions

View File

@ -7,6 +7,8 @@ load(
"tf_cc_shared_object",
"tf_cc_test",
)
load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
def tflite_copts():
"""Defines compile time flags."""
@ -732,7 +734,12 @@ def tflite_experimental_runtime_linkopts(if_eager = [], if_non_eager = [], if_no
if_none = [] + if_none,
)
def tflite_custom_cc_library(name, models = [], srcs = [], deps = [], visibility = ["//visibility:private"]):
def tflite_custom_cc_library(
name,
models = [],
srcs = [],
deps = [],
visibility = ["//visibility:private"]):
"""Generates a tflite cc library, stripping off unused operators.
This library includes the TfLite runtime as well as all operators needed for the given models.
@ -781,3 +788,62 @@ def tflite_custom_cc_library(name, models = [], srcs = [], deps = [], visibility
] + real_deps),
visibility = visibility,
)
def tflite_custom_android_library(
name,
models = [],
srcs = [],
deps = [],
custom_package = "org.tensorflow.lite",
visibility = ["//visibility:private"]):
"""Generates a tflite Android library, stripping off unused operators.
Note that due to a limitation in the JNI Java wrapper, the compiled TfLite shared binary
has to be named as tensorflowlite_jni.so so please make sure that there is no naming conflict.
i.e. you can't call this rule multiple times in the same build file.
Args:
name: Name of the target.
models: List of models to be supported. This TFLite build will only include
operators used in these models. If the list is empty, all builtin
operators are included.
srcs: List of files implementing custom operators if any.
deps: Additional dependencies to build all the custom operators.
custom_package: Name of the Java package. It is required by android_library in case
the Java source file can't be inferred from the directory where this rule is used.
visibility: Visibility setting for the generated target. Default to private.
"""
tflite_custom_cc_library(name = "%s_cc" % name, models = models, srcs = srcs, deps = deps, visibility = visibility)
# JNI wrapper expects a binary file called `libtensorflowlite_jni.so` in java path.
tflite_jni_binary(
name = "libtensorflowlite_jni.so",
linkscript = "//tensorflow/lite/java:tflite_version_script.lds",
deps = [
":%s_cc" % name,
"//tensorflow/lite/java/src/main/native:native_framework_only",
],
)
native.cc_library(
name = "%s_jni" % name,
srcs = ["libtensorflowlite_jni.so"],
visibility = visibility,
)
android_library(
name = name,
manifest = "//tensorflow/lite/java:AndroidManifest.xml",
deps = [
":%s_jni" % name,
"//tensorflow/lite/java:tensorflowlite_java",
"@org_checkerframework_qual",
],
custom_package = custom_package,
visibility = visibility,
)
aar_with_jni(
name = "%s_aar" % name,
android_library = name,
)

View File

@ -18,6 +18,7 @@ exports_files([
"src/testdata/grace_hopper_224.jpg",
"AndroidManifest.xml",
"proguard.flags",
"tflite_version_script.lds",
])
JAVA_SRCS = glob([
@ -340,6 +341,33 @@ java_test(
],
)
java_test(
name = "InterpreterCustomizedAndroidBuildTest",
size = "small",
srcs = [
"src/test/java/org/tensorflow/lite/InterpreterCustomizedAndroidBuildTest.java",
"src/test/java/org/tensorflow/lite/TestUtils.java",
],
data = [
"//tensorflow/lite:testdata/add.bin",
"//tensorflow/lite:testdata/test_model.bin",
],
javacopts = JAVACOPTS,
# Add customized libtensorflowlite_jni.so to java_path
jvm_flags = ["-Djava.library.path=third_party/tensorflow/lite/testing"],
tags = [
"no_mac", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac.
"v1only",
],
test_class = "org.tensorflow.lite.InterpreterCustomizedAndroidBuildTest",
visibility = ["//visibility:private"],
deps = [
"//tensorflow/lite/testing:customtized_tflite_for_add_ops",
"@com_google_truth",
"@junit",
],
)
# portable_tests includes files for running TFLite interpreter tests.
filegroup(
name = "portable_tests",

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import java.nio.ByteBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link org.tensorflow.lite.Interpreter} with selective registration. */
@RunWith(JUnit4.class)
public final class InterpreterCustomizedAndroidBuildTest {
// Supported model.
private static final String SUPPORTED_MODEL_PATH = "tensorflow/lite/testdata/add.bin";
private static final ByteBuffer SUPPORTED_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(SUPPORTED_MODEL_PATH);
// Model with unregistered operator.
private static final String UNSUPPORTED_MODEL_PATH =
"tensorflow/lite/testdata/test_model.bin";
private static final ByteBuffer UNSUPPORTED_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(UNSUPPORTED_MODEL_PATH);
@Test
public void testSupportedModel() throws Exception {
try (Interpreter interpreter = new Interpreter(SUPPORTED_MODEL_BUFFER)) {
assertThat(interpreter).isNotNull();
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
interpreter.run(fourD, parsedOutputs);
}
}
@Test
public void testUnsupportedModel() throws Exception {
try (Interpreter interpreter = new Interpreter(UNSUPPORTED_MODEL_BUFFER)) {
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains("Cannot create interpreter: Didn't find op for builtin opcode 'CONV_2D'");
}
}
}

View File

@ -4,6 +4,7 @@ load(
"gen_zipped_test_file",
"generated_test_models_all",
"merged_test_models",
"tflite_custom_android_library",
"tflite_custom_cc_library",
)
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
@ -566,6 +567,12 @@ pybind_extension(
tflite_portable_test_suite()
tflite_custom_android_library(
name = "customtized_tflite_for_add_ops",
models = ["//tensorflow/lite:testdata/add.bin"],
visibility = ["//visibility:public"],
)
edgetpu_ops = [
"add",
"avg_pool",