diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index 50470c027e7..8baf6229c4e 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -53,7 +53,7 @@ cc_library( name = "tensorflow_native_libs", srcs = [ ":libtensorflow_demo.so", - "//tensorflow/contrib/android:libtensorflow_inference.so", + "//tensorflow/tools/android/inference_interface:libtensorflow_inference.so", ], tags = [ "manual", @@ -84,7 +84,7 @@ android_binary( ], deps = [ ":tensorflow_native_libs", - "//tensorflow/contrib/android:android_tensorflow_inference_java", + "//tensorflow/tools/android/inference_interface:android_tensorflow_inference_java", ], ) diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index bb646d2da0e..af93adfb1e4 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -9,9 +9,9 @@ The demos in this folder are designed to give straightforward samples of using TensorFlow in mobile applications. Inference is done using the [TensorFlow Android Inference -Interface](../../../tensorflow/contrib/android), which may be built separately -if you want a standalone library to drop into your existing application. Object -tracking and efficient YUV -> RGB conversion are handled by +Interface](../../tools/android/inference_interface), which may be built +separately if you want a standalone library to drop into your existing +application. Object tracking and efficient YUV -> RGB conversion are handled by `libtensorflow_demo.so`. A device running Android 5.0 (API 21) or higher is required to run the demo due @@ -49,7 +49,7 @@ The fastest path to trying the demo is to download the [prebuilt demo APK](https Also available are precompiled native libraries, and a jcenter package that you may simply drop into your own applications. See -[tensorflow/contrib/android/README.md](../../../tensorflow/contrib/android/README.md) +[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md) for more details. ## Running the Demo @@ -89,7 +89,7 @@ For any project that does not include custom low level TensorFlow code, this is likely sufficient. For details on how to include this JCenter package in your own project see -[tensorflow/contrib/android/README.md](../../../tensorflow/contrib/android/README.md) +[tensorflow/tools/android/inference_interface/README.md](../../tools/android/inference_interface/README.md) ## Building the Demo with TensorFlow from Source @@ -212,4 +212,4 @@ NDK). Full CMake support for the demo is coming soon, but for now it is possible to build the TensorFlow Android Inference library using -[tensorflow/contrib/android/cmake](../../../tensorflow/contrib/android/cmake). +[tensorflow/tools/android/inference_interface/cmake](../../tools/android/inference_interface/cmake). diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 92b84b6ca81..884c52e6da2 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -41,6 +41,7 @@ filegroup( visibility = [ "//tensorflow/contrib/android:__pkg__", "//tensorflow/java:__pkg__", + "//tensorflow/tools/android/inference_interface:__pkg__", ], ) diff --git a/tensorflow/java/src/main/native/BUILD b/tensorflow/java/src/main/native/BUILD index a80955be9d2..4c3077e0d16 100644 --- a/tensorflow/java/src/main/native/BUILD +++ b/tensorflow/java/src/main/native/BUILD @@ -5,11 +5,12 @@ package(default_visibility = [ "//tensorflow/java:__pkg__", # TODO(ashankar): Temporary hack for the Java API and - # //third_party/tensorflow/contrib/android:android_tensorflow_inference_jni + # //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_inference_jni # to co-exist in a single shared library. However, the hope is that - # //third_party/tensorflow/contrib/android:android_tensorflow_jni can be + # //third_party/tensorflow/tools/android/inference_interface:android_tensorflow_jni can be # removed once the Java API provides feature parity with it. "//tensorflow/contrib/android:__pkg__", + "//tensorflow/tools/android/inference_interface:__pkg__", ]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/tools/android/README.md b/tensorflow/tools/android/README.md new file mode 100644 index 00000000000..750b6a8f90e --- /dev/null +++ b/tensorflow/tools/android/README.md @@ -0,0 +1,4 @@ +# Deprecated android inference interface. + +WARNING: This directory contains deprecated tf-mobile android inference +interface do not use this for anything new. Use TFLite. diff --git a/tensorflow/tools/android/inference_interface/BUILD b/tensorflow/tools/android/inference_interface/BUILD new file mode 100644 index 00000000000..00d23b274e5 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/BUILD @@ -0,0 +1,89 @@ +# Description: +# JNI-based Java inference interface for TensorFlow. + +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load( + "//tensorflow:tensorflow.bzl", + "if_android", + "tf_cc_binary", + "tf_copts", +) + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "LICENSE", + "jni/version_script.lds", +]) + +filegroup( + name = "android_tensorflow_inference_jni_srcs", + srcs = glob([ + "**/*.cc", + "**/*.h", + ]), + visibility = ["//visibility:public"], +) + +cc_library( + name = "android_tensorflow_inference_jni", + srcs = if_android([":android_tensorflow_inference_jni_srcs"]), + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/java/src/main/native", + ], + alwayslink = 1, +) + +# JAR with Java bindings to TF. +android_library( + name = "android_tensorflow_inference_java", + srcs = glob(["java/**/*.java"]) + ["//tensorflow/java:java_sources"], + tags = [ + "manual", + "notap", + ], +) + +# Build the native .so. +# bazel build //tensorflow/tools/android/inference_interface:libtensorflow_inference.so \ +# --crosstool_top=//external:android/crosstool \ +# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ +# --cpu=armeabi-v7a +LINKER_SCRIPT = "//tensorflow/tools/android/inference_interface:jni/version_script.lds" + +# This fails to buiild if converted to tf_cc_binary. +cc_binary( + name = "libtensorflow_inference.so", + copts = tf_copts() + [ + "-ffunction-sections", + "-fdata-sections", + ], + linkopts = if_android([ + "-landroid", + "-latomic", + "-ldl", + "-llog", + "-lm", + "-z defs", + "-s", + "-Wl,--gc-sections", + "-Wl,--version-script,$(location {})".format(LINKER_SCRIPT), + ]), + linkshared = 1, + linkstatic = 1, + tags = [ + "manual", + "notap", + ], + deps = [ + ":android_tensorflow_inference_jni", + "//tensorflow/core:android_tensorflow_lib", + LINKER_SCRIPT, + ], +) diff --git a/tensorflow/tools/android/inference_interface/README.md b/tensorflow/tools/android/inference_interface/README.md new file mode 100644 index 00000000000..31045614a50 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/README.md @@ -0,0 +1,95 @@ +# Android TensorFlow support + +This directory defines components (a native `.so` library and a Java JAR) +geared towards supporting TensorFlow on Android. This includes: + +- The [TensorFlow Java API](../../java/README.md) +- A `TensorFlowInferenceInterface` class that provides a smaller API + surface suitable for inference and summarizing performance of model execution. + +For example usage, see [TensorFlowImageClassifier.java](../../examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java) +in the [TensorFlow Android Demo](../../examples/android). + +For prebuilt libraries, see the +[nightly Android build artifacts](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) +page for a recent build. + +The TensorFlow Inference Interface is also available as a +[JCenter package](https://bintray.com/google/tensorflow/tensorflow) +(see the tensorflow-android directory) and can be included quite simply in your +android project with a couple of lines in the project's `build.gradle` file: + +``` +allprojects { + repositories { + jcenter() + } +} + +dependencies { + compile 'org.tensorflow:tensorflow-android:+' +} +``` + +This will tell Gradle to use the +[latest version](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +of the TensorFlow AAR that has been released to +[JCenter](https://jcenter.bintray.com/org/tensorflow/tensorflow-android/). +You may replace the `+` with an explicit version label if you wish to +use a specific release of TensorFlow in your app. + +To build the libraries yourself (if, for example, you want to support custom +TensorFlow operators), pick your preferred approach below: + +### Bazel + +First follow the Bazel setup instructions described in +[tensorflow/examples/android/README.md](../../examples/android/README.md) + +Then, to build the native TF library: + +``` +bazel build -c opt //tensorflow/tools/android/inference_interface:libtensorflow_inference.so \ + --crosstool_top=//external:android/crosstool \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --cxxopt=-std=c++11 \ + --cpu=armeabi-v7a +``` + +Replacing `armeabi-v7a` with your desired target architecture. + +The library will be located at: + +``` +bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so +``` + +To build the Java counterpart: + +``` +bazel build //tensorflow/tools/android/inference_interface:android_tensorflow_inference_java +``` + +You will find the JAR file at: + +``` +bazel-bin/tensorflow/tools/android/inference_interface/libandroid_tensorflow_inference_java.jar +``` + +### CMake + +For documentation on building a self-contained AAR file with cmake, see +[tensorflow/tools/android/inference_interface/cmake](cmake). + + +### Makefile + +For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md) + + +## AssetManagerFileSystem + +This directory also contains a TensorFlow filesystem supporting the Android +asset manager. This may be useful when writing native (C++) code that is tightly +coupled with TensorFlow. For typical usage, the library above will be +sufficient. diff --git a/tensorflow/tools/android/inference_interface/asset_manager_filesystem.cc b/tensorflow/tools/android/inference_interface/asset_manager_filesystem.cc new file mode 100644 index 00000000000..ee56f9affdf --- /dev/null +++ b/tensorflow/tools/android/inference_interface/asset_manager_filesystem.cc @@ -0,0 +1,272 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/android/inference_interface/asset_manager_filesystem.h" + +#include + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system_helper.h" + +namespace tensorflow { +namespace { + +string RemoveSuffix(const string& name, const string& suffix) { + string output(name); + StringPiece piece(output); + absl::ConsumeSuffix(&piece, suffix); + return string(piece); +} + +// Closes the given AAsset when variable is destructed. +class ScopedAsset { + public: + ScopedAsset(AAsset* asset) : asset_(asset) {} + ~ScopedAsset() { + if (asset_ != nullptr) { + AAsset_close(asset_); + } + } + + AAsset* get() const { return asset_; } + + private: + AAsset* asset_; +}; + +// Closes the given AAssetDir when variable is destructed. +class ScopedAssetDir { + public: + ScopedAssetDir(AAssetDir* asset_dir) : asset_dir_(asset_dir) {} + ~ScopedAssetDir() { + if (asset_dir_ != nullptr) { + AAssetDir_close(asset_dir_); + } + } + + AAssetDir* get() const { return asset_dir_; } + + private: + AAssetDir* asset_dir_; +}; + +class ReadOnlyMemoryRegionFromAsset : public ReadOnlyMemoryRegion { + public: + ReadOnlyMemoryRegionFromAsset(std::unique_ptr data, uint64 length) + : data_(std::move(data)), length_(length) {} + ~ReadOnlyMemoryRegionFromAsset() override = default; + + const void* data() override { return reinterpret_cast(data_.get()); } + uint64 length() override { return length_; } + + private: + std::unique_ptr data_; + uint64 length_; +}; + +// Note that AAssets are not thread-safe and cannot be used across threads. +// However, AAssetManager is. Because RandomAccessFile must be thread-safe and +// used across threads, new AAssets must be created for every access. +// TODO(tylerrhodes): is there a more efficient way to do this? +class RandomAccessFileFromAsset : public RandomAccessFile { + public: + RandomAccessFileFromAsset(AAssetManager* asset_manager, const string& name) + : asset_manager_(asset_manager), file_name_(name) {} + ~RandomAccessFileFromAsset() override = default; + + Status Read(uint64 offset, size_t to_read, StringPiece* result, + char* scratch) const override { + auto asset = ScopedAsset(AAssetManager_open( + asset_manager_, file_name_.c_str(), AASSET_MODE_RANDOM)); + if (asset.get() == nullptr) { + return errors::NotFound("File ", file_name_, " not found."); + } + + off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET); + off64_t length = AAsset_getLength64(asset.get()); + if (new_offset < 0) { + *result = StringPiece(scratch, 0); + return errors::OutOfRange("Read after file end."); + } + const off64_t region_left = + std::min(length - new_offset, static_cast(to_read)); + int read = AAsset_read(asset.get(), scratch, region_left); + if (read < 0) { + return errors::Internal("Error reading from asset."); + } + *result = StringPiece(scratch, region_left); + return (region_left == to_read) + ? Status::OK() + : errors::OutOfRange("Read less bytes than requested."); + } + + private: + AAssetManager* asset_manager_; + string file_name_; +}; + +} // namespace + +AssetManagerFileSystem::AssetManagerFileSystem(AAssetManager* asset_manager, + const string& prefix) + : asset_manager_(asset_manager), prefix_(prefix) {} + +Status AssetManagerFileSystem::FileExists(const string& fname) { + string path = RemoveAssetPrefix(fname); + auto asset = ScopedAsset( + AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM)); + if (asset.get() == nullptr) { + return errors::NotFound("File ", fname, " not found."); + } + return Status::OK(); +} + +Status AssetManagerFileSystem::NewRandomAccessFile( + const string& fname, std::unique_ptr* result) { + string path = RemoveAssetPrefix(fname); + auto asset = ScopedAsset( + AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM)); + if (asset.get() == nullptr) { + return errors::NotFound("File ", fname, " not found."); + } + result->reset(new RandomAccessFileFromAsset(asset_manager_, path)); + return Status::OK(); +} + +Status AssetManagerFileSystem::NewReadOnlyMemoryRegionFromFile( + const string& fname, std::unique_ptr* result) { + string path = RemoveAssetPrefix(fname); + auto asset = ScopedAsset( + AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_STREAMING)); + if (asset.get() == nullptr) { + return errors::NotFound("File ", fname, " not found."); + } + + off64_t start, length; + int fd = AAsset_openFileDescriptor64(asset.get(), &start, &length); + std::unique_ptr data; + if (fd >= 0) { + data.reset(new char[length]); + ssize_t result = pread(fd, data.get(), length, start); + if (result < 0) { + return errors::Internal("Error reading from file ", fname, + " using 'read': ", result); + } + if (result != length) { + return errors::Internal("Expected size does not match size read: ", + "Expected ", length, " vs. read ", result); + } + close(fd); + } else { + length = AAsset_getLength64(asset.get()); + data.reset(new char[length]); + const void* asset_buffer = AAsset_getBuffer(asset.get()); + if (asset_buffer == nullptr) { + return errors::Internal("Error reading ", fname, " from asset manager."); + } + memcpy(data.get(), asset_buffer, length); + } + result->reset(new ReadOnlyMemoryRegionFromAsset(std::move(data), length)); + return Status::OK(); +} + +Status AssetManagerFileSystem::GetChildren(const string& prefixed_dir, + std::vector* r) { + std::string path = NormalizeDirectoryPath(prefixed_dir); + auto dir = + ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str())); + if (dir.get() == nullptr) { + return errors::NotFound("Directory ", prefixed_dir, " not found."); + } + const char* next_file = AAssetDir_getNextFileName(dir.get()); + while (next_file != nullptr) { + r->push_back(next_file); + next_file = AAssetDir_getNextFileName(dir.get()); + } + return Status::OK(); +} + +Status AssetManagerFileSystem::GetFileSize(const string& fname, uint64* s) { + // If fname corresponds to a directory, return early. It doesn't map to an + // AAsset, and would otherwise return NotFound. + if (DirectoryExists(fname)) { + *s = 0; + return Status::OK(); + } + string path = RemoveAssetPrefix(fname); + auto asset = ScopedAsset( + AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM)); + if (asset.get() == nullptr) { + return errors::NotFound("File ", fname, " not found."); + } + *s = AAsset_getLength64(asset.get()); + return Status::OK(); +} + +Status AssetManagerFileSystem::Stat(const string& fname, FileStatistics* stat) { + uint64 size; + stat->is_directory = DirectoryExists(fname); + TF_RETURN_IF_ERROR(GetFileSize(fname, &size)); + stat->length = size; + return Status::OK(); +} + +string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) { + return RemoveSuffix(RemoveAssetPrefix(fname), "/"); +} + +string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) { + StringPiece piece(name); + absl::ConsumePrefix(&piece, prefix_); + return string(piece); +} + +bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) { + std::string path = NormalizeDirectoryPath(fname); + auto dir = + ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str())); + // Note that openDir will return something even if the directory doesn't + // exist. Therefore, we need to ensure one file exists in the folder. + return AAssetDir_getNextFileName(dir.get()) != NULL; +} + +Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern, + std::vector* results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + +Status AssetManagerFileSystem::NewWritableFile( + const string& fname, std::unique_ptr* result) { + return errors::Unimplemented("Asset storage is read only."); +} +Status AssetManagerFileSystem::NewAppendableFile( + const string& fname, std::unique_ptr* result) { + return errors::Unimplemented("Asset storage is read only."); +} +Status AssetManagerFileSystem::DeleteFile(const string& f) { + return errors::Unimplemented("Asset storage is read only."); +} +Status AssetManagerFileSystem::CreateDir(const string& d) { + return errors::Unimplemented("Asset storage is read only."); +} +Status AssetManagerFileSystem::DeleteDir(const string& d) { + return errors::Unimplemented("Asset storage is read only."); +} +Status AssetManagerFileSystem::RenameFile(const string& s, const string& t) { + return errors::Unimplemented("Asset storage is read only."); +} + +} // namespace tensorflow diff --git a/tensorflow/tools/android/inference_interface/asset_manager_filesystem.h b/tensorflow/tools/android/inference_interface/asset_manager_filesystem.h new file mode 100644 index 00000000000..a87ff42ae21 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/asset_manager_filesystem.h @@ -0,0 +1,85 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ +#define TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ + +#include +#include + +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { + +// FileSystem that uses Android's AAssetManager. Once initialized with a given +// AAssetManager, files in the given AAssetManager can be accessed through the +// prefix given when registered with the TensorFlow Env. +// Note that because APK assets are immutable, any operation that tries to +// modify the FileSystem will return tensorflow::error::code::UNIMPLEMENTED. +class AssetManagerFileSystem : public FileSystem { + public: + // Initialize an AssetManagerFileSystem. Note that this does not register the + // file system with TensorFlow. + // asset_manager - Non-null Android AAssetManager that backs this file + // system. The asset manager is not owned by this file system, and must + // outlive this class. + // prefix - Common prefix to strip from all file URIs before passing them to + // the asset_manager. This is required because TensorFlow gives the entire + // file URI (file:///my_dir/my_file.txt) and AssetManager only knows paths + // relative to its base directory. + AssetManagerFileSystem(AAssetManager* asset_manager, const string& prefix); + ~AssetManagerFileSystem() override = default; + + Status FileExists(const string& fname) override; + Status NewRandomAccessFile( + const string& filename, + std::unique_ptr* result) override; + Status NewReadOnlyMemoryRegionFromFile( + const string& filename, + std::unique_ptr* result) override; + + Status GetFileSize(const string& f, uint64* s) override; + // Currently just returns size. + Status Stat(const string& fname, FileStatistics* stat) override; + Status GetChildren(const string& dir, std::vector* r) override; + + // All these functions return Unimplemented error. Asset storage is + // read only. + Status NewWritableFile(const string& fname, + std::unique_ptr* result) override; + Status NewAppendableFile(const string& fname, + std::unique_ptr* result) override; + Status DeleteFile(const string& f) override; + Status CreateDir(const string& d) override; + Status DeleteDir(const string& d) override; + Status RenameFile(const string& s, const string& t) override; + + Status GetMatchingPaths(const string& pattern, + std::vector* results) override; + + private: + string RemoveAssetPrefix(const string& name); + + // Return a string path that can be passed into AAssetManager functions. + // For example, 'my_prefix://some/dir/' would return 'some/dir'. + string NormalizeDirectoryPath(const string& fname); + bool DirectoryExists(const std::string& fname); + + AAssetManager* asset_manager_; + string prefix_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ diff --git a/tensorflow/tools/android/inference_interface/cmake/CMakeLists.txt b/tensorflow/tools/android/inference_interface/cmake/CMakeLists.txt new file mode 100644 index 00000000000..ecf1a103d29 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/cmake/CMakeLists.txt @@ -0,0 +1,80 @@ +# +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +cmake_minimum_required(VERSION 3.4.1) +include(ExternalProject) + +# TENSORFLOW_ROOT_DIR: +# root directory of tensorflow repo +# used for shared source files and pre-built libs +get_filename_component(TENSORFLOW_ROOT_DIR ../../../.. ABSOLUTE) +set(PREBUILT_DIR ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen) + +add_library(lib_proto STATIC IMPORTED ) +set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION + ${PREBUILT_DIR}/protobuf/lib/libprotobuf.a) + +add_library(lib_nsync STATIC IMPORTED ) +set_target_properties(lib_nsync PROPERTIES IMPORTED_LOCATION + ${TARGET_NSYNC_LIB}/lib/libnsync.a) + +add_library(lib_tf STATIC IMPORTED ) +set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION + ${PREBUILT_DIR}/lib/libtensorflow-core.a) +# Change to compile flags should be replicated into bazel build file +# TODO: Consider options other than -O2 for binary size. +# e.g. -Os for gcc, and -Oz for clang. +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \ + -std=c++11 -fno-rtti -fno-exceptions \ + -O2 -Wno-narrowing -fomit-frame-pointer \ + -mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \ + -ftemplate-depth=900 \ + -DGOOGLE_PROTOBUF_NO_RTTI \ + -DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER") + +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \ + -Wl,--allow-multiple-definition \ + -Wl,--whole-archive \ + -fPIE -pie -v") +file(GLOB tensorflow_inference_sources + ${CMAKE_CURRENT_SOURCE_DIR}/../jni/*.cc) +file(GLOB java_api_native_sources + ${TENSORFLOW_ROOT_DIR}/tensorflow/java/src/main/native/*.cc) + +add_library(tensorflow_inference SHARED + ${tensorflow_inference_sources} + ${TENSORFLOW_ROOT_DIR}/tensorflow/c/tf_status_helper.cc + ${TENSORFLOW_ROOT_DIR}/tensorflow/c/checkpoint_reader.cc + ${TENSORFLOW_ROOT_DIR}/tensorflow/c/test_op.cc + ${TENSORFLOW_ROOT_DIR}/tensorflow/c/c_api.cc + ${java_api_native_sources}) + +# Include libraries needed for hello-jni lib +target_link_libraries(tensorflow_inference + android + dl + log + m + z + lib_tf + lib_proto + lib_nsync) + +include_directories( + ${PREBUILT_DIR}/proto + ${PREBUILT_DIR}/protobuf/include + ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen + ${TENSORFLOW_ROOT_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/tensorflow/tools/android/inference_interface/cmake/README.md b/tensorflow/tools/android/inference_interface/cmake/README.md new file mode 100644 index 00000000000..e04507dd2d1 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/cmake/README.md @@ -0,0 +1,48 @@ +TensorFlow-Android-Inference +============================ +This directory contains CMake support for building the Android Java Inference +interface to the TensorFlow native APIs. + +See [tensorflow/tools/android/inference_interface](..) for more details about +the library, and instructions for building with Bazel. + +Usage +----- +Add TensorFlow-Android-Inference as a dependency of your Android application + +* settings.gradle + +``` +include ':TensorFlow-Android-Inference' +findProject(":TensorFlow-Android-Inference").projectDir = + new File("${/path/to/tensorflow_repo}/examples/android_inference_interface/cmake") +``` + +* application's build.gradle (adding dependency): + +``` +debugCompile project(path: ':tensorflow_inference', configuration: 'debug') +releaseCompile project(path: ':tensorflow_inference', configuration: 'release') +``` +Note: this makes native code in the lib traceable from your app. + +Dependencies +------------ +TensorFlow-Android-Inference depends on the TensorFlow static libs already built +in your local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh +is used in build.gradle to build it. It DOES take time to build the core libs; +so, by default, it is commented out to avoid confusion (otherwise +Android Studio would appear to hang during opening the project). +To enable it, refer to the comment in + +* build.gradle + +Output +------ +- TensorFlow-Inference-debug.aar +- TensorFlow-Inference-release.aar + +File libtensorflow_inference.so should be packed under jni/${ANDROID_ABI}/ +in the above aar, and it is transparent to the app as it will access them via +equivalent java APIs. + diff --git a/tensorflow/tools/android/inference_interface/cmake/build.gradle b/tensorflow/tools/android/inference_interface/cmake/build.gradle new file mode 100644 index 00000000000..16940c3911f --- /dev/null +++ b/tensorflow/tools/android/inference_interface/cmake/build.gradle @@ -0,0 +1,105 @@ +apply plugin: 'com.android.library' + +// TensorFlow repo root dir on local machine +def TF_SRC_DIR = projectDir.toString() + "/../../../.." + +android { + compileSdkVersion 24 + // Check local build_tools_version as this is liable to change within Android Studio. + buildToolsVersion '25.0.2' + + // for debugging native code purpose + publishNonDefault true + + defaultConfig { + archivesBaseName = "Tensorflow-Android-Inference" + minSdkVersion 21 + targetSdkVersion 23 + versionCode 1 + versionName "1.0" + ndk { + abiFilters 'armeabi-v7a' + } + externalNativeBuild { + cmake { + arguments '-DANDROID_TOOLCHAIN=clang', + '-DANDROID_STL=c++_static' + } + } + } + sourceSets { + main { + java { + srcDir "${TF_SRC_DIR}/tensorflow/tools/android/inference_interface/java" + srcDir "${TF_SRC_DIR}/tensorflow/java/src/main/java" + exclude '**/examples/**' + } + } + } + + externalNativeBuild { + cmake { + path 'CMakeLists.txt' + } + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), + 'proguard-rules.pro' + } + } +} + +// Build libtensorflow-core.a if necessary +// Note: the environment needs to be set up already +// [ such as installing autoconfig, make, etc ] +// How to use: +// 1) install all of the necessary tools to build libtensorflow-core.a +// 2) inside Android Studio IDE, uncomment buildTensorFlow in +// whenTaskAdded{...} +// 3) re-sync and re-build. It could take a long time if NOT building +// with multiple processes. +import org.apache.tools.ant.taskdefs.condition.Os + +Properties properties = new Properties() +properties.load(project.rootProject.file('local.properties') + .newDataInputStream()) +def ndkDir = properties.getProperty('ndk.dir') +if (ndkDir == null || ndkDir == "") { + ndkDir = System.getenv('ANDROID_NDK_HOME') +} + +if (!Os.isFamily(Os.FAMILY_WINDOWS)) { + // This script is for non-Windows OS. For Windows OS, MANUALLY build + // (or copy the built) libs/headers to the + // ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen + // refer to CMakeLists.txt about lib and header directories for details + task buildTensorflow(type: Exec) { + group 'buildTensorflowLib' + workingDir getProjectDir().toString() + '/../../../../' + environment PATH: '/opt/local/bin:/opt/local/sbin:' + + System.getenv('PATH') + environment NDK_ROOT: ndkDir + commandLine 'tensorflow/contrib/makefile/build_all_android.sh' + } + + tasks.whenTaskAdded { task -> + group 'buildTensorflowLib' + if (task.name.toLowerCase().contains('sources')) { + def tensorflowTarget = new File(getProjectDir().toString() + + '/../../makefile/gen/lib/libtensorflow-core.a') + if (!tensorflowTarget.exists()) { + // Note: + // just uncomment this line to use it: + // it can take long time to build by default + // it is disabled to avoid false first impression + task.dependsOn buildTensorflow + } + } + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) +} diff --git a/tensorflow/tools/android/inference_interface/cmake/src/main/AndroidManifest.xml b/tensorflow/tools/android/inference_interface/cmake/src/main/AndroidManifest.xml new file mode 100644 index 00000000000..c17110a78be --- /dev/null +++ b/tensorflow/tools/android/inference_interface/cmake/src/main/AndroidManifest.xml @@ -0,0 +1,13 @@ + + + + + + + + + diff --git a/tensorflow/tools/android/inference_interface/cmake/src/main/res/values/strings.xml b/tensorflow/tools/android/inference_interface/cmake/src/main/res/values/strings.xml new file mode 100644 index 00000000000..92dc3a1baf0 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/cmake/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ + + TensorFlowInference + diff --git a/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/RunStats.java b/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/RunStats.java new file mode 100644 index 00000000000..39996f6ab03 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/RunStats.java @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.contrib.android; + +/** Accumulate and analyze stats from metadata obtained from Session.Runner.run. */ +public class RunStats implements AutoCloseable { + + /** + * Options to be provided to a {@link org.tensorflow.Session.Runner} to enable stats accumulation. + */ + public static byte[] runOptions() { + return fullTraceRunOptions; + } + + public RunStats() { + nativeHandle = allocate(); + } + + @Override + public void close() { + if (nativeHandle != 0) { + delete(nativeHandle); + } + nativeHandle = 0; + } + + /** Accumulate stats obtained when executing a graph. */ + public synchronized void add(byte[] runMetadata) { + add(nativeHandle, runMetadata); + } + + /** Summary of the accumulated runtime stats. */ + public synchronized String summary() { + return summary(nativeHandle); + } + + private long nativeHandle; + + // Hack: This is what a serialized RunOptions protocol buffer with trace_level: FULL_TRACE ends + // up as. + private static byte[] fullTraceRunOptions = new byte[] {0x08, 0x03}; + + private static native long allocate(); + + private static native void delete(long handle); + + private static native void add(long handle, byte[] runMetadata); + + private static native String summary(long handle); +} diff --git a/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java new file mode 100644 index 00000000000..abddadac5bc --- /dev/null +++ b/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -0,0 +1,650 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.contrib.android; + +import android.content.res.AssetManager; +import android.os.Build.VERSION; +import android.os.Trace; +import android.text.TextUtils; +import android.util.Log; +import java.io.ByteArrayOutputStream; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.util.ArrayList; +import java.util.List; +import org.tensorflow.Graph; +import org.tensorflow.Operation; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.TensorFlow; +import org.tensorflow.Tensors; +import org.tensorflow.types.UInt8; + +/** + * Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface + * for inference. + * + *

See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an + * example usage. + */ +public class TensorFlowInferenceInterface { + private static final String TAG = "TensorFlowInferenceInterface"; + private static final String ASSET_FILE_PREFIX = "file:///android_asset/"; + + /* + * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file. + * + * @param assetManager The AssetManager to use to load the model file. + * @param model The filepath to the GraphDef proto representing the model. + */ + public TensorFlowInferenceInterface(AssetManager assetManager, String model) { + prepareNativeRuntime(); + + this.modelName = model; + this.g = new Graph(); + this.sess = new Session(g); + this.runner = sess.runner(); + + final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX); + InputStream is = null; + try { + String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model; + is = assetManager.open(aname); + } catch (IOException e) { + if (hasAssetPrefix) { + throw new RuntimeException("Failed to load model from '" + model + "'", e); + } + // Perhaps the model file is not an asset but is on disk. + try { + is = new FileInputStream(model); + } catch (IOException e2) { + throw new RuntimeException("Failed to load model from '" + model + "'", e); + } + } + + try { + if (VERSION.SDK_INT >= 18) { + Trace.beginSection("initializeTensorFlow"); + Trace.beginSection("readGraphDef"); + } + + // TODO(ashankar): Can we somehow mmap the contents instead of copying them? + byte[] graphDef = new byte[is.available()]; + final int numBytesRead = is.read(graphDef); + if (numBytesRead != graphDef.length) { + throw new IOException( + "read error: read only " + + numBytesRead + + " of the graph, expected to read " + + graphDef.length); + } + + if (VERSION.SDK_INT >= 18) { + Trace.endSection(); // readGraphDef. + } + + loadGraph(graphDef, g); + is.close(); + Log.i(TAG, "Successfully loaded model from '" + model + "'"); + + if (VERSION.SDK_INT >= 18) { + Trace.endSection(); // initializeTensorFlow. + } + } catch (IOException e) { + throw new RuntimeException("Failed to load model from '" + model + "'", e); + } + } + + /* + * Load a TensorFlow model from provided InputStream. + * Note: The InputStream will not be closed after loading model, users need to + * close it themselves. + * + * @param is The InputStream to use to load the model. + */ + public TensorFlowInferenceInterface(InputStream is) { + prepareNativeRuntime(); + + // modelName is redundant for model loading from input stream, here is for + // avoiding error in initialization as modelName is marked final. + this.modelName = ""; + this.g = new Graph(); + this.sess = new Session(g); + this.runner = sess.runner(); + + try { + if (VERSION.SDK_INT >= 18) { + Trace.beginSection("initializeTensorFlow"); + Trace.beginSection("readGraphDef"); + } + + int baosInitSize = is.available() > 16384 ? is.available() : 16384; + ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize); + int numBytesRead; + byte[] buf = new byte[16384]; + while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) { + baos.write(buf, 0, numBytesRead); + } + byte[] graphDef = baos.toByteArray(); + + if (VERSION.SDK_INT >= 18) { + Trace.endSection(); // readGraphDef. + } + + loadGraph(graphDef, g); + Log.i(TAG, "Successfully loaded model from the input stream"); + + if (VERSION.SDK_INT >= 18) { + Trace.endSection(); // initializeTensorFlow. + } + } catch (IOException e) { + throw new RuntimeException("Failed to load model from the input stream", e); + } + } + + /* + * Construct a TensorFlowInferenceInterface with provided Graph + * + * @param g The Graph to use to construct this interface. + */ + public TensorFlowInferenceInterface(Graph g) { + prepareNativeRuntime(); + + // modelName is redundant here, here is for + // avoiding error in initialization as modelName is marked final. + this.modelName = ""; + this.g = g; + this.sess = new Session(g); + this.runner = sess.runner(); + } + + /** + * Runs inference between the previously registered input nodes (via feed*) and the requested + * output nodes. Output nodes can then be queried with the fetch* methods. + * + * @param outputNames A list of output nodes which should be filled by the inference pass. + */ + public void run(String[] outputNames) { + run(outputNames, false); + } + + /** + * Runs inference between the previously registered input nodes (via feed*) and the requested + * output nodes. Output nodes can then be queried with the fetch* methods. + * + * @param outputNames A list of output nodes which should be filled by the inference pass. + */ + public void run(String[] outputNames, boolean enableStats) { + run(outputNames, enableStats, new String[] {}); + } + + /** An overloaded version of runInference that allows supplying targetNodeNames as well */ + public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) { + // Release any Tensors from the previous run calls. + closeFetches(); + + // Add fetches. + for (String o : outputNames) { + fetchNames.add(o); + TensorId tid = TensorId.parse(o); + runner.fetch(tid.name, tid.outputIndex); + } + + // Add targets. + for (String t : targetNodeNames) { + runner.addTarget(t); + } + + // Run the session. + try { + if (enableStats) { + Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata(); + fetchTensors = r.outputs; + + if (runStats == null) { + runStats = new RunStats(); + } + runStats.add(r.metadata); + } else { + fetchTensors = runner.run(); + } + } catch (RuntimeException e) { + // Ideally the exception would have been let through, but since this interface predates the + // TensorFlow Java API, must return -1. + Log.e( + TAG, + "Failed to run TensorFlow inference with inputs:[" + + TextUtils.join(", ", feedNames) + + "], outputs:[" + + TextUtils.join(", ", fetchNames) + + "]"); + throw e; + } finally { + // Always release the feeds (to save resources) and reset the runner, this run is + // over. + closeFeeds(); + runner = sess.runner(); + } + } + + /** Returns a reference to the Graph describing the computation run during inference. */ + public Graph graph() { + return g; + } + + public Operation graphOperation(String operationName) { + final Operation operation = g.operation(operationName); + if (operation == null) { + throw new RuntimeException( + "Node '" + operationName + "' does not exist in model '" + modelName + "'"); + } + return operation; + } + + /** Returns the last stat summary string if logging is enabled. */ + public String getStatString() { + return (runStats == null) ? "" : runStats.summary(); + } + + /** + * Cleans up the state associated with this Object. + * + *

The TenosrFlowInferenceInterface object is no longer usable after this method returns. + */ + public void close() { + closeFeeds(); + closeFetches(); + sess.close(); + g.close(); + if (runStats != null) { + runStats.close(); + } + runStats = null; + } + + @Override + protected void finalize() throws Throwable { + try { + close(); + } finally { + super.finalize(); + } + } + + // Methods for taking a native Tensor and filling it with values from Java arrays. + + /** + * Given a source array with shape {@link dims} and content {@link src}, copy the contents into + * the input Tensor with name {@link inputName}. The source array {@link src} must have at least + * as many elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, boolean[] src, long... dims) { + byte[] b = new byte[src.length]; + + for (int i = 0; i < src.length; i++) { + b[i] = src[i] ? (byte) 1 : (byte) 0; + } + + addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b))); + } + + /** + * Given a source array with shape {@link dims} and content {@link src}, copy the contents into + * the input Tensor with name {@link inputName}. The source array {@link src} must have at least + * as many elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, float[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src))); + } + + /** + * Given a source array with shape {@link dims} and content {@link src}, copy the contents into + * the input Tensor with name {@link inputName}. The source array {@link src} must have at least + * as many elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, int[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src))); + } + + /** + * Given a source array with shape {@link dims} and content {@link src}, copy the contents into + * the input Tensor with name {@link inputName}. The source array {@link src} must have at least + * as many elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, long[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src))); + } + + /** + * Given a source array with shape {@link dims} and content {@link src}, copy the contents into + * the input Tensor with name {@link inputName}. The source array {@link src} must have at least + * as many elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, double[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src))); + } + + /** + * Given a source array with shape {@link dims} and content {@link src}, copy the contents into + * the input Tensor with name {@link inputName}. The source array {@link src} must have at least + * as many elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, byte[] src, long... dims) { + addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src))); + } + + /** + * Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued + * scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not + * a Java {@code String} (which is a sequence of characters). + */ + public void feedString(String inputName, byte[] src) { + addFeed(inputName, Tensors.create(src)); + } + + /** + * Copy an array of byte sequences into the input Tensor with name {@link inputName} as a + * string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an + * arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters). + */ + public void feedString(String inputName, byte[][] src) { + addFeed(inputName, Tensors.create(src)); + } + + // Methods for taking a native Tensor and filling it with src from Java native IO buffers. + + /** + * Given a source buffer with shape {@link dims} and content {@link src}, both stored as + * direct and native ordered java.nio buffers, copy the contents into the input + * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many + * elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, FloatBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); + } + + /** + * Given a source buffer with shape {@link dims} and content {@link src}, both stored as + * direct and native ordered java.nio buffers, copy the contents into the input + * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many + * elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, IntBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); + } + + /** + * Given a source buffer with shape {@link dims} and content {@link src}, both stored as + * direct and native ordered java.nio buffers, copy the contents into the input + * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many + * elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, LongBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); + } + + /** + * Given a source buffer with shape {@link dims} and content {@link src}, both stored as + * direct and native ordered java.nio buffers, copy the contents into the input + * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many + * elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, DoubleBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); + } + + /** + * Given a source buffer with shape {@link dims} and content {@link src}, both stored as + * direct and native ordered java.nio buffers, copy the contents into the input + * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many + * elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, ByteBuffer src, long... dims) { + addFeed(inputName, Tensor.create(UInt8.class, dims, src)); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link + * dst} must have length greater than or equal to that of the source Tensor. This operation will + * not affect dst's content past the source Tensor's size. + */ + public void fetch(String outputName, float[] dst) { + fetch(outputName, FloatBuffer.wrap(dst)); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link + * dst} must have length greater than or equal to that of the source Tensor. This operation will + * not affect dst's content past the source Tensor's size. + */ + public void fetch(String outputName, int[] dst) { + fetch(outputName, IntBuffer.wrap(dst)); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link + * dst} must have length greater than or equal to that of the source Tensor. This operation will + * not affect dst's content past the source Tensor's size. + */ + public void fetch(String outputName, long[] dst) { + fetch(outputName, LongBuffer.wrap(dst)); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link + * dst} must have length greater than or equal to that of the source Tensor. This operation will + * not affect dst's content past the source Tensor's size. + */ + public void fetch(String outputName, double[] dst) { + fetch(outputName, DoubleBuffer.wrap(dst)); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link + * dst} must have length greater than or equal to that of the source Tensor. This operation will + * not affect dst's content past the source Tensor's size. + */ + public void fetch(String outputName, byte[] dst) { + fetch(outputName, ByteBuffer.wrap(dst)); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into the direct and + * native ordered java.nio buffer {@link dst}. {@link dst} must have capacity greater than + * or equal to that of the source Tensor. This operation will not affect dst's content past the + * source Tensor's size. + */ + public void fetch(String outputName, FloatBuffer dst) { + getTensor(outputName).writeTo(dst); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into the direct and + * native ordered java.nio buffer {@link dst}. {@link dst} must have capacity greater than + * or equal to that of the source Tensor. This operation will not affect dst's content past the + * source Tensor's size. + */ + public void fetch(String outputName, IntBuffer dst) { + getTensor(outputName).writeTo(dst); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into the direct and + * native ordered java.nio buffer {@link dst}. {@link dst} must have capacity greater than + * or equal to that of the source Tensor. This operation will not affect dst's content past the + * source Tensor's size. + */ + public void fetch(String outputName, LongBuffer dst) { + getTensor(outputName).writeTo(dst); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into the direct and + * native ordered java.nio buffer {@link dst}. {@link dst} must have capacity greater than + * or equal to that of the source Tensor. This operation will not affect dst's content past the + * source Tensor's size. + */ + public void fetch(String outputName, DoubleBuffer dst) { + getTensor(outputName).writeTo(dst); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into the direct and + * native ordered java.nio buffer {@link dst}. {@link dst} must have capacity greater than + * or equal to that of the source Tensor. This operation will not affect dst's content past the + * source Tensor's size. + */ + public void fetch(String outputName, ByteBuffer dst) { + getTensor(outputName).writeTo(dst); + } + + private void prepareNativeRuntime() { + Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded"); + try { + // Hack to see if the native libraries have been loaded. + new RunStats(); + Log.i(TAG, "TensorFlow native methods already loaded"); + } catch (UnsatisfiedLinkError e1) { + Log.i( + TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference"); + try { + System.loadLibrary("tensorflow_inference"); + Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)"); + } catch (UnsatisfiedLinkError e2) { + throw new RuntimeException( + "Native TF methods not found; check that the correct native" + + " libraries are present in the APK."); + } + } + } + + private void loadGraph(byte[] graphDef, Graph g) throws IOException { + final long startMs = System.currentTimeMillis(); + + if (VERSION.SDK_INT >= 18) { + Trace.beginSection("importGraphDef"); + } + + try { + g.importGraphDef(graphDef); + } catch (IllegalArgumentException e) { + throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage()); + } + + if (VERSION.SDK_INT >= 18) { + Trace.endSection(); // importGraphDef. + } + + final long endMs = System.currentTimeMillis(); + Log.i( + TAG, + "Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version()); + } + + private void addFeed(String inputName, Tensor t) { + // The string format accepted by TensorFlowInferenceInterface is node_name[:output_index]. + TensorId tid = TensorId.parse(inputName); + runner.feed(tid.name, tid.outputIndex, t); + feedNames.add(inputName); + feedTensors.add(t); + } + + private static class TensorId { + String name; + int outputIndex; + + // Parse output names into a TensorId. + // + // E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1) + public static TensorId parse(String name) { + TensorId tid = new TensorId(); + int colonIndex = name.lastIndexOf(':'); + if (colonIndex < 0) { + tid.outputIndex = 0; + tid.name = name; + return tid; + } + try { + tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1)); + tid.name = name.substring(0, colonIndex); + } catch (NumberFormatException e) { + tid.outputIndex = 0; + tid.name = name; + } + return tid; + } + } + + private Tensor getTensor(String outputName) { + int i = 0; + for (String n : fetchNames) { + if (n.equals(outputName)) { + return fetchTensors.get(i); + } + ++i; + } + throw new RuntimeException( + "Node '" + outputName + "' was not provided to run(), so it cannot be read"); + } + + private void closeFeeds() { + for (Tensor t : feedTensors) { + t.close(); + } + feedTensors.clear(); + feedNames.clear(); + } + + private void closeFetches() { + for (Tensor t : fetchTensors) { + t.close(); + } + fetchTensors.clear(); + fetchNames.clear(); + } + + // Immutable state. + private final String modelName; + private final Graph g; + private final Session sess; + + // State reset on every call to run. + private Session.Runner runner; + private List feedNames = new ArrayList(); + private List> feedTensors = new ArrayList>(); + private List fetchNames = new ArrayList(); + private List> fetchTensors = new ArrayList>(); + + // Mutable state. + private RunStats runStats; +} diff --git a/tensorflow/tools/android/inference_interface/jni/run_stats_jni.cc b/tensorflow/tools/android/inference_interface/jni/run_stats_jni.cc new file mode 100644 index 00000000000..01b6196cc01 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/jni/run_stats_jni.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/android/inference_interface/jni/run_stats_jni.h" + +#include + +#include + +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/util/stat_summarizer.h" + +using tensorflow::RunMetadata; +using tensorflow::StatSummarizer; + +namespace { +StatSummarizer* requireHandle(JNIEnv* env, jlong handle) { + if (handle == 0) { + env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), + "close() has been called on the RunStats object"); + return nullptr; + } + return reinterpret_cast(handle); +} +} // namespace + +#define RUN_STATS_METHOD(name) \ + JNICALL Java_org_tensorflow_contrib_android_RunStats_##name + +JNIEXPORT jlong RUN_STATS_METHOD(allocate)(JNIEnv* env, jclass clazz) { + static_assert(sizeof(jlong) >= sizeof(StatSummarizer*), + "Cannot package C++ object pointers as a Java long"); + tensorflow::StatSummarizerOptions opts; + return reinterpret_cast(new StatSummarizer(opts)); +} + +JNIEXPORT void RUN_STATS_METHOD(delete)(JNIEnv* env, jclass clazz, + jlong handle) { + if (handle == 0) return; + delete reinterpret_cast(handle); +} + +JNIEXPORT void RUN_STATS_METHOD(add)(JNIEnv* env, jclass clazz, jlong handle, + jbyteArray run_metadata) { + StatSummarizer* s = requireHandle(env, handle); + if (s == nullptr) return; + jbyte* data = env->GetByteArrayElements(run_metadata, nullptr); + int size = static_cast(env->GetArrayLength(run_metadata)); + tensorflow::RunMetadata proto; + if (!proto.ParseFromArray(data, size)) { + env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), + "runMetadata does not seem to be a serialized RunMetadata " + "protocol message"); + } else if (proto.has_step_stats()) { + s->ProcessStepStats(proto.step_stats()); + } + env->ReleaseByteArrayElements(run_metadata, data, JNI_ABORT); +} + +JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz, + jlong handle) { + StatSummarizer* s = requireHandle(env, handle); + if (s == nullptr) return nullptr; + std::stringstream ret; + ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME, + 10) + << s->GetStatsByNodeType() << s->ShortSummary(); + return env->NewStringUTF(ret.str().c_str()); +} + +#undef RUN_STATS_METHOD diff --git a/tensorflow/tools/android/inference_interface/jni/run_stats_jni.h b/tensorflow/tools/android/inference_interface/jni/run_stats_jni.h new file mode 100644 index 00000000000..de3bceff0a1 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/jni/run_stats_jni.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_ +#define ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define RUN_STATS_METHOD(name) \ + Java_org_tensorflow_contrib_android_RunStats_##name + +JNIEXPORT JNICALL jlong RUN_STATS_METHOD(allocate)(JNIEnv*, jclass); +JNIEXPORT JNICALL void RUN_STATS_METHOD(delete)(JNIEnv*, jclass, jlong); +JNIEXPORT JNICALL void RUN_STATS_METHOD(add)(JNIEnv*, jclass, jlong, + jbyteArray); +JNIEXPORT JNICALL jstring RUN_STATS_METHOD(summary)(JNIEnv*, jclass, jlong); + +#undef RUN_STATS_METHOD + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_ diff --git a/tensorflow/tools/android/inference_interface/jni/version_script.lds b/tensorflow/tools/android/inference_interface/jni/version_script.lds new file mode 100644 index 00000000000..38c93dda730 --- /dev/null +++ b/tensorflow/tools/android/inference_interface/jni/version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + # Export JNI symbols. + global: + Java_*; + JNI_OnLoad; + JNI_OnUnload; + + # Hide everything else. + local: + *; +};