From 627a032dc2c9f5e18b4264ed5717a97b7f981df0 Mon Sep 17 00:00:00 2001
From: Fergus Henderson <fergus@google.com>
Date: Fri, 11 Dec 2020 13:15:36 -0800
Subject: [PATCH] Change the TF Lite Java API to use the shim layer (attempt
 #2).

PiperOrigin-RevId: 347061214
Change-Id: Iaf3e9cb423cc6b495a83e3c17b041199d14063e3
---
 tensorflow/lite/BUILD                         |  2 +-
 tensorflow/lite/core/shims/BUILD              |  4 +
 .../create_op_resolver_with_builtin_ops.cc    |  2 +-
 tensorflow/lite/java/src/main/native/BUILD    |  9 ++-
 .../native/nativeinterpreterwrapper_jni.cc    | 78 +++++++++----------
 .../lite/java/src/main/native/tensor_jni.cc   | 12 +--
 6 files changed, 56 insertions(+), 51 deletions(-)

diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index 295c071ba20..35b72eeaa11 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -719,7 +719,7 @@ cc_library(
     deps = [
         "//tensorflow/lite:op_resolver",
         "//tensorflow/lite/core/api",
-        "//tensorflow/lite/kernels:builtin_ops",
+        "//tensorflow/lite/core/shims:builtin_ops",
     ],
 )
 
diff --git a/tensorflow/lite/core/shims/BUILD b/tensorflow/lite/core/shims/BUILD
index aa250597447..e09573db872 100644
--- a/tensorflow/lite/core/shims/BUILD
+++ b/tensorflow/lite/core/shims/BUILD
@@ -112,6 +112,10 @@ cc_library(
         "//tensorflow/lite/kernels:fully_connected.h",
     ],
     compatible_with = get_compatible_with_portable(),
+    visibility = [
+        "//tensorflow/lite:__subpackages__",
+        "//tensorflow_lite_support:__subpackages__",
+    ],
     deps = [
         "//tensorflow/lite:cc_api",
         "//tensorflow/lite/c:common",
diff --git a/tensorflow/lite/create_op_resolver_with_builtin_ops.cc b/tensorflow/lite/create_op_resolver_with_builtin_ops.cc
index 5801fad369b..ff3c583b0e1 100644
--- a/tensorflow/lite/create_op_resolver_with_builtin_ops.cc
+++ b/tensorflow/lite/create_op_resolver_with_builtin_ops.cc
@@ -15,8 +15,8 @@ limitations under the License.
 
 #include <memory>
 
+#include "tensorflow/lite/core/shims/cc/kernels/register.h"
 #include "tensorflow/lite/create_op_resolver.h"
-#include "tensorflow/lite/kernels/register.h"
 
 namespace tflite {
 
diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD
index fa3195062e6..9dc00105900 100644
--- a/tensorflow/lite/java/src/main/native/BUILD
+++ b/tensorflow/lite/java/src/main/native/BUILD
@@ -26,11 +26,12 @@ cc_library(
         "-ldl",
     ],
     deps = [
-        "//tensorflow/lite:framework",
+        "//tensorflow/lite:op_resolver",
         "//tensorflow/lite:schema_fbs_version",
         "//tensorflow/lite:string_util",
         "//tensorflow/lite:util",
-        "//tensorflow/lite/c:common",
+        "//tensorflow/lite/core/shims:common",
+        "//tensorflow/lite/core/shims:framework",
         "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
         "//tensorflow/lite/java/jni",
     ],
@@ -45,7 +46,9 @@ cc_library(
     deps = [
         ":native_framework_only",
         "//tensorflow/lite:create_op_resolver_with_builtin_ops",
-        "//tensorflow/lite:framework",
+        "//tensorflow/lite/core/api",
+        "//tensorflow/lite/core/shims:builtin_ops",
+        "//tensorflow/lite/core/shims:framework",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index fdfab0bd078..ce0be597cdd 100644
--- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -21,12 +21,12 @@ limitations under the License.
 #include <atomic>
 #include <vector>
 
-#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/shims/c/common.h"
+#include "tensorflow/lite/core/shims/cc/interpreter.h"
+#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
+#include "tensorflow/lite/core/shims/cc/model_builder.h"
 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/interpreter_builder.h"
 #include "tensorflow/lite/java/src/main/native/jni_utils.h"
-#include "tensorflow/lite/model_builder.h"
 #include "tensorflow/lite/util.h"
 
 namespace tflite {
@@ -36,25 +36,28 @@ extern std::unique_ptr<OpResolver> CreateOpResolver();
 
 using tflite::jni::BufferErrorReporter;
 using tflite::jni::ThrowException;
+using tflite_shims::FlatBufferModel;
+using tflite_shims::Interpreter;
+using tflite_shims::InterpreterBuilder;
 
 namespace {
 
-tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
+Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
   if (handle == 0) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to Interpreter.");
     return nullptr;
   }
-  return reinterpret_cast<tflite::Interpreter*>(handle);
+  return reinterpret_cast<Interpreter*>(handle);
 }
 
-tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
+FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) {
   if (handle == 0) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to model.");
     return nullptr;
   }
-  return reinterpret_cast<tflite::FlatBufferModel*>(handle);
+  return reinterpret_cast<FlatBufferModel*>(handle);
 }
 
 BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) {
@@ -163,7 +166,7 @@ JNIEXPORT jobjectArray JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
                                                                 jclass clazz,
                                                                 jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return nullptr;
   jclass string_class = env->FindClass("java/lang/String");
   if (string_class == nullptr) {
@@ -185,7 +188,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
     JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
@@ -203,7 +206,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
 JNIEXPORT jboolean JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
     JNIEnv* env, jclass clazz, jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return JNI_FALSE;
 
   // TODO(b/132995737): Remove this logic by caching whether an unresolved
@@ -226,7 +229,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_hasUnresolvedFlexOp(
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
     JNIEnv* env, jclass clazz, jlong handle, jint input_index) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return interpreter->inputs()[input_index];
 }
@@ -234,7 +237,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
     JNIEnv* env, jclass clazz, jlong handle, jint output_index) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return interpreter->outputs()[output_index];
 }
@@ -242,7 +245,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength(
     JNIEnv* env, jclass clazz, jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return static_cast<jint>(interpreter->execution_plan().size());
 }
@@ -251,7 +254,7 @@ JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
                                                                 jclass clazz,
                                                                 jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return static_cast<jint>(interpreter->inputs().size());
 }
@@ -260,7 +263,7 @@ JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
                                                                  jclass clazz,
                                                                  jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return 0;
   return static_cast<jint>(interpreter->outputs().size());
 }
@@ -269,7 +272,7 @@ JNIEXPORT jobjectArray JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
                                                                  jclass clazz,
                                                                  jlong handle) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return nullptr;
   jclass string_class = env->FindClass("java/lang/String");
   if (string_class == nullptr) {
@@ -291,7 +294,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
     JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
 }
@@ -299,7 +302,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
     JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   interpreter->SetAllowBufferHandleOutput(allow);
 }
@@ -313,7 +316,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
     return;
   }
 
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) {
     return;
   }
@@ -341,8 +344,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
     if (num_threads > 0) {
       options.num_threads = num_threads;
     }
-    tflite::Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options),
-                                                    xnnpack_delete);
+    Interpreter::TfLiteDelegatePtr delegate(xnnpack_create(&options),
+                                            xnnpack_delete);
     auto delegation_status =
         interpreter->ModifyGraphWithDelegate(std::move(delegate));
     // kTfLiteApplicationError occurs in cases where delegation fails but
@@ -374,7 +377,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
                                                              jclass clazz,
                                                              jlong handle,
                                                              jint num_threads) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return;
   interpreter->SetNumThreads(static_cast<int>(num_threads));
 }
@@ -411,8 +414,8 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
   std::unique_ptr<tflite::TfLiteVerifier> verifier;
   verifier.reset(new JNIFlatBufferVerifier());
 
-  auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
-      path, verifier.get(), error_reporter);
+  auto model = FlatBufferModel::VerifyAndBuildFromFile(path, verifier.get(),
+                                                       error_reporter);
   if (!model) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Contents of %s does not encode a valid "
@@ -440,7 +443,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer(
     return 0;
   }
 
-  auto model = tflite::FlatBufferModel::BuildFromBuffer(
+  auto model = FlatBufferModel::BuildFromBuffer(
       buf, static_cast<size_t>(capacity), error_reporter);
   if (!model) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
@@ -455,14 +458,14 @@ JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
     JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
     jint num_threads) {
-  tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
+  FlatBufferModel* model = convertLongToModel(env, model_handle);
   if (model == nullptr) return 0;
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
   if (error_reporter == nullptr) return 0;
   auto resolver = ::tflite::CreateOpResolver();
-  std::unique_ptr<tflite::Interpreter> interpreter;
-  TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))(
+  std::unique_ptr<Interpreter> interpreter;
+  TfLiteStatus status = InterpreterBuilder(*model, *(resolver.get()))(
       &interpreter, static_cast<int>(num_threads));
   if (status != kTfLiteOk) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
@@ -478,8 +481,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
 // Sets inputs, runs inference, and returns outputs as long handles.
 JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return;
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
@@ -497,7 +499,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
 JNIEXPORT jint JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
     JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
-  tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, handle);
   if (interpreter == nullptr) return -1;
   const int idx = static_cast<int>(output_idx);
   if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
@@ -518,8 +520,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
   BufferErrorReporter* error_reporter =
       convertLongToErrorReporter(env, error_handle);
   if (error_reporter == nullptr) return JNI_FALSE;
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return JNI_FALSE;
   if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
@@ -555,8 +556,7 @@ JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
     jlong delegate_handle) {
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return;
 
   BufferErrorReporter* error_reporter =
@@ -577,8 +577,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_applyDelegate(
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) return;
 
   BufferErrorReporter* error_reporter =
@@ -596,8 +595,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_resetVariableTensors(
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_createCancellationFlag(
     JNIEnv* env, jclass clazz, jlong interpreter_handle) {
-  tflite::Interpreter* interpreter =
-      convertLongToInterpreter(env, interpreter_handle);
+  Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle);
   if (interpreter == nullptr) {
     ThrowException(env, tflite::jni::kIllegalArgumentException,
                    "Internal error: Invalid handle to interpreter.");
diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc
index 44302a339b8..00f2a6904c0 100644
--- a/tensorflow/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc
@@ -19,12 +19,13 @@ limitations under the License.
 #include <memory>
 #include <string>
 
-#include "tensorflow/lite/c/common.h"
-#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/core/shims/c/common.h"
+#include "tensorflow/lite/core/shims/cc/interpreter.h"
 #include "tensorflow/lite/java/src/main/native/jni_utils.h"
 #include "tensorflow/lite/string_util.h"
 
 using tflite::jni::ThrowException;
+using tflite_shims::Interpreter;
 
 namespace {
 
@@ -39,14 +40,14 @@ static const char* kStringClassPath = "java/lang/String";
 // invalidate all TfLiteTensor* handles during inference or allocation.
 class TensorHandle {
  public:
-  TensorHandle(tflite::Interpreter* interpreter, int tensor_index)
+  TensorHandle(Interpreter* interpreter, int tensor_index)
       : interpreter_(interpreter), tensor_index_(tensor_index) {}
 
   TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); }
   int index() const { return tensor_index_; }
 
  private:
-  tflite::Interpreter* const interpreter_;
+  Interpreter* const interpreter_;
   const int tensor_index_;
 };
 
@@ -396,8 +397,7 @@ extern "C" {
 
 JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create(
     JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) {
-  tflite::Interpreter* interpreter =
-      reinterpret_cast<tflite::Interpreter*>(interpreter_handle);
+  Interpreter* interpreter = reinterpret_cast<Interpreter*>(interpreter_handle);
   return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index));
 }