TFLite selective registration: Add android target.
This includes build rule for - libtensorflowlite_jni.so - Android target - AAR target PiperOrigin-RevId: 322952184 Change-Id: If0fd97d9f50867dbce45a304758e3b9e9ac3eb0a
This commit is contained in:
parent
a8b690362f
commit
b0cec6bd77
@ -7,6 +7,8 @@ load(
|
||||
"tf_cc_shared_object",
|
||||
"tf_cc_test",
|
||||
)
|
||||
load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
|
||||
def tflite_copts():
|
||||
"""Defines compile time flags."""
|
||||
@ -732,7 +734,12 @@ def tflite_experimental_runtime_linkopts(if_eager = [], if_non_eager = [], if_no
|
||||
if_none = [] + if_none,
|
||||
)
|
||||
|
||||
def tflite_custom_cc_library(name, models = [], srcs = [], deps = [], visibility = ["//visibility:private"]):
|
||||
def tflite_custom_cc_library(
|
||||
name,
|
||||
models = [],
|
||||
srcs = [],
|
||||
deps = [],
|
||||
visibility = ["//visibility:private"]):
|
||||
"""Generates a tflite cc library, stripping off unused operators.
|
||||
|
||||
This library includes the TfLite runtime as well as all operators needed for the given models.
|
||||
@ -781,3 +788,62 @@ def tflite_custom_cc_library(name, models = [], srcs = [], deps = [], visibility
|
||||
] + real_deps),
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
def tflite_custom_android_library(
|
||||
name,
|
||||
models = [],
|
||||
srcs = [],
|
||||
deps = [],
|
||||
custom_package = "org.tensorflow.lite",
|
||||
visibility = ["//visibility:private"]):
|
||||
"""Generates a tflite Android library, stripping off unused operators.
|
||||
|
||||
Note that due to a limitation in the JNI Java wrapper, the compiled TfLite shared binary
|
||||
has to be named as tensorflowlite_jni.so so please make sure that there is no naming conflict.
|
||||
i.e. you can't call this rule multiple times in the same build file.
|
||||
|
||||
Args:
|
||||
name: Name of the target.
|
||||
models: List of models to be supported. This TFLite build will only include
|
||||
operators used in these models. If the list is empty, all builtin
|
||||
operators are included.
|
||||
srcs: List of files implementing custom operators if any.
|
||||
deps: Additional dependencies to build all the custom operators.
|
||||
custom_package: Name of the Java package. It is required by android_library in case
|
||||
the Java source file can't be inferred from the directory where this rule is used.
|
||||
visibility: Visibility setting for the generated target. Default to private.
|
||||
"""
|
||||
tflite_custom_cc_library(name = "%s_cc" % name, models = models, srcs = srcs, deps = deps, visibility = visibility)
|
||||
|
||||
# JNI wrapper expects a binary file called `libtensorflowlite_jni.so` in java path.
|
||||
tflite_jni_binary(
|
||||
name = "libtensorflowlite_jni.so",
|
||||
linkscript = "//tensorflow/lite/java:tflite_version_script.lds",
|
||||
deps = [
|
||||
":%s_cc" % name,
|
||||
"//tensorflow/lite/java/src/main/native:native_framework_only",
|
||||
],
|
||||
)
|
||||
|
||||
native.cc_library(
|
||||
name = "%s_jni" % name,
|
||||
srcs = ["libtensorflowlite_jni.so"],
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = name,
|
||||
manifest = "//tensorflow/lite/java:AndroidManifest.xml",
|
||||
deps = [
|
||||
":%s_jni" % name,
|
||||
"//tensorflow/lite/java:tensorflowlite_java",
|
||||
"@org_checkerframework_qual",
|
||||
],
|
||||
custom_package = custom_package,
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
aar_with_jni(
|
||||
name = "%s_aar" % name,
|
||||
android_library = name,
|
||||
)
|
||||
|
@ -18,6 +18,7 @@ exports_files([
|
||||
"src/testdata/grace_hopper_224.jpg",
|
||||
"AndroidManifest.xml",
|
||||
"proguard.flags",
|
||||
"tflite_version_script.lds",
|
||||
])
|
||||
|
||||
JAVA_SRCS = glob([
|
||||
@ -340,6 +341,33 @@ java_test(
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "InterpreterCustomizedAndroidBuildTest",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/InterpreterCustomizedAndroidBuildTest.java",
|
||||
"src/test/java/org/tensorflow/lite/TestUtils.java",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/add.bin",
|
||||
"//tensorflow/lite:testdata/test_model.bin",
|
||||
],
|
||||
javacopts = JAVACOPTS,
|
||||
# Add customized libtensorflowlite_jni.so to java_path
|
||||
jvm_flags = ["-Djava.library.path=third_party/tensorflow/lite/testing"],
|
||||
tags = [
|
||||
"no_mac", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac.
|
||||
"v1only",
|
||||
],
|
||||
test_class = "org.tensorflow.lite.InterpreterCustomizedAndroidBuildTest",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//tensorflow/lite/testing:customtized_tflite_for_add_ops",
|
||||
"@com_google_truth",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
# portable_tests includes files for running TFLite interpreter tests.
|
||||
filegroup(
|
||||
name = "portable_tests",
|
||||
|
@ -0,0 +1,63 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.lite.Interpreter} with selective registration. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class InterpreterCustomizedAndroidBuildTest {
|
||||
// Supported model.
|
||||
private static final String SUPPORTED_MODEL_PATH = "tensorflow/lite/testdata/add.bin";
|
||||
private static final ByteBuffer SUPPORTED_MODEL_BUFFER =
|
||||
TestUtils.getTestFileAsBuffer(SUPPORTED_MODEL_PATH);
|
||||
|
||||
// Model with unregistered operator.
|
||||
private static final String UNSUPPORTED_MODEL_PATH =
|
||||
"tensorflow/lite/testdata/test_model.bin";
|
||||
private static final ByteBuffer UNSUPPORTED_MODEL_BUFFER =
|
||||
TestUtils.getTestFileAsBuffer(UNSUPPORTED_MODEL_PATH);
|
||||
|
||||
@Test
|
||||
public void testSupportedModel() throws Exception {
|
||||
try (Interpreter interpreter = new Interpreter(SUPPORTED_MODEL_BUFFER)) {
|
||||
assertThat(interpreter).isNotNull();
|
||||
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||
float[][][][] fourD = {threeD, threeD};
|
||||
float[][][][] parsedOutputs = new float[2][8][8][3];
|
||||
interpreter.run(fourD, parsedOutputs);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnsupportedModel() throws Exception {
|
||||
try (Interpreter interpreter = new Interpreter(UNSUPPORTED_MODEL_BUFFER)) {
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
assertThat(e)
|
||||
.hasMessageThat()
|
||||
.contains("Cannot create interpreter: Didn't find op for builtin opcode 'CONV_2D'");
|
||||
}
|
||||
}
|
||||
}
|
@ -4,6 +4,7 @@ load(
|
||||
"gen_zipped_test_file",
|
||||
"generated_test_models_all",
|
||||
"merged_test_models",
|
||||
"tflite_custom_android_library",
|
||||
"tflite_custom_cc_library",
|
||||
)
|
||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||
@ -566,6 +567,12 @@ pybind_extension(
|
||||
|
||||
tflite_portable_test_suite()
|
||||
|
||||
tflite_custom_android_library(
|
||||
name = "customtized_tflite_for_add_ops",
|
||||
models = ["//tensorflow/lite:testdata/add.bin"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
edgetpu_ops = [
|
||||
"add",
|
||||
"avg_pool",
|
||||
|
Loading…
Reference in New Issue
Block a user