TFLite selective registration: Add android target.
This includes build rule for - libtensorflowlite_jni.so - Android target - AAR target PiperOrigin-RevId: 322952184 Change-Id: If0fd97d9f50867dbce45a304758e3b9e9ac3eb0a
This commit is contained in:
		
							parent
							
								
									a8b690362f
								
							
						
					
					
						commit
						b0cec6bd77
					
				@ -7,6 +7,8 @@ load(
 | 
			
		||||
    "tf_cc_shared_object",
 | 
			
		||||
    "tf_cc_test",
 | 
			
		||||
)
 | 
			
		||||
load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
 | 
			
		||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
 | 
			
		||||
 | 
			
		||||
def tflite_copts():
 | 
			
		||||
    """Defines compile time flags."""
 | 
			
		||||
@ -732,7 +734,12 @@ def tflite_experimental_runtime_linkopts(if_eager = [], if_non_eager = [], if_no
 | 
			
		||||
        if_none = [] + if_none,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
def tflite_custom_cc_library(name, models = [], srcs = [], deps = [], visibility = ["//visibility:private"]):
 | 
			
		||||
def tflite_custom_cc_library(
 | 
			
		||||
        name,
 | 
			
		||||
        models = [],
 | 
			
		||||
        srcs = [],
 | 
			
		||||
        deps = [],
 | 
			
		||||
        visibility = ["//visibility:private"]):
 | 
			
		||||
    """Generates a tflite cc library, stripping off unused operators.
 | 
			
		||||
 | 
			
		||||
    This library includes the TfLite runtime as well as all operators needed for the given models.
 | 
			
		||||
@ -781,3 +788,62 @@ def tflite_custom_cc_library(name, models = [], srcs = [], deps = [], visibility
 | 
			
		||||
        ] + real_deps),
 | 
			
		||||
        visibility = visibility,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
def tflite_custom_android_library(
 | 
			
		||||
        name,
 | 
			
		||||
        models = [],
 | 
			
		||||
        srcs = [],
 | 
			
		||||
        deps = [],
 | 
			
		||||
        custom_package = "org.tensorflow.lite",
 | 
			
		||||
        visibility = ["//visibility:private"]):
 | 
			
		||||
    """Generates a tflite Android library, stripping off unused operators.
 | 
			
		||||
 | 
			
		||||
    Note that due to a limitation in the JNI Java wrapper, the compiled TfLite shared binary
 | 
			
		||||
    has to be named as tensorflowlite_jni.so so please make sure that there is no naming conflict.
 | 
			
		||||
    i.e. you can't call this rule multiple times in the same build file.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        name: Name of the target.
 | 
			
		||||
        models: List of models to be supported. This TFLite build will only include
 | 
			
		||||
            operators used in these models. If the list is empty, all builtin
 | 
			
		||||
            operators are included.
 | 
			
		||||
        srcs: List of files implementing custom operators if any.
 | 
			
		||||
        deps: Additional dependencies to build all the custom operators.
 | 
			
		||||
        custom_package: Name of the Java package. It is required by android_library in case
 | 
			
		||||
            the Java source file can't be inferred from the directory where this rule is used.
 | 
			
		||||
        visibility: Visibility setting for the generated target. Default to private.
 | 
			
		||||
    """
 | 
			
		||||
    tflite_custom_cc_library(name = "%s_cc" % name, models = models, srcs = srcs, deps = deps, visibility = visibility)
 | 
			
		||||
 | 
			
		||||
    # JNI wrapper expects a binary file called `libtensorflowlite_jni.so` in java path.
 | 
			
		||||
    tflite_jni_binary(
 | 
			
		||||
        name = "libtensorflowlite_jni.so",
 | 
			
		||||
        linkscript = "//tensorflow/lite/java:tflite_version_script.lds",
 | 
			
		||||
        deps = [
 | 
			
		||||
            ":%s_cc" % name,
 | 
			
		||||
            "//tensorflow/lite/java/src/main/native:native_framework_only",
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    native.cc_library(
 | 
			
		||||
        name = "%s_jni" % name,
 | 
			
		||||
        srcs = ["libtensorflowlite_jni.so"],
 | 
			
		||||
        visibility = visibility,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    android_library(
 | 
			
		||||
        name = name,
 | 
			
		||||
        manifest = "//tensorflow/lite/java:AndroidManifest.xml",
 | 
			
		||||
        deps = [
 | 
			
		||||
            ":%s_jni" % name,
 | 
			
		||||
            "//tensorflow/lite/java:tensorflowlite_java",
 | 
			
		||||
            "@org_checkerframework_qual",
 | 
			
		||||
        ],
 | 
			
		||||
        custom_package = custom_package,
 | 
			
		||||
        visibility = visibility,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    aar_with_jni(
 | 
			
		||||
        name = "%s_aar" % name,
 | 
			
		||||
        android_library = name,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -18,6 +18,7 @@ exports_files([
 | 
			
		||||
    "src/testdata/grace_hopper_224.jpg",
 | 
			
		||||
    "AndroidManifest.xml",
 | 
			
		||||
    "proguard.flags",
 | 
			
		||||
    "tflite_version_script.lds",
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
JAVA_SRCS = glob([
 | 
			
		||||
@ -340,6 +341,33 @@ java_test(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
java_test(
 | 
			
		||||
    name = "InterpreterCustomizedAndroidBuildTest",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "src/test/java/org/tensorflow/lite/InterpreterCustomizedAndroidBuildTest.java",
 | 
			
		||||
        "src/test/java/org/tensorflow/lite/TestUtils.java",
 | 
			
		||||
    ],
 | 
			
		||||
    data = [
 | 
			
		||||
        "//tensorflow/lite:testdata/add.bin",
 | 
			
		||||
        "//tensorflow/lite:testdata/test_model.bin",
 | 
			
		||||
    ],
 | 
			
		||||
    javacopts = JAVACOPTS,
 | 
			
		||||
    # Add customized libtensorflowlite_jni.so to java_path
 | 
			
		||||
    jvm_flags = ["-Djava.library.path=third_party/tensorflow/lite/testing"],
 | 
			
		||||
    tags = [
 | 
			
		||||
        "no_mac",  # TODO(b/122888913): libtensorflowlite_test_jni broke on mac.
 | 
			
		||||
        "v1only",
 | 
			
		||||
    ],
 | 
			
		||||
    test_class = "org.tensorflow.lite.InterpreterCustomizedAndroidBuildTest",
 | 
			
		||||
    visibility = ["//visibility:private"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/lite/testing:customtized_tflite_for_add_ops",
 | 
			
		||||
        "@com_google_truth",
 | 
			
		||||
        "@junit",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# portable_tests includes files for running TFLite interpreter tests.
 | 
			
		||||
filegroup(
 | 
			
		||||
    name = "portable_tests",
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,63 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
package org.tensorflow.lite;
 | 
			
		||||
 | 
			
		||||
import static com.google.common.truth.Truth.assertThat;
 | 
			
		||||
import static org.junit.Assert.fail;
 | 
			
		||||
 | 
			
		||||
import java.nio.ByteBuffer;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.junit.runner.RunWith;
 | 
			
		||||
import org.junit.runners.JUnit4;
 | 
			
		||||
 | 
			
		||||
/** Unit tests for {@link org.tensorflow.lite.Interpreter} with selective registration. */
 | 
			
		||||
@RunWith(JUnit4.class)
 | 
			
		||||
public final class InterpreterCustomizedAndroidBuildTest {
 | 
			
		||||
  // Supported model.
 | 
			
		||||
  private static final String SUPPORTED_MODEL_PATH = "tensorflow/lite/testdata/add.bin";
 | 
			
		||||
  private static final ByteBuffer SUPPORTED_MODEL_BUFFER =
 | 
			
		||||
      TestUtils.getTestFileAsBuffer(SUPPORTED_MODEL_PATH);
 | 
			
		||||
 | 
			
		||||
  // Model with unregistered operator.
 | 
			
		||||
  private static final String UNSUPPORTED_MODEL_PATH =
 | 
			
		||||
      "tensorflow/lite/testdata/test_model.bin";
 | 
			
		||||
  private static final ByteBuffer UNSUPPORTED_MODEL_BUFFER =
 | 
			
		||||
      TestUtils.getTestFileAsBuffer(UNSUPPORTED_MODEL_PATH);
 | 
			
		||||
 | 
			
		||||
  @Test
 | 
			
		||||
  public void testSupportedModel() throws Exception {
 | 
			
		||||
    try (Interpreter interpreter = new Interpreter(SUPPORTED_MODEL_BUFFER)) {
 | 
			
		||||
      assertThat(interpreter).isNotNull();
 | 
			
		||||
      float[] oneD = {1.23f, 6.54f, 7.81f};
 | 
			
		||||
      float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
 | 
			
		||||
      float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
 | 
			
		||||
      float[][][][] fourD = {threeD, threeD};
 | 
			
		||||
      float[][][][] parsedOutputs = new float[2][8][8][3];
 | 
			
		||||
      interpreter.run(fourD, parsedOutputs);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  @Test
 | 
			
		||||
  public void testUnsupportedModel() throws Exception {
 | 
			
		||||
    try (Interpreter interpreter = new Interpreter(UNSUPPORTED_MODEL_BUFFER)) {
 | 
			
		||||
      fail();
 | 
			
		||||
    } catch (IllegalArgumentException e) {
 | 
			
		||||
      assertThat(e)
 | 
			
		||||
          .hasMessageThat()
 | 
			
		||||
          .contains("Cannot create interpreter: Didn't find op for builtin opcode 'CONV_2D'");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -4,6 +4,7 @@ load(
 | 
			
		||||
    "gen_zipped_test_file",
 | 
			
		||||
    "generated_test_models_all",
 | 
			
		||||
    "merged_test_models",
 | 
			
		||||
    "tflite_custom_android_library",
 | 
			
		||||
    "tflite_custom_cc_library",
 | 
			
		||||
)
 | 
			
		||||
load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
 | 
			
		||||
@ -566,6 +567,12 @@ pybind_extension(
 | 
			
		||||
 | 
			
		||||
tflite_portable_test_suite()
 | 
			
		||||
 | 
			
		||||
tflite_custom_android_library(
 | 
			
		||||
    name = "customtized_tflite_for_add_ops",
 | 
			
		||||
    models = ["//tensorflow/lite:testdata/add.bin"],
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
edgetpu_ops = [
 | 
			
		||||
    "add",
 | 
			
		||||
    "avg_pool",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user