Add script to generate trimmed tflite aar files
The script can be used as follows: bash tensorflow/lite/java/build_customized_aar_for_models.sh \ --customize_for_models=model1,model2 \ --target_archs=x86,x86_64,arm64-v8a,armeabi-v7a it will generate the tensorflow-lite.aar and tensorflow-lite-select-tf-ops.aar if needed. PiperOrigin-RevId: 325980230 Change-Id: I9c5216108ae87ac22a69cbb88b37bb473bd4d37f
This commit is contained in:
parent
7169ab7935
commit
cb09f04f33
tensorflow/lite
@ -205,8 +205,19 @@ bazel build -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
|
||||
This will generate an AAR file in `bazel-bin/tensorflow/lite/java/`. Note
|
||||
that this builds a "fat" AAR with several different architectures; if you don't
|
||||
need all of them, use the subset appropriate for your deployment environment.
|
||||
From there, there are several approaches you can take to use the .aar in your
|
||||
Android Studio project.
|
||||
|
||||
Caution: Following feature is experimental and only available at HEAD. You can
|
||||
build smaller AAR files targeting only a set of models as follows:
|
||||
|
||||
```sh
|
||||
bash tensorflow/lite/tools/build_aar.sh \
|
||||
--input_models=model1,model2 \
|
||||
--target_archs=x86,x86_64,arm64-v8a,armeabi-v7a
|
||||
```
|
||||
|
||||
Above script will generate the `tensorflow-lite.aar` file and optionally the
|
||||
`tensorflow-lite-select-tf-ops.aar` file if one of the models is using
|
||||
Tensorflow ops.
|
||||
|
||||
##### Add AAR directly to project
|
||||
|
||||
|
@ -296,6 +296,27 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "list_flex_ops_no_kernel",
|
||||
srcs = ["list_flex_ops_no_kernel.cc"],
|
||||
hdrs = ["list_flex_ops.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "list_flex_ops_no_kernel_main",
|
||||
srcs = ["list_flex_ops_main.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":list_flex_ops_no_kernel",
|
||||
"//tensorflow/lite/tools:command_line_flags",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "list_flex_ops_test",
|
||||
srcs = ["list_flex_ops_test.cc"],
|
||||
|
214
tensorflow/lite/tools/build_aar.sh
Executable file
214
tensorflow/lite/tools/build_aar.sh
Executable file
@ -0,0 +1,214 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
ROOT_DIR="$(cd "${SCRIPT_DIR}/../../../" && pwd)"
|
||||
|
||||
function print_usage {
|
||||
echo "Usage:"
|
||||
echo " $(basename ${BASH_SOURCE}) \\"
|
||||
echo " --input_models=model1.tflite,model2.tflite \\"
|
||||
echo " --target_archs=x86,x86_64,arm64-v8a,armeabi-v7a \\"
|
||||
echo " --tflite_custom_ops_srcs=file1.cc,file2.h \\"
|
||||
echo " --tflite_custom_ops_deps=dep1,dep2"
|
||||
echo ""
|
||||
echo "Where: "
|
||||
echo " --input_models: Supported TFLite models. "
|
||||
echo " --target_archs: Supported arches included in the aar file."
|
||||
echo " --tflite_custom_ops_srcs: The src files for building additional TFLite custom ops if any."
|
||||
echo " --tflite_custom_ops_deps: Dependencies for building additional TFLite custom ops if any."
|
||||
echo ""
|
||||
exit 1
|
||||
}
|
||||
|
||||
function generate_list_field {
|
||||
local name="$1"
|
||||
local list_string="$2"
|
||||
local list=(${list_string//,/ })
|
||||
|
||||
local message+=("$name=[")
|
||||
for item in "${list[@]}"
|
||||
do
|
||||
message+=("\"$item\",")
|
||||
done
|
||||
message+=('],')
|
||||
printf '%s' "${message[@]}"
|
||||
}
|
||||
|
||||
function print_output {
|
||||
echo "Output can be found here:"
|
||||
for i in "$@"
|
||||
do
|
||||
# Check if the file exist.
|
||||
ls -1a ${ROOT_DIR}/$i
|
||||
done
|
||||
}
|
||||
|
||||
function generate_tflite_aar {
|
||||
pushd ${TMP_DIR} > /dev/null
|
||||
# Generate the BUILD file.
|
||||
message=(
|
||||
'load("//tensorflow/lite:build_def.bzl", "tflite_custom_android_library")'
|
||||
'load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")'
|
||||
''
|
||||
'tflite_custom_android_library('
|
||||
' name = "custom_tensorflowlite",'
|
||||
)
|
||||
message+=(' '$(generate_list_field "models" $MODEL_NAMES))
|
||||
message+=(' '$(generate_list_field "srcs" $TFLITE_OPS_SRCS))
|
||||
message+=(' '$(generate_list_field "deps" $FLAG_TFLITE_OPS_DEPS))
|
||||
message+=(
|
||||
')'
|
||||
''
|
||||
'aar_with_jni('
|
||||
' name = "tensorflow-lite",'
|
||||
' android_library = ":custom_tensorflowlite",'
|
||||
')'
|
||||
''
|
||||
)
|
||||
printf '%s\n' "${message[@]}" >> BUILD
|
||||
|
||||
# Build the aar package.
|
||||
popd > /dev/null
|
||||
bazel build -c opt --cxxopt='--std=c++14' \
|
||||
--fat_apk_cpu=${TARGET_ARCHS} \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
//tmp:tensorflow-lite
|
||||
|
||||
OUT_FILES="${OUT_FILES} bazel-bin/tmp/tensorflow-lite.aar"
|
||||
}
|
||||
|
||||
function generate_flex_aar {
|
||||
pushd ${TMP_DIR}
|
||||
# Generating the BUILD file.
|
||||
message=(
|
||||
'load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_android_library")'
|
||||
'load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")'
|
||||
''
|
||||
'tflite_flex_android_library('
|
||||
' name = "custom_tensorflowlite_flex",'
|
||||
)
|
||||
message+=(' '$(generate_list_field "models" $MODEL_NAMES))
|
||||
message+=(
|
||||
')'
|
||||
''
|
||||
'aar_with_jni('
|
||||
' name = "tensorflow-lite-select-tf-ops",'
|
||||
' android_library = ":custom_tensorflowlite_flex",'
|
||||
')'
|
||||
)
|
||||
printf '%s\n' "${message[@]}" >> BUILD
|
||||
|
||||
cp ${ROOT_DIR}/tensorflow/lite/java/AndroidManifest.xml .
|
||||
cp ${ROOT_DIR}/tensorflow/lite/java/proguard.flags .
|
||||
popd
|
||||
|
||||
# Build the aar package.
|
||||
bazel build -c opt --cxxopt='--std=c++14' \
|
||||
--fat_apk_cpu=${TARGET_ARCHS} \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
//tmp:tensorflow-lite-select-tf-ops
|
||||
|
||||
OUT_FILES="${OUT_FILES} bazel-bin/tmp/tensorflow-lite-select-tf-ops.aar"
|
||||
}
|
||||
|
||||
# Check command line flags.
|
||||
TARGET_ARCHS=x86,x86_64,arm64-v8a,armeabi-v7a
|
||||
|
||||
if [ "$#" -gt 4 ]; then
|
||||
echo "ERROR: Too many arguments."
|
||||
print_usage
|
||||
fi
|
||||
|
||||
for i in "$@"
|
||||
do
|
||||
case $i in
|
||||
--input_models=*)
|
||||
FLAG_MODELS="${i#*=}"
|
||||
shift;;
|
||||
--target_archs=*)
|
||||
TARGET_ARCHS="${i#*=}"
|
||||
shift;;
|
||||
--tflite_custom_ops_srcs=*)
|
||||
FLAG_TFLITE_OPS_SRCS="${i#*=}"
|
||||
shift;;
|
||||
--tflite_custom_ops_deps=*)
|
||||
FLAG_TFLITE_OPS_DEPS="${i#*=}"
|
||||
shift;;
|
||||
*)
|
||||
echo "ERROR: Unrecognized argument: ${i}"
|
||||
print_usage;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check if users already run configure
|
||||
cd $ROOT_DIR
|
||||
if [ ! -f "$ROOT_DIR/.tf_configure.bazelrc" ]; then
|
||||
echo "ERROR: Please run ./configure first."
|
||||
exit 1
|
||||
else
|
||||
if ! grep -q ANDROID_SDK_HOME "$ROOT_DIR/.tf_configure.bazelrc"; then
|
||||
echo "ERROR: Please run ./configure with Android config."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Build the standard aar package of no models provided.
|
||||
if [ -z ${FLAG_MODELS} ]; then
|
||||
bazel build -c opt --cxxopt='--std=c++14' \
|
||||
--fat_apk_cpu=${TARGET_ARCHS} \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
//tensorflow/lite/java:tensorflow-lite
|
||||
|
||||
print_output bazel-bin/tensorflow/lite/java/tensorflow-lite.aar
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Prepare the tmp directory.
|
||||
TMP_DIR="${ROOT_DIR}/tmp/"
|
||||
rm -rf ${TMP_DIR} && mkdir -p ${TMP_DIR}
|
||||
|
||||
# Copy models to tmp directory.
|
||||
MODEL_NAMES=""
|
||||
for model in $(echo ${FLAG_MODELS} | sed "s/,/ /g")
|
||||
do
|
||||
cp ${model} ${TMP_DIR}
|
||||
MODEL_NAMES="${MODEL_NAMES},$(basename ${model})"
|
||||
done
|
||||
|
||||
# Copy srcs of additional tflite ops to tmp directory.
|
||||
TFLITE_OPS_SRCS=""
|
||||
for src_file in $(echo ${FLAG_TFLITE_OPS_SRCS} | sed "s/,/ /g")
|
||||
do
|
||||
cp ${src_file} ${TMP_DIR}
|
||||
TFLITE_OPS_SRCS="${TFLITE_OPS_SRCS},$(basename ${src_file})"
|
||||
done
|
||||
|
||||
# Build the custom aar package.
|
||||
generate_tflite_aar
|
||||
|
||||
# Build flex aar if one of the models contain flex ops.
|
||||
bazel build -c opt --config=monolithic //tensorflow/lite/tools:list_flex_ops_no_kernel_main
|
||||
bazel-bin/tensorflow/lite/tools/list_flex_ops_no_kernel_main --graphs=${FLAG_MODELS} > ${TMP_DIR}/ops_list.txt
|
||||
if [[ `cat ${TMP_DIR}/ops_list.txt` != "[]" ]]; then
|
||||
generate_flex_aar
|
||||
fi
|
||||
|
||||
# List the output files.
|
||||
rm -rf ${TMP_DIR}
|
||||
print_output ${OUT_FILES}
|
115
tensorflow/lite/tools/build_aar_with_docker.sh
Executable file
115
tensorflow/lite/tools/build_aar_with_docker.sh
Executable file
@ -0,0 +1,115 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
function print_usage {
|
||||
echo "Usage:"
|
||||
echo " $(basename ${BASH_SOURCE}) \\"
|
||||
echo " --input_models=model1.tflite,model2.tflite \\"
|
||||
echo " --target_archs=x86,x86_64,arm64-v8a,armeabi-v7a \\"
|
||||
echo " --checkpoint=master"
|
||||
echo ""
|
||||
echo "Where: "
|
||||
echo " --input_models: Supported TFLite models. "
|
||||
echo " --target_archs: Supported arches included in the aar file."
|
||||
echo " --checkpoint: Checkpoint of the github repo, could be a branch, a commit or a tag. Default: master"
|
||||
echo ""
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Check command line flags.
|
||||
ARGUMENTS=$@
|
||||
TARGET_ARCHS=x86,x86_64,arm64-v8a,armeabi-v7a
|
||||
FLAG_CHECKPOINT="master"
|
||||
|
||||
if [ "$#" -gt 3 ]; then
|
||||
echo "ERROR: Too many arguments."
|
||||
print_usage
|
||||
fi
|
||||
|
||||
for i in "$@"
|
||||
do
|
||||
case $i in
|
||||
--input_models=*)
|
||||
FLAG_MODELS="${i#*=}"
|
||||
shift;;
|
||||
--target_archs=*)
|
||||
TARGET_ARCHS="${i#*=}"
|
||||
shift;;
|
||||
--checkpoint=*)
|
||||
FLAG_CHECKPOINT="${i#*=}"
|
||||
shift;;
|
||||
*)
|
||||
echo "ERROR: Unrecognized argument: ${i}"
|
||||
print_usage;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ ! -d /tensorflow_src ]; then
|
||||
# Running on host.
|
||||
for model in $(echo ${FLAG_MODELS} | sed "s/,/ /g")
|
||||
do
|
||||
FLAG_DIR="${FLAG_DIR} -v ${model}:${model}"
|
||||
done
|
||||
docker run --rm -it -v $PWD:/tmp -v ${SCRIPT_DIR}:/script_dir ${FLAG_DIR} \
|
||||
--entrypoint /script_dir/build_aar_with_docker.sh tflite-builder \
|
||||
${ARGUMENTS}
|
||||
exit 0
|
||||
else
|
||||
# Running inside docker container, download the SDK first.
|
||||
android update sdk --no-ui -a \
|
||||
--filter tools,platform-tools,android-${ANDROID_API_LEVEL},build-tools-${ANDROID_BUILD_TOOLS_VERSION}
|
||||
|
||||
cd /tensorflow_src
|
||||
|
||||
# Run configure.
|
||||
configs=(
|
||||
'/usr/bin/python3'
|
||||
'/usr/lib/python3/dist-packages'
|
||||
'N'
|
||||
'N'
|
||||
'N'
|
||||
'N'
|
||||
'-march=native -Wno-sign-compare'
|
||||
'y'
|
||||
'/android/sdk'
|
||||
)
|
||||
printf '%s\n' "${configs[@]}" | ./configure
|
||||
|
||||
# Pull the latest code from tensorflow.
|
||||
git pull -a
|
||||
git checkout ${FLAG_CHECKPOINT}
|
||||
|
||||
# Building with bazel.
|
||||
bash /tensorflow_src/tensorflow/lite/tools/build_aar.sh ${ARGUMENTS}
|
||||
|
||||
# Copy the output files from docker container.
|
||||
clear
|
||||
OUT_FILES="/tensorflow_src/bazel-bin/tmp/tensorflow-lite.aar"
|
||||
OUT_FILES="${OUT_FILES} /tensorflow_src/bazel-bin/tmp/tensorflow-lite-select-tf-ops.aar"
|
||||
echo "Output can be found here:"
|
||||
for i in ${OUT_FILES}
|
||||
do
|
||||
if [ -f $i ]; then
|
||||
cp $i /tmp
|
||||
basename $i
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
@ -42,7 +42,7 @@ struct OpKernelCompare {
|
||||
using OpKernelSet = std::set<OpKernel, OpKernelCompare>;
|
||||
|
||||
// Find flex ops and its kernel classes inside a TFLite model and add them to
|
||||
// the map flex_ops. The map stores
|
||||
// the map flex_ops.
|
||||
void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops);
|
||||
|
||||
// Serialize the list op of to a json string. If flex_ops is empty, return an
|
||||
|
61
tensorflow/lite/tools/list_flex_ops_no_kernel.cc
Normal file
61
tensorflow/lite/tools/list_flex_ops_no_kernel.cc
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/lite/tools/list_flex_ops.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
|
||||
std::string OpListToJSONString(const OpKernelSet& flex_ops) {
|
||||
return absl::StrCat("[",
|
||||
absl::StrJoin(flex_ops, ",\n",
|
||||
[](std::string* out, const OpKernel& op) {
|
||||
absl::StrAppend(out, "\"", op.op_name,
|
||||
"\"");
|
||||
}),
|
||||
"]");
|
||||
}
|
||||
|
||||
void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops) {
|
||||
auto* subgraphs = model->subgraphs();
|
||||
if (!subgraphs) return;
|
||||
|
||||
for (int subgraph_index = 0; subgraph_index < subgraphs->size();
|
||||
++subgraph_index) {
|
||||
const tflite::SubGraph* subgraph = subgraphs->Get(subgraph_index);
|
||||
auto* operators = subgraph->operators();
|
||||
auto* opcodes = model->operator_codes();
|
||||
if (!operators || !opcodes) continue;
|
||||
|
||||
for (int i = 0; i < operators->size(); ++i) {
|
||||
const tflite::Operator* op = operators->Get(i);
|
||||
const tflite::OperatorCode* opcode = opcodes->Get(op->opcode_index());
|
||||
if (opcode->builtin_code() != tflite::BuiltinOperator_CUSTOM ||
|
||||
!tflite::IsFlexOp(opcode->custom_code()->c_str())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Remove the "Flex" prefix from op name.
|
||||
std::string flex_op_name(opcode->custom_code()->c_str());
|
||||
std::string tf_op_name =
|
||||
flex_op_name.substr(strlen(tflite::kFlexCustomCodePrefix));
|
||||
|
||||
flex_ops->insert({tf_op_name, ""});
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user