Change use of TFLiteDelegate to TFLiteOpaqueDelegate in nativeinterpreterwrapper_jni.cc and use Interpreter by its imported name instead of the fully-qualified one.

PiperOrigin-RevId: 357587845
Change-Id: I14bb303c28e7db30994548848671ad790ce08420
This commit is contained in:
A. Unique TensorFlower 2021-02-15 09:55:44 -08:00 committed by TensorFlower Gardener
parent c92a33f338
commit b4786e349f
4 changed files with 16 additions and 12 deletions

View File

@ -33,7 +33,8 @@ public interface Delegate {
* <p>Note: The Java {@link Delegate} maintains ownership of the native delegate instance, and
* must ensure its existence for the duration of usage with any {@link Interpreter}.
*
* @return The native delegate handle.
* @return The native delegate handle. In C/C++, this should be a pointer to
* 'TfLiteOpaqueDelegate'.
*/
public long getNativeHandle();
}

View File

@ -26,6 +26,7 @@ cc_library(
"-ldl",
],
deps = [
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite:op_resolver",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <time.h>
#include <atomic>
#include <map>
#include <vector>
#include "tensorflow/lite/core/shims/c/common.h"
@ -69,13 +70,13 @@ BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
return reinterpret_cast<BufferErrorReporter*>(handle);
}
TfLiteDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) {
TfLiteOpaqueDelegate* convertLongToDelegate(JNIEnv* env, jlong handle) {
if (handle == 0) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: Invalid handle to delegate.");
return nullptr;
}
return reinterpret_cast<TfLiteDelegate*>(handle);
return reinterpret_cast<TfLiteOpaqueDelegate*>(handle);
}
std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
@ -162,8 +163,7 @@ bool VerifyModel(const void* buf, size_t len) {
// from either inputs or outputs.
// Returns -1 if invalid names are passed.
int GetTensorIndexForSignature(JNIEnv* env, jstring signature_tensor_name,
jstring method_name,
tflite::Interpreter* interpreter,
jstring method_name, Interpreter* interpreter,
bool is_input) {
// Fetch name strings.
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
@ -271,7 +271,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureDefNames(
JNIEnv* env, jclass clazz, jlong handle) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
jclass string_class = env->FindClass("java/lang/String");
if (string_class == nullptr) {
@ -293,7 +293,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureDefNames(
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureInputs(
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
const jobjectArray signature_inputs = GetSignatureInputsOutputsList(
@ -306,7 +306,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureInputs(
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getSignatureOutputs(
JNIEnv* env, jclass clazz, jlong handle, jstring method_name) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return nullptr;
const char* method_name_ptr = env->GetStringUTFChars(method_name, nullptr);
const jobjectArray signature_outputs = GetSignatureInputsOutputsList(
@ -320,7 +320,7 @@ JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndexFromSignature(
JNIEnv* env, jclass clazz, jlong handle, jstring signature_input_name,
jstring method_name) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return -1;
return GetTensorIndexForSignature(env, signature_input_name, method_name,
interpreter, /*is_input=*/true);
@ -330,7 +330,7 @@ JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndexFromSignature(
JNIEnv* env, jclass clazz, jlong handle, jstring signature_output_name,
jstring method_name) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return -1;
return GetTensorIndexForSignature(env, signature_output_name, method_name,
interpreter, /*is_input=*/false);
@ -646,7 +646,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
if (is_changed) {
TfLiteStatus status;
if (strict) {
status = interpreter->ResizeInputTensorStrict(
status = interpreter->ResizeInputTensorStrict(
tensor_idx, convertJIntArrayToVector(env, dims));
} else {
status = interpreter->ResizeInputTensor(
@ -673,7 +673,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) return;
TfLiteDelegate* delegate = convertLongToDelegate(env, delegate_handle);
TfLiteOpaqueDelegate* delegate = convertLongToDelegate(env, delegate_handle);
if (delegate == nullptr) return;
TfLiteStatus status = interpreter->ModifyGraphWithDelegate(delegate);
@ -709,6 +709,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag(
if (interpreter == nullptr) {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Internal error: Invalid handle to interpreter.");
return 0;
}
std::atomic_bool* cancellation_flag = new std::atomic_bool(false);
interpreter->SetCancellationFunction(cancellation_flag, [](void* payload) {

View File

@ -23,6 +23,7 @@ tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h
tensorflow/lite/delegates/gpu/cl/serialization_generated.h
tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h
tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
tensorflow/lite/micro/build_def.bzl
tensorflow/lite/schema/schema_generated.h
tensorflow/opensource_only/BUILD