Add script to generate trimmed tflite aar files

The script can be used as follows:
bash tensorflow/lite/java/build_customized_aar_for_models.sh \
  --customize_for_models=model1,model2 \
  --target_archs=x86,x86_64,arm64-v8a,armeabi-v7a

it will generate the tensorflow-lite.aar and tensorflow-lite-select-tf-ops.aar if needed.

PiperOrigin-RevId: 325980230
Change-Id: I9c5216108ae87ac22a69cbb88b37bb473bd4d37f
This commit is contained in:
Thai Nguyen 2020-08-11 02:08:22 -07:00 committed by TensorFlower Gardener
parent 7169ab7935
commit cb09f04f33
6 changed files with 425 additions and 3 deletions

View File

@ -205,8 +205,19 @@ bazel build -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
This will generate an AAR file in `bazel-bin/tensorflow/lite/java/`. Note
that this builds a "fat" AAR with several different architectures; if you don't
need all of them, use the subset appropriate for your deployment environment.
From there, there are several approaches you can take to use the .aar in your
Android Studio project.
Caution: Following feature is experimental and only available at HEAD. You can
build smaller AAR files targeting only a set of models as follows:
```sh
bash tensorflow/lite/tools/build_aar.sh \
--input_models=model1,model2 \
--target_archs=x86,x86_64,arm64-v8a,armeabi-v7a
```
Above script will generate the `tensorflow-lite.aar` file and optionally the
`tensorflow-lite-select-tf-ops.aar` file if one of the models is using
Tensorflow ops.
##### Add AAR directly to project

View File

@ -296,6 +296,27 @@ cc_library(
],
)
cc_library(
name = "list_flex_ops_no_kernel",
srcs = ["list_flex_ops_no_kernel.cc"],
hdrs = ["list_flex_ops.h"],
deps = [
"//tensorflow/lite:framework",
"@com_google_absl//absl/strings",
],
)
tf_cc_binary(
name = "list_flex_ops_no_kernel_main",
srcs = ["list_flex_ops_main.cc"],
visibility = ["//visibility:public"],
deps = [
":list_flex_ops_no_kernel",
"//tensorflow/lite/tools:command_line_flags",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "list_flex_ops_test",
srcs = ["list_flex_ops_test.cc"],

View File

@ -0,0 +1,214 @@
#!/bin/bash
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
ROOT_DIR="$(cd "${SCRIPT_DIR}/../../../" && pwd)"
function print_usage {
echo "Usage:"
echo " $(basename ${BASH_SOURCE}) \\"
echo " --input_models=model1.tflite,model2.tflite \\"
echo " --target_archs=x86,x86_64,arm64-v8a,armeabi-v7a \\"
echo " --tflite_custom_ops_srcs=file1.cc,file2.h \\"
echo " --tflite_custom_ops_deps=dep1,dep2"
echo ""
echo "Where: "
echo " --input_models: Supported TFLite models. "
echo " --target_archs: Supported arches included in the aar file."
echo " --tflite_custom_ops_srcs: The src files for building additional TFLite custom ops if any."
echo " --tflite_custom_ops_deps: Dependencies for building additional TFLite custom ops if any."
echo ""
exit 1
}
function generate_list_field {
local name="$1"
local list_string="$2"
local list=(${list_string//,/ })
local message+=("$name=[")
for item in "${list[@]}"
do
message+=("\"$item\",")
done
message+=('],')
printf '%s' "${message[@]}"
}
function print_output {
echo "Output can be found here:"
for i in "$@"
do
# Check if the file exist.
ls -1a ${ROOT_DIR}/$i
done
}
function generate_tflite_aar {
pushd ${TMP_DIR} > /dev/null
# Generate the BUILD file.
message=(
'load("//tensorflow/lite:build_def.bzl", "tflite_custom_android_library")'
'load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")'
''
'tflite_custom_android_library('
' name = "custom_tensorflowlite",'
)
message+=(' '$(generate_list_field "models" $MODEL_NAMES))
message+=(' '$(generate_list_field "srcs" $TFLITE_OPS_SRCS))
message+=(' '$(generate_list_field "deps" $FLAG_TFLITE_OPS_DEPS))
message+=(
')'
''
'aar_with_jni('
' name = "tensorflow-lite",'
' android_library = ":custom_tensorflowlite",'
')'
''
)
printf '%s\n' "${message[@]}" >> BUILD
# Build the aar package.
popd > /dev/null
bazel build -c opt --cxxopt='--std=c++14' \
--fat_apk_cpu=${TARGET_ARCHS} \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
//tmp:tensorflow-lite
OUT_FILES="${OUT_FILES} bazel-bin/tmp/tensorflow-lite.aar"
}
function generate_flex_aar {
pushd ${TMP_DIR}
# Generating the BUILD file.
message=(
'load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_android_library")'
'load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")'
''
'tflite_flex_android_library('
' name = "custom_tensorflowlite_flex",'
)
message+=(' '$(generate_list_field "models" $MODEL_NAMES))
message+=(
')'
''
'aar_with_jni('
' name = "tensorflow-lite-select-tf-ops",'
' android_library = ":custom_tensorflowlite_flex",'
')'
)
printf '%s\n' "${message[@]}" >> BUILD
cp ${ROOT_DIR}/tensorflow/lite/java/AndroidManifest.xml .
cp ${ROOT_DIR}/tensorflow/lite/java/proguard.flags .
popd
# Build the aar package.
bazel build -c opt --cxxopt='--std=c++14' \
--fat_apk_cpu=${TARGET_ARCHS} \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
//tmp:tensorflow-lite-select-tf-ops
OUT_FILES="${OUT_FILES} bazel-bin/tmp/tensorflow-lite-select-tf-ops.aar"
}
# Check command line flags.
TARGET_ARCHS=x86,x86_64,arm64-v8a,armeabi-v7a
if [ "$#" -gt 4 ]; then
echo "ERROR: Too many arguments."
print_usage
fi
for i in "$@"
do
case $i in
--input_models=*)
FLAG_MODELS="${i#*=}"
shift;;
--target_archs=*)
TARGET_ARCHS="${i#*=}"
shift;;
--tflite_custom_ops_srcs=*)
FLAG_TFLITE_OPS_SRCS="${i#*=}"
shift;;
--tflite_custom_ops_deps=*)
FLAG_TFLITE_OPS_DEPS="${i#*=}"
shift;;
*)
echo "ERROR: Unrecognized argument: ${i}"
print_usage;;
esac
done
# Check if users already run configure
cd $ROOT_DIR
if [ ! -f "$ROOT_DIR/.tf_configure.bazelrc" ]; then
echo "ERROR: Please run ./configure first."
exit 1
else
if ! grep -q ANDROID_SDK_HOME "$ROOT_DIR/.tf_configure.bazelrc"; then
echo "ERROR: Please run ./configure with Android config."
exit 1
fi
fi
# Build the standard aar package of no models provided.
if [ -z ${FLAG_MODELS} ]; then
bazel build -c opt --cxxopt='--std=c++14' \
--fat_apk_cpu=${TARGET_ARCHS} \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
//tensorflow/lite/java:tensorflow-lite
print_output bazel-bin/tensorflow/lite/java/tensorflow-lite.aar
exit 0
fi
# Prepare the tmp directory.
TMP_DIR="${ROOT_DIR}/tmp/"
rm -rf ${TMP_DIR} && mkdir -p ${TMP_DIR}
# Copy models to tmp directory.
MODEL_NAMES=""
for model in $(echo ${FLAG_MODELS} | sed "s/,/ /g")
do
cp ${model} ${TMP_DIR}
MODEL_NAMES="${MODEL_NAMES},$(basename ${model})"
done
# Copy srcs of additional tflite ops to tmp directory.
TFLITE_OPS_SRCS=""
for src_file in $(echo ${FLAG_TFLITE_OPS_SRCS} | sed "s/,/ /g")
do
cp ${src_file} ${TMP_DIR}
TFLITE_OPS_SRCS="${TFLITE_OPS_SRCS},$(basename ${src_file})"
done
# Build the custom aar package.
generate_tflite_aar
# Build flex aar if one of the models contain flex ops.
bazel build -c opt --config=monolithic //tensorflow/lite/tools:list_flex_ops_no_kernel_main
bazel-bin/tensorflow/lite/tools/list_flex_ops_no_kernel_main --graphs=${FLAG_MODELS} > ${TMP_DIR}/ops_list.txt
if [[ `cat ${TMP_DIR}/ops_list.txt` != "[]" ]]; then
generate_flex_aar
fi
# List the output files.
rm -rf ${TMP_DIR}
print_output ${OUT_FILES}

View File

@ -0,0 +1,115 @@
#!/bin/bash
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
function print_usage {
echo "Usage:"
echo " $(basename ${BASH_SOURCE}) \\"
echo " --input_models=model1.tflite,model2.tflite \\"
echo " --target_archs=x86,x86_64,arm64-v8a,armeabi-v7a \\"
echo " --checkpoint=master"
echo ""
echo "Where: "
echo " --input_models: Supported TFLite models. "
echo " --target_archs: Supported arches included in the aar file."
echo " --checkpoint: Checkpoint of the github repo, could be a branch, a commit or a tag. Default: master"
echo ""
exit 1
}
# Check command line flags.
ARGUMENTS=$@
TARGET_ARCHS=x86,x86_64,arm64-v8a,armeabi-v7a
FLAG_CHECKPOINT="master"
if [ "$#" -gt 3 ]; then
echo "ERROR: Too many arguments."
print_usage
fi
for i in "$@"
do
case $i in
--input_models=*)
FLAG_MODELS="${i#*=}"
shift;;
--target_archs=*)
TARGET_ARCHS="${i#*=}"
shift;;
--checkpoint=*)
FLAG_CHECKPOINT="${i#*=}"
shift;;
*)
echo "ERROR: Unrecognized argument: ${i}"
print_usage;;
esac
done
if [ ! -d /tensorflow_src ]; then
# Running on host.
for model in $(echo ${FLAG_MODELS} | sed "s/,/ /g")
do
FLAG_DIR="${FLAG_DIR} -v ${model}:${model}"
done
docker run --rm -it -v $PWD:/tmp -v ${SCRIPT_DIR}:/script_dir ${FLAG_DIR} \
--entrypoint /script_dir/build_aar_with_docker.sh tflite-builder \
${ARGUMENTS}
exit 0
else
# Running inside docker container, download the SDK first.
android update sdk --no-ui -a \
--filter tools,platform-tools,android-${ANDROID_API_LEVEL},build-tools-${ANDROID_BUILD_TOOLS_VERSION}
cd /tensorflow_src
# Run configure.
configs=(
'/usr/bin/python3'
'/usr/lib/python3/dist-packages'
'N'
'N'
'N'
'N'
'-march=native -Wno-sign-compare'
'y'
'/android/sdk'
)
printf '%s\n' "${configs[@]}" | ./configure
# Pull the latest code from tensorflow.
git pull -a
git checkout ${FLAG_CHECKPOINT}
# Building with bazel.
bash /tensorflow_src/tensorflow/lite/tools/build_aar.sh ${ARGUMENTS}
# Copy the output files from docker container.
clear
OUT_FILES="/tensorflow_src/bazel-bin/tmp/tensorflow-lite.aar"
OUT_FILES="${OUT_FILES} /tensorflow_src/bazel-bin/tmp/tensorflow-lite-select-tf-ops.aar"
echo "Output can be found here:"
for i in ${OUT_FILES}
do
if [ -f $i ]; then
cp $i /tmp
basename $i
fi
done
fi

View File

@ -42,7 +42,7 @@ struct OpKernelCompare {
using OpKernelSet = std::set<OpKernel, OpKernelCompare>;
// Find flex ops and its kernel classes inside a TFLite model and add them to
// the map flex_ops. The map stores
// the map flex_ops.
void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops);
// Serialize the list op of to a json string. If flex_ops is empty, return an

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/lite/tools/list_flex_ops.h"
namespace tflite {
namespace flex {
std::string OpListToJSONString(const OpKernelSet& flex_ops) {
return absl::StrCat("[",
absl::StrJoin(flex_ops, ",\n",
[](std::string* out, const OpKernel& op) {
absl::StrAppend(out, "\"", op.op_name,
"\"");
}),
"]");
}
void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops) {
auto* subgraphs = model->subgraphs();
if (!subgraphs) return;
for (int subgraph_index = 0; subgraph_index < subgraphs->size();
++subgraph_index) {
const tflite::SubGraph* subgraph = subgraphs->Get(subgraph_index);
auto* operators = subgraph->operators();
auto* opcodes = model->operator_codes();
if (!operators || !opcodes) continue;
for (int i = 0; i < operators->size(); ++i) {
const tflite::Operator* op = operators->Get(i);
const tflite::OperatorCode* opcode = opcodes->Get(op->opcode_index());
if (opcode->builtin_code() != tflite::BuiltinOperator_CUSTOM ||
!tflite::IsFlexOp(opcode->custom_code()->c_str())) {
continue;
}
// Remove the "Flex" prefix from op name.
std::string flex_op_name(opcode->custom_code()->c_str());
std::string tf_op_name =
flex_op_name.substr(strlen(tflite::kFlexCustomCodePrefix));
flex_ops->insert({tf_op_name, ""});
}
}
}
} // namespace flex
} // namespace tflite