Copy Android Inference Interface out of contrib.
Point tensorflow/examples/android at the copy. PiperOrigin-RevId: 269423640
This commit is contained in:
parent
b056951d0d
commit
b4d110caee
@ -53,7 +53,7 @@ cc_library(
|
|||||||
name = "tensorflow_native_libs",
|
name = "tensorflow_native_libs",
|
||||||
srcs = [
|
srcs = [
|
||||||
":libtensorflow_demo.so",
|
":libtensorflow_demo.so",
|
||||||
"//tensorflow/contrib/android:libtensorflow_inference.so",
|
"//tensorflow/tools/android/inference_interface:libtensorflow_inference.so",
|
||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
@ -84,7 +84,7 @@ android_binary(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":tensorflow_native_libs",
|
":tensorflow_native_libs",
|
||||||
"//tensorflow/contrib/android:android_tensorflow_inference_java",
|
"//tensorflow/tools/android/inference_interface:android_tensorflow_inference_java",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ The demos in this folder are designed to give straightforward samples of using
|
|||||||
TensorFlow in mobile applications.
|
TensorFlow in mobile applications.
|
||||||
|
|
||||||
Inference is done using the [TensorFlow Android Inference
|
Inference is done using the [TensorFlow Android Inference
|
||||||
Interface](../../../tensorflow/contrib/android), which may be built separately
|
Interface](../../tools/android/inference_interface), which may be built
|
||||||
if you want a standalone library to drop into your existing application. Object
|
separately if you want a standalone library to drop into your existing
|
||||||
tracking and efficient YUV -> RGB conversion are handled by
|
application. Object tracking and efficient YUV -> RGB conversion are handled by
|
||||||
`libtensorflow_demo.so`.
|
`libtensorflow_demo.so`.
|
||||||
|
|
||||||
A device running Android 5.0 (API 21) or higher is required to run the demo due
|
A device running Android 5.0 (API 21) or higher is required to run the demo due
|
||||||
@ -49,7 +49,7 @@ The fastest path to trying the demo is to download the [prebuilt demo APK](https
|
|||||||
|
|
||||||
Also available are precompiled native libraries, and a jcenter package that you
|
Also available are precompiled native libraries, and a jcenter package that you
|
||||||
may simply drop into your own applications. See
|
may simply drop into your own applications. See
|
||||||
[tensorflow/contrib/android/README.md](../../../tensorflow/contrib/android/README.md)
|
[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
|
||||||
for more details.
|
for more details.
|
||||||
|
|
||||||
## Running the Demo
|
## Running the Demo
|
||||||
@ -89,7 +89,7 @@ For any project that does not include custom low level TensorFlow code, this is
|
|||||||
likely sufficient.
|
likely sufficient.
|
||||||
|
|
||||||
For details on how to include this JCenter package in your own project see
|
For details on how to include this JCenter package in your own project see
|
||||||
[tensorflow/contrib/android/README.md](../../../tensorflow/contrib/android/README.md)
|
[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
|
||||||
|
|
||||||
## Building the Demo with TensorFlow from Source
|
## Building the Demo with TensorFlow from Source
|
||||||
|
|
||||||
@ -212,4 +212,4 @@ NDK).
|
|||||||
|
|
||||||
Full CMake support for the demo is coming soon, but for now it is possible to
|
Full CMake support for the demo is coming soon, but for now it is possible to
|
||||||
build the TensorFlow Android Inference library using
|
build the TensorFlow Android Inference library using
|
||||||
[tensorflow/contrib/android/cmake](../../../tensorflow/contrib/android/cmake).
|
[tensorflow/tools/android/inference_interface/cmake](../../tools/android/inference_interface/cmake).
|
||||||
|
@ -41,6 +41,7 @@ filegroup(
|
|||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow/contrib/android:__pkg__",
|
"//tensorflow/contrib/android:__pkg__",
|
||||||
"//tensorflow/java:__pkg__",
|
"//tensorflow/java:__pkg__",
|
||||||
|
"//tensorflow/tools/android/inference_interface:__pkg__",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -5,11 +5,12 @@
|
|||||||
package(default_visibility = [
|
package(default_visibility = [
|
||||||
"//tensorflow/java:__pkg__",
|
"//tensorflow/java:__pkg__",
|
||||||
# TODO(ashankar): Temporary hack for the Java API and
|
# TODO(ashankar): Temporary hack for the Java API and
|
||||||
# //third_party/tensorflow/contrib/android:android_tensorflow_inference_jni
|
# //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_inference_jni
|
||||||
# to co-exist in a single shared library. However, the hope is that
|
# to co-exist in a single shared library. However, the hope is that
|
||||||
# //third_party/tensorflow/contrib/android:android_tensorflow_jni can be
|
# //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_jni can be
|
||||||
# removed once the Java API provides feature parity with it.
|
# removed once the Java API provides feature parity with it.
|
||||||
"//tensorflow/contrib/android:__pkg__",
|
"//tensorflow/contrib/android:__pkg__",
|
||||||
|
"//tensorflow/tools/android/inference_interface:__pkg__",
|
||||||
])
|
])
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
4
tensorflow/tools/android/README.md
Normal file
4
tensorflow/tools/android/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Deprecated android inference interface.
|
||||||
|
|
||||||
|
WARNING: This directory contains deprecated tf-mobile android inference
|
||||||
|
interface do not use this for anything new. Use TFLite.
|
89
tensorflow/tools/android/inference_interface/BUILD
Normal file
89
tensorflow/tools/android/inference_interface/BUILD
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
# Description:
|
||||||
|
# JNI-based Java inference interface for TensorFlow.
|
||||||
|
|
||||||
|
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"if_android",
|
||||||
|
"tf_cc_binary",
|
||||||
|
"tf_copts",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//visibility:public"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
exports_files([
|
||||||
|
"LICENSE",
|
||||||
|
"jni/version_script.lds",
|
||||||
|
])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "android_tensorflow_inference_jni_srcs",
|
||||||
|
srcs = glob([
|
||||||
|
"**/*.cc",
|
||||||
|
"**/*.h",
|
||||||
|
]),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "android_tensorflow_inference_jni",
|
||||||
|
srcs = if_android([":android_tensorflow_inference_jni_srcs"]),
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||||
|
"//tensorflow/java/src/main/native",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# JAR with Java bindings to TF.
|
||||||
|
android_library(
|
||||||
|
name = "android_tensorflow_inference_java",
|
||||||
|
srcs = glob(["java/**/*.java"]) + ["//tensorflow/java:java_sources"],
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the native .so.
|
||||||
|
# bazel build //tensorflow/tools/android/inference_interface:libtensorflow_inference.so \
|
||||||
|
# --crosstool_top=//external:android/crosstool \
|
||||||
|
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||||
|
# --cpu=armeabi-v7a
|
||||||
|
LINKER_SCRIPT = "//tensorflow/tools/android/inference_interface:jni/version_script.lds"
|
||||||
|
|
||||||
|
# This fails to buiild if converted to tf_cc_binary.
|
||||||
|
cc_binary(
|
||||||
|
name = "libtensorflow_inference.so",
|
||||||
|
copts = tf_copts() + [
|
||||||
|
"-ffunction-sections",
|
||||||
|
"-fdata-sections",
|
||||||
|
],
|
||||||
|
linkopts = if_android([
|
||||||
|
"-landroid",
|
||||||
|
"-latomic",
|
||||||
|
"-ldl",
|
||||||
|
"-llog",
|
||||||
|
"-lm",
|
||||||
|
"-z defs",
|
||||||
|
"-s",
|
||||||
|
"-Wl,--gc-sections",
|
||||||
|
"-Wl,--version-script,$(location {})".format(LINKER_SCRIPT),
|
||||||
|
]),
|
||||||
|
linkshared = 1,
|
||||||
|
linkstatic = 1,
|
||||||
|
tags = [
|
||||||
|
"manual",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":android_tensorflow_inference_jni",
|
||||||
|
"//tensorflow/core:android_tensorflow_lib",
|
||||||
|
LINKER_SCRIPT,
|
||||||
|
],
|
||||||
|
)
|
95
tensorflow/tools/android/inference_interface/README.md
Normal file
95
tensorflow/tools/android/inference_interface/README.md
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# Android TensorFlow support
|
||||||
|
|
||||||
|
This directory defines components (a native `.so` library and a Java JAR)
|
||||||
|
geared towards supporting TensorFlow on Android. This includes:
|
||||||
|
|
||||||
|
- The [TensorFlow Java API](../../java/README.md)
|
||||||
|
- A `TensorFlowInferenceInterface` class that provides a smaller API
|
||||||
|
surface suitable for inference and summarizing performance of model execution.
|
||||||
|
|
||||||
|
For example usage, see [TensorFlowImageClassifier.java](../../examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java)
|
||||||
|
in the [TensorFlow Android Demo](../../examples/android).
|
||||||
|
|
||||||
|
For prebuilt libraries, see the
|
||||||
|
[nightly Android build artifacts](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)
|
||||||
|
page for a recent build.
|
||||||
|
|
||||||
|
The TensorFlow Inference Interface is also available as a
|
||||||
|
[JCenter package](https://bintray.com/google/tensorflow/tensorflow)
|
||||||
|
(see the tensorflow-android directory) and can be included quite simply in your
|
||||||
|
android project with a couple of lines in the project's `build.gradle` file:
|
||||||
|
|
||||||
|
```
|
||||||
|
allprojects {
|
||||||
|
repositories {
|
||||||
|
jcenter()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
compile 'org.tensorflow:tensorflow-android:+'
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This will tell Gradle to use the
|
||||||
|
[latest version](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||||
|
of the TensorFlow AAR that has been released to
|
||||||
|
[JCenter](https://jcenter.bintray.com/org/tensorflow/tensorflow-android/).
|
||||||
|
You may replace the `+` with an explicit version label if you wish to
|
||||||
|
use a specific release of TensorFlow in your app.
|
||||||
|
|
||||||
|
To build the libraries yourself (if, for example, you want to support custom
|
||||||
|
TensorFlow operators), pick your preferred approach below:
|
||||||
|
|
||||||
|
### Bazel
|
||||||
|
|
||||||
|
First follow the Bazel setup instructions described in
|
||||||
|
[tensorflow/examples/android/README.md](../../examples/android/README.md)
|
||||||
|
|
||||||
|
Then, to build the native TF library:
|
||||||
|
|
||||||
|
```
|
||||||
|
bazel build -c opt //tensorflow/tools/android/inference_interface:libtensorflow_inference.so \
|
||||||
|
--crosstool_top=//external:android/crosstool \
|
||||||
|
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||||
|
--cxxopt=-std=c++11 \
|
||||||
|
--cpu=armeabi-v7a
|
||||||
|
```
|
||||||
|
|
||||||
|
Replacing `armeabi-v7a` with your desired target architecture.
|
||||||
|
|
||||||
|
The library will be located at:
|
||||||
|
|
||||||
|
```
|
||||||
|
bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so
|
||||||
|
```
|
||||||
|
|
||||||
|
To build the Java counterpart:
|
||||||
|
|
||||||
|
```
|
||||||
|
bazel build //tensorflow/tools/android/inference_interface:android_tensorflow_inference_java
|
||||||
|
```
|
||||||
|
|
||||||
|
You will find the JAR file at:
|
||||||
|
|
||||||
|
```
|
||||||
|
bazel-bin/tensorflow/tools/android/inference_interface/libandroid_tensorflow_inference_java.jar
|
||||||
|
```
|
||||||
|
|
||||||
|
### CMake
|
||||||
|
|
||||||
|
For documentation on building a self-contained AAR file with cmake, see
|
||||||
|
[tensorflow/tools/android/inference_interface/cmake](cmake).
|
||||||
|
|
||||||
|
|
||||||
|
### Makefile
|
||||||
|
|
||||||
|
For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md)
|
||||||
|
|
||||||
|
|
||||||
|
## AssetManagerFileSystem
|
||||||
|
|
||||||
|
This directory also contains a TensorFlow filesystem supporting the Android
|
||||||
|
asset manager. This may be useful when writing native (C++) code that is tightly
|
||||||
|
coupled with TensorFlow. For typical usage, the library above will be
|
||||||
|
sufficient.
|
@ -0,0 +1,272 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/android/inference_interface/asset_manager_filesystem.h"
|
||||||
|
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/file_system_helper.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
string RemoveSuffix(const string& name, const string& suffix) {
|
||||||
|
string output(name);
|
||||||
|
StringPiece piece(output);
|
||||||
|
absl::ConsumeSuffix(&piece, suffix);
|
||||||
|
return string(piece);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Closes the given AAsset when variable is destructed.
|
||||||
|
class ScopedAsset {
|
||||||
|
public:
|
||||||
|
ScopedAsset(AAsset* asset) : asset_(asset) {}
|
||||||
|
~ScopedAsset() {
|
||||||
|
if (asset_ != nullptr) {
|
||||||
|
AAsset_close(asset_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AAsset* get() const { return asset_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
AAsset* asset_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Closes the given AAssetDir when variable is destructed.
|
||||||
|
class ScopedAssetDir {
|
||||||
|
public:
|
||||||
|
ScopedAssetDir(AAssetDir* asset_dir) : asset_dir_(asset_dir) {}
|
||||||
|
~ScopedAssetDir() {
|
||||||
|
if (asset_dir_ != nullptr) {
|
||||||
|
AAssetDir_close(asset_dir_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AAssetDir* get() const { return asset_dir_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
AAssetDir* asset_dir_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReadOnlyMemoryRegionFromAsset : public ReadOnlyMemoryRegion {
|
||||||
|
public:
|
||||||
|
ReadOnlyMemoryRegionFromAsset(std::unique_ptr<char[]> data, uint64 length)
|
||||||
|
: data_(std::move(data)), length_(length) {}
|
||||||
|
~ReadOnlyMemoryRegionFromAsset() override = default;
|
||||||
|
|
||||||
|
const void* data() override { return reinterpret_cast<void*>(data_.get()); }
|
||||||
|
uint64 length() override { return length_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<char[]> data_;
|
||||||
|
uint64 length_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Note that AAssets are not thread-safe and cannot be used across threads.
|
||||||
|
// However, AAssetManager is. Because RandomAccessFile must be thread-safe and
|
||||||
|
// used across threads, new AAssets must be created for every access.
|
||||||
|
// TODO(tylerrhodes): is there a more efficient way to do this?
|
||||||
|
class RandomAccessFileFromAsset : public RandomAccessFile {
|
||||||
|
public:
|
||||||
|
RandomAccessFileFromAsset(AAssetManager* asset_manager, const string& name)
|
||||||
|
: asset_manager_(asset_manager), file_name_(name) {}
|
||||||
|
~RandomAccessFileFromAsset() override = default;
|
||||||
|
|
||||||
|
Status Read(uint64 offset, size_t to_read, StringPiece* result,
|
||||||
|
char* scratch) const override {
|
||||||
|
auto asset = ScopedAsset(AAssetManager_open(
|
||||||
|
asset_manager_, file_name_.c_str(), AASSET_MODE_RANDOM));
|
||||||
|
if (asset.get() == nullptr) {
|
||||||
|
return errors::NotFound("File ", file_name_, " not found.");
|
||||||
|
}
|
||||||
|
|
||||||
|
off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET);
|
||||||
|
off64_t length = AAsset_getLength64(asset.get());
|
||||||
|
if (new_offset < 0) {
|
||||||
|
*result = StringPiece(scratch, 0);
|
||||||
|
return errors::OutOfRange("Read after file end.");
|
||||||
|
}
|
||||||
|
const off64_t region_left =
|
||||||
|
std::min(length - new_offset, static_cast<off64_t>(to_read));
|
||||||
|
int read = AAsset_read(asset.get(), scratch, region_left);
|
||||||
|
if (read < 0) {
|
||||||
|
return errors::Internal("Error reading from asset.");
|
||||||
|
}
|
||||||
|
*result = StringPiece(scratch, region_left);
|
||||||
|
return (region_left == to_read)
|
||||||
|
? Status::OK()
|
||||||
|
: errors::OutOfRange("Read less bytes than requested.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AAssetManager* asset_manager_;
|
||||||
|
string file_name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
AssetManagerFileSystem::AssetManagerFileSystem(AAssetManager* asset_manager,
|
||||||
|
const string& prefix)
|
||||||
|
: asset_manager_(asset_manager), prefix_(prefix) {}
|
||||||
|
|
||||||
|
Status AssetManagerFileSystem::FileExists(const string& fname) {
|
||||||
|
string path = RemoveAssetPrefix(fname);
|
||||||
|
auto asset = ScopedAsset(
|
||||||
|
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
|
||||||
|
if (asset.get() == nullptr) {
|
||||||
|
return errors::NotFound("File ", fname, " not found.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AssetManagerFileSystem::NewRandomAccessFile(
|
||||||
|
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
|
||||||
|
string path = RemoveAssetPrefix(fname);
|
||||||
|
auto asset = ScopedAsset(
|
||||||
|
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
|
||||||
|
if (asset.get() == nullptr) {
|
||||||
|
return errors::NotFound("File ", fname, " not found.");
|
||||||
|
}
|
||||||
|
result->reset(new RandomAccessFileFromAsset(asset_manager_, path));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AssetManagerFileSystem::NewReadOnlyMemoryRegionFromFile(
|
||||||
|
const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
|
||||||
|
string path = RemoveAssetPrefix(fname);
|
||||||
|
auto asset = ScopedAsset(
|
||||||
|
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_STREAMING));
|
||||||
|
if (asset.get() == nullptr) {
|
||||||
|
return errors::NotFound("File ", fname, " not found.");
|
||||||
|
}
|
||||||
|
|
||||||
|
off64_t start, length;
|
||||||
|
int fd = AAsset_openFileDescriptor64(asset.get(), &start, &length);
|
||||||
|
std::unique_ptr<char[]> data;
|
||||||
|
if (fd >= 0) {
|
||||||
|
data.reset(new char[length]);
|
||||||
|
ssize_t result = pread(fd, data.get(), length, start);
|
||||||
|
if (result < 0) {
|
||||||
|
return errors::Internal("Error reading from file ", fname,
|
||||||
|
" using 'read': ", result);
|
||||||
|
}
|
||||||
|
if (result != length) {
|
||||||
|
return errors::Internal("Expected size does not match size read: ",
|
||||||
|
"Expected ", length, " vs. read ", result);
|
||||||
|
}
|
||||||
|
close(fd);
|
||||||
|
} else {
|
||||||
|
length = AAsset_getLength64(asset.get());
|
||||||
|
data.reset(new char[length]);
|
||||||
|
const void* asset_buffer = AAsset_getBuffer(asset.get());
|
||||||
|
if (asset_buffer == nullptr) {
|
||||||
|
return errors::Internal("Error reading ", fname, " from asset manager.");
|
||||||
|
}
|
||||||
|
memcpy(data.get(), asset_buffer, length);
|
||||||
|
}
|
||||||
|
result->reset(new ReadOnlyMemoryRegionFromAsset(std::move(data), length));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AssetManagerFileSystem::GetChildren(const string& prefixed_dir,
|
||||||
|
std::vector<string>* r) {
|
||||||
|
std::string path = NormalizeDirectoryPath(prefixed_dir);
|
||||||
|
auto dir =
|
||||||
|
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
|
||||||
|
if (dir.get() == nullptr) {
|
||||||
|
return errors::NotFound("Directory ", prefixed_dir, " not found.");
|
||||||
|
}
|
||||||
|
const char* next_file = AAssetDir_getNextFileName(dir.get());
|
||||||
|
while (next_file != nullptr) {
|
||||||
|
r->push_back(next_file);
|
||||||
|
next_file = AAssetDir_getNextFileName(dir.get());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AssetManagerFileSystem::GetFileSize(const string& fname, uint64* s) {
|
||||||
|
// If fname corresponds to a directory, return early. It doesn't map to an
|
||||||
|
// AAsset, and would otherwise return NotFound.
|
||||||
|
if (DirectoryExists(fname)) {
|
||||||
|
*s = 0;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
string path = RemoveAssetPrefix(fname);
|
||||||
|
auto asset = ScopedAsset(
|
||||||
|
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
|
||||||
|
if (asset.get() == nullptr) {
|
||||||
|
return errors::NotFound("File ", fname, " not found.");
|
||||||
|
}
|
||||||
|
*s = AAsset_getLength64(asset.get());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AssetManagerFileSystem::Stat(const string& fname, FileStatistics* stat) {
|
||||||
|
uint64 size;
|
||||||
|
stat->is_directory = DirectoryExists(fname);
|
||||||
|
TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
|
||||||
|
stat->length = size;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) {
|
||||||
|
return RemoveSuffix(RemoveAssetPrefix(fname), "/");
|
||||||
|
}
|
||||||
|
|
||||||
|
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
|
||||||
|
StringPiece piece(name);
|
||||||
|
absl::ConsumePrefix(&piece, prefix_);
|
||||||
|
return string(piece);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) {
|
||||||
|
std::string path = NormalizeDirectoryPath(fname);
|
||||||
|
auto dir =
|
||||||
|
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
|
||||||
|
// Note that openDir will return something even if the directory doesn't
|
||||||
|
// exist. Therefore, we need to ensure one file exists in the folder.
|
||||||
|
return AAssetDir_getNextFileName(dir.get()) != NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern,
|
||||||
|
std::vector<string>* results) {
|
||||||
|
return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AssetManagerFileSystem::NewWritableFile(
|
||||||
|
const string& fname, std::unique_ptr<WritableFile>* result) {
|
||||||
|
return errors::Unimplemented("Asset storage is read only.");
|
||||||
|
}
|
||||||
|
Status AssetManagerFileSystem::NewAppendableFile(
|
||||||
|
const string& fname, std::unique_ptr<WritableFile>* result) {
|
||||||
|
return errors::Unimplemented("Asset storage is read only.");
|
||||||
|
}
|
||||||
|
Status AssetManagerFileSystem::DeleteFile(const string& f) {
|
||||||
|
return errors::Unimplemented("Asset storage is read only.");
|
||||||
|
}
|
||||||
|
Status AssetManagerFileSystem::CreateDir(const string& d) {
|
||||||
|
return errors::Unimplemented("Asset storage is read only.");
|
||||||
|
}
|
||||||
|
Status AssetManagerFileSystem::DeleteDir(const string& d) {
|
||||||
|
return errors::Unimplemented("Asset storage is read only.");
|
||||||
|
}
|
||||||
|
Status AssetManagerFileSystem::RenameFile(const string& s, const string& t) {
|
||||||
|
return errors::Unimplemented("Asset storage is read only.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,85 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
|
||||||
|
#define TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
|
||||||
|
|
||||||
|
#include <android/asset_manager.h>
|
||||||
|
#include <android/asset_manager_jni.h>
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/file_system.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// FileSystem that uses Android's AAssetManager. Once initialized with a given
|
||||||
|
// AAssetManager, files in the given AAssetManager can be accessed through the
|
||||||
|
// prefix given when registered with the TensorFlow Env.
|
||||||
|
// Note that because APK assets are immutable, any operation that tries to
|
||||||
|
// modify the FileSystem will return tensorflow::error::code::UNIMPLEMENTED.
|
||||||
|
class AssetManagerFileSystem : public FileSystem {
|
||||||
|
public:
|
||||||
|
// Initialize an AssetManagerFileSystem. Note that this does not register the
|
||||||
|
// file system with TensorFlow.
|
||||||
|
// asset_manager - Non-null Android AAssetManager that backs this file
|
||||||
|
// system. The asset manager is not owned by this file system, and must
|
||||||
|
// outlive this class.
|
||||||
|
// prefix - Common prefix to strip from all file URIs before passing them to
|
||||||
|
// the asset_manager. This is required because TensorFlow gives the entire
|
||||||
|
// file URI (file:///my_dir/my_file.txt) and AssetManager only knows paths
|
||||||
|
// relative to its base directory.
|
||||||
|
AssetManagerFileSystem(AAssetManager* asset_manager, const string& prefix);
|
||||||
|
~AssetManagerFileSystem() override = default;
|
||||||
|
|
||||||
|
Status FileExists(const string& fname) override;
|
||||||
|
Status NewRandomAccessFile(
|
||||||
|
const string& filename,
|
||||||
|
std::unique_ptr<RandomAccessFile>* result) override;
|
||||||
|
Status NewReadOnlyMemoryRegionFromFile(
|
||||||
|
const string& filename,
|
||||||
|
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
|
||||||
|
|
||||||
|
Status GetFileSize(const string& f, uint64* s) override;
|
||||||
|
// Currently just returns size.
|
||||||
|
Status Stat(const string& fname, FileStatistics* stat) override;
|
||||||
|
Status GetChildren(const string& dir, std::vector<string>* r) override;
|
||||||
|
|
||||||
|
// All these functions return Unimplemented error. Asset storage is
|
||||||
|
// read only.
|
||||||
|
Status NewWritableFile(const string& fname,
|
||||||
|
std::unique_ptr<WritableFile>* result) override;
|
||||||
|
Status NewAppendableFile(const string& fname,
|
||||||
|
std::unique_ptr<WritableFile>* result) override;
|
||||||
|
Status DeleteFile(const string& f) override;
|
||||||
|
Status CreateDir(const string& d) override;
|
||||||
|
Status DeleteDir(const string& d) override;
|
||||||
|
Status RenameFile(const string& s, const string& t) override;
|
||||||
|
|
||||||
|
Status GetMatchingPaths(const string& pattern,
|
||||||
|
std::vector<string>* results) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
string RemoveAssetPrefix(const string& name);
|
||||||
|
|
||||||
|
// Return a string path that can be passed into AAssetManager functions.
|
||||||
|
// For example, 'my_prefix://some/dir/' would return 'some/dir'.
|
||||||
|
string NormalizeDirectoryPath(const string& fname);
|
||||||
|
bool DirectoryExists(const std::string& fname);
|
||||||
|
|
||||||
|
AAssetManager* asset_manager_;
|
||||||
|
string prefix_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
|
@ -0,0 +1,80 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
cmake_minimum_required(VERSION 3.4.1)
|
||||||
|
include(ExternalProject)
|
||||||
|
|
||||||
|
# TENSORFLOW_ROOT_DIR:
|
||||||
|
# root directory of tensorflow repo
|
||||||
|
# used for shared source files and pre-built libs
|
||||||
|
get_filename_component(TENSORFLOW_ROOT_DIR ../../../.. ABSOLUTE)
|
||||||
|
set(PREBUILT_DIR ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen)
|
||||||
|
|
||||||
|
add_library(lib_proto STATIC IMPORTED )
|
||||||
|
set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION
|
||||||
|
${PREBUILT_DIR}/protobuf/lib/libprotobuf.a)
|
||||||
|
|
||||||
|
add_library(lib_nsync STATIC IMPORTED )
|
||||||
|
set_target_properties(lib_nsync PROPERTIES IMPORTED_LOCATION
|
||||||
|
${TARGET_NSYNC_LIB}/lib/libnsync.a)
|
||||||
|
|
||||||
|
add_library(lib_tf STATIC IMPORTED )
|
||||||
|
set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
|
||||||
|
${PREBUILT_DIR}/lib/libtensorflow-core.a)
|
||||||
|
# Change to compile flags should be replicated into bazel build file
|
||||||
|
# TODO: Consider options other than -O2 for binary size.
|
||||||
|
# e.g. -Os for gcc, and -Oz for clang.
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \
|
||||||
|
-std=c++11 -fno-rtti -fno-exceptions \
|
||||||
|
-O2 -Wno-narrowing -fomit-frame-pointer \
|
||||||
|
-mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \
|
||||||
|
-ftemplate-depth=900 \
|
||||||
|
-DGOOGLE_PROTOBUF_NO_RTTI \
|
||||||
|
-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")
|
||||||
|
|
||||||
|
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
|
||||||
|
-Wl,--allow-multiple-definition \
|
||||||
|
-Wl,--whole-archive \
|
||||||
|
-fPIE -pie -v")
|
||||||
|
file(GLOB tensorflow_inference_sources
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../jni/*.cc)
|
||||||
|
file(GLOB java_api_native_sources
|
||||||
|
${TENSORFLOW_ROOT_DIR}/tensorflow/java/src/main/native/*.cc)
|
||||||
|
|
||||||
|
add_library(tensorflow_inference SHARED
|
||||||
|
${tensorflow_inference_sources}
|
||||||
|
${TENSORFLOW_ROOT_DIR}/tensorflow/c/tf_status_helper.cc
|
||||||
|
${TENSORFLOW_ROOT_DIR}/tensorflow/c/checkpoint_reader.cc
|
||||||
|
${TENSORFLOW_ROOT_DIR}/tensorflow/c/test_op.cc
|
||||||
|
${TENSORFLOW_ROOT_DIR}/tensorflow/c/c_api.cc
|
||||||
|
${java_api_native_sources})
|
||||||
|
|
||||||
|
# Include libraries needed for hello-jni lib
|
||||||
|
target_link_libraries(tensorflow_inference
|
||||||
|
android
|
||||||
|
dl
|
||||||
|
log
|
||||||
|
m
|
||||||
|
z
|
||||||
|
lib_tf
|
||||||
|
lib_proto
|
||||||
|
lib_nsync)
|
||||||
|
|
||||||
|
include_directories(
|
||||||
|
${PREBUILT_DIR}/proto
|
||||||
|
${PREBUILT_DIR}/protobuf/include
|
||||||
|
${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen
|
||||||
|
${TENSORFLOW_ROOT_DIR}
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/..)
|
48
tensorflow/tools/android/inference_interface/cmake/README.md
Normal file
48
tensorflow/tools/android/inference_interface/cmake/README.md
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
TensorFlow-Android-Inference
|
||||||
|
============================
|
||||||
|
This directory contains CMake support for building the Android Java Inference
|
||||||
|
interface to the TensorFlow native APIs.
|
||||||
|
|
||||||
|
See [tensorflow/tools/android/inference_interface](..) for more details about
|
||||||
|
the library, and instructions for building with Bazel.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
Add TensorFlow-Android-Inference as a dependency of your Android application
|
||||||
|
|
||||||
|
* settings.gradle
|
||||||
|
|
||||||
|
```
|
||||||
|
include ':TensorFlow-Android-Inference'
|
||||||
|
findProject(":TensorFlow-Android-Inference").projectDir =
|
||||||
|
new File("${/path/to/tensorflow_repo}/examples/android_inference_interface/cmake")
|
||||||
|
```
|
||||||
|
|
||||||
|
* application's build.gradle (adding dependency):
|
||||||
|
|
||||||
|
```
|
||||||
|
debugCompile project(path: ':tensorflow_inference', configuration: 'debug')
|
||||||
|
releaseCompile project(path: ':tensorflow_inference', configuration: 'release')
|
||||||
|
```
|
||||||
|
Note: this makes native code in the lib traceable from your app.
|
||||||
|
|
||||||
|
Dependencies
|
||||||
|
------------
|
||||||
|
TensorFlow-Android-Inference depends on the TensorFlow static libs already built
|
||||||
|
in your local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh
|
||||||
|
is used in build.gradle to build it. It DOES take time to build the core libs;
|
||||||
|
so, by default, it is commented out to avoid confusion (otherwise
|
||||||
|
Android Studio would appear to hang during opening the project).
|
||||||
|
To enable it, refer to the comment in
|
||||||
|
|
||||||
|
* build.gradle
|
||||||
|
|
||||||
|
Output
|
||||||
|
------
|
||||||
|
- TensorFlow-Inference-debug.aar
|
||||||
|
- TensorFlow-Inference-release.aar
|
||||||
|
|
||||||
|
File libtensorflow_inference.so should be packed under jni/${ANDROID_ABI}/
|
||||||
|
in the above aar, and it is transparent to the app as it will access them via
|
||||||
|
equivalent java APIs.
|
||||||
|
|
105
tensorflow/tools/android/inference_interface/cmake/build.gradle
Normal file
105
tensorflow/tools/android/inference_interface/cmake/build.gradle
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
apply plugin: 'com.android.library'
|
||||||
|
|
||||||
|
// TensorFlow repo root dir on local machine
|
||||||
|
def TF_SRC_DIR = projectDir.toString() + "/../../../.."
|
||||||
|
|
||||||
|
android {
|
||||||
|
compileSdkVersion 24
|
||||||
|
// Check local build_tools_version as this is liable to change within Android Studio.
|
||||||
|
buildToolsVersion '25.0.2'
|
||||||
|
|
||||||
|
// for debugging native code purpose
|
||||||
|
publishNonDefault true
|
||||||
|
|
||||||
|
defaultConfig {
|
||||||
|
archivesBaseName = "Tensorflow-Android-Inference"
|
||||||
|
minSdkVersion 21
|
||||||
|
targetSdkVersion 23
|
||||||
|
versionCode 1
|
||||||
|
versionName "1.0"
|
||||||
|
ndk {
|
||||||
|
abiFilters 'armeabi-v7a'
|
||||||
|
}
|
||||||
|
externalNativeBuild {
|
||||||
|
cmake {
|
||||||
|
arguments '-DANDROID_TOOLCHAIN=clang',
|
||||||
|
'-DANDROID_STL=c++_static'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sourceSets {
|
||||||
|
main {
|
||||||
|
java {
|
||||||
|
srcDir "${TF_SRC_DIR}/tensorflow/tools/android/inference_interface/java"
|
||||||
|
srcDir "${TF_SRC_DIR}/tensorflow/java/src/main/java"
|
||||||
|
exclude '**/examples/**'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
externalNativeBuild {
|
||||||
|
cmake {
|
||||||
|
path 'CMakeLists.txt'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buildTypes {
|
||||||
|
release {
|
||||||
|
minifyEnabled false
|
||||||
|
proguardFiles getDefaultProguardFile('proguard-android.txt'),
|
||||||
|
'proguard-rules.pro'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build libtensorflow-core.a if necessary
|
||||||
|
// Note: the environment needs to be set up already
|
||||||
|
// [ such as installing autoconfig, make, etc ]
|
||||||
|
// How to use:
|
||||||
|
// 1) install all of the necessary tools to build libtensorflow-core.a
|
||||||
|
// 2) inside Android Studio IDE, uncomment buildTensorFlow in
|
||||||
|
// whenTaskAdded{...}
|
||||||
|
// 3) re-sync and re-build. It could take a long time if NOT building
|
||||||
|
// with multiple processes.
|
||||||
|
import org.apache.tools.ant.taskdefs.condition.Os
|
||||||
|
|
||||||
|
Properties properties = new Properties()
|
||||||
|
properties.load(project.rootProject.file('local.properties')
|
||||||
|
.newDataInputStream())
|
||||||
|
def ndkDir = properties.getProperty('ndk.dir')
|
||||||
|
if (ndkDir == null || ndkDir == "") {
|
||||||
|
ndkDir = System.getenv('ANDROID_NDK_HOME')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Os.isFamily(Os.FAMILY_WINDOWS)) {
|
||||||
|
// This script is for non-Windows OS. For Windows OS, MANUALLY build
|
||||||
|
// (or copy the built) libs/headers to the
|
||||||
|
// ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen
|
||||||
|
// refer to CMakeLists.txt about lib and header directories for details
|
||||||
|
task buildTensorflow(type: Exec) {
|
||||||
|
group 'buildTensorflowLib'
|
||||||
|
workingDir getProjectDir().toString() + '/../../../../'
|
||||||
|
environment PATH: '/opt/local/bin:/opt/local/sbin:' +
|
||||||
|
System.getenv('PATH')
|
||||||
|
environment NDK_ROOT: ndkDir
|
||||||
|
commandLine 'tensorflow/contrib/makefile/build_all_android.sh'
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.whenTaskAdded { task ->
|
||||||
|
group 'buildTensorflowLib'
|
||||||
|
if (task.name.toLowerCase().contains('sources')) {
|
||||||
|
def tensorflowTarget = new File(getProjectDir().toString() +
|
||||||
|
'/../../makefile/gen/lib/libtensorflow-core.a')
|
||||||
|
if (!tensorflowTarget.exists()) {
|
||||||
|
// Note:
|
||||||
|
// just uncomment this line to use it:
|
||||||
|
// it can take long time to build by default
|
||||||
|
// it is disabled to avoid false first impression
|
||||||
|
task.dependsOn buildTensorflow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
compile fileTree(dir: 'libs', include: ['*.jar'])
|
||||||
|
}
|
@ -0,0 +1,13 @@
|
|||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="org.tensorflow.contrib.android">
|
||||||
|
|
||||||
|
<uses-sdk
|
||||||
|
android:minSdkVersion="4"
|
||||||
|
android:targetSdkVersion="19" />
|
||||||
|
|
||||||
|
<application android:allowBackup="true" android:label="@string/app_name"
|
||||||
|
android:supportsRtl="true">
|
||||||
|
|
||||||
|
</application>
|
||||||
|
|
||||||
|
</manifest>
|
@ -0,0 +1,3 @@
|
|||||||
|
<resources>
|
||||||
|
<string name="app_name">TensorFlowInference</string>
|
||||||
|
</resources>
|
@ -0,0 +1,63 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.contrib.android;
|
||||||
|
|
||||||
|
/** Accumulate and analyze stats from metadata obtained from Session.Runner.run. */
|
||||||
|
public class RunStats implements AutoCloseable {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Options to be provided to a {@link org.tensorflow.Session.Runner} to enable stats accumulation.
|
||||||
|
*/
|
||||||
|
public static byte[] runOptions() {
|
||||||
|
return fullTraceRunOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RunStats() {
|
||||||
|
nativeHandle = allocate();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
if (nativeHandle != 0) {
|
||||||
|
delete(nativeHandle);
|
||||||
|
}
|
||||||
|
nativeHandle = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Accumulate stats obtained when executing a graph. */
|
||||||
|
public synchronized void add(byte[] runMetadata) {
|
||||||
|
add(nativeHandle, runMetadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Summary of the accumulated runtime stats. */
|
||||||
|
public synchronized String summary() {
|
||||||
|
return summary(nativeHandle);
|
||||||
|
}
|
||||||
|
|
||||||
|
private long nativeHandle;
|
||||||
|
|
||||||
|
// Hack: This is what a serialized RunOptions protocol buffer with trace_level: FULL_TRACE ends
|
||||||
|
// up as.
|
||||||
|
private static byte[] fullTraceRunOptions = new byte[] {0x08, 0x03};
|
||||||
|
|
||||||
|
private static native long allocate();
|
||||||
|
|
||||||
|
private static native void delete(long handle);
|
||||||
|
|
||||||
|
private static native void add(long handle, byte[] runMetadata);
|
||||||
|
|
||||||
|
private static native String summary(long handle);
|
||||||
|
}
|
@ -0,0 +1,650 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.contrib.android;
|
||||||
|
|
||||||
|
import android.content.res.AssetManager;
|
||||||
|
import android.os.Build.VERSION;
|
||||||
|
import android.os.Trace;
|
||||||
|
import android.text.TextUtils;
|
||||||
|
import android.util.Log;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.DoubleBuffer;
|
||||||
|
import java.nio.FloatBuffer;
|
||||||
|
import java.nio.IntBuffer;
|
||||||
|
import java.nio.LongBuffer;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import org.tensorflow.Graph;
|
||||||
|
import org.tensorflow.Operation;
|
||||||
|
import org.tensorflow.Session;
|
||||||
|
import org.tensorflow.Tensor;
|
||||||
|
import org.tensorflow.TensorFlow;
|
||||||
|
import org.tensorflow.Tensors;
|
||||||
|
import org.tensorflow.types.UInt8;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
|
||||||
|
* for inference.
|
||||||
|
*
|
||||||
|
* <p>See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an
|
||||||
|
* example usage.
|
||||||
|
*/
|
||||||
|
public class TensorFlowInferenceInterface {
|
||||||
|
private static final String TAG = "TensorFlowInferenceInterface";
|
||||||
|
private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
|
||||||
|
*
|
||||||
|
* @param assetManager The AssetManager to use to load the model file.
|
||||||
|
* @param model The filepath to the GraphDef proto representing the model.
|
||||||
|
*/
|
||||||
|
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
|
||||||
|
prepareNativeRuntime();
|
||||||
|
|
||||||
|
this.modelName = model;
|
||||||
|
this.g = new Graph();
|
||||||
|
this.sess = new Session(g);
|
||||||
|
this.runner = sess.runner();
|
||||||
|
|
||||||
|
final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX);
|
||||||
|
InputStream is = null;
|
||||||
|
try {
|
||||||
|
String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model;
|
||||||
|
is = assetManager.open(aname);
|
||||||
|
} catch (IOException e) {
|
||||||
|
if (hasAssetPrefix) {
|
||||||
|
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||||
|
}
|
||||||
|
// Perhaps the model file is not an asset but is on disk.
|
||||||
|
try {
|
||||||
|
is = new FileInputStream(model);
|
||||||
|
} catch (IOException e2) {
|
||||||
|
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.beginSection("initializeTensorFlow");
|
||||||
|
Trace.beginSection("readGraphDef");
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
|
||||||
|
byte[] graphDef = new byte[is.available()];
|
||||||
|
final int numBytesRead = is.read(graphDef);
|
||||||
|
if (numBytesRead != graphDef.length) {
|
||||||
|
throw new IOException(
|
||||||
|
"read error: read only "
|
||||||
|
+ numBytesRead
|
||||||
|
+ " of the graph, expected to read "
|
||||||
|
+ graphDef.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.endSection(); // readGraphDef.
|
||||||
|
}
|
||||||
|
|
||||||
|
loadGraph(graphDef, g);
|
||||||
|
is.close();
|
||||||
|
Log.i(TAG, "Successfully loaded model from '" + model + "'");
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.endSection(); // initializeTensorFlow.
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Load a TensorFlow model from provided InputStream.
|
||||||
|
* Note: The InputStream will not be closed after loading model, users need to
|
||||||
|
* close it themselves.
|
||||||
|
*
|
||||||
|
* @param is The InputStream to use to load the model.
|
||||||
|
*/
|
||||||
|
public TensorFlowInferenceInterface(InputStream is) {
|
||||||
|
prepareNativeRuntime();
|
||||||
|
|
||||||
|
// modelName is redundant for model loading from input stream, here is for
|
||||||
|
// avoiding error in initialization as modelName is marked final.
|
||||||
|
this.modelName = "";
|
||||||
|
this.g = new Graph();
|
||||||
|
this.sess = new Session(g);
|
||||||
|
this.runner = sess.runner();
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.beginSection("initializeTensorFlow");
|
||||||
|
Trace.beginSection("readGraphDef");
|
||||||
|
}
|
||||||
|
|
||||||
|
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
|
||||||
|
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
|
||||||
|
int numBytesRead;
|
||||||
|
byte[] buf = new byte[16384];
|
||||||
|
while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
|
||||||
|
baos.write(buf, 0, numBytesRead);
|
||||||
|
}
|
||||||
|
byte[] graphDef = baos.toByteArray();
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.endSection(); // readGraphDef.
|
||||||
|
}
|
||||||
|
|
||||||
|
loadGraph(graphDef, g);
|
||||||
|
Log.i(TAG, "Successfully loaded model from the input stream");
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.endSection(); // initializeTensorFlow.
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("Failed to load model from the input stream", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Construct a TensorFlowInferenceInterface with provided Graph
|
||||||
|
*
|
||||||
|
* @param g The Graph to use to construct this interface.
|
||||||
|
*/
|
||||||
|
public TensorFlowInferenceInterface(Graph g) {
|
||||||
|
prepareNativeRuntime();
|
||||||
|
|
||||||
|
// modelName is redundant here, here is for
|
||||||
|
// avoiding error in initialization as modelName is marked final.
|
||||||
|
this.modelName = "";
|
||||||
|
this.g = g;
|
||||||
|
this.sess = new Session(g);
|
||||||
|
this.runner = sess.runner();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs inference between the previously registered input nodes (via feed*) and the requested
|
||||||
|
* output nodes. Output nodes can then be queried with the fetch* methods.
|
||||||
|
*
|
||||||
|
* @param outputNames A list of output nodes which should be filled by the inference pass.
|
||||||
|
*/
|
||||||
|
public void run(String[] outputNames) {
|
||||||
|
run(outputNames, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs inference between the previously registered input nodes (via feed*) and the requested
|
||||||
|
* output nodes. Output nodes can then be queried with the fetch* methods.
|
||||||
|
*
|
||||||
|
* @param outputNames A list of output nodes which should be filled by the inference pass.
|
||||||
|
*/
|
||||||
|
public void run(String[] outputNames, boolean enableStats) {
|
||||||
|
run(outputNames, enableStats, new String[] {});
|
||||||
|
}
|
||||||
|
|
||||||
|
/** An overloaded version of runInference that allows supplying targetNodeNames as well */
|
||||||
|
public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) {
|
||||||
|
// Release any Tensors from the previous run calls.
|
||||||
|
closeFetches();
|
||||||
|
|
||||||
|
// Add fetches.
|
||||||
|
for (String o : outputNames) {
|
||||||
|
fetchNames.add(o);
|
||||||
|
TensorId tid = TensorId.parse(o);
|
||||||
|
runner.fetch(tid.name, tid.outputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add targets.
|
||||||
|
for (String t : targetNodeNames) {
|
||||||
|
runner.addTarget(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the session.
|
||||||
|
try {
|
||||||
|
if (enableStats) {
|
||||||
|
Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
|
||||||
|
fetchTensors = r.outputs;
|
||||||
|
|
||||||
|
if (runStats == null) {
|
||||||
|
runStats = new RunStats();
|
||||||
|
}
|
||||||
|
runStats.add(r.metadata);
|
||||||
|
} else {
|
||||||
|
fetchTensors = runner.run();
|
||||||
|
}
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
// Ideally the exception would have been let through, but since this interface predates the
|
||||||
|
// TensorFlow Java API, must return -1.
|
||||||
|
Log.e(
|
||||||
|
TAG,
|
||||||
|
"Failed to run TensorFlow inference with inputs:["
|
||||||
|
+ TextUtils.join(", ", feedNames)
|
||||||
|
+ "], outputs:["
|
||||||
|
+ TextUtils.join(", ", fetchNames)
|
||||||
|
+ "]");
|
||||||
|
throw e;
|
||||||
|
} finally {
|
||||||
|
// Always release the feeds (to save resources) and reset the runner, this run is
|
||||||
|
// over.
|
||||||
|
closeFeeds();
|
||||||
|
runner = sess.runner();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns a reference to the Graph describing the computation run during inference. */
|
||||||
|
public Graph graph() {
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Operation graphOperation(String operationName) {
|
||||||
|
final Operation operation = g.operation(operationName);
|
||||||
|
if (operation == null) {
|
||||||
|
throw new RuntimeException(
|
||||||
|
"Node '" + operationName + "' does not exist in model '" + modelName + "'");
|
||||||
|
}
|
||||||
|
return operation;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the last stat summary string if logging is enabled. */
|
||||||
|
public String getStatString() {
|
||||||
|
return (runStats == null) ? "" : runStats.summary();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cleans up the state associated with this Object.
|
||||||
|
*
|
||||||
|
* <p>The TenosrFlowInferenceInterface object is no longer usable after this method returns.
|
||||||
|
*/
|
||||||
|
public void close() {
|
||||||
|
closeFeeds();
|
||||||
|
closeFetches();
|
||||||
|
sess.close();
|
||||||
|
g.close();
|
||||||
|
if (runStats != null) {
|
||||||
|
runStats.close();
|
||||||
|
}
|
||||||
|
runStats = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void finalize() throws Throwable {
|
||||||
|
try {
|
||||||
|
close();
|
||||||
|
} finally {
|
||||||
|
super.finalize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Methods for taking a native Tensor and filling it with values from Java arrays.
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||||
|
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||||
|
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, boolean[] src, long... dims) {
|
||||||
|
byte[] b = new byte[src.length];
|
||||||
|
|
||||||
|
for (int i = 0; i < src.length; i++) {
|
||||||
|
b[i] = src[i] ? (byte) 1 : (byte) 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||||
|
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||||
|
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, float[] src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||||
|
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||||
|
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, int[] src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||||
|
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||||
|
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, long[] src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||||
|
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||||
|
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, double[] src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||||
|
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||||
|
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, byte[] src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
|
||||||
|
* scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
|
||||||
|
* a Java {@code String} (which is a sequence of characters).
|
||||||
|
*/
|
||||||
|
public void feedString(String inputName, byte[] src) {
|
||||||
|
addFeed(inputName, Tensors.create(src));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
|
||||||
|
* string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
|
||||||
|
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
|
||||||
|
*/
|
||||||
|
public void feedString(String inputName, byte[][] src) {
|
||||||
|
addFeed(inputName, Tensors.create(src));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||||
|
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||||
|
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||||
|
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, FloatBuffer src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(dims, src));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||||
|
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||||
|
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||||
|
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, IntBuffer src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(dims, src));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||||
|
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||||
|
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||||
|
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, LongBuffer src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(dims, src));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||||
|
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||||
|
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||||
|
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, DoubleBuffer src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(dims, src));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||||
|
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||||
|
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||||
|
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||||
|
* destination has capacity, the copy is truncated.
|
||||||
|
*/
|
||||||
|
public void feed(String inputName, ByteBuffer src, long... dims) {
|
||||||
|
addFeed(inputName, Tensor.create(UInt8.class, dims, src));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||||
|
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||||
|
* not affect dst's content past the source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, float[] dst) {
|
||||||
|
fetch(outputName, FloatBuffer.wrap(dst));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||||
|
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||||
|
* not affect dst's content past the source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, int[] dst) {
|
||||||
|
fetch(outputName, IntBuffer.wrap(dst));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||||
|
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||||
|
* not affect dst's content past the source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, long[] dst) {
|
||||||
|
fetch(outputName, LongBuffer.wrap(dst));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||||
|
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||||
|
* not affect dst's content past the source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, double[] dst) {
|
||||||
|
fetch(outputName, DoubleBuffer.wrap(dst));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||||
|
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||||
|
* not affect dst's content past the source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, byte[] dst) {
|
||||||
|
fetch(outputName, ByteBuffer.wrap(dst));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||||
|
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||||
|
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||||
|
* source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, FloatBuffer dst) {
|
||||||
|
getTensor(outputName).writeTo(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||||
|
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||||
|
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||||
|
* source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, IntBuffer dst) {
|
||||||
|
getTensor(outputName).writeTo(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||||
|
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||||
|
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||||
|
* source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, LongBuffer dst) {
|
||||||
|
getTensor(outputName).writeTo(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||||
|
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||||
|
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||||
|
* source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, DoubleBuffer dst) {
|
||||||
|
getTensor(outputName).writeTo(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||||
|
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||||
|
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||||
|
* source Tensor's size.
|
||||||
|
*/
|
||||||
|
public void fetch(String outputName, ByteBuffer dst) {
|
||||||
|
getTensor(outputName).writeTo(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void prepareNativeRuntime() {
|
||||||
|
Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
|
||||||
|
try {
|
||||||
|
// Hack to see if the native libraries have been loaded.
|
||||||
|
new RunStats();
|
||||||
|
Log.i(TAG, "TensorFlow native methods already loaded");
|
||||||
|
} catch (UnsatisfiedLinkError e1) {
|
||||||
|
Log.i(
|
||||||
|
TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
|
||||||
|
try {
|
||||||
|
System.loadLibrary("tensorflow_inference");
|
||||||
|
Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
|
||||||
|
} catch (UnsatisfiedLinkError e2) {
|
||||||
|
throw new RuntimeException(
|
||||||
|
"Native TF methods not found; check that the correct native"
|
||||||
|
+ " libraries are present in the APK.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void loadGraph(byte[] graphDef, Graph g) throws IOException {
|
||||||
|
final long startMs = System.currentTimeMillis();
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.beginSection("importGraphDef");
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
g.importGraphDef(graphDef);
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.endSection(); // importGraphDef.
|
||||||
|
}
|
||||||
|
|
||||||
|
final long endMs = System.currentTimeMillis();
|
||||||
|
Log.i(
|
||||||
|
TAG,
|
||||||
|
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addFeed(String inputName, Tensor<?> t) {
|
||||||
|
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
|
||||||
|
TensorId tid = TensorId.parse(inputName);
|
||||||
|
runner.feed(tid.name, tid.outputIndex, t);
|
||||||
|
feedNames.add(inputName);
|
||||||
|
feedTensors.add(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class TensorId {
|
||||||
|
String name;
|
||||||
|
int outputIndex;
|
||||||
|
|
||||||
|
// Parse output names into a TensorId.
|
||||||
|
//
|
||||||
|
// E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1)
|
||||||
|
public static TensorId parse(String name) {
|
||||||
|
TensorId tid = new TensorId();
|
||||||
|
int colonIndex = name.lastIndexOf(':');
|
||||||
|
if (colonIndex < 0) {
|
||||||
|
tid.outputIndex = 0;
|
||||||
|
tid.name = name;
|
||||||
|
return tid;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1));
|
||||||
|
tid.name = name.substring(0, colonIndex);
|
||||||
|
} catch (NumberFormatException e) {
|
||||||
|
tid.outputIndex = 0;
|
||||||
|
tid.name = name;
|
||||||
|
}
|
||||||
|
return tid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Tensor<?> getTensor(String outputName) {
|
||||||
|
int i = 0;
|
||||||
|
for (String n : fetchNames) {
|
||||||
|
if (n.equals(outputName)) {
|
||||||
|
return fetchTensors.get(i);
|
||||||
|
}
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
throw new RuntimeException(
|
||||||
|
"Node '" + outputName + "' was not provided to run(), so it cannot be read");
|
||||||
|
}
|
||||||
|
|
||||||
|
private void closeFeeds() {
|
||||||
|
for (Tensor<?> t : feedTensors) {
|
||||||
|
t.close();
|
||||||
|
}
|
||||||
|
feedTensors.clear();
|
||||||
|
feedNames.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void closeFetches() {
|
||||||
|
for (Tensor<?> t : fetchTensors) {
|
||||||
|
t.close();
|
||||||
|
}
|
||||||
|
fetchTensors.clear();
|
||||||
|
fetchNames.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Immutable state.
|
||||||
|
private final String modelName;
|
||||||
|
private final Graph g;
|
||||||
|
private final Session sess;
|
||||||
|
|
||||||
|
// State reset on every call to run.
|
||||||
|
private Session.Runner runner;
|
||||||
|
private List<String> feedNames = new ArrayList<String>();
|
||||||
|
private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
|
||||||
|
private List<String> fetchNames = new ArrayList<String>();
|
||||||
|
private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();
|
||||||
|
|
||||||
|
// Mutable state.
|
||||||
|
private RunStats runStats;
|
||||||
|
}
|
@ -0,0 +1,83 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/tools/android/inference_interface/jni/run_stats_jni.h"
|
||||||
|
|
||||||
|
#include <jni.h>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
|
#include "tensorflow/core/util/stat_summarizer.h"
|
||||||
|
|
||||||
|
using tensorflow::RunMetadata;
|
||||||
|
using tensorflow::StatSummarizer;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
StatSummarizer* requireHandle(JNIEnv* env, jlong handle) {
|
||||||
|
if (handle == 0) {
|
||||||
|
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
|
||||||
|
"close() has been called on the RunStats object");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return reinterpret_cast<StatSummarizer*>(handle);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#define RUN_STATS_METHOD(name) \
|
||||||
|
JNICALL Java_org_tensorflow_contrib_android_RunStats_##name
|
||||||
|
|
||||||
|
JNIEXPORT jlong RUN_STATS_METHOD(allocate)(JNIEnv* env, jclass clazz) {
|
||||||
|
static_assert(sizeof(jlong) >= sizeof(StatSummarizer*),
|
||||||
|
"Cannot package C++ object pointers as a Java long");
|
||||||
|
tensorflow::StatSummarizerOptions opts;
|
||||||
|
return reinterpret_cast<jlong>(new StatSummarizer(opts));
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void RUN_STATS_METHOD(delete)(JNIEnv* env, jclass clazz,
|
||||||
|
jlong handle) {
|
||||||
|
if (handle == 0) return;
|
||||||
|
delete reinterpret_cast<StatSummarizer*>(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void RUN_STATS_METHOD(add)(JNIEnv* env, jclass clazz, jlong handle,
|
||||||
|
jbyteArray run_metadata) {
|
||||||
|
StatSummarizer* s = requireHandle(env, handle);
|
||||||
|
if (s == nullptr) return;
|
||||||
|
jbyte* data = env->GetByteArrayElements(run_metadata, nullptr);
|
||||||
|
int size = static_cast<int>(env->GetArrayLength(run_metadata));
|
||||||
|
tensorflow::RunMetadata proto;
|
||||||
|
if (!proto.ParseFromArray(data, size)) {
|
||||||
|
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
|
||||||
|
"runMetadata does not seem to be a serialized RunMetadata "
|
||||||
|
"protocol message");
|
||||||
|
} else if (proto.has_step_stats()) {
|
||||||
|
s->ProcessStepStats(proto.step_stats());
|
||||||
|
}
|
||||||
|
env->ReleaseByteArrayElements(run_metadata, data, JNI_ABORT);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz,
|
||||||
|
jlong handle) {
|
||||||
|
StatSummarizer* s = requireHandle(env, handle);
|
||||||
|
if (s == nullptr) return nullptr;
|
||||||
|
std::stringstream ret;
|
||||||
|
ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME,
|
||||||
|
10)
|
||||||
|
<< s->GetStatsByNodeType() << s->ShortSummary();
|
||||||
|
return env->NewStringUTF(ret.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef RUN_STATS_METHOD
|
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
|
||||||
|
#define ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
|
||||||
|
|
||||||
|
#include <jni.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#define RUN_STATS_METHOD(name) \
|
||||||
|
Java_org_tensorflow_contrib_android_RunStats_##name
|
||||||
|
|
||||||
|
JNIEXPORT JNICALL jlong RUN_STATS_METHOD(allocate)(JNIEnv*, jclass);
|
||||||
|
JNIEXPORT JNICALL void RUN_STATS_METHOD(delete)(JNIEnv*, jclass, jlong);
|
||||||
|
JNIEXPORT JNICALL void RUN_STATS_METHOD(add)(JNIEnv*, jclass, jlong,
|
||||||
|
jbyteArray);
|
||||||
|
JNIEXPORT JNICALL jstring RUN_STATS_METHOD(summary)(JNIEnv*, jclass, jlong);
|
||||||
|
|
||||||
|
#undef RUN_STATS_METHOD
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
|
@ -0,0 +1,11 @@
|
|||||||
|
VERS_1.0 {
|
||||||
|
# Export JNI symbols.
|
||||||
|
global:
|
||||||
|
Java_*;
|
||||||
|
JNI_OnLoad;
|
||||||
|
JNI_OnUnload;
|
||||||
|
|
||||||
|
# Hide everything else.
|
||||||
|
local:
|
||||||
|
*;
|
||||||
|
};
|
Loading…
Reference in New Issue
Block a user