Support scalar inputs in Java TFLite API

PiperOrigin-RevId: 308864253
Change-Id: Ic9993903e571601b3d3f3a133b4abc5a64bc2155
This commit is contained in:
Jared Duke 2020-04-28 11:38:59 -07:00 committed by TensorFlower Gardener
parent f9f6b4cec2
commit b238a739ac
7 changed files with 192 additions and 19 deletions
tensorflow/lite/java

View File

@ -211,6 +211,8 @@ java_test(
"src/testdata/int64.bin",
"src/testdata/invalid_model.bin",
"src/testdata/string.bin",
# Takes a scalar string and reshapes to a rank-1, single element string.
"src/testdata/string_scalar.bin",
"src/testdata/uint8.bin",
"src/testdata/with_custom_op.lite",
],

View File

@ -31,6 +31,7 @@ import java.util.Arrays;
* not needed to be closed by the client. However, once the {@code NativeInterpreterWrapper} has
* been closed, the tensor handle will be invalidated.
*/
// TODO(b/153882978): Add scalar getters similar to TF's Java API.
public final class Tensor {
/**
@ -187,8 +188,10 @@ public final class Tensor {
throwIfDataIsIncompatible(src);
if (isBuffer(src)) {
setTo((Buffer) src);
} else {
} else if (src.getClass().isArray()) {
writeMultiDimensionalArray(nativeHandle, src);
} else {
writeScalar(nativeHandle, src);
}
}
@ -300,19 +303,39 @@ public final class Tensor {
static DataType dataTypeOf(Object o) {
if (o != null) {
Class<?> c = o.getClass();
while (c.isArray()) {
c = c.getComponentType();
}
if (float.class.equals(c) || o instanceof FloatBuffer) {
return DataType.FLOAT32;
} else if (int.class.equals(c) || o instanceof IntBuffer) {
return DataType.INT32;
} else if (byte.class.equals(c)) {
return DataType.UINT8;
} else if (long.class.equals(c) || o instanceof LongBuffer) {
return DataType.INT64;
} else if (String.class.equals(c)) {
return DataType.STRING;
// For arrays, the data elements must be a *primitive* type, e.g., an
// array of floats is fine, but not an array of Floats.
if (c.isArray()) {
while (c.isArray()) {
c = c.getComponentType();
}
if (float.class.equals(c)) {
return DataType.FLOAT32;
} else if (int.class.equals(c)) {
return DataType.INT32;
} else if (byte.class.equals(c)) {
return DataType.UINT8;
} else if (long.class.equals(c)) {
return DataType.INT64;
} else if (String.class.equals(c)) {
return DataType.STRING;
}
} else {
// For scalars, the type will be boxed.
if (Float.class.equals(c) || o instanceof FloatBuffer) {
return DataType.FLOAT32;
} else if (Integer.class.equals(c) || o instanceof IntBuffer) {
return DataType.INT32;
} else if (Byte.class.equals(c)) {
// Note that we don't check for ByteBuffer here; ByteBuffer payloads
// are allowed to map to any type, and should be handled earlier
// in the input/output processing pipeline.
return DataType.UINT8;
} else if (Long.class.equals(c) || o instanceof LongBuffer) {
return DataType.INT64;
} else if (String.class.equals(c)) {
return DataType.STRING;
}
}
}
throw new IllegalArgumentException(
@ -466,6 +489,8 @@ public final class Tensor {
private static native void writeMultiDimensionalArray(long handle, Object src);
private static native void writeScalar(long handle, Object src);
private static native int index(long handle);
private static native String name(long handle);

View File

@ -81,10 +81,16 @@ size_t ElementByteSize(TfLiteType data_type) {
"Interal error: Java int not compatible with kTfLiteInt");
return 4;
case kTfLiteUInt8:
case kTfLiteInt8:
static_assert(sizeof(jbyte) == 1,
"Interal error: Java byte not compatible with "
"kTfLiteUInt8");
return 1;
case kTfLiteBool:
static_assert(sizeof(jboolean) == 1,
"Interal error: Java boolean not compatible with "
"kTfLiteBool");
return 1;
case kTfLiteInt64:
static_assert(sizeof(jlong) == 8,
"Interal error: Java long not compatible with "
@ -265,6 +271,15 @@ size_t WriteMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
}
}
void AddStringDynamicBuffer(JNIEnv* env, jstring src,
tflite::DynamicBuffer* dst_buffer) {
const char* chars = env->GetStringUTFChars(src, nullptr);
// + 1 for terminating character.
const int byte_len = env->GetStringUTFLength(src) + 1;
dst_buffer->AddString(chars, byte_len);
env->ReleaseStringUTFChars(src, chars);
}
void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
tflite::DynamicBuffer* dst_buffer,
int dims_left) {
@ -277,11 +292,7 @@ void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
for (int i = 0; i < num_elements; ++i) {
jstring string_obj =
static_cast<jstring>(env->GetObjectArrayElement(object_array, i));
const char* chars = env->GetStringUTFChars(string_obj, nullptr);
// + 1 for terminating character.
const int byte_len = env->GetStringUTFLength(string_obj) + 1;
dst_buffer->AddString(chars, byte_len);
env->ReleaseStringUTFChars(string_obj, chars);
AddStringDynamicBuffer(env, string_obj, dst_buffer);
env->DeleteLocalRef(string_obj);
}
} else {
@ -303,6 +314,56 @@ void WriteMultiDimensionalStringArray(JNIEnv* env, jobject src,
}
}
void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst,
int dst_size) {
size_t src_size = ElementByteSize(type);
if (src_size != dst_size) {
ThrowException(
env, kIllegalStateException,
"Scalar (%d bytes) not compatible with allocated tensor (%d bytes)",
src_size, dst_size);
return;
}
switch (type) {
// env->FindClass and env->GetMethodID are expensive and JNI best practices
// suggest that they should be cached. However, until the creation of scalar
// valued tensors seems to become a noticeable fraction of program execution,
// ignore that cost.
#define CASE(type, jtype, method_name, method_signature, call_type) \
case type: { \
jclass clazz = env->FindClass("java/lang/Number"); \
jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
jtype v = env->Call##call_type##Method(src, method); \
memcpy(dst, &v, src_size); \
return; \
}
CASE(kTfLiteFloat32, jfloat, "floatValue", "()F", Float);
CASE(kTfLiteInt32, jint, "intValue", "()I", Int);
CASE(kTfLiteInt64, jlong, "longValue", "()J", Long);
CASE(kTfLiteInt8, jbyte, "byteValue", "()B", Byte);
CASE(kTfLiteUInt8, jbyte, "byteValue", "()B", Byte);
#undef CASE
case kTfLiteBool: {
jclass clazz = env->FindClass("java/lang/Boolean");
jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
jboolean v = env->CallBooleanMethod(src, method);
*(static_cast<unsigned char*>(dst)) = v ? 1 : 0;
return;
}
default:
ThrowException(env, kIllegalStateException, "Invalid DataType(%d)", type);
return;
}
}
void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) {
tflite::DynamicBuffer dst_buffer;
AddStringDynamicBuffer(env, static_cast<jstring>(src), &dst_buffer);
if (!env->ExceptionCheck()) {
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
}
}
} // namespace
#ifdef __cplusplus
@ -399,6 +460,28 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
}
}
JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeScalar(
JNIEnv* env, jclass clazz, jlong handle, jobject src) {
TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return;
if ((tensor->type != kTfLiteString) && (tensor->data.raw == nullptr)) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Target Tensor hasn't been allocated.");
return;
}
if ((tensor->dims->size != 0) && (tensor->dims->data[0] != 1)) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Cannot write Java scalar to non-scalar "
"Tensor.");
return;
}
if (tensor->type == kTfLiteString) {
WriteScalarString(env, src, tensor);
} else {
WriteScalar(env, src, tensor->type, tensor->data.data, tensor->bytes);
}
}
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {

View File

@ -22,6 +22,7 @@ import java.io.File;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
@ -209,6 +210,15 @@ public final class InterpreterTest {
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
@Test
public void testRunWithScalarInput() {
FloatBuffer parsedOutput = FloatBuffer.allocate(1);
try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
interpreter.run(2.37f, parsedOutput);
}
assertThat(parsedOutput.get(0)).isWithin(0.1f).of(7.11f);
}
@Test
public void testResizeInput() {
try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {

View File

@ -46,6 +46,9 @@ public final class NativeInterpreterWrapperTest {
private static final String STRING_MODEL_PATH =
"tensorflow/lite/java/src/testdata/string.bin";
private static final String STRING_SCALAR_MODEL_PATH =
"tensorflow/lite/java/src/testdata/string_scalar.bin";
private static final String INVALID_MODEL_PATH =
"tensorflow/lite/java/src/testdata/invalid_model.bin";
@ -245,6 +248,20 @@ public final class NativeInterpreterWrapperTest {
}
}
@Test
public void testRunWithScalarString() {
try (NativeInterpreterWrapper wrapper =
new NativeInterpreterWrapper(STRING_SCALAR_MODEL_PATH)) {
String[] parsedOutputs = new String[1];
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, parsedOutputs);
Object[] inputs = {"s1"};
wrapper.run(inputs, outputs);
String[] expected = {"s1"};
assertThat(parsedOutputs).isEqualTo(expected);
}
}
@Test
public void testRunWithString_supplementaryUnicodeCharacters() {
try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(STRING_MODEL_PATH)) {

View File

@ -242,6 +242,22 @@ public final class TensorTest {
tensor.setTo(inputFloatBuffer);
tensor.copyTo(output);
assertThat(output[0][0][0][0]).isEqualTo(5.0f);
// Assign from scalar float.
wrapper.resizeInput(0, new int[0]);
wrapper.allocateTensors();
float scalar = 5.0f;
tensor.setTo(scalar);
FloatBuffer outputScalar = FloatBuffer.allocate(1);
tensor.copyTo(outputScalar);
assertThat(outputScalar.get(0)).isEqualTo(5.0f);
// Assign from boxed scalar Float.
Float boxedScalar = 9.0f;
tensor.setTo(boxedScalar);
outputScalar = FloatBuffer.allocate(1);
tensor.copyTo(outputScalar);
assertThat(outputScalar.get(0)).isEqualTo(9.0f);
}
@Test
@ -374,6 +390,9 @@ public final class TensorTest {
float[][][][] differentShapeInput = new float[1][8][8][3];
assertThat(tensor.getInputShapeIfDifferent(differentShapeInput))
.isEqualTo(new int[] {1, 8, 8, 3});
Float differentShapeInputScalar = 5.0f;
assertThat(tensor.getInputShapeIfDifferent(differentShapeInputScalar)).isEqualTo(new int[] {});
}
@Test
@ -390,6 +409,9 @@ public final class TensorTest {
FloatBuffer testFloatBuffer = FloatBuffer.allocate(1);
dataType = Tensor.dataTypeOf(testFloatBuffer);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
float testFloat = 1.0f;
dataType = Tensor.dataTypeOf(testFloat);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
try {
double[] testDoubleArray = {0.783, 0.251};
Tensor.dataTypeOf(testDoubleArray);
@ -445,6 +467,20 @@ public final class TensorTest {
assertThat(shape[2]).isEqualTo(1);
}
@Test
public void testCopyToScalarUnsupported() {
wrapper.resizeInput(0, new int[0]);
wrapper.allocateTensors();
tensor.setTo(5.0f);
Float outputScalar = 7.0f;
try {
tensor.copyTo(outputScalar);
fail();
} catch (IllegalArgumentException e) {
// Expected failure.
}
}
@Test
public void testUseAfterClose() {
tensor.close();

Binary file not shown.