Support scalar inputs in Java TFLite API
PiperOrigin-RevId: 308864253 Change-Id: Ic9993903e571601b3d3f3a133b4abc5a64bc2155
This commit is contained in:
parent
f9f6b4cec2
commit
b238a739ac
tensorflow/lite/java
@ -211,6 +211,8 @@ java_test(
|
||||
"src/testdata/int64.bin",
|
||||
"src/testdata/invalid_model.bin",
|
||||
"src/testdata/string.bin",
|
||||
# Takes a scalar string and reshapes to a rank-1, single element string.
|
||||
"src/testdata/string_scalar.bin",
|
||||
"src/testdata/uint8.bin",
|
||||
"src/testdata/with_custom_op.lite",
|
||||
],
|
||||
|
@ -31,6 +31,7 @@ import java.util.Arrays;
|
||||
* not needed to be closed by the client. However, once the {@code NativeInterpreterWrapper} has
|
||||
* been closed, the tensor handle will be invalidated.
|
||||
*/
|
||||
// TODO(b/153882978): Add scalar getters similar to TF's Java API.
|
||||
public final class Tensor {
|
||||
|
||||
/**
|
||||
@ -187,8 +188,10 @@ public final class Tensor {
|
||||
throwIfDataIsIncompatible(src);
|
||||
if (isBuffer(src)) {
|
||||
setTo((Buffer) src);
|
||||
} else {
|
||||
} else if (src.getClass().isArray()) {
|
||||
writeMultiDimensionalArray(nativeHandle, src);
|
||||
} else {
|
||||
writeScalar(nativeHandle, src);
|
||||
}
|
||||
}
|
||||
|
||||
@ -300,19 +303,39 @@ public final class Tensor {
|
||||
static DataType dataTypeOf(Object o) {
|
||||
if (o != null) {
|
||||
Class<?> c = o.getClass();
|
||||
while (c.isArray()) {
|
||||
c = c.getComponentType();
|
||||
}
|
||||
if (float.class.equals(c) || o instanceof FloatBuffer) {
|
||||
return DataType.FLOAT32;
|
||||
} else if (int.class.equals(c) || o instanceof IntBuffer) {
|
||||
return DataType.INT32;
|
||||
} else if (byte.class.equals(c)) {
|
||||
return DataType.UINT8;
|
||||
} else if (long.class.equals(c) || o instanceof LongBuffer) {
|
||||
return DataType.INT64;
|
||||
} else if (String.class.equals(c)) {
|
||||
return DataType.STRING;
|
||||
// For arrays, the data elements must be a *primitive* type, e.g., an
|
||||
// array of floats is fine, but not an array of Floats.
|
||||
if (c.isArray()) {
|
||||
while (c.isArray()) {
|
||||
c = c.getComponentType();
|
||||
}
|
||||
if (float.class.equals(c)) {
|
||||
return DataType.FLOAT32;
|
||||
} else if (int.class.equals(c)) {
|
||||
return DataType.INT32;
|
||||
} else if (byte.class.equals(c)) {
|
||||
return DataType.UINT8;
|
||||
} else if (long.class.equals(c)) {
|
||||
return DataType.INT64;
|
||||
} else if (String.class.equals(c)) {
|
||||
return DataType.STRING;
|
||||
}
|
||||
} else {
|
||||
// For scalars, the type will be boxed.
|
||||
if (Float.class.equals(c) || o instanceof FloatBuffer) {
|
||||
return DataType.FLOAT32;
|
||||
} else if (Integer.class.equals(c) || o instanceof IntBuffer) {
|
||||
return DataType.INT32;
|
||||
} else if (Byte.class.equals(c)) {
|
||||
// Note that we don't check for ByteBuffer here; ByteBuffer payloads
|
||||
// are allowed to map to any type, and should be handled earlier
|
||||
// in the input/output processing pipeline.
|
||||
return DataType.UINT8;
|
||||
} else if (Long.class.equals(c) || o instanceof LongBuffer) {
|
||||
return DataType.INT64;
|
||||
} else if (String.class.equals(c)) {
|
||||
return DataType.STRING;
|
||||
}
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
@ -466,6 +489,8 @@ public final class Tensor {
|
||||
|
||||
private static native void writeMultiDimensionalArray(long handle, Object src);
|
||||
|
||||
private static native void writeScalar(long handle, Object src);
|
||||
|
||||
private static native int index(long handle);
|
||||
|
||||
private static native String name(long handle);
|
||||
|
@ -81,10 +81,16 @@ size_t ElementByteSize(TfLiteType data_type) {
|
||||
"Interal error: Java int not compatible with kTfLiteInt");
|
||||
return 4;
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
static_assert(sizeof(jbyte) == 1,
|
||||
"Interal error: Java byte not compatible with "
|
||||
"kTfLiteUInt8");
|
||||
return 1;
|
||||
case kTfLiteBool:
|
||||
static_assert(sizeof(jboolean) == 1,
|
||||
"Interal error: Java boolean not compatible with "
|
||||
"kTfLiteBool");
|
||||
return 1;
|
||||
case kTfLiteInt64:
|
||||
static_assert(sizeof(jlong) == 8,
|
||||
"Interal error: Java long not compatible with "
|
||||
@ -265,6 +271,15 @@ size_t WriteMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
|
||||
}
|
||||
}
|
||||
|
||||
void AddStringDynamicBuffer(JNIEnv* env, jstring src,
|
||||
tflite::DynamicBuffer* dst_buffer) {
|
||||
const char* chars = env->GetStringUTFChars(src, nullptr);
|
||||
// + 1 for terminating character.
|
||||
const int byte_len = env->GetStringUTFLength(src) + 1;
|
||||
dst_buffer->AddString(chars, byte_len);
|
||||
env->ReleaseStringUTFChars(src, chars);
|
||||
}
|
||||
|
||||
void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
|
||||
tflite::DynamicBuffer* dst_buffer,
|
||||
int dims_left) {
|
||||
@ -277,11 +292,7 @@ void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
jstring string_obj =
|
||||
static_cast<jstring>(env->GetObjectArrayElement(object_array, i));
|
||||
const char* chars = env->GetStringUTFChars(string_obj, nullptr);
|
||||
// + 1 for terminating character.
|
||||
const int byte_len = env->GetStringUTFLength(string_obj) + 1;
|
||||
dst_buffer->AddString(chars, byte_len);
|
||||
env->ReleaseStringUTFChars(string_obj, chars);
|
||||
AddStringDynamicBuffer(env, string_obj, dst_buffer);
|
||||
env->DeleteLocalRef(string_obj);
|
||||
}
|
||||
} else {
|
||||
@ -303,6 +314,56 @@ void WriteMultiDimensionalStringArray(JNIEnv* env, jobject src,
|
||||
}
|
||||
}
|
||||
|
||||
void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst,
|
||||
int dst_size) {
|
||||
size_t src_size = ElementByteSize(type);
|
||||
if (src_size != dst_size) {
|
||||
ThrowException(
|
||||
env, kIllegalStateException,
|
||||
"Scalar (%d bytes) not compatible with allocated tensor (%d bytes)",
|
||||
src_size, dst_size);
|
||||
return;
|
||||
}
|
||||
switch (type) {
|
||||
// env->FindClass and env->GetMethodID are expensive and JNI best practices
|
||||
// suggest that they should be cached. However, until the creation of scalar
|
||||
// valued tensors seems to become a noticeable fraction of program execution,
|
||||
// ignore that cost.
|
||||
#define CASE(type, jtype, method_name, method_signature, call_type) \
|
||||
case type: { \
|
||||
jclass clazz = env->FindClass("java/lang/Number"); \
|
||||
jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
|
||||
jtype v = env->Call##call_type##Method(src, method); \
|
||||
memcpy(dst, &v, src_size); \
|
||||
return; \
|
||||
}
|
||||
CASE(kTfLiteFloat32, jfloat, "floatValue", "()F", Float);
|
||||
CASE(kTfLiteInt32, jint, "intValue", "()I", Int);
|
||||
CASE(kTfLiteInt64, jlong, "longValue", "()J", Long);
|
||||
CASE(kTfLiteInt8, jbyte, "byteValue", "()B", Byte);
|
||||
CASE(kTfLiteUInt8, jbyte, "byteValue", "()B", Byte);
|
||||
#undef CASE
|
||||
case kTfLiteBool: {
|
||||
jclass clazz = env->FindClass("java/lang/Boolean");
|
||||
jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
|
||||
jboolean v = env->CallBooleanMethod(src, method);
|
||||
*(static_cast<unsigned char*>(dst)) = v ? 1 : 0;
|
||||
return;
|
||||
}
|
||||
default:
|
||||
ThrowException(env, kIllegalStateException, "Invalid DataType(%d)", type);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) {
|
||||
tflite::DynamicBuffer dst_buffer;
|
||||
AddStringDynamicBuffer(env, static_cast<jstring>(src), &dst_buffer);
|
||||
if (!env->ExceptionCheck()) {
|
||||
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#ifdef __cplusplus
|
||||
@ -399,6 +460,28 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
|
||||
}
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeScalar(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jobject src) {
|
||||
TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
|
||||
if (tensor == nullptr) return;
|
||||
if ((tensor->type != kTfLiteString) && (tensor->data.raw == nullptr)) {
|
||||
ThrowException(env, kIllegalArgumentException,
|
||||
"Internal error: Target Tensor hasn't been allocated.");
|
||||
return;
|
||||
}
|
||||
if ((tensor->dims->size != 0) && (tensor->dims->data[0] != 1)) {
|
||||
ThrowException(env, kIllegalArgumentException,
|
||||
"Internal error: Cannot write Java scalar to non-scalar "
|
||||
"Tensor.");
|
||||
return;
|
||||
}
|
||||
if (tensor->type == kTfLiteString) {
|
||||
WriteScalarString(env, src, tensor);
|
||||
} else {
|
||||
WriteScalar(env, src, tensor->type, tensor->data.data, tensor->bytes);
|
||||
}
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
|
||||
jclass clazz,
|
||||
jlong handle) {
|
||||
|
@ -22,6 +22,7 @@ import java.io.File;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Test;
|
||||
@ -209,6 +210,15 @@ public final class InterpreterTest {
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRunWithScalarInput() {
|
||||
FloatBuffer parsedOutput = FloatBuffer.allocate(1);
|
||||
try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
|
||||
interpreter.run(2.37f, parsedOutput);
|
||||
}
|
||||
assertThat(parsedOutput.get(0)).isWithin(0.1f).of(7.11f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResizeInput() {
|
||||
try (Interpreter interpreter = new Interpreter(MODEL_BUFFER)) {
|
||||
|
@ -46,6 +46,9 @@ public final class NativeInterpreterWrapperTest {
|
||||
private static final String STRING_MODEL_PATH =
|
||||
"tensorflow/lite/java/src/testdata/string.bin";
|
||||
|
||||
private static final String STRING_SCALAR_MODEL_PATH =
|
||||
"tensorflow/lite/java/src/testdata/string_scalar.bin";
|
||||
|
||||
private static final String INVALID_MODEL_PATH =
|
||||
"tensorflow/lite/java/src/testdata/invalid_model.bin";
|
||||
|
||||
@ -245,6 +248,20 @@ public final class NativeInterpreterWrapperTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRunWithScalarString() {
|
||||
try (NativeInterpreterWrapper wrapper =
|
||||
new NativeInterpreterWrapper(STRING_SCALAR_MODEL_PATH)) {
|
||||
String[] parsedOutputs = new String[1];
|
||||
Map<Integer, Object> outputs = new HashMap<>();
|
||||
outputs.put(0, parsedOutputs);
|
||||
Object[] inputs = {"s1"};
|
||||
wrapper.run(inputs, outputs);
|
||||
String[] expected = {"s1"};
|
||||
assertThat(parsedOutputs).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRunWithString_supplementaryUnicodeCharacters() {
|
||||
try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(STRING_MODEL_PATH)) {
|
||||
|
@ -242,6 +242,22 @@ public final class TensorTest {
|
||||
tensor.setTo(inputFloatBuffer);
|
||||
tensor.copyTo(output);
|
||||
assertThat(output[0][0][0][0]).isEqualTo(5.0f);
|
||||
|
||||
// Assign from scalar float.
|
||||
wrapper.resizeInput(0, new int[0]);
|
||||
wrapper.allocateTensors();
|
||||
float scalar = 5.0f;
|
||||
tensor.setTo(scalar);
|
||||
FloatBuffer outputScalar = FloatBuffer.allocate(1);
|
||||
tensor.copyTo(outputScalar);
|
||||
assertThat(outputScalar.get(0)).isEqualTo(5.0f);
|
||||
|
||||
// Assign from boxed scalar Float.
|
||||
Float boxedScalar = 9.0f;
|
||||
tensor.setTo(boxedScalar);
|
||||
outputScalar = FloatBuffer.allocate(1);
|
||||
tensor.copyTo(outputScalar);
|
||||
assertThat(outputScalar.get(0)).isEqualTo(9.0f);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -374,6 +390,9 @@ public final class TensorTest {
|
||||
float[][][][] differentShapeInput = new float[1][8][8][3];
|
||||
assertThat(tensor.getInputShapeIfDifferent(differentShapeInput))
|
||||
.isEqualTo(new int[] {1, 8, 8, 3});
|
||||
|
||||
Float differentShapeInputScalar = 5.0f;
|
||||
assertThat(tensor.getInputShapeIfDifferent(differentShapeInputScalar)).isEqualTo(new int[] {});
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -390,6 +409,9 @@ public final class TensorTest {
|
||||
FloatBuffer testFloatBuffer = FloatBuffer.allocate(1);
|
||||
dataType = Tensor.dataTypeOf(testFloatBuffer);
|
||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||
float testFloat = 1.0f;
|
||||
dataType = Tensor.dataTypeOf(testFloat);
|
||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||
try {
|
||||
double[] testDoubleArray = {0.783, 0.251};
|
||||
Tensor.dataTypeOf(testDoubleArray);
|
||||
@ -445,6 +467,20 @@ public final class TensorTest {
|
||||
assertThat(shape[2]).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCopyToScalarUnsupported() {
|
||||
wrapper.resizeInput(0, new int[0]);
|
||||
wrapper.allocateTensors();
|
||||
tensor.setTo(5.0f);
|
||||
Float outputScalar = 7.0f;
|
||||
try {
|
||||
tensor.copyTo(outputScalar);
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
// Expected failure.
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUseAfterClose() {
|
||||
tensor.close();
|
||||
|
BIN
tensorflow/lite/java/src/testdata/string_scalar.bin
vendored
Normal file
BIN
tensorflow/lite/java/src/testdata/string_scalar.bin
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user