From 627a032dc2c9f5e18b4264ed5717a97b7f981df0 Mon Sep 17 00:00:00 2001 From: Fergus Henderson <fergus@google.com> Date: Fri, 11 Dec 2020 13:15:36 -0800 Subject: [PATCH] Change the TF Lite Java API to use the shim layer (attempt #2). PiperOrigin-RevId: 347061214 Change-Id: Iaf3e9cb423cc6b495a83e3c17b041199d14063e3 --- tensorflow/lite/BUILD | 2 +- tensorflow/lite/core/shims/BUILD | 4 + .../create_op_resolver_with_builtin_ops.cc | 2 +- tensorflow/lite/java/src/main/native/BUILD | 9 ++- .../native/nativeinterpreterwrapper_jni.cc | 78 +++++++++---------- .../lite/java/src/main/native/tensor_jni.cc | 12 +-- 6 files changed, 56 insertions(+), 51 deletions(-) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 295c071ba20..35b72eeaa11 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -719,7 +719,7 @@ cc_library( deps = [ "//tensorflow/lite:op_resolver", "//tensorflow/lite/core/api", - "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/core/shims:builtin_ops", ], ) diff --git a/tensorflow/lite/core/shims/BUILD b/tensorflow/lite/core/shims/BUILD index aa250597447..e09573db872 100644 --- a/tensorflow/lite/core/shims/BUILD +++ b/tensorflow/lite/core/shims/BUILD @@ -112,6 +112,10 @@ cc_library( "//tensorflow/lite/kernels:fully_connected.h", ], compatible_with = get_compatible_with_portable(), + visibility = [ + "//tensorflow/lite:__subpackages__", + "//tensorflow_lite_support:__subpackages__", + ], deps = [ "//tensorflow/lite:cc_api", "//tensorflow/lite/c:common", diff --git a/tensorflow/lite/create_op_resolver_with_builtin_ops.cc b/tensorflow/lite/create_op_resolver_with_builtin_ops.cc index 5801fad369b..ff3c583b0e1 100644 --- a/tensorflow/lite/create_op_resolver_with_builtin_ops.cc +++ b/tensorflow/lite/create_op_resolver_with_builtin_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #include <memory> +#include "tensorflow/lite/core/shims/cc/kernels/register.h" #include "tensorflow/lite/create_op_resolver.h" -#include "tensorflow/lite/kernels/register.h" namespace tflite { diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD index fa3195062e6..9dc00105900 100644 --- a/tensorflow/lite/java/src/main/native/BUILD +++ b/tensorflow/lite/java/src/main/native/BUILD @@ -26,11 +26,12 @@ cc_library( "-ldl", ], deps = [ - "//tensorflow/lite:framework", + "//tensorflow/lite:op_resolver", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string_util", "//tensorflow/lite:util", - "//tensorflow/lite/c:common", + "//tensorflow/lite/core/shims:common", + "//tensorflow/lite/core/shims:framework", "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only", "//tensorflow/lite/java/jni", ], @@ -45,7 +46,9 @@ cc_library( deps = [ ":native_framework_only", "//tensorflow/lite:create_op_resolver_with_builtin_ops", - "//tensorflow/lite:framework", + "//tensorflow/lite/core/api", + "//tensorflow/lite/core/shims:builtin_ops", + "//tensorflow/lite/core/shims:framework", ], alwayslink = 1, ) diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index fdfab0bd078..ce0be597cdd 100644 --- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -21,12 +21,12 @@ limitations under the License. #include <atomic> #include <vector> -#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/shims/c/common.h" +#include "tensorflow/lite/core/shims/cc/interpreter.h" +#include "tensorflow/lite/core/shims/cc/interpreter_builder.h" +#include "tensorflow/lite/core/shims/cc/model_builder.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/java/src/main/native/jni_utils.h" -#include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/util.h" namespace tflite { @@ -36,25 +36,28 @@ extern std::unique_ptr<OpResolver> CreateOpResolver(); using tflite::jni::BufferErrorReporter; using tflite::jni::ThrowException; +using tflite_shims::FlatBufferModel; +using tflite_shims::Interpreter; +using tflite_shims::InterpreterBuilder; namespace { -tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { +Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { if (handle == 0) { ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to Interpreter."); return nullptr; } - return reinterpret_cast<tflite::Interpreter*>(handle); + return reinterpret_cast<Interpreter*>(handle); } -tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) { +FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) { if (handle == 0) { ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to model."); return nullptr; } - return reinterpret_cast<tflite::FlatBufferModel*>(handle); + return reinterpret_cast<FlatBufferModel*>(handle); } BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) { @@ -163,7 +166,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, jclass clazz, jlong handle) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return nullptr; jclass string_class = env->FindClass("java/lang/String"); if (string_class == nullptr) { @@ -185,7 +188,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); @@ -203,7 +206,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( JNIEXPORT jboolean JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp( JNIEnv* env, jclass clazz, jlong handle) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return JNI_FALSE; // TODO(b/132995737): Remove this logic by caching whether an unresolved @@ -226,7 +229,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp( JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex( JNIEnv* env, jclass clazz, jlong handle, jint input_index) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return interpreter->inputs()[input_index]; } @@ -234,7 +237,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex( JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex( JNIEnv* env, jclass clazz, jlong handle, jint output_index) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return interpreter->outputs()[output_index]; } @@ -242,7 +245,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex( JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength( JNIEnv* env, jclass clazz, jlong handle) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return static_cast<jint>(interpreter->execution_plan().size()); } @@ -251,7 +254,7 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, jclass clazz, jlong handle) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return static_cast<jint>(interpreter->inputs().size()); } @@ -260,7 +263,7 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env, jclass clazz, jlong handle) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return 0; return static_cast<jint>(interpreter->outputs().size()); } @@ -269,7 +272,7 @@ JNIEXPORT jobjectArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, jclass clazz, jlong handle) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return nullptr; jclass string_class = env->FindClass("java/lang/String"); if (string_class == nullptr) { @@ -291,7 +294,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32( JNIEnv* env, jclass clazz, jlong handle, jboolean allow) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow)); } @@ -299,7 +302,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32( JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput( JNIEnv* env, jclass clazz, jlong handle, jboolean allow) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; interpreter->SetAllowBufferHandleOutput(allow); } @@ -313,7 +316,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK( return; } - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) { return; } @@ -341,8 +344,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK( if (num_threads > 0) { options.num_threads = num_threads; } - tflite::Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options), - xnnpack_delete); + Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options), + xnnpack_delete); auto delegation_status = interpreter->ModifyGraphWithDelegate(std::move(delegate)); // kTfLiteApplicationError occurs in cases where delegation fails but @@ -374,7 +377,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, jclass clazz, jlong handle, jint num_threads) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return; interpreter->SetNumThreads(static_cast<int>(num_threads)); } @@ -411,8 +414,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( std::unique_ptr<tflite::TfLiteVerifier> verifier; verifier.reset(new JNIFlatBufferVerifier()); - auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile( - path, verifier.get(), error_reporter); + auto model = FlatBufferModel::VerifyAndBuildFromFile(path, verifier.get(), + error_reporter); if (!model) { ThrowException(env, tflite::jni::kIllegalArgumentException, "Contents of %s does not encode a valid " @@ -440,7 +443,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( return 0; } - auto model = tflite::FlatBufferModel::BuildFromBuffer( + auto model = FlatBufferModel::BuildFromBuffer( buf, static_cast<size_t>(capacity), error_reporter); if (!model) { ThrowException(env, tflite::jni::kIllegalArgumentException, @@ -455,14 +458,14 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle, jint num_threads) { - tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); + FlatBufferModel* model = convertLongToModel(env, model_handle); if (model == nullptr) return 0; BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); if (error_reporter == nullptr) return 0; auto resolver = ::tflite::CreateOpResolver(); - std::unique_ptr<tflite::Interpreter> interpreter; - TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))( + std::unique_ptr<Interpreter> interpreter; + TfLiteStatus status = InterpreterBuilder(*model, *(resolver.get()))( &interpreter, static_cast<int>(num_threads)); if (status != kTfLiteOk) { ThrowException(env, tflite::jni::kIllegalArgumentException, @@ -478,8 +481,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( // Sets inputs, runs inference, and returns outputs as long handles. JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) { - tflite::Interpreter* interpreter = - convertLongToInterpreter(env, interpreter_handle); + Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return; BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); @@ -497,7 +499,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + Interpreter* interpreter = convertLongToInterpreter(env, handle); if (interpreter == nullptr) return -1; const int idx = static_cast<int>(output_idx); if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { @@ -518,8 +520,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); if (error_reporter == nullptr) return JNI_FALSE; - tflite::Interpreter* interpreter = - convertLongToInterpreter(env, interpreter_handle); + Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return JNI_FALSE; if (input_idx < 0 || input_idx >= interpreter->inputs().size()) { ThrowException(env, tflite::jni::kIllegalArgumentException, @@ -555,8 +556,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, jlong delegate_handle) { - tflite::Interpreter* interpreter = - convertLongToInterpreter(env, interpreter_handle); + Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return; BufferErrorReporter* error_reporter = @@ -577,8 +577,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate( JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) { - tflite::Interpreter* interpreter = - convertLongToInterpreter(env, interpreter_handle); + Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) return; BufferErrorReporter* error_reporter = @@ -596,8 +595,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors( JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag( JNIEnv* env, jclass clazz, jlong interpreter_handle) { - tflite::Interpreter* interpreter = - convertLongToInterpreter(env, interpreter_handle); + Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); if (interpreter == nullptr) { ThrowException(env, tflite::jni::kIllegalArgumentException, "Internal error: Invalid handle to interpreter."); diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc index 44302a339b8..00f2a6904c0 100644 --- a/tensorflow/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc @@ -19,12 +19,13 @@ limitations under the License. #include <memory> #include <string> -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/core/shims/c/common.h" +#include "tensorflow/lite/core/shims/cc/interpreter.h" #include "tensorflow/lite/java/src/main/native/jni_utils.h" #include "tensorflow/lite/string_util.h" using tflite::jni::ThrowException; +using tflite_shims::Interpreter; namespace { @@ -39,14 +40,14 @@ static const char* kStringClassPath = "java/lang/String"; // invalidate all TfLiteTensor* handles during inference or allocation. class TensorHandle { public: - TensorHandle(tflite::Interpreter* interpreter, int tensor_index) + TensorHandle(Interpreter* interpreter, int tensor_index) : interpreter_(interpreter), tensor_index_(tensor_index) {} TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); } int index() const { return tensor_index_; } private: - tflite::Interpreter* const interpreter_; + Interpreter* const interpreter_; const int tensor_index_; }; @@ -396,8 +397,7 @@ extern "C" { JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create( JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) { - tflite::Interpreter* interpreter = - reinterpret_cast<tflite::Interpreter*>(interpreter_handle); + Interpreter* interpreter = reinterpret_cast<Interpreter*>(interpreter_handle); return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index)); }