Change use of TFLiteDelegate to TFLiteOpaqueDelegate in nativeinterpreterwrapper_jni.cc and use Interpreter by its imported name instead of the fully-qualified one.
PiperOrigin-RevId: 357587845 Change-Id: I14bb303c28e7db30994548848671ad790ce08420
This commit is contained in:
parent
c92a33f338
commit
b4786e349f
@ -33,7 +33,8 @@ public interface Delegate {
|
||||
* <p>Note: The Java {@link Delegate} maintains ownership of the native delegate instance, and
|
||||
* must ensure its existence for the duration of usage with any {@link Interpreter}.
|
||||
*
|
||||
* @return The native delegate handle.
|
||||
* @return The native delegate handle. In C/C++, this should be a pointer to
|
||||
* 'TfLiteOpaqueDelegate'.
|
||||
*/
|
||||
public long getNativeHandle();
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@ cc_library(
|
||||
"-ldl",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:op_resolver",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite:string_util",
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <time.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/core/shims/c/common.h"
|
||||
@ -69,13 +70,13 @@ BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
|
||||
return reinterpret_cast<BufferErrorReporter*>(handle);
|
||||
}
|
||||
|
||||
TfLiteDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) {
|
||||
TfLiteOpaqueDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
"Internal error: Invalid handle to delegate.");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<TfLiteDelegate*>(handle);
|
||||
return reinterpret_cast<TfLiteOpaqueDelegate*>(handle);
|
||||
}
|
||||
|
||||
std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
|
||||
@ -162,8 +163,7 @@ bool VerifyModel(const void* buf, size_t len) {
|
||||
// from either inputs or outputs.
|
||||
// Returns -1 if invalid names are passed.
|
||||
int GetTensorIndexForSignature(JNIEnv* env, jstring signature_tensor_name,
|
||||
jstring method_name,
|
||||
tflite::Interpreter* interpreter,
|
||||
jstring method_name, Interpreter* interpreter,
|
||||
bool is_input) {
|
||||
// Fetch name strings.
|
||||
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
|
||||
@ -271,7 +271,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureDefNames(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return nullptr;
|
||||
jclass string_class = env->FindClass("java/lang/String");
|
||||
if (string_class == nullptr) {
|
||||
@ -293,7 +293,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureDefNames(
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureInputs(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return nullptr;
|
||||
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
|
||||
const jobjectArray signature_inputs = GetSignatureInputsOutputsList(
|
||||
@ -306,7 +306,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureInputs(
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureOutputs(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return nullptr;
|
||||
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
|
||||
const jobjectArray signature_outputs = GetSignatureInputsOutputsList(
|
||||
@ -320,7 +320,7 @@ JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndexFromSignature(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring signature_input_name,
|
||||
jstring method_name) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return -1;
|
||||
return GetTensorIndexForSignature(env, signature_input_name, method_name,
|
||||
interpreter, /*is_input=*/true);
|
||||
@ -330,7 +330,7 @@ JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndexFromSignature(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jstring signature_output_name,
|
||||
jstring method_name) {
|
||||
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
Interpreter* interpreter = convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return -1;
|
||||
return GetTensorIndexForSignature(env, signature_output_name, method_name,
|
||||
interpreter, /*is_input=*/false);
|
||||
@ -646,7 +646,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
|
||||
if (is_changed) {
|
||||
TfLiteStatus status;
|
||||
if (strict) {
|
||||
status = interpreter->ResizeInputTensorStrict(
|
||||
status = interpreter->ResizeInputTensorStrict(
|
||||
tensor_idx, convertJIntArrayToVector(env, dims));
|
||||
} else {
|
||||
status = interpreter->ResizeInputTensor(
|
||||
@ -673,7 +673,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
|
||||
convertLongToErrorReporter(env, error_handle);
|
||||
if (error_reporter == nullptr) return;
|
||||
|
||||
TfLiteDelegate* delegate = convertLongToDelegate(env, delegate_handle);
|
||||
TfLiteOpaqueDelegate* delegate = convertLongToDelegate(env, delegate_handle);
|
||||
if (delegate == nullptr) return;
|
||||
|
||||
TfLiteStatus status = interpreter->ModifyGraphWithDelegate(delegate);
|
||||
@ -709,6 +709,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag(
|
||||
if (interpreter == nullptr) {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
"Internal error: Invalid handle to interpreter.");
|
||||
return 0;
|
||||
}
|
||||
std::atomic_bool* cancellation_flag = new std::atomic_bool(false);
|
||||
interpreter->SetCancellationFunction(cancellation_flag, [](void* payload) {
|
||||
|
||||
@ -23,6 +23,7 @@ tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h
|
||||
tensorflow/lite/delegates/gpu/cl/serialization_generated.h
|
||||
tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
|
||||
tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h
|
||||
tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
|
||||
tensorflow/lite/micro/build_def.bzl
|
||||
tensorflow/lite/schema/schema_generated.h
|
||||
tensorflow/opensource_only/BUILD
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user