Copy Android Inference Interface out of contrib.
Point tensorflow/examples/android at the copy. PiperOrigin-RevId: 269423640
This commit is contained in:
parent
b056951d0d
commit
b4d110caee
@ -53,7 +53,7 @@ cc_library(
|
||||
name = "tensorflow_native_libs",
|
||||
srcs = [
|
||||
":libtensorflow_demo.so",
|
||||
"//tensorflow/contrib/android:libtensorflow_inference.so",
|
||||
"//tensorflow/tools/android/inference_interface:libtensorflow_inference.so",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
@ -84,7 +84,7 @@ android_binary(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_native_libs",
|
||||
"//tensorflow/contrib/android:android_tensorflow_inference_java",
|
||||
"//tensorflow/tools/android/inference_interface:android_tensorflow_inference_java",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -9,9 +9,9 @@ The demos in this folder are designed to give straightforward samples of using
|
||||
TensorFlow in mobile applications.
|
||||
|
||||
Inference is done using the [TensorFlow Android Inference
|
||||
Interface](../../../tensorflow/contrib/android), which may be built separately
|
||||
if you want a standalone library to drop into your existing application. Object
|
||||
tracking and efficient YUV -> RGB conversion are handled by
|
||||
Interface](../../tools/android/inference_interface), which may be built
|
||||
separately if you want a standalone library to drop into your existing
|
||||
application. Object tracking and efficient YUV -> RGB conversion are handled by
|
||||
`libtensorflow_demo.so`.
|
||||
|
||||
A device running Android 5.0 (API 21) or higher is required to run the demo due
|
||||
@ -49,7 +49,7 @@ The fastest path to trying the demo is to download the [prebuilt demo APK](https
|
||||
|
||||
Also available are precompiled native libraries, and a jcenter package that you
|
||||
may simply drop into your own applications. See
|
||||
[tensorflow/contrib/android/README.md](../../../tensorflow/contrib/android/README.md)
|
||||
[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
|
||||
for more details.
|
||||
|
||||
## Running the Demo
|
||||
@ -89,7 +89,7 @@ For any project that does not include custom low level TensorFlow code, this is
|
||||
likely sufficient.
|
||||
|
||||
For details on how to include this JCenter package in your own project see
|
||||
[tensorflow/contrib/android/README.md](../../../tensorflow/contrib/android/README.md)
|
||||
[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md)
|
||||
|
||||
## Building the Demo with TensorFlow from Source
|
||||
|
||||
@ -212,4 +212,4 @@ NDK).
|
||||
|
||||
Full CMake support for the demo is coming soon, but for now it is possible to
|
||||
build the TensorFlow Android Inference library using
|
||||
[tensorflow/contrib/android/cmake](../../../tensorflow/contrib/android/cmake).
|
||||
[tensorflow/tools/android/inference_interface/cmake](../../tools/android/inference_interface/cmake).
|
||||
|
@ -41,6 +41,7 @@ filegroup(
|
||||
visibility = [
|
||||
"//tensorflow/contrib/android:__pkg__",
|
||||
"//tensorflow/java:__pkg__",
|
||||
"//tensorflow/tools/android/inference_interface:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -5,11 +5,12 @@
|
||||
package(default_visibility = [
|
||||
"//tensorflow/java:__pkg__",
|
||||
# TODO(ashankar): Temporary hack for the Java API and
|
||||
# //third_party/tensorflow/contrib/android:android_tensorflow_inference_jni
|
||||
# //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_inference_jni
|
||||
# to co-exist in a single shared library. However, the hope is that
|
||||
# //third_party/tensorflow/contrib/android:android_tensorflow_jni can be
|
||||
# //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_jni can be
|
||||
# removed once the Java API provides feature parity with it.
|
||||
"//tensorflow/contrib/android:__pkg__",
|
||||
"//tensorflow/tools/android/inference_interface:__pkg__",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
4
tensorflow/tools/android/README.md
Normal file
4
tensorflow/tools/android/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# Deprecated android inference interface.
|
||||
|
||||
WARNING: This directory contains deprecated tf-mobile android inference
|
||||
interface do not use this for anything new. Use TFLite.
|
89
tensorflow/tools/android/inference_interface/BUILD
Normal file
89
tensorflow/tools/android/inference_interface/BUILD
Normal file
@ -0,0 +1,89 @@
|
||||
# Description:
|
||||
# JNI-based Java inference interface for TensorFlow.
|
||||
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_android",
|
||||
"tf_cc_binary",
|
||||
"tf_copts",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
"jni/version_script.lds",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
name = "android_tensorflow_inference_jni_srcs",
|
||||
srcs = glob([
|
||||
"**/*.cc",
|
||||
"**/*.h",
|
||||
]),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "android_tensorflow_inference_jni",
|
||||
srcs = if_android([":android_tensorflow_inference_jni_srcs"]),
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/java/src/main/native",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# JAR with Java bindings to TF.
|
||||
android_library(
|
||||
name = "android_tensorflow_inference_java",
|
||||
srcs = glob(["java/**/*.java"]) + ["//tensorflow/java:java_sources"],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
)
|
||||
|
||||
# Build the native .so.
|
||||
# bazel build //tensorflow/tools/android/inference_interface:libtensorflow_inference.so \
|
||||
# --crosstool_top=//external:android/crosstool \
|
||||
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
# --cpu=armeabi-v7a
|
||||
LINKER_SCRIPT = "//tensorflow/tools/android/inference_interface:jni/version_script.lds"
|
||||
|
||||
# This fails to buiild if converted to tf_cc_binary.
|
||||
cc_binary(
|
||||
name = "libtensorflow_inference.so",
|
||||
copts = tf_copts() + [
|
||||
"-ffunction-sections",
|
||||
"-fdata-sections",
|
||||
],
|
||||
linkopts = if_android([
|
||||
"-landroid",
|
||||
"-latomic",
|
||||
"-ldl",
|
||||
"-llog",
|
||||
"-lm",
|
||||
"-z defs",
|
||||
"-s",
|
||||
"-Wl,--gc-sections",
|
||||
"-Wl,--version-script,$(location {})".format(LINKER_SCRIPT),
|
||||
]),
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":android_tensorflow_inference_jni",
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
LINKER_SCRIPT,
|
||||
],
|
||||
)
|
95
tensorflow/tools/android/inference_interface/README.md
Normal file
95
tensorflow/tools/android/inference_interface/README.md
Normal file
@ -0,0 +1,95 @@
|
||||
# Android TensorFlow support
|
||||
|
||||
This directory defines components (a native `.so` library and a Java JAR)
|
||||
geared towards supporting TensorFlow on Android. This includes:
|
||||
|
||||
- The [TensorFlow Java API](../../java/README.md)
|
||||
- A `TensorFlowInferenceInterface` class that provides a smaller API
|
||||
surface suitable for inference and summarizing performance of model execution.
|
||||
|
||||
For example usage, see [TensorFlowImageClassifier.java](../../examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java)
|
||||
in the [TensorFlow Android Demo](../../examples/android).
|
||||
|
||||
For prebuilt libraries, see the
|
||||
[nightly Android build artifacts](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)
|
||||
page for a recent build.
|
||||
|
||||
The TensorFlow Inference Interface is also available as a
|
||||
[JCenter package](https://bintray.com/google/tensorflow/tensorflow)
|
||||
(see the tensorflow-android directory) and can be included quite simply in your
|
||||
android project with a couple of lines in the project's `build.gradle` file:
|
||||
|
||||
```
|
||||
allprojects {
|
||||
repositories {
|
||||
jcenter()
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compile 'org.tensorflow:tensorflow-android:+'
|
||||
}
|
||||
```
|
||||
|
||||
This will tell Gradle to use the
|
||||
[latest version](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
of the TensorFlow AAR that has been released to
|
||||
[JCenter](https://jcenter.bintray.com/org/tensorflow/tensorflow-android/).
|
||||
You may replace the `+` with an explicit version label if you wish to
|
||||
use a specific release of TensorFlow in your app.
|
||||
|
||||
To build the libraries yourself (if, for example, you want to support custom
|
||||
TensorFlow operators), pick your preferred approach below:
|
||||
|
||||
### Bazel
|
||||
|
||||
First follow the Bazel setup instructions described in
|
||||
[tensorflow/examples/android/README.md](../../examples/android/README.md)
|
||||
|
||||
Then, to build the native TF library:
|
||||
|
||||
```
|
||||
bazel build -c opt //tensorflow/tools/android/inference_interface:libtensorflow_inference.so \
|
||||
--crosstool_top=//external:android/crosstool \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
--cxxopt=-std=c++11 \
|
||||
--cpu=armeabi-v7a
|
||||
```
|
||||
|
||||
Replacing `armeabi-v7a` with your desired target architecture.
|
||||
|
||||
The library will be located at:
|
||||
|
||||
```
|
||||
bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so
|
||||
```
|
||||
|
||||
To build the Java counterpart:
|
||||
|
||||
```
|
||||
bazel build //tensorflow/tools/android/inference_interface:android_tensorflow_inference_java
|
||||
```
|
||||
|
||||
You will find the JAR file at:
|
||||
|
||||
```
|
||||
bazel-bin/tensorflow/tools/android/inference_interface/libandroid_tensorflow_inference_java.jar
|
||||
```
|
||||
|
||||
### CMake
|
||||
|
||||
For documentation on building a self-contained AAR file with cmake, see
|
||||
[tensorflow/tools/android/inference_interface/cmake](cmake).
|
||||
|
||||
|
||||
### Makefile
|
||||
|
||||
For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md)
|
||||
|
||||
|
||||
## AssetManagerFileSystem
|
||||
|
||||
This directory also contains a TensorFlow filesystem supporting the Android
|
||||
asset manager. This may be useful when writing native (C++) code that is tightly
|
||||
coupled with TensorFlow. For typical usage, the library above will be
|
||||
sufficient.
|
@ -0,0 +1,272 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/tools/android/inference_interface/asset_manager_filesystem.h"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
string RemoveSuffix(const string& name, const string& suffix) {
|
||||
string output(name);
|
||||
StringPiece piece(output);
|
||||
absl::ConsumeSuffix(&piece, suffix);
|
||||
return string(piece);
|
||||
}
|
||||
|
||||
// Closes the given AAsset when variable is destructed.
|
||||
class ScopedAsset {
|
||||
public:
|
||||
ScopedAsset(AAsset* asset) : asset_(asset) {}
|
||||
~ScopedAsset() {
|
||||
if (asset_ != nullptr) {
|
||||
AAsset_close(asset_);
|
||||
}
|
||||
}
|
||||
|
||||
AAsset* get() const { return asset_; }
|
||||
|
||||
private:
|
||||
AAsset* asset_;
|
||||
};
|
||||
|
||||
// Closes the given AAssetDir when variable is destructed.
|
||||
class ScopedAssetDir {
|
||||
public:
|
||||
ScopedAssetDir(AAssetDir* asset_dir) : asset_dir_(asset_dir) {}
|
||||
~ScopedAssetDir() {
|
||||
if (asset_dir_ != nullptr) {
|
||||
AAssetDir_close(asset_dir_);
|
||||
}
|
||||
}
|
||||
|
||||
AAssetDir* get() const { return asset_dir_; }
|
||||
|
||||
private:
|
||||
AAssetDir* asset_dir_;
|
||||
};
|
||||
|
||||
class ReadOnlyMemoryRegionFromAsset : public ReadOnlyMemoryRegion {
|
||||
public:
|
||||
ReadOnlyMemoryRegionFromAsset(std::unique_ptr<char[]> data, uint64 length)
|
||||
: data_(std::move(data)), length_(length) {}
|
||||
~ReadOnlyMemoryRegionFromAsset() override = default;
|
||||
|
||||
const void* data() override { return reinterpret_cast<void*>(data_.get()); }
|
||||
uint64 length() override { return length_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<char[]> data_;
|
||||
uint64 length_;
|
||||
};
|
||||
|
||||
// Note that AAssets are not thread-safe and cannot be used across threads.
|
||||
// However, AAssetManager is. Because RandomAccessFile must be thread-safe and
|
||||
// used across threads, new AAssets must be created for every access.
|
||||
// TODO(tylerrhodes): is there a more efficient way to do this?
|
||||
class RandomAccessFileFromAsset : public RandomAccessFile {
|
||||
public:
|
||||
RandomAccessFileFromAsset(AAssetManager* asset_manager, const string& name)
|
||||
: asset_manager_(asset_manager), file_name_(name) {}
|
||||
~RandomAccessFileFromAsset() override = default;
|
||||
|
||||
Status Read(uint64 offset, size_t to_read, StringPiece* result,
|
||||
char* scratch) const override {
|
||||
auto asset = ScopedAsset(AAssetManager_open(
|
||||
asset_manager_, file_name_.c_str(), AASSET_MODE_RANDOM));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", file_name_, " not found.");
|
||||
}
|
||||
|
||||
off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET);
|
||||
off64_t length = AAsset_getLength64(asset.get());
|
||||
if (new_offset < 0) {
|
||||
*result = StringPiece(scratch, 0);
|
||||
return errors::OutOfRange("Read after file end.");
|
||||
}
|
||||
const off64_t region_left =
|
||||
std::min(length - new_offset, static_cast<off64_t>(to_read));
|
||||
int read = AAsset_read(asset.get(), scratch, region_left);
|
||||
if (read < 0) {
|
||||
return errors::Internal("Error reading from asset.");
|
||||
}
|
||||
*result = StringPiece(scratch, region_left);
|
||||
return (region_left == to_read)
|
||||
? Status::OK()
|
||||
: errors::OutOfRange("Read less bytes than requested.");
|
||||
}
|
||||
|
||||
private:
|
||||
AAssetManager* asset_manager_;
|
||||
string file_name_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
AssetManagerFileSystem::AssetManagerFileSystem(AAssetManager* asset_manager,
|
||||
const string& prefix)
|
||||
: asset_manager_(asset_manager), prefix_(prefix) {}
|
||||
|
||||
Status AssetManagerFileSystem::FileExists(const string& fname) {
|
||||
string path = RemoveAssetPrefix(fname);
|
||||
auto asset = ScopedAsset(
|
||||
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", fname, " not found.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::NewRandomAccessFile(
|
||||
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
|
||||
string path = RemoveAssetPrefix(fname);
|
||||
auto asset = ScopedAsset(
|
||||
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", fname, " not found.");
|
||||
}
|
||||
result->reset(new RandomAccessFileFromAsset(asset_manager_, path));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::NewReadOnlyMemoryRegionFromFile(
|
||||
const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
|
||||
string path = RemoveAssetPrefix(fname);
|
||||
auto asset = ScopedAsset(
|
||||
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_STREAMING));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", fname, " not found.");
|
||||
}
|
||||
|
||||
off64_t start, length;
|
||||
int fd = AAsset_openFileDescriptor64(asset.get(), &start, &length);
|
||||
std::unique_ptr<char[]> data;
|
||||
if (fd >= 0) {
|
||||
data.reset(new char[length]);
|
||||
ssize_t result = pread(fd, data.get(), length, start);
|
||||
if (result < 0) {
|
||||
return errors::Internal("Error reading from file ", fname,
|
||||
" using 'read': ", result);
|
||||
}
|
||||
if (result != length) {
|
||||
return errors::Internal("Expected size does not match size read: ",
|
||||
"Expected ", length, " vs. read ", result);
|
||||
}
|
||||
close(fd);
|
||||
} else {
|
||||
length = AAsset_getLength64(asset.get());
|
||||
data.reset(new char[length]);
|
||||
const void* asset_buffer = AAsset_getBuffer(asset.get());
|
||||
if (asset_buffer == nullptr) {
|
||||
return errors::Internal("Error reading ", fname, " from asset manager.");
|
||||
}
|
||||
memcpy(data.get(), asset_buffer, length);
|
||||
}
|
||||
result->reset(new ReadOnlyMemoryRegionFromAsset(std::move(data), length));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::GetChildren(const string& prefixed_dir,
|
||||
std::vector<string>* r) {
|
||||
std::string path = NormalizeDirectoryPath(prefixed_dir);
|
||||
auto dir =
|
||||
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
|
||||
if (dir.get() == nullptr) {
|
||||
return errors::NotFound("Directory ", prefixed_dir, " not found.");
|
||||
}
|
||||
const char* next_file = AAssetDir_getNextFileName(dir.get());
|
||||
while (next_file != nullptr) {
|
||||
r->push_back(next_file);
|
||||
next_file = AAssetDir_getNextFileName(dir.get());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::GetFileSize(const string& fname, uint64* s) {
|
||||
// If fname corresponds to a directory, return early. It doesn't map to an
|
||||
// AAsset, and would otherwise return NotFound.
|
||||
if (DirectoryExists(fname)) {
|
||||
*s = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
string path = RemoveAssetPrefix(fname);
|
||||
auto asset = ScopedAsset(
|
||||
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", fname, " not found.");
|
||||
}
|
||||
*s = AAsset_getLength64(asset.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::Stat(const string& fname, FileStatistics* stat) {
|
||||
uint64 size;
|
||||
stat->is_directory = DirectoryExists(fname);
|
||||
TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
|
||||
stat->length = size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) {
|
||||
return RemoveSuffix(RemoveAssetPrefix(fname), "/");
|
||||
}
|
||||
|
||||
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
|
||||
StringPiece piece(name);
|
||||
absl::ConsumePrefix(&piece, prefix_);
|
||||
return string(piece);
|
||||
}
|
||||
|
||||
bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) {
|
||||
std::string path = NormalizeDirectoryPath(fname);
|
||||
auto dir =
|
||||
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
|
||||
// Note that openDir will return something even if the directory doesn't
|
||||
// exist. Therefore, we need to ensure one file exists in the folder.
|
||||
return AAssetDir_getNextFileName(dir.get()) != NULL;
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern,
|
||||
std::vector<string>* results) {
|
||||
return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::NewWritableFile(
|
||||
const string& fname, std::unique_ptr<WritableFile>* result) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::NewAppendableFile(
|
||||
const string& fname, std::unique_ptr<WritableFile>* result) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::DeleteFile(const string& f) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::CreateDir(const string& d) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::DeleteDir(const string& d) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::RenameFile(const string& s, const string& t) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,85 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
|
||||
#define TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
|
||||
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// FileSystem that uses Android's AAssetManager. Once initialized with a given
|
||||
// AAssetManager, files in the given AAssetManager can be accessed through the
|
||||
// prefix given when registered with the TensorFlow Env.
|
||||
// Note that because APK assets are immutable, any operation that tries to
|
||||
// modify the FileSystem will return tensorflow::error::code::UNIMPLEMENTED.
|
||||
class AssetManagerFileSystem : public FileSystem {
|
||||
public:
|
||||
// Initialize an AssetManagerFileSystem. Note that this does not register the
|
||||
// file system with TensorFlow.
|
||||
// asset_manager - Non-null Android AAssetManager that backs this file
|
||||
// system. The asset manager is not owned by this file system, and must
|
||||
// outlive this class.
|
||||
// prefix - Common prefix to strip from all file URIs before passing them to
|
||||
// the asset_manager. This is required because TensorFlow gives the entire
|
||||
// file URI (file:///my_dir/my_file.txt) and AssetManager only knows paths
|
||||
// relative to its base directory.
|
||||
AssetManagerFileSystem(AAssetManager* asset_manager, const string& prefix);
|
||||
~AssetManagerFileSystem() override = default;
|
||||
|
||||
Status FileExists(const string& fname) override;
|
||||
Status NewRandomAccessFile(
|
||||
const string& filename,
|
||||
std::unique_ptr<RandomAccessFile>* result) override;
|
||||
Status NewReadOnlyMemoryRegionFromFile(
|
||||
const string& filename,
|
||||
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
|
||||
|
||||
Status GetFileSize(const string& f, uint64* s) override;
|
||||
// Currently just returns size.
|
||||
Status Stat(const string& fname, FileStatistics* stat) override;
|
||||
Status GetChildren(const string& dir, std::vector<string>* r) override;
|
||||
|
||||
// All these functions return Unimplemented error. Asset storage is
|
||||
// read only.
|
||||
Status NewWritableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override;
|
||||
Status NewAppendableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override;
|
||||
Status DeleteFile(const string& f) override;
|
||||
Status CreateDir(const string& d) override;
|
||||
Status DeleteDir(const string& d) override;
|
||||
Status RenameFile(const string& s, const string& t) override;
|
||||
|
||||
Status GetMatchingPaths(const string& pattern,
|
||||
std::vector<string>* results) override;
|
||||
|
||||
private:
|
||||
string RemoveAssetPrefix(const string& name);
|
||||
|
||||
// Return a string path that can be passed into AAssetManager functions.
|
||||
// For example, 'my_prefix://some/dir/' would return 'some/dir'.
|
||||
string NormalizeDirectoryPath(const string& fname);
|
||||
bool DirectoryExists(const std::string& fname);
|
||||
|
||||
AAssetManager* asset_manager_;
|
||||
string prefix_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
|
@ -0,0 +1,80 @@
|
||||
#
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
cmake_minimum_required(VERSION 3.4.1)
|
||||
include(ExternalProject)
|
||||
|
||||
# TENSORFLOW_ROOT_DIR:
|
||||
# root directory of tensorflow repo
|
||||
# used for shared source files and pre-built libs
|
||||
get_filename_component(TENSORFLOW_ROOT_DIR ../../../.. ABSOLUTE)
|
||||
set(PREBUILT_DIR ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen)
|
||||
|
||||
add_library(lib_proto STATIC IMPORTED )
|
||||
set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION
|
||||
${PREBUILT_DIR}/protobuf/lib/libprotobuf.a)
|
||||
|
||||
add_library(lib_nsync STATIC IMPORTED )
|
||||
set_target_properties(lib_nsync PROPERTIES IMPORTED_LOCATION
|
||||
${TARGET_NSYNC_LIB}/lib/libnsync.a)
|
||||
|
||||
add_library(lib_tf STATIC IMPORTED )
|
||||
set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
|
||||
${PREBUILT_DIR}/lib/libtensorflow-core.a)
|
||||
# Change to compile flags should be replicated into bazel build file
|
||||
# TODO: Consider options other than -O2 for binary size.
|
||||
# e.g. -Os for gcc, and -Oz for clang.
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \
|
||||
-std=c++11 -fno-rtti -fno-exceptions \
|
||||
-O2 -Wno-narrowing -fomit-frame-pointer \
|
||||
-mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \
|
||||
-ftemplate-depth=900 \
|
||||
-DGOOGLE_PROTOBUF_NO_RTTI \
|
||||
-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")
|
||||
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
|
||||
-Wl,--allow-multiple-definition \
|
||||
-Wl,--whole-archive \
|
||||
-fPIE -pie -v")
|
||||
file(GLOB tensorflow_inference_sources
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../jni/*.cc)
|
||||
file(GLOB java_api_native_sources
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/java/src/main/native/*.cc)
|
||||
|
||||
add_library(tensorflow_inference SHARED
|
||||
${tensorflow_inference_sources}
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/c/tf_status_helper.cc
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/c/checkpoint_reader.cc
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/c/test_op.cc
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/c/c_api.cc
|
||||
${java_api_native_sources})
|
||||
|
||||
# Include libraries needed for hello-jni lib
|
||||
target_link_libraries(tensorflow_inference
|
||||
android
|
||||
dl
|
||||
log
|
||||
m
|
||||
z
|
||||
lib_tf
|
||||
lib_proto
|
||||
lib_nsync)
|
||||
|
||||
include_directories(
|
||||
${PREBUILT_DIR}/proto
|
||||
${PREBUILT_DIR}/protobuf/include
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen
|
||||
${TENSORFLOW_ROOT_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/..)
|
48
tensorflow/tools/android/inference_interface/cmake/README.md
Normal file
48
tensorflow/tools/android/inference_interface/cmake/README.md
Normal file
@ -0,0 +1,48 @@
|
||||
TensorFlow-Android-Inference
|
||||
============================
|
||||
This directory contains CMake support for building the Android Java Inference
|
||||
interface to the TensorFlow native APIs.
|
||||
|
||||
See [tensorflow/tools/android/inference_interface](..) for more details about
|
||||
the library, and instructions for building with Bazel.
|
||||
|
||||
Usage
|
||||
-----
|
||||
Add TensorFlow-Android-Inference as a dependency of your Android application
|
||||
|
||||
* settings.gradle
|
||||
|
||||
```
|
||||
include ':TensorFlow-Android-Inference'
|
||||
findProject(":TensorFlow-Android-Inference").projectDir =
|
||||
new File("${/path/to/tensorflow_repo}/examples/android_inference_interface/cmake")
|
||||
```
|
||||
|
||||
* application's build.gradle (adding dependency):
|
||||
|
||||
```
|
||||
debugCompile project(path: ':tensorflow_inference', configuration: 'debug')
|
||||
releaseCompile project(path: ':tensorflow_inference', configuration: 'release')
|
||||
```
|
||||
Note: this makes native code in the lib traceable from your app.
|
||||
|
||||
Dependencies
|
||||
------------
|
||||
TensorFlow-Android-Inference depends on the TensorFlow static libs already built
|
||||
in your local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh
|
||||
is used in build.gradle to build it. It DOES take time to build the core libs;
|
||||
so, by default, it is commented out to avoid confusion (otherwise
|
||||
Android Studio would appear to hang during opening the project).
|
||||
To enable it, refer to the comment in
|
||||
|
||||
* build.gradle
|
||||
|
||||
Output
|
||||
------
|
||||
- TensorFlow-Inference-debug.aar
|
||||
- TensorFlow-Inference-release.aar
|
||||
|
||||
File libtensorflow_inference.so should be packed under jni/${ANDROID_ABI}/
|
||||
in the above aar, and it is transparent to the app as it will access them via
|
||||
equivalent java APIs.
|
||||
|
105
tensorflow/tools/android/inference_interface/cmake/build.gradle
Normal file
105
tensorflow/tools/android/inference_interface/cmake/build.gradle
Normal file
@ -0,0 +1,105 @@
|
||||
apply plugin: 'com.android.library'
|
||||
|
||||
// TensorFlow repo root dir on local machine
|
||||
def TF_SRC_DIR = projectDir.toString() + "/../../../.."
|
||||
|
||||
android {
|
||||
compileSdkVersion 24
|
||||
// Check local build_tools_version as this is liable to change within Android Studio.
|
||||
buildToolsVersion '25.0.2'
|
||||
|
||||
// for debugging native code purpose
|
||||
publishNonDefault true
|
||||
|
||||
defaultConfig {
|
||||
archivesBaseName = "Tensorflow-Android-Inference"
|
||||
minSdkVersion 21
|
||||
targetSdkVersion 23
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
ndk {
|
||||
abiFilters 'armeabi-v7a'
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
arguments '-DANDROID_TOOLCHAIN=clang',
|
||||
'-DANDROID_STL=c++_static'
|
||||
}
|
||||
}
|
||||
}
|
||||
sourceSets {
|
||||
main {
|
||||
java {
|
||||
srcDir "${TF_SRC_DIR}/tensorflow/tools/android/inference_interface/java"
|
||||
srcDir "${TF_SRC_DIR}/tensorflow/java/src/main/java"
|
||||
exclude '**/examples/**'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path 'CMakeLists.txt'
|
||||
}
|
||||
}
|
||||
buildTypes {
|
||||
release {
|
||||
minifyEnabled false
|
||||
proguardFiles getDefaultProguardFile('proguard-android.txt'),
|
||||
'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build libtensorflow-core.a if necessary
|
||||
// Note: the environment needs to be set up already
|
||||
// [ such as installing autoconfig, make, etc ]
|
||||
// How to use:
|
||||
// 1) install all of the necessary tools to build libtensorflow-core.a
|
||||
// 2) inside Android Studio IDE, uncomment buildTensorFlow in
|
||||
// whenTaskAdded{...}
|
||||
// 3) re-sync and re-build. It could take a long time if NOT building
|
||||
// with multiple processes.
|
||||
import org.apache.tools.ant.taskdefs.condition.Os
|
||||
|
||||
Properties properties = new Properties()
|
||||
properties.load(project.rootProject.file('local.properties')
|
||||
.newDataInputStream())
|
||||
def ndkDir = properties.getProperty('ndk.dir')
|
||||
if (ndkDir == null || ndkDir == "") {
|
||||
ndkDir = System.getenv('ANDROID_NDK_HOME')
|
||||
}
|
||||
|
||||
if (!Os.isFamily(Os.FAMILY_WINDOWS)) {
|
||||
// This script is for non-Windows OS. For Windows OS, MANUALLY build
|
||||
// (or copy the built) libs/headers to the
|
||||
// ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen
|
||||
// refer to CMakeLists.txt about lib and header directories for details
|
||||
task buildTensorflow(type: Exec) {
|
||||
group 'buildTensorflowLib'
|
||||
workingDir getProjectDir().toString() + '/../../../../'
|
||||
environment PATH: '/opt/local/bin:/opt/local/sbin:' +
|
||||
System.getenv('PATH')
|
||||
environment NDK_ROOT: ndkDir
|
||||
commandLine 'tensorflow/contrib/makefile/build_all_android.sh'
|
||||
}
|
||||
|
||||
tasks.whenTaskAdded { task ->
|
||||
group 'buildTensorflowLib'
|
||||
if (task.name.toLowerCase().contains('sources')) {
|
||||
def tensorflowTarget = new File(getProjectDir().toString() +
|
||||
'/../../makefile/gen/lib/libtensorflow-core.a')
|
||||
if (!tensorflowTarget.exists()) {
|
||||
// Note:
|
||||
// just uncomment this line to use it:
|
||||
// it can take long time to build by default
|
||||
// it is disabled to avoid false first impression
|
||||
task.dependsOn buildTensorflow
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compile fileTree(dir: 'libs', include: ['*.jar'])
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="org.tensorflow.contrib.android">
|
||||
|
||||
<uses-sdk
|
||||
android:minSdkVersion="4"
|
||||
android:targetSdkVersion="19" />
|
||||
|
||||
<application android:allowBackup="true" android:label="@string/app_name"
|
||||
android:supportsRtl="true">
|
||||
|
||||
</application>
|
||||
|
||||
</manifest>
|
@ -0,0 +1,3 @@
|
||||
<resources>
|
||||
<string name="app_name">TensorFlowInference</string>
|
||||
</resources>
|
@ -0,0 +1,63 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.contrib.android;
|
||||
|
||||
/** Accumulate and analyze stats from metadata obtained from Session.Runner.run. */
|
||||
public class RunStats implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Options to be provided to a {@link org.tensorflow.Session.Runner} to enable stats accumulation.
|
||||
*/
|
||||
public static byte[] runOptions() {
|
||||
return fullTraceRunOptions;
|
||||
}
|
||||
|
||||
public RunStats() {
|
||||
nativeHandle = allocate();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (nativeHandle != 0) {
|
||||
delete(nativeHandle);
|
||||
}
|
||||
nativeHandle = 0;
|
||||
}
|
||||
|
||||
/** Accumulate stats obtained when executing a graph. */
|
||||
public synchronized void add(byte[] runMetadata) {
|
||||
add(nativeHandle, runMetadata);
|
||||
}
|
||||
|
||||
/** Summary of the accumulated runtime stats. */
|
||||
public synchronized String summary() {
|
||||
return summary(nativeHandle);
|
||||
}
|
||||
|
||||
private long nativeHandle;
|
||||
|
||||
// Hack: This is what a serialized RunOptions protocol buffer with trace_level: FULL_TRACE ends
|
||||
// up as.
|
||||
private static byte[] fullTraceRunOptions = new byte[] {0x08, 0x03};
|
||||
|
||||
private static native long allocate();
|
||||
|
||||
private static native void delete(long handle);
|
||||
|
||||
private static native void add(long handle, byte[] runMetadata);
|
||||
|
||||
private static native String summary(long handle);
|
||||
}
|
@ -0,0 +1,650 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.contrib.android;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.os.Build.VERSION;
|
||||
import android.os.Trace;
|
||||
import android.text.TextUtils;
|
||||
import android.util.Log;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.tensorflow.Graph;
|
||||
import org.tensorflow.Operation;
|
||||
import org.tensorflow.Session;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.TensorFlow;
|
||||
import org.tensorflow.Tensors;
|
||||
import org.tensorflow.types.UInt8;
|
||||
|
||||
/**
|
||||
* Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
|
||||
* for inference.
|
||||
*
|
||||
* <p>See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an
|
||||
* example usage.
|
||||
*/
|
||||
public class TensorFlowInferenceInterface {
|
||||
private static final String TAG = "TensorFlowInferenceInterface";
|
||||
private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
|
||||
|
||||
/*
|
||||
* Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
|
||||
*
|
||||
* @param assetManager The AssetManager to use to load the model file.
|
||||
* @param model The filepath to the GraphDef proto representing the model.
|
||||
*/
|
||||
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
|
||||
prepareNativeRuntime();
|
||||
|
||||
this.modelName = model;
|
||||
this.g = new Graph();
|
||||
this.sess = new Session(g);
|
||||
this.runner = sess.runner();
|
||||
|
||||
final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX);
|
||||
InputStream is = null;
|
||||
try {
|
||||
String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model;
|
||||
is = assetManager.open(aname);
|
||||
} catch (IOException e) {
|
||||
if (hasAssetPrefix) {
|
||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||
}
|
||||
// Perhaps the model file is not an asset but is on disk.
|
||||
try {
|
||||
is = new FileInputStream(model);
|
||||
} catch (IOException e2) {
|
||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("initializeTensorFlow");
|
||||
Trace.beginSection("readGraphDef");
|
||||
}
|
||||
|
||||
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
|
||||
byte[] graphDef = new byte[is.available()];
|
||||
final int numBytesRead = is.read(graphDef);
|
||||
if (numBytesRead != graphDef.length) {
|
||||
throw new IOException(
|
||||
"read error: read only "
|
||||
+ numBytesRead
|
||||
+ " of the graph, expected to read "
|
||||
+ graphDef.length);
|
||||
}
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // readGraphDef.
|
||||
}
|
||||
|
||||
loadGraph(graphDef, g);
|
||||
is.close();
|
||||
Log.i(TAG, "Successfully loaded model from '" + model + "'");
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // initializeTensorFlow.
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Load a TensorFlow model from provided InputStream.
|
||||
* Note: The InputStream will not be closed after loading model, users need to
|
||||
* close it themselves.
|
||||
*
|
||||
* @param is The InputStream to use to load the model.
|
||||
*/
|
||||
public TensorFlowInferenceInterface(InputStream is) {
|
||||
prepareNativeRuntime();
|
||||
|
||||
// modelName is redundant for model loading from input stream, here is for
|
||||
// avoiding error in initialization as modelName is marked final.
|
||||
this.modelName = "";
|
||||
this.g = new Graph();
|
||||
this.sess = new Session(g);
|
||||
this.runner = sess.runner();
|
||||
|
||||
try {
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("initializeTensorFlow");
|
||||
Trace.beginSection("readGraphDef");
|
||||
}
|
||||
|
||||
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
|
||||
int numBytesRead;
|
||||
byte[] buf = new byte[16384];
|
||||
while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
|
||||
baos.write(buf, 0, numBytesRead);
|
||||
}
|
||||
byte[] graphDef = baos.toByteArray();
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // readGraphDef.
|
||||
}
|
||||
|
||||
loadGraph(graphDef, g);
|
||||
Log.i(TAG, "Successfully loaded model from the input stream");
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // initializeTensorFlow.
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load model from the input stream", e);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Construct a TensorFlowInferenceInterface with provided Graph
|
||||
*
|
||||
* @param g The Graph to use to construct this interface.
|
||||
*/
|
||||
public TensorFlowInferenceInterface(Graph g) {
|
||||
prepareNativeRuntime();
|
||||
|
||||
// modelName is redundant here, here is for
|
||||
// avoiding error in initialization as modelName is marked final.
|
||||
this.modelName = "";
|
||||
this.g = g;
|
||||
this.sess = new Session(g);
|
||||
this.runner = sess.runner();
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs inference between the previously registered input nodes (via feed*) and the requested
|
||||
* output nodes. Output nodes can then be queried with the fetch* methods.
|
||||
*
|
||||
* @param outputNames A list of output nodes which should be filled by the inference pass.
|
||||
*/
|
||||
public void run(String[] outputNames) {
|
||||
run(outputNames, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs inference between the previously registered input nodes (via feed*) and the requested
|
||||
* output nodes. Output nodes can then be queried with the fetch* methods.
|
||||
*
|
||||
* @param outputNames A list of output nodes which should be filled by the inference pass.
|
||||
*/
|
||||
public void run(String[] outputNames, boolean enableStats) {
|
||||
run(outputNames, enableStats, new String[] {});
|
||||
}
|
||||
|
||||
/** An overloaded version of runInference that allows supplying targetNodeNames as well */
|
||||
public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) {
|
||||
// Release any Tensors from the previous run calls.
|
||||
closeFetches();
|
||||
|
||||
// Add fetches.
|
||||
for (String o : outputNames) {
|
||||
fetchNames.add(o);
|
||||
TensorId tid = TensorId.parse(o);
|
||||
runner.fetch(tid.name, tid.outputIndex);
|
||||
}
|
||||
|
||||
// Add targets.
|
||||
for (String t : targetNodeNames) {
|
||||
runner.addTarget(t);
|
||||
}
|
||||
|
||||
// Run the session.
|
||||
try {
|
||||
if (enableStats) {
|
||||
Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
|
||||
fetchTensors = r.outputs;
|
||||
|
||||
if (runStats == null) {
|
||||
runStats = new RunStats();
|
||||
}
|
||||
runStats.add(r.metadata);
|
||||
} else {
|
||||
fetchTensors = runner.run();
|
||||
}
|
||||
} catch (RuntimeException e) {
|
||||
// Ideally the exception would have been let through, but since this interface predates the
|
||||
// TensorFlow Java API, must return -1.
|
||||
Log.e(
|
||||
TAG,
|
||||
"Failed to run TensorFlow inference with inputs:["
|
||||
+ TextUtils.join(", ", feedNames)
|
||||
+ "], outputs:["
|
||||
+ TextUtils.join(", ", fetchNames)
|
||||
+ "]");
|
||||
throw e;
|
||||
} finally {
|
||||
// Always release the feeds (to save resources) and reset the runner, this run is
|
||||
// over.
|
||||
closeFeeds();
|
||||
runner = sess.runner();
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns a reference to the Graph describing the computation run during inference. */
|
||||
public Graph graph() {
|
||||
return g;
|
||||
}
|
||||
|
||||
public Operation graphOperation(String operationName) {
|
||||
final Operation operation = g.operation(operationName);
|
||||
if (operation == null) {
|
||||
throw new RuntimeException(
|
||||
"Node '" + operationName + "' does not exist in model '" + modelName + "'");
|
||||
}
|
||||
return operation;
|
||||
}
|
||||
|
||||
/** Returns the last stat summary string if logging is enabled. */
|
||||
public String getStatString() {
|
||||
return (runStats == null) ? "" : runStats.summary();
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleans up the state associated with this Object.
|
||||
*
|
||||
* <p>The TenosrFlowInferenceInterface object is no longer usable after this method returns.
|
||||
*/
|
||||
public void close() {
|
||||
closeFeeds();
|
||||
closeFetches();
|
||||
sess.close();
|
||||
g.close();
|
||||
if (runStats != null) {
|
||||
runStats.close();
|
||||
}
|
||||
runStats = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
try {
|
||||
close();
|
||||
} finally {
|
||||
super.finalize();
|
||||
}
|
||||
}
|
||||
|
||||
// Methods for taking a native Tensor and filling it with values from Java arrays.
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, boolean[] src, long... dims) {
|
||||
byte[] b = new byte[src.length];
|
||||
|
||||
for (int i = 0; i < src.length; i++) {
|
||||
b[i] = src[i] ? (byte) 1 : (byte) 0;
|
||||
}
|
||||
|
||||
addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, float[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, int[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, long[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, double[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, byte[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
|
||||
* scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
|
||||
* a Java {@code String} (which is a sequence of characters).
|
||||
*/
|
||||
public void feedString(String inputName, byte[] src) {
|
||||
addFeed(inputName, Tensors.create(src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
|
||||
* string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
|
||||
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
|
||||
*/
|
||||
public void feedString(String inputName, byte[][] src) {
|
||||
addFeed(inputName, Tensors.create(src));
|
||||
}
|
||||
|
||||
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, FloatBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, IntBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, LongBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, DoubleBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, ByteBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(UInt8.class, dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, float[] dst) {
|
||||
fetch(outputName, FloatBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, int[] dst) {
|
||||
fetch(outputName, IntBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, long[] dst) {
|
||||
fetch(outputName, LongBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, double[] dst) {
|
||||
fetch(outputName, DoubleBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, byte[] dst) {
|
||||
fetch(outputName, ByteBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, FloatBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, IntBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, LongBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, DoubleBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, ByteBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
private void prepareNativeRuntime() {
|
||||
Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
|
||||
try {
|
||||
// Hack to see if the native libraries have been loaded.
|
||||
new RunStats();
|
||||
Log.i(TAG, "TensorFlow native methods already loaded");
|
||||
} catch (UnsatisfiedLinkError e1) {
|
||||
Log.i(
|
||||
TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
|
||||
try {
|
||||
System.loadLibrary("tensorflow_inference");
|
||||
Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
|
||||
} catch (UnsatisfiedLinkError e2) {
|
||||
throw new RuntimeException(
|
||||
"Native TF methods not found; check that the correct native"
|
||||
+ " libraries are present in the APK.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void loadGraph(byte[] graphDef, Graph g) throws IOException {
|
||||
final long startMs = System.currentTimeMillis();
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("importGraphDef");
|
||||
}
|
||||
|
||||
try {
|
||||
g.importGraphDef(graphDef);
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
|
||||
}
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // importGraphDef.
|
||||
}
|
||||
|
||||
final long endMs = System.currentTimeMillis();
|
||||
Log.i(
|
||||
TAG,
|
||||
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
|
||||
}
|
||||
|
||||
private void addFeed(String inputName, Tensor<?> t) {
|
||||
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
|
||||
TensorId tid = TensorId.parse(inputName);
|
||||
runner.feed(tid.name, tid.outputIndex, t);
|
||||
feedNames.add(inputName);
|
||||
feedTensors.add(t);
|
||||
}
|
||||
|
||||
private static class TensorId {
|
||||
String name;
|
||||
int outputIndex;
|
||||
|
||||
// Parse output names into a TensorId.
|
||||
//
|
||||
// E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1)
|
||||
public static TensorId parse(String name) {
|
||||
TensorId tid = new TensorId();
|
||||
int colonIndex = name.lastIndexOf(':');
|
||||
if (colonIndex < 0) {
|
||||
tid.outputIndex = 0;
|
||||
tid.name = name;
|
||||
return tid;
|
||||
}
|
||||
try {
|
||||
tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1));
|
||||
tid.name = name.substring(0, colonIndex);
|
||||
} catch (NumberFormatException e) {
|
||||
tid.outputIndex = 0;
|
||||
tid.name = name;
|
||||
}
|
||||
return tid;
|
||||
}
|
||||
}
|
||||
|
||||
private Tensor<?> getTensor(String outputName) {
|
||||
int i = 0;
|
||||
for (String n : fetchNames) {
|
||||
if (n.equals(outputName)) {
|
||||
return fetchTensors.get(i);
|
||||
}
|
||||
++i;
|
||||
}
|
||||
throw new RuntimeException(
|
||||
"Node '" + outputName + "' was not provided to run(), so it cannot be read");
|
||||
}
|
||||
|
||||
private void closeFeeds() {
|
||||
for (Tensor<?> t : feedTensors) {
|
||||
t.close();
|
||||
}
|
||||
feedTensors.clear();
|
||||
feedNames.clear();
|
||||
}
|
||||
|
||||
private void closeFetches() {
|
||||
for (Tensor<?> t : fetchTensors) {
|
||||
t.close();
|
||||
}
|
||||
fetchTensors.clear();
|
||||
fetchNames.clear();
|
||||
}
|
||||
|
||||
// Immutable state.
|
||||
private final String modelName;
|
||||
private final Graph g;
|
||||
private final Session sess;
|
||||
|
||||
// State reset on every call to run.
|
||||
private Session.Runner runner;
|
||||
private List<String> feedNames = new ArrayList<String>();
|
||||
private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
|
||||
private List<String> fetchNames = new ArrayList<String>();
|
||||
private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();
|
||||
|
||||
// Mutable state.
|
||||
private RunStats runStats;
|
||||
}
|
@ -0,0 +1,83 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/tools/android/inference_interface/jni/run_stats_jni.h"
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/stat_summarizer.h"
|
||||
|
||||
using tensorflow::RunMetadata;
|
||||
using tensorflow::StatSummarizer;
|
||||
|
||||
namespace {
|
||||
StatSummarizer* requireHandle(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
|
||||
"close() has been called on the RunStats object");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<StatSummarizer*>(handle);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#define RUN_STATS_METHOD(name) \
|
||||
JNICALL Java_org_tensorflow_contrib_android_RunStats_##name
|
||||
|
||||
JNIEXPORT jlong RUN_STATS_METHOD(allocate)(JNIEnv* env, jclass clazz) {
|
||||
static_assert(sizeof(jlong) >= sizeof(StatSummarizer*),
|
||||
"Cannot package C++ object pointers as a Java long");
|
||||
tensorflow::StatSummarizerOptions opts;
|
||||
return reinterpret_cast<jlong>(new StatSummarizer(opts));
|
||||
}
|
||||
|
||||
JNIEXPORT void RUN_STATS_METHOD(delete)(JNIEnv* env, jclass clazz,
|
||||
jlong handle) {
|
||||
if (handle == 0) return;
|
||||
delete reinterpret_cast<StatSummarizer*>(handle);
|
||||
}
|
||||
|
||||
JNIEXPORT void RUN_STATS_METHOD(add)(JNIEnv* env, jclass clazz, jlong handle,
|
||||
jbyteArray run_metadata) {
|
||||
StatSummarizer* s = requireHandle(env, handle);
|
||||
if (s == nullptr) return;
|
||||
jbyte* data = env->GetByteArrayElements(run_metadata, nullptr);
|
||||
int size = static_cast<int>(env->GetArrayLength(run_metadata));
|
||||
tensorflow::RunMetadata proto;
|
||||
if (!proto.ParseFromArray(data, size)) {
|
||||
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
|
||||
"runMetadata does not seem to be a serialized RunMetadata "
|
||||
"protocol message");
|
||||
} else if (proto.has_step_stats()) {
|
||||
s->ProcessStepStats(proto.step_stats());
|
||||
}
|
||||
env->ReleaseByteArrayElements(run_metadata, data, JNI_ABORT);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz,
|
||||
jlong handle) {
|
||||
StatSummarizer* s = requireHandle(env, handle);
|
||||
if (s == nullptr) return nullptr;
|
||||
std::stringstream ret;
|
||||
ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME,
|
||||
10)
|
||||
<< s->GetStatsByNodeType() << s->ShortSummary();
|
||||
return env->NewStringUTF(ret.str().c_str());
|
||||
}
|
||||
|
||||
#undef RUN_STATS_METHOD
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
|
||||
#define ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
#define RUN_STATS_METHOD(name) \
|
||||
Java_org_tensorflow_contrib_android_RunStats_##name
|
||||
|
||||
JNIEXPORT JNICALL jlong RUN_STATS_METHOD(allocate)(JNIEnv*, jclass);
|
||||
JNIEXPORT JNICALL void RUN_STATS_METHOD(delete)(JNIEnv*, jclass, jlong);
|
||||
JNIEXPORT JNICALL void RUN_STATS_METHOD(add)(JNIEnv*, jclass, jlong,
|
||||
jbyteArray);
|
||||
JNIEXPORT JNICALL jstring RUN_STATS_METHOD(summary)(JNIEnv*, jclass, jlong);
|
||||
|
||||
#undef RUN_STATS_METHOD
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
|
@ -0,0 +1,11 @@
|
||||
VERS_1.0 {
|
||||
# Export JNI symbols.
|
||||
global:
|
||||
Java_*;
|
||||
JNI_OnLoad;
|
||||
JNI_OnUnload;
|
||||
|
||||
# Hide everything else.
|
||||
local:
|
||||
*;
|
||||
};
|
Loading…
Reference in New Issue
Block a user