From b238a739ac0ba78cb162ff1ee9080ba7e827fd95 Mon Sep 17 00:00:00 2001
From: Jared Duke <jdduke@google.com>
Date: Tue, 28 Apr 2020 11:38:59 -0700
Subject: [PATCH] Support scalar inputs in Java TFLite API

PiperOrigin-RevId: 308864253
Change-Id: Ic9993903e571601b3d3f3a133b4abc5a64bc2155
---
 tensorflow/lite/java/BUILD                    |   2 +
 .../main/java/org/tensorflow/lite/Tensor.java |  53 +++++++---
 .../lite/java/src/main/native/tensor_jni.cc   |  93 +++++++++++++++++-
 .../org/tensorflow/lite/InterpreterTest.java  |  10 ++
 .../lite/NativeInterpreterWrapperTest.java    |  17 ++++
 .../java/org/tensorflow/lite/TensorTest.java  |  36 +++++++
 .../lite/java/src/testdata/string_scalar.bin  | Bin 0 -> 448 bytes
 7 files changed, 192 insertions(+), 19 deletions(-)
 create mode 100644 tensorflow/lite/java/src/testdata/string_scalar.bin

diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD
index 857974ecce2..c736c7c4f31 100644
--- a/tensorflow/lite/java/BUILD
+++ b/tensorflow/lite/java/BUILD
@@ -211,6 +211,8 @@ java_test(
         "src/testdata/int64.bin",
         "src/testdata/invalid_model.bin",
         "src/testdata/string.bin",
+        # Takes a scalar string and reshapes to a rank-1, single element string.
+        "src/testdata/string_scalar.bin",
         "src/testdata/uint8.bin",
         "src/testdata/with_custom_op.lite",
     ],
diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index 34647275b92..89a2a6a0639 100644
--- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -31,6 +31,7 @@ import java.util.Arrays;
  * not needed to be closed by the client. However, once the {@code NativeInterpreterWrapper} has
  * been closed, the tensor handle will be invalidated.
  */
+// TODO(b/153882978): Add scalar getters similar to TF's Java API.
 public final class Tensor {
 
   /**
@@ -187,8 +188,10 @@ public final class Tensor {
     throwIfDataIsIncompatible(src);
     if (isBuffer(src)) {
       setTo((Buffer) src);
-    } else {
+    } else if (src.getClass().isArray()) {
       writeMultiDimensionalArray(nativeHandle, src);
+    } else {
+      writeScalar(nativeHandle, src);
     }
   }
 
@@ -300,19 +303,39 @@ public final class Tensor {
   static DataType dataTypeOf(Object o) {
     if (o != null) {
       Class<?> c = o.getClass();
-      while (c.isArray()) {
-        c = c.getComponentType();
-      }
-      if (float.class.equals(c) || o instanceof FloatBuffer) {
-        return DataType.FLOAT32;
-      } else if (int.class.equals(c) || o instanceof IntBuffer) {
-        return DataType.INT32;
-      } else if (byte.class.equals(c)) {
-        return DataType.UINT8;
-      } else if (long.class.equals(c) || o instanceof LongBuffer) {
-        return DataType.INT64;
-      } else if (String.class.equals(c)) {
-        return DataType.STRING;
+      // For arrays, the data elements must be a *primitive* type, e.g., an
+      // array of floats is fine, but not an array of Floats.
+      if (c.isArray()) {
+        while (c.isArray()) {
+          c = c.getComponentType();
+        }
+        if (float.class.equals(c)) {
+          return DataType.FLOAT32;
+        } else if (int.class.equals(c)) {
+          return DataType.INT32;
+        } else if (byte.class.equals(c)) {
+          return DataType.UINT8;
+        } else if (long.class.equals(c)) {
+          return DataType.INT64;
+        } else if (String.class.equals(c)) {
+          return DataType.STRING;
+        }
+      } else {
+        // For scalars, the type will be boxed.
+        if (Float.class.equals(c) || o instanceof FloatBuffer) {
+          return DataType.FLOAT32;
+        } else if (Integer.class.equals(c) || o instanceof IntBuffer) {
+          return DataType.INT32;
+        } else if (Byte.class.equals(c)) {
+          // Note that we don't check for ByteBuffer here; ByteBuffer payloads
+          // are allowed to map to any type, and should be handled earlier
+          // in the input/output processing pipeline.
+          return DataType.UINT8;
+        } else if (Long.class.equals(c) || o instanceof LongBuffer) {
+          return DataType.INT64;
+        } else if (String.class.equals(c)) {
+          return DataType.STRING;
+        }
       }
     }
     throw new IllegalArgumentException(
@@ -466,6 +489,8 @@ public final class Tensor {
 
   private static native void writeMultiDimensionalArray(long handle, Object src);
 
+  private static native void writeScalar(long handle, Object src);
+
   private static native int index(long handle);
 
   private static native String name(long handle);
diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc
index 00706ef0a46..99be71ba37d 100644
--- a/tensorflow/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc
@@ -81,10 +81,16 @@ size_t ElementByteSize(TfLiteType data_type) {
                     "Interal error: Java int not compatible with kTfLiteInt");
       return 4;
     case kTfLiteUInt8:
+    case kTfLiteInt8:
       static_assert(sizeof(jbyte) == 1,
                     "Interal error: Java byte not compatible with "
                     "kTfLiteUInt8");
       return 1;
+    case kTfLiteBool:
+      static_assert(sizeof(jboolean) == 1,
+                    "Interal error: Java boolean not compatible with "
+                    "kTfLiteBool");
+      return 1;
     case kTfLiteInt64:
       static_assert(sizeof(jlong) == 8,
                     "Interal error: Java long not compatible with "
@@ -265,6 +271,15 @@ size_t WriteMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
   }
 }
 
+void AddStringDynamicBuffer(JNIEnv* env, jstring src,
+                            tflite::DynamicBuffer* dst_buffer) {
+  const char* chars = env->GetStringUTFChars(src, nullptr);
+  // + 1 for terminating character.
+  const int byte_len = env->GetStringUTFLength(src) + 1;
+  dst_buffer->AddString(chars, byte_len);
+  env->ReleaseStringUTFChars(src, chars);
+}
+
 void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
                                  tflite::DynamicBuffer* dst_buffer,
                                  int dims_left) {
@@ -277,11 +292,7 @@ void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
     for (int i = 0; i < num_elements; ++i) {
       jstring string_obj =
           static_cast<jstring>(env->GetObjectArrayElement(object_array, i));
-      const char* chars = env->GetStringUTFChars(string_obj, nullptr);
-      // + 1 for terminating character.
-      const int byte_len = env->GetStringUTFLength(string_obj) + 1;
-      dst_buffer->AddString(chars, byte_len);
-      env->ReleaseStringUTFChars(string_obj, chars);
+      AddStringDynamicBuffer(env, string_obj, dst_buffer);
       env->DeleteLocalRef(string_obj);
     }
   } else {
@@ -303,6 +314,56 @@ void WriteMultiDimensionalStringArray(JNIEnv* env, jobject src,
   }
 }
 
+void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst,
+                 int dst_size) {
+  size_t src_size = ElementByteSize(type);
+  if (src_size != dst_size) {
+    ThrowException(
+        env, kIllegalStateException,
+        "Scalar (%d bytes) not compatible with allocated tensor (%d bytes)",
+        src_size, dst_size);
+    return;
+  }
+  switch (type) {
+// env->FindClass and env->GetMethodID are expensive and JNI best practices
+// suggest that they should be cached. However, until the creation of scalar
+// valued tensors seems to become a noticeable fraction of program execution,
+// ignore that cost.
+#define CASE(type, jtype, method_name, method_signature, call_type)            \
+  case type: {                                                                 \
+    jclass clazz = env->FindClass("java/lang/Number");                         \
+    jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
+    jtype v = env->Call##call_type##Method(src, method);                       \
+    memcpy(dst, &v, src_size);                                                 \
+    return;                                                                    \
+  }
+    CASE(kTfLiteFloat32, jfloat, "floatValue", "()F", Float);
+    CASE(kTfLiteInt32, jint, "intValue", "()I", Int);
+    CASE(kTfLiteInt64, jlong, "longValue", "()J", Long);
+    CASE(kTfLiteInt8, jbyte, "byteValue", "()B", Byte);
+    CASE(kTfLiteUInt8, jbyte, "byteValue", "()B", Byte);
+#undef CASE
+    case kTfLiteBool: {
+      jclass clazz = env->FindClass("java/lang/Boolean");
+      jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
+      jboolean v = env->CallBooleanMethod(src, method);
+      *(static_cast<unsigned char*>(dst)) = v ? 1 : 0;
+      return;
+    }
+    default:
+      ThrowException(env, kIllegalStateException, "Invalid DataType(%d)", type);
+      return;
+  }
+}
+
+void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) {
+  tflite::DynamicBuffer dst_buffer;
+  AddStringDynamicBuffer(env, static_cast<jstring>(src), &dst_buffer);
+  if (!env->ExceptionCheck()) {
+    dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
+  }
+}
+
 }  // namespace
 
 #ifdef __cplusplus
@@ -399,6 +460,28 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
   }
 }
 
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeScalar(
+    JNIEnv* env, jclass clazz, jlong handle, jobject src) {
+  TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
+  if (tensor == nullptr) return;
+  if ((tensor->type != kTfLiteString) && (tensor->data.raw == nullptr)) {
+    ThrowException(env, kIllegalArgumentException,
+                   "Internal error: Target Tensor hasn't been allocated.");
+    return;
+  }
+  if ((tensor->dims->size != 0) && (tensor->dims->data[0] != 1)) {
+    ThrowException(env, kIllegalArgumentException,
+                   "Internal error: Cannot write Java scalar to non-scalar "
+                   "Tensor.");
+    return;
+  }
+  if (tensor->type == kTfLiteString) {
+    WriteScalarString(env, src, tensor);
+  } else {
+    WriteScalar(env, src, tensor->type, tensor->data.data, tensor->bytes);
+  }
+}
+
 JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
                                                              jclass clazz,
                                                              jlong handle) {
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index b38f1ad771d..328ccf8cef6 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -22,6 +22,7 @@ import java.io.File;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
 import java.util.HashMap;
 import java.util.Map;
 import org.junit.Test;
@@ -209,6 +210,15 @@ public final class InterpreterTest {
     assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
   }
 
+  @Test
+  public void testRunWithScalarInput() {
+    FloatBuffer parsedOutput = FloatBuffer.allocate(1);
+    try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
+      interpreter.run(2.37f, parsedOutput);
+    }
+    assertThat(parsedOutput.get(0)).isWithin(0.1f).of(7.11f);
+  }
+
   @Test
   public void testResizeInput() {
     try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index bab39793130..6436481c285 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -46,6 +46,9 @@ public final class NativeInterpreterWrapperTest {
   private static final String STRING_MODEL_PATH =
       "tensorflow/lite/java/src/testdata/string.bin";
 
+  private static final String STRING_SCALAR_MODEL_PATH =
+      "tensorflow/lite/java/src/testdata/string_scalar.bin";
+
   private static final String INVALID_MODEL_PATH =
       "tensorflow/lite/java/src/testdata/invalid_model.bin";
 
@@ -245,6 +248,20 @@ public final class NativeInterpreterWrapperTest {
     }
   }
 
+  @Test
+  public void testRunWithScalarString() {
+    try (NativeInterpreterWrapper wrapper =
+        new NativeInterpreterWrapper(STRING_SCALAR_MODEL_PATH)) {
+      String[] parsedOutputs = new String[1];
+      Map<Integer, Object> outputs = new HashMap<>();
+      outputs.put(0, parsedOutputs);
+      Object[] inputs = {"s1"};
+      wrapper.run(inputs, outputs);
+      String[] expected = {"s1"};
+      assertThat(parsedOutputs).isEqualTo(expected);
+    }
+  }
+
   @Test
   public void testRunWithString_supplementaryUnicodeCharacters() {
     try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(STRING_MODEL_PATH)) {
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index f828f26f4c5..06a7deacc2c 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -242,6 +242,22 @@ public final class TensorTest {
     tensor.setTo(inputFloatBuffer);
     tensor.copyTo(output);
     assertThat(output[0][0][0][0]).isEqualTo(5.0f);
+
+    // Assign from scalar float.
+    wrapper.resizeInput(0, new int[0]);
+    wrapper.allocateTensors();
+    float scalar = 5.0f;
+    tensor.setTo(scalar);
+    FloatBuffer outputScalar = FloatBuffer.allocate(1);
+    tensor.copyTo(outputScalar);
+    assertThat(outputScalar.get(0)).isEqualTo(5.0f);
+
+    // Assign from boxed scalar Float.
+    Float boxedScalar = 9.0f;
+    tensor.setTo(boxedScalar);
+    outputScalar = FloatBuffer.allocate(1);
+    tensor.copyTo(outputScalar);
+    assertThat(outputScalar.get(0)).isEqualTo(9.0f);
   }
 
   @Test
@@ -374,6 +390,9 @@ public final class TensorTest {
     float[][][][] differentShapeInput = new float[1][8][8][3];
     assertThat(tensor.getInputShapeIfDifferent(differentShapeInput))
         .isEqualTo(new int[] {1, 8, 8, 3});
+
+    Float differentShapeInputScalar = 5.0f;
+    assertThat(tensor.getInputShapeIfDifferent(differentShapeInputScalar)).isEqualTo(new int[] {});
   }
 
   @Test
@@ -390,6 +409,9 @@ public final class TensorTest {
     FloatBuffer testFloatBuffer = FloatBuffer.allocate(1);
     dataType = Tensor.dataTypeOf(testFloatBuffer);
     assertThat(dataType).isEqualTo(DataType.FLOAT32);
+    float testFloat = 1.0f;
+    dataType = Tensor.dataTypeOf(testFloat);
+    assertThat(dataType).isEqualTo(DataType.FLOAT32);
     try {
       double[] testDoubleArray = {0.783, 0.251};
       Tensor.dataTypeOf(testDoubleArray);
@@ -445,6 +467,20 @@ public final class TensorTest {
     assertThat(shape[2]).isEqualTo(1);
   }
 
+  @Test
+  public void testCopyToScalarUnsupported() {
+    wrapper.resizeInput(0, new int[0]);
+    wrapper.allocateTensors();
+    tensor.setTo(5.0f);
+    Float outputScalar = 7.0f;
+    try {
+      tensor.copyTo(outputScalar);
+      fail();
+    } catch (IllegalArgumentException e) {
+      // Expected failure.
+    }
+  }
+
   @Test
   public void testUseAfterClose() {
     tensor.close();
diff --git a/tensorflow/lite/java/src/testdata/string_scalar.bin b/tensorflow/lite/java/src/testdata/string_scalar.bin
new file mode 100644
index 0000000000000000000000000000000000000000..8f7d0f69ccf743f6bd310817e74fe86a9d7e0edf
GIT binary patch
literal 448
zcmYL_y=nqc5QRtGY!*$7g#=3r3kyjK+jN2;D9Dl*SmK2h{6j;EkTQi&6H?>}#0T;L
zgw1#Mx_IGmX6~IiXXZv`_BtH*f3hsA5H~;*^e$Y2k3jqa0<b6f1}EY#_yL}bM{#?w
zUVXUjyPG|6Q@0k?k=>!$B=fp$xC6RC9`X#(_=8hGm>+Q&h+9Cu%VfSXtD>kpBJ%U<
zt*BRzj@MD`Et#(+c+~wTyFxeGwb)r6dK1(UyIq=Woc-ptAqs(o+{>%*J9wS4bmS}U
yfI{b<_@eXd)zlm2{r-O<ysrO4c|Mc+wWDUzse$Ptd2{J{F>_lx;~Su&IsO5(i7zAo

literal 0
HcmV?d00001