Make string tensor accept byte array inputs for ParseExample op in Java API

PiperOrigin-RevId: 322446486
Change-Id: I1fec0d4f1dd645972fbf04f0647c07731cfc87fa
This commit is contained in:
Jaesung Chung 2020-07-21 14:44:51 -07:00 committed by TensorFlower Gardener
parent e48b48ffd2
commit cc532b863f
4 changed files with 65 additions and 20 deletions

View File

@ -325,6 +325,7 @@ java_test(
"src/testdata/int32.bin",
"src/testdata/int64.bin",
"src/testdata/quantized.bin",
"src/testdata/string.bin",
],
javacopts = JAVACOPTS,
tags = [

View File

@ -302,7 +302,7 @@ public final class Tensor {
}
/** Returns the type of the data. */
static DataType dataTypeOf(Object o) {
DataType dataTypeOf(Object o) {
if (o != null) {
Class<?> c = o.getClass();
// For arrays, the data elements must be a *primitive* type, e.g., an
@ -316,6 +316,10 @@ public final class Tensor {
} else if (int.class.equals(c)) {
return DataType.INT32;
} else if (byte.class.equals(c)) {
// Byte array can be used for storing string tensors, especially for ParseExample op.
if (dtype == DataType.STRING) {
return DataType.STRING;
}
return DataType.UINT8;
} else if (long.class.equals(c)) {
return DataType.INT64;
@ -345,8 +349,21 @@ public final class Tensor {
}
/** Returns the shape of an object as an int array. */
static int[] computeShapeOf(Object o) {
int[] computeShapeOf(Object o) {
int size = computeNumDimensions(o);
if (dtype == DataType.STRING) {
Class<?> c = o.getClass();
if (c.isArray()) {
while (c.isArray()) {
c = c.getComponentType();
}
// If the given string data is stored in byte streams, the last array dimension should be
// treated as a value.
if (byte.class.equals(c)) {
--size;
}
}
}
int[] dimensions = new int[size];
fillShape(o, 0, dimensions);
return dimensions;

View File

@ -28,6 +28,9 @@ using tflite::jni::ThrowException;
namespace {
static const char* kByteArrayClassPath = "[B";
static const char* kStringClassPath = "java/lang/String";
// Convenience handle for obtaining a TfLiteTensor given an interpreter and
// tensor index.
//
@ -271,13 +274,24 @@ size_t WriteMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
}
}
void AddStringDynamicBuffer(JNIEnv* env, jstring src,
void AddStringDynamicBuffer(JNIEnv* env, jobject src,
tflite::DynamicBuffer* dst_buffer) {
const char* chars = env->GetStringUTFChars(src, nullptr);
// + 1 for terminating character.
const int byte_len = env->GetStringUTFLength(src) + 1;
dst_buffer->AddString(chars, byte_len);
env->ReleaseStringUTFChars(src, chars);
if (env->IsInstanceOf(src, env->FindClass(kStringClassPath))) {
jstring str = static_cast<jstring>(src);
const char* chars = env->GetStringUTFChars(str, nullptr);
// + 1 for terminating character.
const int byte_len = env->GetStringUTFLength(str) + 1;
dst_buffer->AddString(chars, byte_len);
env->ReleaseStringUTFChars(str, chars);
}
if (env->IsInstanceOf(src, env->FindClass(kByteArrayClassPath))) {
jbyteArray byte_array = static_cast<jbyteArray>(src);
jsize byte_array_length = env->GetArrayLength(byte_array);
jbyte* bytes = env->GetByteArrayElements(byte_array, nullptr);
dst_buffer->AddString(reinterpret_cast<const char*>(bytes),
byte_array_length);
env->ReleaseByteArrayElements(byte_array, bytes, JNI_ABORT);
}
}
void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
@ -290,10 +304,9 @@ void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
// recursively call populateStringDynamicBuffer over sub-dimensions.
if (dims_left <= 1) {
for (int i = 0; i < num_elements; ++i) {
jstring string_obj =
static_cast<jstring>(env->GetObjectArrayElement(object_array, i));
AddStringDynamicBuffer(env, string_obj, dst_buffer);
env->DeleteLocalRef(string_obj);
jobject obj = env->GetObjectArrayElement(object_array, i);
AddStringDynamicBuffer(env, obj, dst_buffer);
env->DeleteLocalRef(obj);
}
} else {
for (int i = 0; i < num_elements; ++i) {
@ -358,7 +371,7 @@ void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst,
void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) {
tflite::DynamicBuffer dst_buffer;
AddStringDynamicBuffer(env, static_cast<jstring>(src), &dst_buffer);
AddStringDynamicBuffer(env, src, &dst_buffer);
if (!env->ExceptionCheck()) {
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
}

View File

@ -46,6 +46,9 @@ public final class TensorTest {
private static final String LONG_MODEL_PATH =
"tensorflow/lite/java/src/testdata/int64.bin";
private static final String STRING_MODEL_PATH =
"tensorflow/lite/java/src/testdata/string.bin";
private static final String QUANTIZED_MODEL_PATH =
"tensorflow/lite/java/src/testdata/quantized.bin";
@ -412,30 +415,30 @@ public final class TensorTest {
@Test
public void testDataTypeOf() {
float[] testEmptyArray = {};
DataType dataType = Tensor.dataTypeOf(testEmptyArray);
DataType dataType = tensor.dataTypeOf(testEmptyArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
float[] testFloatArray = {0.783f, 0.251f};
dataType = Tensor.dataTypeOf(testFloatArray);
dataType = tensor.dataTypeOf(testFloatArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
dataType = Tensor.dataTypeOf(testMultiDimArray);
dataType = tensor.dataTypeOf(testMultiDimArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
FloatBuffer testFloatBuffer = FloatBuffer.allocate(1);
dataType = Tensor.dataTypeOf(testFloatBuffer);
dataType = tensor.dataTypeOf(testFloatBuffer);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
float testFloat = 1.0f;
dataType = Tensor.dataTypeOf(testFloat);
dataType = tensor.dataTypeOf(testFloat);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
try {
double[] testDoubleArray = {0.783, 0.251};
Tensor.dataTypeOf(testDoubleArray);
tensor.dataTypeOf(testDoubleArray);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
}
try {
Float[] testBoxedArray = {0.783f, 0.251f};
Tensor.dataTypeOf(testBoxedArray);
tensor.dataTypeOf(testBoxedArray);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
@ -528,4 +531,15 @@ public final class TensorTest {
assertThat(scale).isWithin(1e-6f).of(0.25f);
assertThat(zeroPoint).isEqualTo(127);
}
@Test
public void testByteArrayStringTensorInput() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(STRING_MODEL_PATH);
wrapper.resizeInput(0, new int[] {1});
Tensor stringTensor = wrapper.getInputTensor(0);
byte[][] byteArray = new byte[][] {new byte[1]};
assertThat(stringTensor.dataTypeOf(byteArray)).isEqualTo(DataType.STRING);
assertThat(stringTensor.shape()).isEqualTo(new int[] {1});
}
}