Change the TF Lite Java API to use the shim layer (attempt #2).

PiperOrigin-RevId: 347061214
Change-Id: Iaf3e9cb423cc6b495a83e3c17b041199d14063e3
This commit is contained in:
Fergus Henderson 2020-12-11 13:15:36 -08:00 committed by TensorFlower Gardener
parent df3f233362
commit 627a032dc2
6 changed files with 56 additions and 51 deletions

View File

@ -719,7 +719,7 @@ cc_library(
deps = [
"//tensorflow/lite:op_resolver",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/core/shims:builtin_ops",
],
)

View File

@ -112,6 +112,10 @@ cc_library(
"//tensorflow/lite/kernels:fully_connected.h",
],
compatible_with = get_compatible_with_portable(),
visibility = [
"//tensorflow/lite:__subpackages__",
"//tensorflow_lite_support:__subpackages__",
],
deps = [
"//tensorflow/lite:cc_api",
"//tensorflow/lite/c:common",

View File

@ -15,8 +15,8 @@ limitations under the License.
#include <memory>
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
#include "tensorflow/lite/create_op_resolver.h"
#include "tensorflow/lite/kernels/register.h"
namespace tflite {

View File

@ -26,11 +26,12 @@ cc_library(
"-ldl",
],
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite:op_resolver",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",
"//tensorflow/lite:util",
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/shims:common",
"//tensorflow/lite/core/shims:framework",
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
"//tensorflow/lite/java/jni",
],
@ -45,7 +46,9 @@ cc_library(
deps = [
":native_framework_only",
"//tensorflow/lite:create_op_resolver_with_builtin_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/core/api",
"//tensorflow/lite/core/shims:builtin_ops",
"//tensorflow/lite/core/shims:framework",
],
alwayslink = 1,
)

View File

@ -21,12 +21,12 @@ limitations under the License.
#include <atomic>
#include <vector>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/shims/c/common.h"
#include "tensorflow/lite/core/shims/cc/interpreter.h"
#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
#include "tensorflow/lite/core/shims/cc/model_builder.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/java/src/main/native/jni_utils.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/util.h"
namespace tflite {
@ -36,25 +36,28 @@ extern std::unique_ptr<OpResolver> CreateOpResolver();
using tflite::jni::BufferErrorReporter;
using tflite::jni::ThrowException;
using tflite_shims::FlatBufferModel;
using tflite_shims::Interpreter;
using tflite_shims::InterpreterBuilder;
namespace {
tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
if (handle == 0) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: Invalid handle to Interpreter.");
return nullptr;
}
return reinterpret_cast<tflite::Interpreter*>(handle);
return reinterpret_cast<Interpreter*>(handle);
}
tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
if (handle == 0) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: Invalid handle to model.");
return nullptr;
}
return reinterpret_cast<tflite::FlatBufferModel*>(handle);
return reinterpret_cast<FlatBufferModel*>(handle);
}
BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
@ -163,7 +166,7 @@ JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
jclass clazz,
jlong handle) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
jclass string_class = env->FindClass("java/lang/String");
if (string_class == nullptr) {
@ -185,7 +188,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
@ -203,7 +206,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
JNIEXPORT jboolean JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
JNIEnv* env, jclass clazz, jlong handle) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return JNI_FALSE;
// TODO(b/132995737): Remove this logic by caching whether an unresolved
@ -226,7 +229,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
JNIEnv* env, jclass clazz, jlong handle, jint input_index) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return interpreter->inputs()[input_index];
}
@ -234,7 +237,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
JNIEnv* env, jclass clazz, jlong handle, jint output_index) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return interpreter->outputs()[output_index];
}
@ -242,7 +245,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength(
JNIEnv* env, jclass clazz, jlong handle) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return static_cast<jint>(interpreter->execution_plan().size());
}
@ -251,7 +254,7 @@ JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
jclass clazz,
jlong handle) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return static_cast<jint>(interpreter->inputs().size());
}
@ -260,7 +263,7 @@ JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
jclass clazz,
jlong handle) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return static_cast<jint>(interpreter->outputs().size());
}
@ -269,7 +272,7 @@ JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
jclass clazz,
jlong handle) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
jclass string_class = env->FindClass("java/lang/String");
if (string_class == nullptr) {
@ -291,7 +294,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
}
@ -299,7 +302,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetAllowBufferHandleOutput(allow);
}
@ -313,7 +316,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
return;
}
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) {
return;
}
@ -341,8 +344,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
if (num_threads > 0) {
options.num_threads = num_threads;
}
tflite::Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options),
xnnpack_delete);
Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options),
xnnpack_delete);
auto delegation_status =
interpreter->ModifyGraphWithDelegate(std::move(delegate));
// kTfLiteApplicationError occurs in cases where delegation fails but
@ -374,7 +377,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
jclass clazz,
jlong handle,
jint num_threads) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetNumThreads(static_cast<int>(num_threads));
}
@ -411,8 +414,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
std::unique_ptr<tflite::TfLiteVerifier> verifier;
verifier.reset(new JNIFlatBufferVerifier());
auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
path, verifier.get(), error_reporter);
auto model = FlatBufferModel::VerifyAndBuildFromFile(path, verifier.get(),
error_reporter);
if (!model) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Contents of %s does not encode a valid "
@ -440,7 +443,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
return 0;
}
auto model = tflite::FlatBufferModel::BuildFromBuffer(
auto model = FlatBufferModel::BuildFromBuffer(
buf, static_cast<size_t>(capacity), error_reporter);
if (!model) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
@ -455,14 +458,14 @@ JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
jint num_threads) {
tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
FlatBufferModel* model = convertLongToModel(env, model_handle);
if (model == nullptr) return 0;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return 0;
auto resolver = ::tflite::CreateOpResolver();
std::unique_ptr<tflite::Interpreter> interpreter;
TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))(
std::unique_ptr<Interpreter> interpreter;
TfLiteStatus status = InterpreterBuilder(*model, *(resolver.get()))(
&interpreter, static_cast<int>(num_threads));
if (status != kTfLiteOk) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
@ -478,8 +481,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
// Sets inputs, runs inference, and returns outputs as long handles.
JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
@ -497,7 +499,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return -1;
const int idx = static_cast<int>(output_idx);
if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
@ -518,8 +520,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return JNI_FALSE;
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) return JNI_FALSE;
if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
@ -555,8 +556,7 @@ JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
jlong delegate_handle) {
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
@ -577,8 +577,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
@ -596,8 +595,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag(
JNIEnv* env, jclass clazz, jlong interpreter_handle) {
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
if (interpreter == nullptr) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: Invalid handle to interpreter.");

View File

@ -19,12 +19,13 @@ limitations under the License.
#include <memory>
#include <string>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/core/shims/c/common.h"
#include "tensorflow/lite/core/shims/cc/interpreter.h"
#include "tensorflow/lite/java/src/main/native/jni_utils.h"
#include "tensorflow/lite/string_util.h"
using tflite::jni::ThrowException;
using tflite_shims::Interpreter;
namespace {
@ -39,14 +40,14 @@ static const char* kStringClassPath = "java/lang/String";
// invalidate all TfLiteTensor* handles during inference or allocation.
class TensorHandle {
public:
TensorHandle(tflite::Interpreter* interpreter, int tensor_index)
TensorHandle(Interpreter* interpreter, int tensor_index)
: interpreter_(interpreter), tensor_index_(tensor_index) {}
TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); }
int index() const { return tensor_index_; }
private:
tflite::Interpreter* const interpreter_;
Interpreter* const interpreter_;
const int tensor_index_;
};
@ -396,8 +397,7 @@ extern "C" {
JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) {
tflite::Interpreter* interpreter =
reinterpret_cast<tflite::Interpreter*>(interpreter_handle);
Interpreter* interpreter = reinterpret_cast<Interpreter*>(interpreter_handle);
return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index));
}