Copy Android Inference Interface out of contrib.

Point tensorflow/examples/android at the copy.

PiperOrigin-RevId: 269423640
This commit is contained in:
Mark Daoust 2019-09-16 14:38:24 -07:00 committed by TensorFlower Gardener
parent b056951d0d
commit b4d110caee
19 changed files with 1653 additions and 10 deletions

View File

@ -53,7 +53,7 @@ cc_library(
name = "tensorflow_native_libs",
srcs = [
":libtensorflow_demo.so",
"//tensorflow/contrib/android:libtensorflow_inference.so",
"//tensorflow/tools/android/inference_interface:libtensorflow_inference.so",
],
tags = [
"manual",
@ -84,7 +84,7 @@ android_binary(
],
deps = [
":tensorflow_native_libs",
"//tensorflow/contrib/android:android_tensorflow_inference_java",
"//tensorflow/tools/android/inference_interface:android_tensorflow_inference_java",
],
)

View File

@ -9,9 +9,9 @@ The demos in this folder are designed to give straightforward samples of using
TensorFlow in mobile applications.
Inference is done using the [TensorFlow Android Inference
Interface](../../../tensorflow/contrib/android), which may be built separately
if you want a standalone library to drop into your existing application. Object
tracking and efficient YUV -> RGB conversion are handled by
Interface](../../tools/android/inference_interface), which may be built
separately if you want a standalone library to drop into your existing
application. Object tracking and efficient YUV -> RGB conversion are handled by
`libtensorflow_demo.so`.
A device running Android 5.0 (API 21) or higher is required to run the demo due
@ -49,7 +49,7 @@ The fastest path to trying the demo is to download the [prebuilt demo APK](https
Also available are precompiled native libraries, and a jcenter package that you
may simply drop into your own applications. See
[tensorflow/contrib/android/README.md](../../../tensorflow/contrib/android/README.md)
[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
for more details.
## Running the Demo
@ -89,7 +89,7 @@ For any project that does not include custom low level TensorFlow code, this is
likely sufficient.
For details on how to include this JCenter package in your own project see
[tensorflow/contrib/android/README.md](../../../tensorflow/contrib/android/README.md)
[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
## Building the Demo with TensorFlow from Source
@ -212,4 +212,4 @@ NDK).
Full CMake support for the demo is coming soon, but for now it is possible to
build the TensorFlow Android Inference library using
[tensorflow/contrib/android/cmake](../../../tensorflow/contrib/android/cmake).
[tensorflow/tools/android/inference_interface/cmake](../../tools/android/inference_interface/cmake).

View File

@ -41,6 +41,7 @@ filegroup(
visibility = [
"//tensorflow/contrib/android:__pkg__",
"//tensorflow/java:__pkg__",
"//tensorflow/tools/android/inference_interface:__pkg__",
],
)

View File

@ -5,11 +5,12 @@
package(default_visibility = [
"//tensorflow/java:__pkg__",
# TODO(ashankar): Temporary hack for the Java API and
# //third_party/tensorflow/contrib/android:android_tensorflow_inference_jni
# //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_inference_jni
# to co-exist in a single shared library. However, the hope is that
# //third_party/tensorflow/contrib/android:android_tensorflow_jni can be
# //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_jni can be
# removed once the Java API provides feature parity with it.
"//tensorflow/contrib/android:__pkg__",
"//tensorflow/tools/android/inference_interface:__pkg__",
])
licenses(["notice"]) # Apache 2.0

View File

@ -0,0 +1,4 @@
# Deprecated android inference interface.
WARNING: This directory contains deprecated tf-mobile android inference
interface do not use this for anything new. Use TFLite.

View File

@ -0,0 +1,89 @@
# Description:
# JNI-based Java inference interface for TensorFlow.
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_cc_binary",
"tf_copts",
)
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files([
"LICENSE",
"jni/version_script.lds",
])
filegroup(
name = "android_tensorflow_inference_jni_srcs",
srcs = glob([
"**/*.cc",
"**/*.h",
]),
visibility = ["//visibility:public"],
)
cc_library(
name = "android_tensorflow_inference_jni",
srcs = if_android([":android_tensorflow_inference_jni_srcs"]),
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/java/src/main/native",
],
alwayslink = 1,
)
# JAR with Java bindings to TF.
android_library(
name = "android_tensorflow_inference_java",
srcs = glob(["java/**/*.java"]) + ["//tensorflow/java:java_sources"],
tags = [
"manual",
"notap",
],
)
# Build the native .so.
# bazel build //tensorflow/tools/android/inference_interface:libtensorflow_inference.so \
# --crosstool_top=//external:android/crosstool \
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
# --cpu=armeabi-v7a
LINKER_SCRIPT = "//tensorflow/tools/android/inference_interface:jni/version_script.lds"
# This fails to buiild if converted to tf_cc_binary.
cc_binary(
name = "libtensorflow_inference.so",
copts = tf_copts() + [
"-ffunction-sections",
"-fdata-sections",
],
linkopts = if_android([
"-landroid",
"-latomic",
"-ldl",
"-llog",
"-lm",
"-z defs",
"-s",
"-Wl,--gc-sections",
"-Wl,--version-script,$(location {})".format(LINKER_SCRIPT),
]),
linkshared = 1,
linkstatic = 1,
tags = [
"manual",
"notap",
],
deps = [
":android_tensorflow_inference_jni",
"//tensorflow/core:android_tensorflow_lib",
LINKER_SCRIPT,
],
)

View File

@ -0,0 +1,95 @@
# Android TensorFlow support
This directory defines components (a native `.so` library and a Java JAR)
geared towards supporting TensorFlow on Android. This includes:
- The [TensorFlow Java API](../../java/README.md)
- A `TensorFlowInferenceInterface` class that provides a smaller API
surface suitable for inference and summarizing performance of model execution.
For example usage, see [TensorFlowImageClassifier.java](../../examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java)
in the [TensorFlow Android Demo](../../examples/android).
For prebuilt libraries, see the
[nightly Android build artifacts](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)
page for a recent build.
The TensorFlow Inference Interface is also available as a
[JCenter package](https://bintray.com/google/tensorflow/tensorflow)
(see the tensorflow-android directory) and can be included quite simply in your
android project with a couple of lines in the project's `build.gradle` file:
```
allprojects {
repositories {
jcenter()
}
}
dependencies {
compile 'org.tensorflow:tensorflow-android:+'
}
```
This will tell Gradle to use the
[latest version](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
of the TensorFlow AAR that has been released to
[JCenter](https://jcenter.bintray.com/org/tensorflow/tensorflow-android/).
You may replace the `+` with an explicit version label if you wish to
use a specific release of TensorFlow in your app.
To build the libraries yourself (if, for example, you want to support custom
TensorFlow operators), pick your preferred approach below:
### Bazel
First follow the Bazel setup instructions described in
[tensorflow/examples/android/README.md](../../examples/android/README.md)
Then, to build the native TF library:
```
bazel build -c opt //tensorflow/tools/android/inference_interface:libtensorflow_inference.so \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cxxopt=-std=c++11 \
--cpu=armeabi-v7a
```
Replacing `armeabi-v7a` with your desired target architecture.
The library will be located at:
```
bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so
```
To build the Java counterpart:
```
bazel build //tensorflow/tools/android/inference_interface:android_tensorflow_inference_java
```
You will find the JAR file at:
```
bazel-bin/tensorflow/tools/android/inference_interface/libandroid_tensorflow_inference_java.jar
```
### CMake
For documentation on building a self-contained AAR file with cmake, see
[tensorflow/tools/android/inference_interface/cmake](cmake).
### Makefile
For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md)
## AssetManagerFileSystem
This directory also contains a TensorFlow filesystem supporting the Android
asset manager. This may be useful when writing native (C++) code that is tightly
coupled with TensorFlow. For typical usage, the library above will be
sufficient.

View File

@ -0,0 +1,272 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/android/inference_interface/asset_manager_filesystem.h"
#include <unistd.h>
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
namespace tensorflow {
namespace {
string RemoveSuffix(const string& name, const string& suffix) {
string output(name);
StringPiece piece(output);
absl::ConsumeSuffix(&piece, suffix);
return string(piece);
}
// Closes the given AAsset when variable is destructed.
class ScopedAsset {
public:
ScopedAsset(AAsset* asset) : asset_(asset) {}
~ScopedAsset() {
if (asset_ != nullptr) {
AAsset_close(asset_);
}
}
AAsset* get() const { return asset_; }
private:
AAsset* asset_;
};
// Closes the given AAssetDir when variable is destructed.
class ScopedAssetDir {
public:
ScopedAssetDir(AAssetDir* asset_dir) : asset_dir_(asset_dir) {}
~ScopedAssetDir() {
if (asset_dir_ != nullptr) {
AAssetDir_close(asset_dir_);
}
}
AAssetDir* get() const { return asset_dir_; }
private:
AAssetDir* asset_dir_;
};
class ReadOnlyMemoryRegionFromAsset : public ReadOnlyMemoryRegion {
public:
ReadOnlyMemoryRegionFromAsset(std::unique_ptr<char[]> data, uint64 length)
: data_(std::move(data)), length_(length) {}
~ReadOnlyMemoryRegionFromAsset() override = default;
const void* data() override { return reinterpret_cast<void*>(data_.get()); }
uint64 length() override { return length_; }
private:
std::unique_ptr<char[]> data_;
uint64 length_;
};
// Note that AAssets are not thread-safe and cannot be used across threads.
// However, AAssetManager is. Because RandomAccessFile must be thread-safe and
// used across threads, new AAssets must be created for every access.
// TODO(tylerrhodes): is there a more efficient way to do this?
class RandomAccessFileFromAsset : public RandomAccessFile {
public:
RandomAccessFileFromAsset(AAssetManager* asset_manager, const string& name)
: asset_manager_(asset_manager), file_name_(name) {}
~RandomAccessFileFromAsset() override = default;
Status Read(uint64 offset, size_t to_read, StringPiece* result,
char* scratch) const override {
auto asset = ScopedAsset(AAssetManager_open(
asset_manager_, file_name_.c_str(), AASSET_MODE_RANDOM));
if (asset.get() == nullptr) {
return errors::NotFound("File ", file_name_, " not found.");
}
off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET);
off64_t length = AAsset_getLength64(asset.get());
if (new_offset < 0) {
*result = StringPiece(scratch, 0);
return errors::OutOfRange("Read after file end.");
}
const off64_t region_left =
std::min(length - new_offset, static_cast<off64_t>(to_read));
int read = AAsset_read(asset.get(), scratch, region_left);
if (read < 0) {
return errors::Internal("Error reading from asset.");
}
*result = StringPiece(scratch, region_left);
return (region_left == to_read)
? Status::OK()
: errors::OutOfRange("Read less bytes than requested.");
}
private:
AAssetManager* asset_manager_;
string file_name_;
};
} // namespace
AssetManagerFileSystem::AssetManagerFileSystem(AAssetManager* asset_manager,
const string& prefix)
: asset_manager_(asset_manager), prefix_(prefix) {}
Status AssetManagerFileSystem::FileExists(const string& fname) {
string path = RemoveAssetPrefix(fname);
auto asset = ScopedAsset(
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
if (asset.get() == nullptr) {
return errors::NotFound("File ", fname, " not found.");
}
return Status::OK();
}
Status AssetManagerFileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
string path = RemoveAssetPrefix(fname);
auto asset = ScopedAsset(
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
if (asset.get() == nullptr) {
return errors::NotFound("File ", fname, " not found.");
}
result->reset(new RandomAccessFileFromAsset(asset_manager_, path));
return Status::OK();
}
Status AssetManagerFileSystem::NewReadOnlyMemoryRegionFromFile(
const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
string path = RemoveAssetPrefix(fname);
auto asset = ScopedAsset(
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_STREAMING));
if (asset.get() == nullptr) {
return errors::NotFound("File ", fname, " not found.");
}
off64_t start, length;
int fd = AAsset_openFileDescriptor64(asset.get(), &start, &length);
std::unique_ptr<char[]> data;
if (fd >= 0) {
data.reset(new char[length]);
ssize_t result = pread(fd, data.get(), length, start);
if (result < 0) {
return errors::Internal("Error reading from file ", fname,
" using 'read': ", result);
}
if (result != length) {
return errors::Internal("Expected size does not match size read: ",
"Expected ", length, " vs. read ", result);
}
close(fd);
} else {
length = AAsset_getLength64(asset.get());
data.reset(new char[length]);
const void* asset_buffer = AAsset_getBuffer(asset.get());
if (asset_buffer == nullptr) {
return errors::Internal("Error reading ", fname, " from asset manager.");
}
memcpy(data.get(), asset_buffer, length);
}
result->reset(new ReadOnlyMemoryRegionFromAsset(std::move(data), length));
return Status::OK();
}
Status AssetManagerFileSystem::GetChildren(const string& prefixed_dir,
std::vector<string>* r) {
std::string path = NormalizeDirectoryPath(prefixed_dir);
auto dir =
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
if (dir.get() == nullptr) {
return errors::NotFound("Directory ", prefixed_dir, " not found.");
}
const char* next_file = AAssetDir_getNextFileName(dir.get());
while (next_file != nullptr) {
r->push_back(next_file);
next_file = AAssetDir_getNextFileName(dir.get());
}
return Status::OK();
}
Status AssetManagerFileSystem::GetFileSize(const string& fname, uint64* s) {
// If fname corresponds to a directory, return early. It doesn't map to an
// AAsset, and would otherwise return NotFound.
if (DirectoryExists(fname)) {
*s = 0;
return Status::OK();
}
string path = RemoveAssetPrefix(fname);
auto asset = ScopedAsset(
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
if (asset.get() == nullptr) {
return errors::NotFound("File ", fname, " not found.");
}
*s = AAsset_getLength64(asset.get());
return Status::OK();
}
Status AssetManagerFileSystem::Stat(const string& fname, FileStatistics* stat) {
uint64 size;
stat->is_directory = DirectoryExists(fname);
TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
stat->length = size;
return Status::OK();
}
string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) {
return RemoveSuffix(RemoveAssetPrefix(fname), "/");
}
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
StringPiece piece(name);
absl::ConsumePrefix(&piece, prefix_);
return string(piece);
}
bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) {
std::string path = NormalizeDirectoryPath(fname);
auto dir =
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
// Note that openDir will return something even if the directory doesn't
// exist. Therefore, we need to ensure one file exists in the folder.
return AAssetDir_getNextFileName(dir.get()) != NULL;
}
Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern,
std::vector<string>* results) {
return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
}
Status AssetManagerFileSystem::NewWritableFile(
const string& fname, std::unique_ptr<WritableFile>* result) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::NewAppendableFile(
const string& fname, std::unique_ptr<WritableFile>* result) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::DeleteFile(const string& f) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::CreateDir(const string& d) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::DeleteDir(const string& d) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::RenameFile(const string& s, const string& t) {
return errors::Unimplemented("Asset storage is read only.");
}
} // namespace tensorflow

View File

@ -0,0 +1,85 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
#define TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include "tensorflow/core/platform/file_system.h"
namespace tensorflow {
// FileSystem that uses Android's AAssetManager. Once initialized with a given
// AAssetManager, files in the given AAssetManager can be accessed through the
// prefix given when registered with the TensorFlow Env.
// Note that because APK assets are immutable, any operation that tries to
// modify the FileSystem will return tensorflow::error::code::UNIMPLEMENTED.
class AssetManagerFileSystem : public FileSystem {
public:
// Initialize an AssetManagerFileSystem. Note that this does not register the
// file system with TensorFlow.
// asset_manager - Non-null Android AAssetManager that backs this file
// system. The asset manager is not owned by this file system, and must
// outlive this class.
// prefix - Common prefix to strip from all file URIs before passing them to
// the asset_manager. This is required because TensorFlow gives the entire
// file URI (file:///my_dir/my_file.txt) and AssetManager only knows paths
// relative to its base directory.
AssetManagerFileSystem(AAssetManager* asset_manager, const string& prefix);
~AssetManagerFileSystem() override = default;
Status FileExists(const string& fname) override;
Status NewRandomAccessFile(
const string& filename,
std::unique_ptr<RandomAccessFile>* result) override;
Status NewReadOnlyMemoryRegionFromFile(
const string& filename,
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
Status GetFileSize(const string& f, uint64* s) override;
// Currently just returns size.
Status Stat(const string& fname, FileStatistics* stat) override;
Status GetChildren(const string& dir, std::vector<string>* r) override;
// All these functions return Unimplemented error. Asset storage is
// read only.
Status NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) override;
Status NewAppendableFile(const string& fname,
std::unique_ptr<WritableFile>* result) override;
Status DeleteFile(const string& f) override;
Status CreateDir(const string& d) override;
Status DeleteDir(const string& d) override;
Status RenameFile(const string& s, const string& t) override;
Status GetMatchingPaths(const string& pattern,
std::vector<string>* results) override;
private:
string RemoveAssetPrefix(const string& name);
// Return a string path that can be passed into AAssetManager functions.
// For example, 'my_prefix://some/dir/' would return 'some/dir'.
string NormalizeDirectoryPath(const string& fname);
bool DirectoryExists(const std::string& fname);
AAssetManager* asset_manager_;
string prefix_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_

View File

@ -0,0 +1,80 @@
#
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cmake_minimum_required(VERSION 3.4.1)
include(ExternalProject)
# TENSORFLOW_ROOT_DIR:
# root directory of tensorflow repo
# used for shared source files and pre-built libs
get_filename_component(TENSORFLOW_ROOT_DIR ../../../.. ABSOLUTE)
set(PREBUILT_DIR ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen)
add_library(lib_proto STATIC IMPORTED )
set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION
${PREBUILT_DIR}/protobuf/lib/libprotobuf.a)
add_library(lib_nsync STATIC IMPORTED )
set_target_properties(lib_nsync PROPERTIES IMPORTED_LOCATION
${TARGET_NSYNC_LIB}/lib/libnsync.a)
add_library(lib_tf STATIC IMPORTED )
set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
${PREBUILT_DIR}/lib/libtensorflow-core.a)
# Change to compile flags should be replicated into bazel build file
# TODO: Consider options other than -O2 for binary size.
# e.g. -Os for gcc, and -Oz for clang.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \
-std=c++11 -fno-rtti -fno-exceptions \
-O2 -Wno-narrowing -fomit-frame-pointer \
-mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \
-ftemplate-depth=900 \
-DGOOGLE_PROTOBUF_NO_RTTI \
-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
-Wl,--allow-multiple-definition \
-Wl,--whole-archive \
-fPIE -pie -v")
file(GLOB tensorflow_inference_sources
${CMAKE_CURRENT_SOURCE_DIR}/../jni/*.cc)
file(GLOB java_api_native_sources
${TENSORFLOW_ROOT_DIR}/tensorflow/java/src/main/native/*.cc)
add_library(tensorflow_inference SHARED
${tensorflow_inference_sources}
${TENSORFLOW_ROOT_DIR}/tensorflow/c/tf_status_helper.cc
${TENSORFLOW_ROOT_DIR}/tensorflow/c/checkpoint_reader.cc
${TENSORFLOW_ROOT_DIR}/tensorflow/c/test_op.cc
${TENSORFLOW_ROOT_DIR}/tensorflow/c/c_api.cc
${java_api_native_sources})
# Include libraries needed for hello-jni lib
target_link_libraries(tensorflow_inference
android
dl
log
m
z
lib_tf
lib_proto
lib_nsync)
include_directories(
${PREBUILT_DIR}/proto
${PREBUILT_DIR}/protobuf/include
${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen
${TENSORFLOW_ROOT_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/..)

View File

@ -0,0 +1,48 @@
TensorFlow-Android-Inference
============================
This directory contains CMake support for building the Android Java Inference
interface to the TensorFlow native APIs.
See [tensorflow/tools/android/inference_interface](..) for more details about
the library, and instructions for building with Bazel.
Usage
-----
Add TensorFlow-Android-Inference as a dependency of your Android application
* settings.gradle
```
include ':TensorFlow-Android-Inference'
findProject(":TensorFlow-Android-Inference").projectDir =
new File("${/path/to/tensorflow_repo}/examples/android_inference_interface/cmake")
```
* application's build.gradle (adding dependency):
```
debugCompile project(path: ':tensorflow_inference', configuration: 'debug')
releaseCompile project(path: ':tensorflow_inference', configuration: 'release')
```
Note: this makes native code in the lib traceable from your app.
Dependencies
------------
TensorFlow-Android-Inference depends on the TensorFlow static libs already built
in your local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh
is used in build.gradle to build it. It DOES take time to build the core libs;
so, by default, it is commented out to avoid confusion (otherwise
Android Studio would appear to hang during opening the project).
To enable it, refer to the comment in
* build.gradle
Output
------
- TensorFlow-Inference-debug.aar
- TensorFlow-Inference-release.aar
File libtensorflow_inference.so should be packed under jni/${ANDROID_ABI}/
in the above aar, and it is transparent to the app as it will access them via
equivalent java APIs.

View File

@ -0,0 +1,105 @@
apply plugin: 'com.android.library'
// TensorFlow repo root dir on local machine
def TF_SRC_DIR = projectDir.toString() + "/../../../.."
android {
compileSdkVersion 24
// Check local build_tools_version as this is liable to change within Android Studio.
buildToolsVersion '25.0.2'
// for debugging native code purpose
publishNonDefault true
defaultConfig {
archivesBaseName = "Tensorflow-Android-Inference"
minSdkVersion 21
targetSdkVersion 23
versionCode 1
versionName "1.0"
ndk {
abiFilters 'armeabi-v7a'
}
externalNativeBuild {
cmake {
arguments '-DANDROID_TOOLCHAIN=clang',
'-DANDROID_STL=c++_static'
}
}
}
sourceSets {
main {
java {
srcDir "${TF_SRC_DIR}/tensorflow/tools/android/inference_interface/java"
srcDir "${TF_SRC_DIR}/tensorflow/java/src/main/java"
exclude '**/examples/**'
}
}
}
externalNativeBuild {
cmake {
path 'CMakeLists.txt'
}
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'),
'proguard-rules.pro'
}
}
}
// Build libtensorflow-core.a if necessary
// Note: the environment needs to be set up already
// [ such as installing autoconfig, make, etc ]
// How to use:
// 1) install all of the necessary tools to build libtensorflow-core.a
// 2) inside Android Studio IDE, uncomment buildTensorFlow in
// whenTaskAdded{...}
// 3) re-sync and re-build. It could take a long time if NOT building
// with multiple processes.
import org.apache.tools.ant.taskdefs.condition.Os
Properties properties = new Properties()
properties.load(project.rootProject.file('local.properties')
.newDataInputStream())
def ndkDir = properties.getProperty('ndk.dir')
if (ndkDir == null || ndkDir == "") {
ndkDir = System.getenv('ANDROID_NDK_HOME')
}
if (!Os.isFamily(Os.FAMILY_WINDOWS)) {
// This script is for non-Windows OS. For Windows OS, MANUALLY build
// (or copy the built) libs/headers to the
// ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen
// refer to CMakeLists.txt about lib and header directories for details
task buildTensorflow(type: Exec) {
group 'buildTensorflowLib'
workingDir getProjectDir().toString() + '/../../../../'
environment PATH: '/opt/local/bin:/opt/local/sbin:' +
System.getenv('PATH')
environment NDK_ROOT: ndkDir
commandLine 'tensorflow/contrib/makefile/build_all_android.sh'
}
tasks.whenTaskAdded { task ->
group 'buildTensorflowLib'
if (task.name.toLowerCase().contains('sources')) {
def tensorflowTarget = new File(getProjectDir().toString() +
'/../../makefile/gen/lib/libtensorflow-core.a')
if (!tensorflowTarget.exists()) {
// Note:
// just uncomment this line to use it:
// it can take long time to build by default
// it is disabled to avoid false first impression
task.dependsOn buildTensorflow
}
}
}
}
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
}

View File

@ -0,0 +1,13 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.contrib.android">
<uses-sdk
android:minSdkVersion="4"
android:targetSdkVersion="19" />
<application android:allowBackup="true" android:label="@string/app_name"
android:supportsRtl="true">
</application>
</manifest>

View File

@ -0,0 +1,3 @@
<resources>
<string name="app_name">TensorFlowInference</string>
</resources>

View File

@ -0,0 +1,63 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.contrib.android;
/** Accumulate and analyze stats from metadata obtained from Session.Runner.run. */
public class RunStats implements AutoCloseable {
/**
* Options to be provided to a {@link org.tensorflow.Session.Runner} to enable stats accumulation.
*/
public static byte[] runOptions() {
return fullTraceRunOptions;
}
public RunStats() {
nativeHandle = allocate();
}
@Override
public void close() {
if (nativeHandle != 0) {
delete(nativeHandle);
}
nativeHandle = 0;
}
/** Accumulate stats obtained when executing a graph. */
public synchronized void add(byte[] runMetadata) {
add(nativeHandle, runMetadata);
}
/** Summary of the accumulated runtime stats. */
public synchronized String summary() {
return summary(nativeHandle);
}
private long nativeHandle;
// Hack: This is what a serialized RunOptions protocol buffer with trace_level: FULL_TRACE ends
// up as.
private static byte[] fullTraceRunOptions = new byte[] {0x08, 0x03};
private static native long allocate();
private static native void delete(long handle);
private static native void add(long handle, byte[] runMetadata);
private static native String summary(long handle);
}

View File

@ -0,0 +1,650 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.contrib.android;
import android.content.res.AssetManager;
import android.os.Build.VERSION;
import android.os.Trace;
import android.text.TextUtils;
import android.util.Log;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.Tensors;
import org.tensorflow.types.UInt8;
/**
* Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
* for inference.
*
* <p>See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an
* example usage.
*/
public class TensorFlowInferenceInterface {
private static final String TAG = "TensorFlowInferenceInterface";
private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
/*
* Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
*
* @param assetManager The AssetManager to use to load the model file.
* @param model The filepath to the GraphDef proto representing the model.
*/
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
prepareNativeRuntime();
this.modelName = model;
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX);
InputStream is = null;
try {
String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model;
is = assetManager.open(aname);
} catch (IOException e) {
if (hasAssetPrefix) {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
// Perhaps the model file is not an asset but is on disk.
try {
is = new FileInputStream(model);
} catch (IOException e2) {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
try {
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
byte[] graphDef = new byte[is.available()];
final int numBytesRead = is.read(graphDef);
if (numBytesRead != graphDef.length) {
throw new IOException(
"read error: read only "
+ numBytesRead
+ " of the graph, expected to read "
+ graphDef.length);
}
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // readGraphDef.
}
loadGraph(graphDef, g);
is.close();
Log.i(TAG, "Successfully loaded model from '" + model + "'");
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
} catch (IOException e) {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
/*
* Load a TensorFlow model from provided InputStream.
* Note: The InputStream will not be closed after loading model, users need to
* close it themselves.
*
* @param is The InputStream to use to load the model.
*/
public TensorFlowInferenceInterface(InputStream is) {
prepareNativeRuntime();
// modelName is redundant for model loading from input stream, here is for
// avoiding error in initialization as modelName is marked final.
this.modelName = "";
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
try {
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
int numBytesRead;
byte[] buf = new byte[16384];
while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
baos.write(buf, 0, numBytesRead);
}
byte[] graphDef = baos.toByteArray();
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // readGraphDef.
}
loadGraph(graphDef, g);
Log.i(TAG, "Successfully loaded model from the input stream");
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
} catch (IOException e) {
throw new RuntimeException("Failed to load model from the input stream", e);
}
}
/*
* Construct a TensorFlowInferenceInterface with provided Graph
*
* @param g The Graph to use to construct this interface.
*/
public TensorFlowInferenceInterface(Graph g) {
prepareNativeRuntime();
// modelName is redundant here, here is for
// avoiding error in initialization as modelName is marked final.
this.modelName = "";
this.g = g;
this.sess = new Session(g);
this.runner = sess.runner();
}
/**
* Runs inference between the previously registered input nodes (via feed*) and the requested
* output nodes. Output nodes can then be queried with the fetch* methods.
*
* @param outputNames A list of output nodes which should be filled by the inference pass.
*/
public void run(String[] outputNames) {
run(outputNames, false);
}
/**
* Runs inference between the previously registered input nodes (via feed*) and the requested
* output nodes. Output nodes can then be queried with the fetch* methods.
*
* @param outputNames A list of output nodes which should be filled by the inference pass.
*/
public void run(String[] outputNames, boolean enableStats) {
run(outputNames, enableStats, new String[] {});
}
/** An overloaded version of runInference that allows supplying targetNodeNames as well */
public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) {
// Release any Tensors from the previous run calls.
closeFetches();
// Add fetches.
for (String o : outputNames) {
fetchNames.add(o);
TensorId tid = TensorId.parse(o);
runner.fetch(tid.name, tid.outputIndex);
}
// Add targets.
for (String t : targetNodeNames) {
runner.addTarget(t);
}
// Run the session.
try {
if (enableStats) {
Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
fetchTensors = r.outputs;
if (runStats == null) {
runStats = new RunStats();
}
runStats.add(r.metadata);
} else {
fetchTensors = runner.run();
}
} catch (RuntimeException e) {
// Ideally the exception would have been let through, but since this interface predates the
// TensorFlow Java API, must return -1.
Log.e(
TAG,
"Failed to run TensorFlow inference with inputs:["
+ TextUtils.join(", ", feedNames)
+ "], outputs:["
+ TextUtils.join(", ", fetchNames)
+ "]");
throw e;
} finally {
// Always release the feeds (to save resources) and reset the runner, this run is
// over.
closeFeeds();
runner = sess.runner();
}
}
/** Returns a reference to the Graph describing the computation run during inference. */
public Graph graph() {
return g;
}
public Operation graphOperation(String operationName) {
final Operation operation = g.operation(operationName);
if (operation == null) {
throw new RuntimeException(
"Node '" + operationName + "' does not exist in model '" + modelName + "'");
}
return operation;
}
/** Returns the last stat summary string if logging is enabled. */
public String getStatString() {
return (runStats == null) ? "" : runStats.summary();
}
/**
* Cleans up the state associated with this Object.
*
* <p>The TenosrFlowInferenceInterface object is no longer usable after this method returns.
*/
public void close() {
closeFeeds();
closeFetches();
sess.close();
g.close();
if (runStats != null) {
runStats.close();
}
runStats = null;
}
@Override
protected void finalize() throws Throwable {
try {
close();
} finally {
super.finalize();
}
}
// Methods for taking a native Tensor and filling it with values from Java arrays.
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, boolean[] src, long... dims) {
byte[] b = new byte[src.length];
for (int i = 0; i < src.length; i++) {
b[i] = src[i] ? (byte) 1 : (byte) 0;
}
addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, float[] src, long... dims) {
addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, int[] src, long... dims) {
addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, long[] src, long... dims) {
addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, double[] src, long... dims) {
addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, byte[] src, long... dims) {
addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
}
/**
* Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
* scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
* a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[] src) {
addFeed(inputName, Tensors.create(src));
}
/**
* Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
* string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[][] src) {
addFeed(inputName, Tensors.create(src));
}
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, FloatBuffer src, long... dims) {
addFeed(inputName, Tensor.create(dims, src));
}
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, IntBuffer src, long... dims) {
addFeed(inputName, Tensor.create(dims, src));
}
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, LongBuffer src, long... dims) {
addFeed(inputName, Tensor.create(dims, src));
}
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, DoubleBuffer src, long... dims) {
addFeed(inputName, Tensor.create(dims, src));
}
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, ByteBuffer src, long... dims) {
addFeed(inputName, Tensor.create(UInt8.class, dims, src));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, float[] dst) {
fetch(outputName, FloatBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, int[] dst) {
fetch(outputName, IntBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, long[] dst) {
fetch(outputName, LongBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, double[] dst) {
fetch(outputName, DoubleBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, byte[] dst) {
fetch(outputName, ByteBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, FloatBuffer dst) {
getTensor(outputName).writeTo(dst);
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, IntBuffer dst) {
getTensor(outputName).writeTo(dst);
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, LongBuffer dst) {
getTensor(outputName).writeTo(dst);
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, DoubleBuffer dst) {
getTensor(outputName).writeTo(dst);
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, ByteBuffer dst) {
getTensor(outputName).writeTo(dst);
}
private void prepareNativeRuntime() {
Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
try {
// Hack to see if the native libraries have been loaded.
new RunStats();
Log.i(TAG, "TensorFlow native methods already loaded");
} catch (UnsatisfiedLinkError e1) {
Log.i(
TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
try {
System.loadLibrary("tensorflow_inference");
Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
} catch (UnsatisfiedLinkError e2) {
throw new RuntimeException(
"Native TF methods not found; check that the correct native"
+ " libraries are present in the APK.");
}
}
}
private void loadGraph(byte[] graphDef, Graph g) throws IOException {
final long startMs = System.currentTimeMillis();
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("importGraphDef");
}
try {
g.importGraphDef(graphDef);
} catch (IllegalArgumentException e) {
throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
}
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // importGraphDef.
}
final long endMs = System.currentTimeMillis();
Log.i(
TAG,
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
}
private void addFeed(String inputName, Tensor<?> t) {
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
TensorId tid = TensorId.parse(inputName);
runner.feed(tid.name, tid.outputIndex, t);
feedNames.add(inputName);
feedTensors.add(t);
}
private static class TensorId {
String name;
int outputIndex;
// Parse output names into a TensorId.
//
// E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1)
public static TensorId parse(String name) {
TensorId tid = new TensorId();
int colonIndex = name.lastIndexOf(':');
if (colonIndex < 0) {
tid.outputIndex = 0;
tid.name = name;
return tid;
}
try {
tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1));
tid.name = name.substring(0, colonIndex);
} catch (NumberFormatException e) {
tid.outputIndex = 0;
tid.name = name;
}
return tid;
}
}
private Tensor<?> getTensor(String outputName) {
int i = 0;
for (String n : fetchNames) {
if (n.equals(outputName)) {
return fetchTensors.get(i);
}
++i;
}
throw new RuntimeException(
"Node '" + outputName + "' was not provided to run(), so it cannot be read");
}
private void closeFeeds() {
for (Tensor<?> t : feedTensors) {
t.close();
}
feedTensors.clear();
feedNames.clear();
}
private void closeFetches() {
for (Tensor<?> t : fetchTensors) {
t.close();
}
fetchTensors.clear();
fetchNames.clear();
}
// Immutable state.
private final String modelName;
private final Graph g;
private final Session sess;
// State reset on every call to run.
private Session.Runner runner;
private List<String> feedNames = new ArrayList<String>();
private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
private List<String> fetchNames = new ArrayList<String>();
private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();
// Mutable state.
private RunStats runStats;
}

View File

@ -0,0 +1,83 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/tools/android/inference_interface/jni/run_stats_jni.h"
#include <jni.h>
#include <sstream>
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/stat_summarizer.h"
using tensorflow::RunMetadata;
using tensorflow::StatSummarizer;
namespace {
StatSummarizer* requireHandle(JNIEnv* env, jlong handle) {
if (handle == 0) {
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
"close() has been called on the RunStats object");
return nullptr;
}
return reinterpret_cast<StatSummarizer*>(handle);
}
} // namespace
#define RUN_STATS_METHOD(name) \
JNICALL Java_org_tensorflow_contrib_android_RunStats_##name
JNIEXPORT jlong RUN_STATS_METHOD(allocate)(JNIEnv* env, jclass clazz) {
static_assert(sizeof(jlong) >= sizeof(StatSummarizer*),
"Cannot package C++ object pointers as a Java long");
tensorflow::StatSummarizerOptions opts;
return reinterpret_cast<jlong>(new StatSummarizer(opts));
}
JNIEXPORT void RUN_STATS_METHOD(delete)(JNIEnv* env, jclass clazz,
jlong handle) {
if (handle == 0) return;
delete reinterpret_cast<StatSummarizer*>(handle);
}
JNIEXPORT void RUN_STATS_METHOD(add)(JNIEnv* env, jclass clazz, jlong handle,
jbyteArray run_metadata) {
StatSummarizer* s = requireHandle(env, handle);
if (s == nullptr) return;
jbyte* data = env->GetByteArrayElements(run_metadata, nullptr);
int size = static_cast<int>(env->GetArrayLength(run_metadata));
tensorflow::RunMetadata proto;
if (!proto.ParseFromArray(data, size)) {
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
"runMetadata does not seem to be a serialized RunMetadata "
"protocol message");
} else if (proto.has_step_stats()) {
s->ProcessStepStats(proto.step_stats());
}
env->ReleaseByteArrayElements(run_metadata, data, JNI_ABORT);
}
JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz,
jlong handle) {
StatSummarizer* s = requireHandle(env, handle);
if (s == nullptr) return nullptr;
std::stringstream ret;
ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME,
10)
<< s->GetStatsByNodeType() << s->ShortSummary();
return env->NewStringUTF(ret.str().c_str());
}
#undef RUN_STATS_METHOD

View File

@ -0,0 +1,40 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
#define ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
#define RUN_STATS_METHOD(name) \
Java_org_tensorflow_contrib_android_RunStats_##name
JNIEXPORT JNICALL jlong RUN_STATS_METHOD(allocate)(JNIEnv*, jclass);
JNIEXPORT JNICALL void RUN_STATS_METHOD(delete)(JNIEnv*, jclass, jlong);
JNIEXPORT JNICALL void RUN_STATS_METHOD(add)(JNIEnv*, jclass, jlong,
jbyteArray);
JNIEXPORT JNICALL jstring RUN_STATS_METHOD(summary)(JNIEnv*, jclass, jlong);
#undef RUN_STATS_METHOD
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_

View File

@ -0,0 +1,11 @@
VERS_1.0 {
# Export JNI symbols.
global:
Java_*;
JNI_OnLoad;
JNI_OnUnload;
# Hide everything else.
local:
*;
};