Add rule to generate selective-built android library for flex delegate

The rule can be used as follow:
tflite_flex_android_library(
    name = "tensorflowlite_flex",
    models = [model1, model2],
)
The size of tensorflow-lite-select-tf-ops.aar built for android_arm:
Full: 9.6MB
Only contain kernel for Add op: 962KB

The default "custom_" prefix for intermediate name is removed, relying on users to add it if the wish.

PiperOrigin-RevId: 317041928
Change-Id: Ia03b15045d9719256a893af886cec525fdd96952
This commit is contained in:
Thai Nguyen 2020-06-17 23:19:58 -07:00 committed by TensorFlower Gardener
parent 85bf166899
commit ecc9976c06
3 changed files with 84 additions and 51 deletions

View File

@ -247,7 +247,6 @@ tflite_flex_jni_library(
models = [ models = [
"//tensorflow/lite:testdata/multi_add_flex.bin", "//tensorflow/lite:testdata/multi_add_flex.bin",
], ],
visibility = ["//tensorflow/lite/android:__subpackages__"],
) )
java_library( java_library(

View File

@ -1,4 +1,4 @@
"""Generate custom library flex delegate.""" """Generate custom flex delegate library."""
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
@ -17,6 +17,7 @@ load(
"tflite_jni_binary", "tflite_jni_binary",
"tflite_jni_linkopts", "tflite_jni_linkopts",
) )
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
def generate_flex_kernel_header( def generate_flex_kernel_header(
name, name,
@ -44,7 +45,7 @@ def generate_flex_kernel_header(
list_ops_output = include_path + "/list_flex_ops" list_ops_output = include_path + "/list_flex_ops"
list_ops_tool = "//tensorflow/lite/tools:list_flex_ops_main" list_ops_tool = "//tensorflow/lite/tools:list_flex_ops_main"
native.genrule( native.genrule(
name = "%s_custom_list_flex_ops" % name, name = "%s_list_flex_ops" % name,
srcs = models, srcs = models,
outs = [list_ops_output], outs = [list_ops_output],
tools = [list_ops_tool], tools = [list_ops_tool],
@ -56,7 +57,7 @@ def generate_flex_kernel_header(
# Generate the kernel registration header file from list of flex ops. # Generate the kernel registration header file from list of flex ops.
tool = "//tensorflow/python/tools:print_selective_registration_header" tool = "//tensorflow/python/tools:print_selective_registration_header"
native.genrule( native.genrule(
name = "%s_custom_kernel_registration" % name, name = "%s_kernel_registration" % name,
srcs = [list_ops_output], srcs = [list_ops_output],
outs = [header], outs = [header],
tools = [tool], tools = [tool],
@ -72,10 +73,10 @@ def tflite_flex_cc_library(
name, name,
portable_tensorflow_lib = "//tensorflow/core:portable_tensorflow_lib", portable_tensorflow_lib = "//tensorflow/core:portable_tensorflow_lib",
visibility = ["//visibility:public"]): visibility = ["//visibility:public"]):
"""A rule to generate a flex delegate with custom android and ios tensorflow libs. """A rule to generate a flex delegate with custom portable tensorflow lib.
These libs should be a custom version of android_tensorflow_lib and ios_tensorflow_lib This lib should be a custom version of portable_tensorflow_lib and contains ops
and contain ops registrations and kernels. If not defined, the default libs will be used. registrations and kernels. If not defined, the default libs will be used.
Args: Args:
name: Name of the generated rule. name: Name of the generated rule.
@ -110,7 +111,7 @@ def tflite_flex_cc_library(
def tflite_flex_jni_library( def tflite_flex_jni_library(
name, name,
models, models = [],
visibility = ["//visibility:private"]): visibility = ["//visibility:private"]):
"""A rule to generate a jni library listing only used operators. """A rule to generate a jni library listing only used operators.
@ -118,24 +119,23 @@ def tflite_flex_jni_library(
Java wrapper, so please make sure there is no naming conflicts. Java wrapper, so please make sure there is no naming conflicts.
Args: Args:
name: Name of the generated library. name: Prefix of the generated libraries.
models: TFLite models to interpret. models: TFLite models to interpret. The library will only include ops and kernels
to support these models. If empty, the library will include all Tensorflow
ops and kernels.
visibility: visibility of the generated rules. visibility: visibility of the generated rules.
Returns:
Generate a jni library support flex ops.
""" """
portable_tensorflow_lib = "//tensorflow/core:portable_tensorflow_lib" portable_tensorflow_lib = "//tensorflow/core:portable_tensorflow_lib"
if models: if models:
CUSTOM_KERNEL_HEADER = generate_flex_kernel_header( CUSTOM_KERNEL_HEADER = generate_flex_kernel_header(
name = "%s_custom_tf_op_headers" % name, name = "%s_tf_op_headers" % name,
models = models, models = models,
) )
# Define a custom_tensorflow_lib with selective registration. # Define a custom tensorflow_lib with selective registration.
# The library will only contain ops exist in provided models. # The library will only contain ops exist in provided models.
native.cc_library( native.cc_library(
name = "%s_custom_tensorflow_lib" % name, name = "%s_tensorflow_lib" % name,
srcs = if_mobile([ srcs = if_mobile([
"//tensorflow/core:portable_op_registrations_and_gradients", "//tensorflow/core:portable_op_registrations_and_gradients",
"//tensorflow/core/kernels:android_all_ops", "//tensorflow/core/kernels:android_all_ops",
@ -168,12 +168,12 @@ def tflite_flex_jni_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
portable_tensorflow_lib = ":%s_custom_tensorflow_lib" % name portable_tensorflow_lib = ":%s_tensorflow_lib" % name
# Define a custom_init_tensorflow that depends on the custom_tensorflow_lib. # Define a custom init_tensorflow that depends on the above tensorflow_lib.
# This will avoid the symbols re-definition errors. # This will avoid the symbols re-definition errors.
native.cc_library( native.cc_library(
name = "%s_custom_init_tensorflow" % name, name = "%s_init_tensorflow" % name,
srcs = [ srcs = [
"//tensorflow/lite/testing:init_tensorflow.cc", "//tensorflow/lite/testing:init_tensorflow.cc",
], ],
@ -194,37 +194,78 @@ def tflite_flex_jni_library(
}), }),
) )
# Define a custom_flex_delegate that depends on custom_tensorflow_lib. # Define a custom flex_delegate that depends on above tensorflow_lib.
# This will reduce the binary size comparing to the original flex delegate. # This will reduce the binary size comparing to the original flex delegate.
tflite_flex_cc_library( tflite_flex_cc_library(
name = "%s_custom_flex_delegate" % name, name = "%s_flex_delegate" % name,
portable_tensorflow_lib = portable_tensorflow_lib, portable_tensorflow_lib = portable_tensorflow_lib,
visibility = visibility, visibility = visibility,
) )
# Define a custom_flex_native that depends on custom_flex_delegate and custom_init_tensorflow. # Define a custom flex_native that depends on above flex_delegate and init_tensorflow.
native.cc_library( native.cc_library(
name = "%s_custom_flex_native" % name, name = "%s_flex_native" % name,
srcs = [ srcs = [
"//tensorflow/lite/delegates/flex/java/src/main/native:flex_delegate_jni.cc", "//tensorflow/lite/delegates/flex/java/src/main/native:flex_delegate_jni.cc",
], ],
copts = tflite_copts(), copts = tflite_copts(),
visibility = visibility, visibility = visibility,
deps = [ deps = [
":%s_custom_flex_delegate" % name, ":%s_flex_delegate" % name,
"%s_custom_init_tensorflow" % name, ":%s_init_tensorflow" % name,
"//tensorflow/lite/java/jni", "//tensorflow/lite/java/jni",
"//tensorflow/lite/delegates/utils:simple_delegate", "//tensorflow/lite/delegates/utils:simple_delegate",
], ],
alwayslink = 1, alwayslink = 1,
) )
# Build the jni binary based on the custom_flex_native. # Build the jni binary based on the above flex_native.
# The library name is fixed as libtensorflowlite_flex_jni.so in FlexDelegate.java. # The library name is fixed as libtensorflowlite_flex_jni.so in FlexDelegate.java.
tflite_jni_binary( tflite_jni_binary(
name = "libtensorflowlite_flex_jni.so", name = "libtensorflowlite_flex_jni.so",
linkopts = tflite_jni_linkopts(), linkopts = tflite_jni_linkopts(),
deps = [ deps = [
":%s_custom_flex_native" % name, ":%s_flex_native" % name,
], ],
) )
def tflite_flex_android_library(
name,
models = [],
custom_package = "org.tensorflow.lite.flex",
visibility = ["//visibility:private"]):
"""A rule to generate an android library based on the selective-built jni library.
Args:
name: name of android library.
models: TFLite models used for selective build. The library will only include ops
and kernels to support these models. If empty, the library will include all
Tensorflow ops and kernels.
custom_package: Java package for which java sources will be generated.
visibility: visibility of the generated rules.
"""
tflite_flex_jni_library(
name = name,
models = models,
visibility = visibility,
)
native.cc_library(
name = "%s_native" % name,
srcs = ["libtensorflowlite_flex_jni.so"],
visibility = visibility,
)
android_library(
name = name,
srcs = ["//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate"],
manifest = "//tensorflow/lite/java:AndroidManifest.xml",
proguard_specs = ["//tensorflow/lite/java:proguard.flags"],
custom_package = custom_package,
deps = [
":%s_native" % name,
"//tensorflow/lite/java:tensorflowlite_java",
"@org_checkerframework_qual",
],
visibility = visibility,
)

View File

@ -5,6 +5,7 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_library")
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
load("//tensorflow/lite:build_def.bzl", "tflite_jni_binary") load("//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_android_library")
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
@ -15,6 +16,8 @@ exports_files([
"src/testdata/add.bin", "src/testdata/add.bin",
"src/testdata/add_unknown_dimensions.bin", "src/testdata/add_unknown_dimensions.bin",
"src/testdata/grace_hopper_224.jpg", "src/testdata/grace_hopper_224.jpg",
"AndroidManifest.xml",
"proguard.flags",
]) ])
JAVA_SRCS = glob([ JAVA_SRCS = glob([
@ -70,16 +73,20 @@ android_library(
# EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite. # EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite.
# Note that this library contains *only* the Flex delegate and its Java wrapper for using # Note that this library contains *only* the Flex delegate and its Java wrapper for using
# select TF ops; clients must also include the core `tensorflowlite` runtime. # select TF ops; clients must also include the core `tensorflowlite` runtime.
android_library( #
# The library is generated by tflite_flex_android_library rule. This rule can also be used
# to generate trimmed library that only contain kernels for flex ops used in
# a set of models by listing them in the models parameter. Ex:
# tflite_flex_android_library(
# name = "tensorflowlite_flex",
# models = [model1, model2],
# )
#
# The tflite_flex_android_library rule also generate the libtensorflowlite_flex_jni.so as
# an intermidiate target.
tflite_flex_android_library(
name = "tensorflowlite_flex", name = "tensorflowlite_flex",
srcs = ["//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate"], visibility = ["//visibility:public"],
manifest = "AndroidManifest.xml",
proguard_specs = ["proguard.flags"],
deps = [
":tensorflowlite_java",
":tensorflowlite_native_flex",
"@org_checkerframework_qual",
],
) )
# EXPERIMENTAL: Android target target for GPU acceleration. Note that this # EXPERIMENTAL: Android target target for GPU acceleration. Note that this
@ -131,7 +138,7 @@ java_library(
srcs = ["//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate"], srcs = ["//tensorflow/lite/delegates/flex/java/src/main/java/org/tensorflow/lite/flex:flex_delegate"],
javacopts = JAVACOPTS, javacopts = JAVACOPTS,
deps = [ deps = [
":libtensorflowlite_flex_jni.so", ":libtensorflowlite_flex_jni.so", # Generated by tflite_flex_android_library rule.
":tensorflowlitelib", ":tensorflowlitelib",
"@org_checkerframework_qual", "@org_checkerframework_qual",
], ],
@ -387,12 +394,6 @@ cc_library(
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
) )
cc_library(
name = "tensorflowlite_native_flex",
srcs = ["libtensorflowlite_flex_jni.so"],
visibility = ["//visibility:private"],
)
cc_library( cc_library(
name = "tensorflowlite_native_gpu", name = "tensorflowlite_native_gpu",
srcs = ["libtensorflowlite_gpu_jni.so"], srcs = ["libtensorflowlite_gpu_jni.so"],
@ -413,14 +414,6 @@ tflite_jni_binary(
], ],
) )
# EXPERIMENTAL: Native target that supports TensorFlow op execution with TFLite.
tflite_jni_binary(
name = "libtensorflowlite_flex_jni.so",
deps = [
"//tensorflow/lite/delegates/flex/java/src/main/native",
],
)
# EXPERIMENTAL: Native target that supports GPU acceleration. # EXPERIMENTAL: Native target that supports GPU acceleration.
tflite_jni_binary( tflite_jni_binary(
name = "libtensorflowlite_gpu_jni.so", name = "libtensorflowlite_gpu_jni.so",