From b238a739ac0ba78cb162ff1ee9080ba7e827fd95 Mon Sep 17 00:00:00 2001 From: Jared Duke <jdduke@google.com> Date: Tue, 28 Apr 2020 11:38:59 -0700 Subject: [PATCH] Support scalar inputs in Java TFLite API PiperOrigin-RevId: 308864253 Change-Id: Ic9993903e571601b3d3f3a133b4abc5a64bc2155 --- tensorflow/lite/java/BUILD | 2 + .../main/java/org/tensorflow/lite/Tensor.java | 53 +++++++--- .../lite/java/src/main/native/tensor_jni.cc | 93 +++++++++++++++++- .../org/tensorflow/lite/InterpreterTest.java | 10 ++ .../lite/NativeInterpreterWrapperTest.java | 17 ++++ .../java/org/tensorflow/lite/TensorTest.java | 36 +++++++ .../lite/java/src/testdata/string_scalar.bin | Bin 0 -> 448 bytes 7 files changed, 192 insertions(+), 19 deletions(-) create mode 100644 tensorflow/lite/java/src/testdata/string_scalar.bin diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index 857974ecce2..c736c7c4f31 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -211,6 +211,8 @@ java_test( "src/testdata/int64.bin", "src/testdata/invalid_model.bin", "src/testdata/string.bin", + # Takes a scalar string and reshapes to a rank-1, single element string. + "src/testdata/string_scalar.bin", "src/testdata/uint8.bin", "src/testdata/with_custom_op.lite", ], diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 34647275b92..89a2a6a0639 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -31,6 +31,7 @@ import java.util.Arrays; * not needed to be closed by the client. However, once the {@code NativeInterpreterWrapper} has * been closed, the tensor handle will be invalidated. */ +// TODO(b/153882978): Add scalar getters similar to TF's Java API. public final class Tensor { /** @@ -187,8 +188,10 @@ public final class Tensor { throwIfDataIsIncompatible(src); if (isBuffer(src)) { setTo((Buffer) src); - } else { + } else if (src.getClass().isArray()) { writeMultiDimensionalArray(nativeHandle, src); + } else { + writeScalar(nativeHandle, src); } } @@ -300,19 +303,39 @@ public final class Tensor { static DataType dataTypeOf(Object o) { if (o != null) { Class<?> c = o.getClass(); - while (c.isArray()) { - c = c.getComponentType(); - } - if (float.class.equals(c) || o instanceof FloatBuffer) { - return DataType.FLOAT32; - } else if (int.class.equals(c) || o instanceof IntBuffer) { - return DataType.INT32; - } else if (byte.class.equals(c)) { - return DataType.UINT8; - } else if (long.class.equals(c) || o instanceof LongBuffer) { - return DataType.INT64; - } else if (String.class.equals(c)) { - return DataType.STRING; + // For arrays, the data elements must be a *primitive* type, e.g., an + // array of floats is fine, but not an array of Floats. + if (c.isArray()) { + while (c.isArray()) { + c = c.getComponentType(); + } + if (float.class.equals(c)) { + return DataType.FLOAT32; + } else if (int.class.equals(c)) { + return DataType.INT32; + } else if (byte.class.equals(c)) { + return DataType.UINT8; + } else if (long.class.equals(c)) { + return DataType.INT64; + } else if (String.class.equals(c)) { + return DataType.STRING; + } + } else { + // For scalars, the type will be boxed. + if (Float.class.equals(c) || o instanceof FloatBuffer) { + return DataType.FLOAT32; + } else if (Integer.class.equals(c) || o instanceof IntBuffer) { + return DataType.INT32; + } else if (Byte.class.equals(c)) { + // Note that we don't check for ByteBuffer here; ByteBuffer payloads + // are allowed to map to any type, and should be handled earlier + // in the input/output processing pipeline. + return DataType.UINT8; + } else if (Long.class.equals(c) || o instanceof LongBuffer) { + return DataType.INT64; + } else if (String.class.equals(c)) { + return DataType.STRING; + } } } throw new IllegalArgumentException( @@ -466,6 +489,8 @@ public final class Tensor { private static native void writeMultiDimensionalArray(long handle, Object src); + private static native void writeScalar(long handle, Object src); + private static native int index(long handle); private static native String name(long handle); diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc index 00706ef0a46..99be71ba37d 100644 --- a/tensorflow/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc @@ -81,10 +81,16 @@ size_t ElementByteSize(TfLiteType data_type) { "Interal error: Java int not compatible with kTfLiteInt"); return 4; case kTfLiteUInt8: + case kTfLiteInt8: static_assert(sizeof(jbyte) == 1, "Interal error: Java byte not compatible with " "kTfLiteUInt8"); return 1; + case kTfLiteBool: + static_assert(sizeof(jboolean) == 1, + "Interal error: Java boolean not compatible with " + "kTfLiteBool"); + return 1; case kTfLiteInt64: static_assert(sizeof(jlong) == 8, "Interal error: Java long not compatible with " @@ -265,6 +271,15 @@ size_t WriteMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, } } +void AddStringDynamicBuffer(JNIEnv* env, jstring src, + tflite::DynamicBuffer* dst_buffer) { + const char* chars = env->GetStringUTFChars(src, nullptr); + // + 1 for terminating character. + const int byte_len = env->GetStringUTFLength(src) + 1; + dst_buffer->AddString(chars, byte_len); + env->ReleaseStringUTFChars(src, chars); +} + void PopulateStringDynamicBuffer(JNIEnv* env, jobject src, tflite::DynamicBuffer* dst_buffer, int dims_left) { @@ -277,11 +292,7 @@ void PopulateStringDynamicBuffer(JNIEnv* env, jobject src, for (int i = 0; i < num_elements; ++i) { jstring string_obj = static_cast<jstring>(env->GetObjectArrayElement(object_array, i)); - const char* chars = env->GetStringUTFChars(string_obj, nullptr); - // + 1 for terminating character. - const int byte_len = env->GetStringUTFLength(string_obj) + 1; - dst_buffer->AddString(chars, byte_len); - env->ReleaseStringUTFChars(string_obj, chars); + AddStringDynamicBuffer(env, string_obj, dst_buffer); env->DeleteLocalRef(string_obj); } } else { @@ -303,6 +314,56 @@ void WriteMultiDimensionalStringArray(JNIEnv* env, jobject src, } } +void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst, + int dst_size) { + size_t src_size = ElementByteSize(type); + if (src_size != dst_size) { + ThrowException( + env, kIllegalStateException, + "Scalar (%d bytes) not compatible with allocated tensor (%d bytes)", + src_size, dst_size); + return; + } + switch (type) { +// env->FindClass and env->GetMethodID are expensive and JNI best practices +// suggest that they should be cached. However, until the creation of scalar +// valued tensors seems to become a noticeable fraction of program execution, +// ignore that cost. +#define CASE(type, jtype, method_name, method_signature, call_type) \ + case type: { \ + jclass clazz = env->FindClass("java/lang/Number"); \ + jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \ + jtype v = env->Call##call_type##Method(src, method); \ + memcpy(dst, &v, src_size); \ + return; \ + } + CASE(kTfLiteFloat32, jfloat, "floatValue", "()F", Float); + CASE(kTfLiteInt32, jint, "intValue", "()I", Int); + CASE(kTfLiteInt64, jlong, "longValue", "()J", Long); + CASE(kTfLiteInt8, jbyte, "byteValue", "()B", Byte); + CASE(kTfLiteUInt8, jbyte, "byteValue", "()B", Byte); +#undef CASE + case kTfLiteBool: { + jclass clazz = env->FindClass("java/lang/Boolean"); + jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z"); + jboolean v = env->CallBooleanMethod(src, method); + *(static_cast<unsigned char*>(dst)) = v ? 1 : 0; + return; + } + default: + ThrowException(env, kIllegalStateException, "Invalid DataType(%d)", type); + return; + } +} + +void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) { + tflite::DynamicBuffer dst_buffer; + AddStringDynamicBuffer(env, static_cast<jstring>(src), &dst_buffer); + if (!env->ExceptionCheck()) { + dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr); + } +} + } // namespace #ifdef __cplusplus @@ -399,6 +460,28 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, } } +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeScalar( + JNIEnv* env, jclass clazz, jlong handle, jobject src) { + TfLiteTensor* tensor = GetTensorFromHandle(env, handle); + if (tensor == nullptr) return; + if ((tensor->type != kTfLiteString) && (tensor->data.raw == nullptr)) { + ThrowException(env, kIllegalArgumentException, + "Internal error: Target Tensor hasn't been allocated."); + return; + } + if ((tensor->dims->size != 0) && (tensor->dims->data[0] != 1)) { + ThrowException(env, kIllegalArgumentException, + "Internal error: Cannot write Java scalar to non-scalar " + "Tensor."); + return; + } + if (tensor->type == kTfLiteString) { + WriteScalarString(env, src, tensor); + } else { + WriteScalar(env, src, tensor->type, tensor->data.data, tensor->bytes); + } +} + JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, jclass clazz, jlong handle) { diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index b38f1ad771d..328ccf8cef6 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -22,6 +22,7 @@ import java.io.File; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; import java.util.HashMap; import java.util.Map; import org.junit.Test; @@ -209,6 +210,15 @@ public final class InterpreterTest { assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); } + @Test + public void testRunWithScalarInput() { + FloatBuffer parsedOutput = FloatBuffer.allocate(1); + try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) { + interpreter.run(2.37f, parsedOutput); + } + assertThat(parsedOutput.get(0)).isWithin(0.1f).of(7.11f); + } + @Test public void testResizeInput() { try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) { diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index bab39793130..6436481c285 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -46,6 +46,9 @@ public final class NativeInterpreterWrapperTest { private static final String STRING_MODEL_PATH = "tensorflow/lite/java/src/testdata/string.bin"; + private static final String STRING_SCALAR_MODEL_PATH = + "tensorflow/lite/java/src/testdata/string_scalar.bin"; + private static final String INVALID_MODEL_PATH = "tensorflow/lite/java/src/testdata/invalid_model.bin"; @@ -245,6 +248,20 @@ public final class NativeInterpreterWrapperTest { } } + @Test + public void testRunWithScalarString() { + try (NativeInterpreterWrapper wrapper = + new NativeInterpreterWrapper(STRING_SCALAR_MODEL_PATH)) { + String[] parsedOutputs = new String[1]; + Map<Integer, Object> outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + Object[] inputs = {"s1"}; + wrapper.run(inputs, outputs); + String[] expected = {"s1"}; + assertThat(parsedOutputs).isEqualTo(expected); + } + } + @Test public void testRunWithString_supplementaryUnicodeCharacters() { try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(STRING_MODEL_PATH)) { diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index f828f26f4c5..06a7deacc2c 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -242,6 +242,22 @@ public final class TensorTest { tensor.setTo(inputFloatBuffer); tensor.copyTo(output); assertThat(output[0][0][0][0]).isEqualTo(5.0f); + + // Assign from scalar float. + wrapper.resizeInput(0, new int[0]); + wrapper.allocateTensors(); + float scalar = 5.0f; + tensor.setTo(scalar); + FloatBuffer outputScalar = FloatBuffer.allocate(1); + tensor.copyTo(outputScalar); + assertThat(outputScalar.get(0)).isEqualTo(5.0f); + + // Assign from boxed scalar Float. + Float boxedScalar = 9.0f; + tensor.setTo(boxedScalar); + outputScalar = FloatBuffer.allocate(1); + tensor.copyTo(outputScalar); + assertThat(outputScalar.get(0)).isEqualTo(9.0f); } @Test @@ -374,6 +390,9 @@ public final class TensorTest { float[][][][] differentShapeInput = new float[1][8][8][3]; assertThat(tensor.getInputShapeIfDifferent(differentShapeInput)) .isEqualTo(new int[] {1, 8, 8, 3}); + + Float differentShapeInputScalar = 5.0f; + assertThat(tensor.getInputShapeIfDifferent(differentShapeInputScalar)).isEqualTo(new int[] {}); } @Test @@ -390,6 +409,9 @@ public final class TensorTest { FloatBuffer testFloatBuffer = FloatBuffer.allocate(1); dataType = Tensor.dataTypeOf(testFloatBuffer); assertThat(dataType).isEqualTo(DataType.FLOAT32); + float testFloat = 1.0f; + dataType = Tensor.dataTypeOf(testFloat); + assertThat(dataType).isEqualTo(DataType.FLOAT32); try { double[] testDoubleArray = {0.783, 0.251}; Tensor.dataTypeOf(testDoubleArray); @@ -445,6 +467,20 @@ public final class TensorTest { assertThat(shape[2]).isEqualTo(1); } + @Test + public void testCopyToScalarUnsupported() { + wrapper.resizeInput(0, new int[0]); + wrapper.allocateTensors(); + tensor.setTo(5.0f); + Float outputScalar = 7.0f; + try { + tensor.copyTo(outputScalar); + fail(); + } catch (IllegalArgumentException e) { + // Expected failure. + } + } + @Test public void testUseAfterClose() { tensor.close(); diff --git a/tensorflow/lite/java/src/testdata/string_scalar.bin b/tensorflow/lite/java/src/testdata/string_scalar.bin new file mode 100644 index 0000000000000000000000000000000000000000..8f7d0f69ccf743f6bd310817e74fe86a9d7e0edf GIT binary patch literal 448 zcmYL_y=nqc5QRtGY!*$7g#=3r3kyjK+jN2;D9Dl*SmK2h{6j;EkTQi&6H?>}#0T;L zgw1#Mx_IGmX6~IiXXZv`_BtH*f3hsA5H~;*^e$Y2k3jqa0<b6f1}EY#_yL}bM{#?w zUVXUjyPG|6Q@0k?k=>!$B=fp$xC6RC9`X#(_=8hGm>+Q&h+9Cu%VfSXtD>kpBJ%U< zt*BRzj@MD`Et#(+c+~wTyFxeGwb)r6dK1(UyIq=Woc-ptAqs(o+{>%*J9wS4bmS}U yfI{b<_@eXd)zlm2{r-O<ysrO4c|Mc+wWDUzse$Ptd2{J{F>_lx;~Su&IsO5(i7zAo literal 0 HcmV?d00001