Change the TF Lite Java API to use the shim layer (attempt #2).
PiperOrigin-RevId: 347061214 Change-Id: Iaf3e9cb423cc6b495a83e3c17b041199d14063e3
This commit is contained in:
parent
df3f233362
commit
627a032dc2
@ -719,7 +719,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/lite:op_resolver",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/core/shims:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -112,6 +112,10 @@ cc_library(
|
||||
"//tensorflow/lite/kernels:fully_connected.h",
|
||||
],
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
visibility = [
|
||||
"//tensorflow/lite:__subpackages__",
|
||||
"//tensorflow_lite_support:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite:cc_api",
|
||||
"//tensorflow/lite/c:common",
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
||||
#include "tensorflow/lite/create_op_resolver.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
|
@ -26,11 +26,12 @@ cc_library(
|
||||
"-ldl",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:op_resolver",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/core/shims:common",
|
||||
"//tensorflow/lite/core/shims:framework",
|
||||
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
|
||||
"//tensorflow/lite/java/jni",
|
||||
],
|
||||
@ -45,7 +46,9 @@ cc_library(
|
||||
deps = [
|
||||
":native_framework_only",
|
||||
"//tensorflow/lite:create_op_resolver_with_builtin_ops",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/core/shims:builtin_ops",
|
||||
"//tensorflow/lite/core/shims:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -21,12 +21,12 @@ limitations under the License.
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/shims/c/common.h"
|
||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
||||
#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
|
||||
#include "tensorflow/lite/core/shims/cc/model_builder.h"
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/java/src/main/native/jni_utils.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -36,25 +36,28 @@ extern std::unique_ptr<OpResolver> CreateOpResolver();
|
||||
|
||||
using tflite::jni::BufferErrorReporter;
|
||||
using tflite::jni::ThrowException;
|
||||
using tflite_shims::FlatBufferModel;
|
||||
using tflite_shims::Interpreter;
|
||||
using tflite_shims::InterpreterBuilder;
|
||||
|
||||
namespace {
|
||||
|
||||
tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
|
||||
Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
"Internal error: Invalid handle to Interpreter.");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<tflite::Interpreter*>(handle);
|
||||
return reinterpret_cast<Interpreter*>(handle);
|
||||
}
|
||||
|
||||
tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
|
||||
FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
"Internal error: Invalid handle to model.");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<tflite::FlatBufferModel*>(handle);
|
||||
return reinterpret_cast<FlatBufferModel*>(handle);
|
||||
}
|
||||
|
||||
BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
|
||||
@ -163,7 +166,7 @@ JNIEXPORT jobjectArray JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return nullptr;
|
||||
jclass string_class = env->FindClass("java/lang/String");
|
||||
if (string_class == nullptr) {
|
||||
@ -185,7 +188,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return;
|
||||
BufferErrorReporter* error_reporter =
|
||||
convertLongToErrorReporter(env, error_handle);
|
||||
@ -203,7 +206,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return JNI_FALSE;
|
||||
|
||||
// TODO(b/132995737): Remove this logic by caching whether an unresolved
|
||||
@ -226,7 +229,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jint input_index) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return 0;
|
||||
return interpreter->inputs()[input_index];
|
||||
}
|
||||
@ -234,7 +237,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jint output_index) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return 0;
|
||||
return interpreter->outputs()[output_index];
|
||||
}
|
||||
@ -242,7 +245,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return 0;
|
||||
return static_cast<jint>(interpreter->execution_plan().size());
|
||||
}
|
||||
@ -251,7 +254,7 @@ JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return 0;
|
||||
return static_cast<jint>(interpreter->inputs().size());
|
||||
}
|
||||
@ -260,7 +263,7 @@ JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return 0;
|
||||
return static_cast<jint>(interpreter->outputs().size());
|
||||
}
|
||||
@ -269,7 +272,7 @@ JNIEXPORT jobjectArray JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return nullptr;
|
||||
jclass string_class = env->FindClass("java/lang/String");
|
||||
if (string_class == nullptr) {
|
||||
@ -291,7 +294,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return;
|
||||
interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
|
||||
}
|
||||
@ -299,7 +302,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return;
|
||||
interpreter->SetAllowBufferHandleOutput(allow);
|
||||
}
|
||||
@ -313,7 +316,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
|
||||
return;
|
||||
}
|
||||
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) {
|
||||
return;
|
||||
}
|
||||
@ -341,8 +344,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
|
||||
if (num_threads > 0) {
|
||||
options.num_threads = num_threads;
|
||||
}
|
||||
tflite::Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options),
|
||||
xnnpack_delete);
|
||||
Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options),
|
||||
xnnpack_delete);
|
||||
auto delegation_status =
|
||||
interpreter->ModifyGraphWithDelegate(std::move(delegate));
|
||||
// kTfLiteApplicationError occurs in cases where delegation fails but
|
||||
@ -374,7 +377,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle,
|
||||
jint num_threads) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return;
|
||||
interpreter->SetNumThreads(static_cast<int>(num_threads));
|
||||
}
|
||||
@ -411,8 +414,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
|
||||
std::unique_ptr<tflite::TfLiteVerifier> verifier;
|
||||
verifier.reset(new JNIFlatBufferVerifier());
|
||||
|
||||
auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
|
||||
path, verifier.get(), error_reporter);
|
||||
auto model = FlatBufferModel::VerifyAndBuildFromFile(path, verifier.get(),
|
||||
error_reporter);
|
||||
if (!model) {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
"Contents of %s does not encode a valid "
|
||||
@ -440,7 +443,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto model = tflite::FlatBufferModel::BuildFromBuffer(
|
||||
auto model = FlatBufferModel::BuildFromBuffer(
|
||||
buf, static_cast<size_t>(capacity), error_reporter);
|
||||
if (!model) {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
@ -455,14 +458,14 @@ JNIEXPORT jlong JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
|
||||
JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
|
||||
jint num_threads) {
|
||||
tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
|
||||
FlatBufferModel* model = convertLongToModel(env, model_handle);
|
||||
if (model == nullptr) return 0;
|
||||
BufferErrorReporter* error_reporter =
|
||||
convertLongToErrorReporter(env, error_handle);
|
||||
if (error_reporter == nullptr) return 0;
|
||||
auto resolver = ::tflite::CreateOpResolver();
|
||||
std::unique_ptr<tflite::Interpreter> interpreter;
|
||||
TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))(
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
TfLiteStatus status = InterpreterBuilder(*model, *(resolver.get()))(
|
||||
&interpreter, static_cast<int>(num_threads));
|
||||
if (status != kTfLiteOk) {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
@ -478,8 +481,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
|
||||
// Sets inputs, runs inference, and returns outputs as long handles.
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
|
||||
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
|
||||
tflite::Interpreter* interpreter =
|
||||
convertLongToInterpreter(env, interpreter_handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
|
||||
if (interpreter == nullptr) return;
|
||||
BufferErrorReporter* error_reporter =
|
||||
convertLongToErrorReporter(env, error_handle);
|
||||
@ -497,7 +499,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return -1;
|
||||
const int idx = static_cast<int>(output_idx);
|
||||
if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
|
||||
@ -518,8 +520,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
|
||||
BufferErrorReporter* error_reporter =
|
||||
convertLongToErrorReporter(env, error_handle);
|
||||
if (error_reporter == nullptr) return JNI_FALSE;
|
||||
tflite::Interpreter* interpreter =
|
||||
convertLongToInterpreter(env, interpreter_handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
|
||||
if (interpreter == nullptr) return JNI_FALSE;
|
||||
if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
@ -555,8 +556,7 @@ JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
|
||||
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
|
||||
jlong delegate_handle) {
|
||||
tflite::Interpreter* interpreter =
|
||||
convertLongToInterpreter(env, interpreter_handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
|
||||
if (interpreter == nullptr) return;
|
||||
|
||||
BufferErrorReporter* error_reporter =
|
||||
@ -577,8 +577,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
|
||||
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
|
||||
tflite::Interpreter* interpreter =
|
||||
convertLongToInterpreter(env, interpreter_handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
|
||||
if (interpreter == nullptr) return;
|
||||
|
||||
BufferErrorReporter* error_reporter =
|
||||
@ -596,8 +595,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag(
|
||||
JNIEnv* env, jclass clazz, jlong interpreter_handle) {
|
||||
tflite::Interpreter* interpreter =
|
||||
convertLongToInterpreter(env, interpreter_handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
|
||||
if (interpreter == nullptr) {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
"Internal error: Invalid handle to interpreter.");
|
||||
|
@ -19,12 +19,13 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/core/shims/c/common.h"
|
||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
||||
#include "tensorflow/lite/java/src/main/native/jni_utils.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
using tflite::jni::ThrowException;
|
||||
using tflite_shims::Interpreter;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -39,14 +40,14 @@ static const char* kStringClassPath = "java/lang/String";
|
||||
// invalidate all TfLiteTensor* handles during inference or allocation.
|
||||
class TensorHandle {
|
||||
public:
|
||||
TensorHandle(tflite::Interpreter* interpreter, int tensor_index)
|
||||
TensorHandle(Interpreter* interpreter, int tensor_index)
|
||||
: interpreter_(interpreter), tensor_index_(tensor_index) {}
|
||||
|
||||
TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); }
|
||||
int index() const { return tensor_index_; }
|
||||
|
||||
private:
|
||||
tflite::Interpreter* const interpreter_;
|
||||
Interpreter* const interpreter_;
|
||||
const int tensor_index_;
|
||||
};
|
||||
|
||||
@ -396,8 +397,7 @@ extern "C" {
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create(
|
||||
JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) {
|
||||
tflite::Interpreter* interpreter =
|
||||
reinterpret_cast<tflite::Interpreter*>(interpreter_handle);
|
||||
Interpreter* interpreter = reinterpret_cast<Interpreter*>(interpreter_handle);
|
||||
return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user