TFLite selective registration: Add android target.
This includes build rule for - libtensorflowlite_jni.so - Android target - AAR target PiperOrigin-RevId: 322952184 Change-Id: If0fd97d9f50867dbce45a304758e3b9e9ac3eb0a
This commit is contained in:
parent
a8b690362f
commit
b0cec6bd77
@ -7,6 +7,8 @@ load(
|
|||||||
"tf_cc_shared_object",
|
"tf_cc_shared_object",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
)
|
)
|
||||||
|
load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
|
||||||
|
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||||
|
|
||||||
def tflite_copts():
|
def tflite_copts():
|
||||||
"""Defines compile time flags."""
|
"""Defines compile time flags."""
|
||||||
@ -732,7 +734,12 @@ def tflite_experimental_runtime_linkopts(if_eager = [], if_non_eager = [], if_no
|
|||||||
if_none = [] + if_none,
|
if_none = [] + if_none,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tflite_custom_cc_library(name, models = [], srcs = [], deps = [], visibility = ["//visibility:private"]):
|
def tflite_custom_cc_library(
|
||||||
|
name,
|
||||||
|
models = [],
|
||||||
|
srcs = [],
|
||||||
|
deps = [],
|
||||||
|
visibility = ["//visibility:private"]):
|
||||||
"""Generates a tflite cc library, stripping off unused operators.
|
"""Generates a tflite cc library, stripping off unused operators.
|
||||||
|
|
||||||
This library includes the TfLite runtime as well as all operators needed for the given models.
|
This library includes the TfLite runtime as well as all operators needed for the given models.
|
||||||
@ -781,3 +788,62 @@ def tflite_custom_cc_library(name, models = [], srcs = [], deps = [], visibility
|
|||||||
] + real_deps),
|
] + real_deps),
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def tflite_custom_android_library(
|
||||||
|
name,
|
||||||
|
models = [],
|
||||||
|
srcs = [],
|
||||||
|
deps = [],
|
||||||
|
custom_package = "org.tensorflow.lite",
|
||||||
|
visibility = ["//visibility:private"]):
|
||||||
|
"""Generates a tflite Android library, stripping off unused operators.
|
||||||
|
|
||||||
|
Note that due to a limitation in the JNI Java wrapper, the compiled TfLite shared binary
|
||||||
|
has to be named as tensorflowlite_jni.so so please make sure that there is no naming conflict.
|
||||||
|
i.e. you can't call this rule multiple times in the same build file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the target.
|
||||||
|
models: List of models to be supported. This TFLite build will only include
|
||||||
|
operators used in these models. If the list is empty, all builtin
|
||||||
|
operators are included.
|
||||||
|
srcs: List of files implementing custom operators if any.
|
||||||
|
deps: Additional dependencies to build all the custom operators.
|
||||||
|
custom_package: Name of the Java package. It is required by android_library in case
|
||||||
|
the Java source file can't be inferred from the directory where this rule is used.
|
||||||
|
visibility: Visibility setting for the generated target. Default to private.
|
||||||
|
"""
|
||||||
|
tflite_custom_cc_library(name = "%s_cc" % name, models = models, srcs = srcs, deps = deps, visibility = visibility)
|
||||||
|
|
||||||
|
# JNI wrapper expects a binary file called `libtensorflowlite_jni.so` in java path.
|
||||||
|
tflite_jni_binary(
|
||||||
|
name = "libtensorflowlite_jni.so",
|
||||||
|
linkscript = "//tensorflow/lite/java:tflite_version_script.lds",
|
||||||
|
deps = [
|
||||||
|
":%s_cc" % name,
|
||||||
|
"//tensorflow/lite/java/src/main/native:native_framework_only",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
native.cc_library(
|
||||||
|
name = "%s_jni" % name,
|
||||||
|
srcs = ["libtensorflowlite_jni.so"],
|
||||||
|
visibility = visibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = name,
|
||||||
|
manifest = "//tensorflow/lite/java:AndroidManifest.xml",
|
||||||
|
deps = [
|
||||||
|
":%s_jni" % name,
|
||||||
|
"//tensorflow/lite/java:tensorflowlite_java",
|
||||||
|
"@org_checkerframework_qual",
|
||||||
|
],
|
||||||
|
custom_package = custom_package,
|
||||||
|
visibility = visibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
aar_with_jni(
|
||||||
|
name = "%s_aar" % name,
|
||||||
|
android_library = name,
|
||||||
|
)
|
||||||
|
@ -18,6 +18,7 @@ exports_files([
|
|||||||
"src/testdata/grace_hopper_224.jpg",
|
"src/testdata/grace_hopper_224.jpg",
|
||||||
"AndroidManifest.xml",
|
"AndroidManifest.xml",
|
||||||
"proguard.flags",
|
"proguard.flags",
|
||||||
|
"tflite_version_script.lds",
|
||||||
])
|
])
|
||||||
|
|
||||||
JAVA_SRCS = glob([
|
JAVA_SRCS = glob([
|
||||||
@ -340,6 +341,33 @@ java_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
java_test(
|
||||||
|
name = "InterpreterCustomizedAndroidBuildTest",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"src/test/java/org/tensorflow/lite/InterpreterCustomizedAndroidBuildTest.java",
|
||||||
|
"src/test/java/org/tensorflow/lite/TestUtils.java",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
"//tensorflow/lite:testdata/add.bin",
|
||||||
|
"//tensorflow/lite:testdata/test_model.bin",
|
||||||
|
],
|
||||||
|
javacopts = JAVACOPTS,
|
||||||
|
# Add customized libtensorflowlite_jni.so to java_path
|
||||||
|
jvm_flags = ["-Djava.library.path=third_party/tensorflow/lite/testing"],
|
||||||
|
tags = [
|
||||||
|
"no_mac", # TODO(b/122888913): libtensorflowlite_test_jni broke on mac.
|
||||||
|
"v1only",
|
||||||
|
],
|
||||||
|
test_class = "org.tensorflow.lite.InterpreterCustomizedAndroidBuildTest",
|
||||||
|
visibility = ["//visibility:private"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/testing:customtized_tflite_for_add_ops",
|
||||||
|
"@com_google_truth",
|
||||||
|
"@junit",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# portable_tests includes files for running TFLite interpreter tests.
|
# portable_tests includes files for running TFLite interpreter tests.
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "portable_tests",
|
name = "portable_tests",
|
||||||
|
@ -0,0 +1,63 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.lite;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.JUnit4;
|
||||||
|
|
||||||
|
/** Unit tests for {@link org.tensorflow.lite.Interpreter} with selective registration. */
|
||||||
|
@RunWith(JUnit4.class)
|
||||||
|
public final class InterpreterCustomizedAndroidBuildTest {
|
||||||
|
// Supported model.
|
||||||
|
private static final String SUPPORTED_MODEL_PATH = "tensorflow/lite/testdata/add.bin";
|
||||||
|
private static final ByteBuffer SUPPORTED_MODEL_BUFFER =
|
||||||
|
TestUtils.getTestFileAsBuffer(SUPPORTED_MODEL_PATH);
|
||||||
|
|
||||||
|
// Model with unregistered operator.
|
||||||
|
private static final String UNSUPPORTED_MODEL_PATH =
|
||||||
|
"tensorflow/lite/testdata/test_model.bin";
|
||||||
|
private static final ByteBuffer UNSUPPORTED_MODEL_BUFFER =
|
||||||
|
TestUtils.getTestFileAsBuffer(UNSUPPORTED_MODEL_PATH);
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSupportedModel() throws Exception {
|
||||||
|
try (Interpreter interpreter = new Interpreter(SUPPORTED_MODEL_BUFFER)) {
|
||||||
|
assertThat(interpreter).isNotNull();
|
||||||
|
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||||
|
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||||
|
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||||
|
float[][][][] fourD = {threeD, threeD};
|
||||||
|
float[][][][] parsedOutputs = new float[2][8][8][3];
|
||||||
|
interpreter.run(fourD, parsedOutputs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUnsupportedModel() throws Exception {
|
||||||
|
try (Interpreter interpreter = new Interpreter(UNSUPPORTED_MODEL_BUFFER)) {
|
||||||
|
fail();
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
assertThat(e)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("Cannot create interpreter: Didn't find op for builtin opcode 'CONV_2D'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,7 @@ load(
|
|||||||
"gen_zipped_test_file",
|
"gen_zipped_test_file",
|
||||||
"generated_test_models_all",
|
"generated_test_models_all",
|
||||||
"merged_test_models",
|
"merged_test_models",
|
||||||
|
"tflite_custom_android_library",
|
||||||
"tflite_custom_cc_library",
|
"tflite_custom_cc_library",
|
||||||
)
|
)
|
||||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
|
||||||
@ -566,6 +567,12 @@ pybind_extension(
|
|||||||
|
|
||||||
tflite_portable_test_suite()
|
tflite_portable_test_suite()
|
||||||
|
|
||||||
|
tflite_custom_android_library(
|
||||||
|
name = "customtized_tflite_for_add_ops",
|
||||||
|
models = ["//tensorflow/lite:testdata/add.bin"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
edgetpu_ops = [
|
edgetpu_ops = [
|
||||||
"add",
|
"add",
|
||||||
"avg_pool",
|
"avg_pool",
|
||||||
|
Loading…
Reference in New Issue
Block a user