Make string tensor accept byte array inputs for ParseExample op in Java API

PiperOrigin-RevId: 322446486
Change-Id: I1fec0d4f1dd645972fbf04f0647c07731cfc87fa
This commit is contained in:
Jaesung Chung 2020-07-21 14:44:51 -07:00 committed by TensorFlower Gardener
parent e48b48ffd2
commit cc532b863f
4 changed files with 65 additions and 20 deletions

View File

@ -325,6 +325,7 @@ java_test(
"src/testdata/int32.bin", "src/testdata/int32.bin",
"src/testdata/int64.bin", "src/testdata/int64.bin",
"src/testdata/quantized.bin", "src/testdata/quantized.bin",
"src/testdata/string.bin",
], ],
javacopts = JAVACOPTS, javacopts = JAVACOPTS,
tags = [ tags = [

View File

@ -302,7 +302,7 @@ public final class Tensor {
} }
/** Returns the type of the data. */ /** Returns the type of the data. */
static DataType dataTypeOf(Object o) { DataType dataTypeOf(Object o) {
if (o != null) { if (o != null) {
Class<?> c = o.getClass(); Class<?> c = o.getClass();
// For arrays, the data elements must be a *primitive* type, e.g., an // For arrays, the data elements must be a *primitive* type, e.g., an
@ -316,6 +316,10 @@ public final class Tensor {
} else if (int.class.equals(c)) { } else if (int.class.equals(c)) {
return DataType.INT32; return DataType.INT32;
} else if (byte.class.equals(c)) { } else if (byte.class.equals(c)) {
// Byte array can be used for storing string tensors, especially for ParseExample op.
if (dtype == DataType.STRING) {
return DataType.STRING;
}
return DataType.UINT8; return DataType.UINT8;
} else if (long.class.equals(c)) { } else if (long.class.equals(c)) {
return DataType.INT64; return DataType.INT64;
@ -345,8 +349,21 @@ public final class Tensor {
} }
/** Returns the shape of an object as an int array. */ /** Returns the shape of an object as an int array. */
static int[] computeShapeOf(Object o) { int[] computeShapeOf(Object o) {
int size = computeNumDimensions(o); int size = computeNumDimensions(o);
if (dtype == DataType.STRING) {
Class<?> c = o.getClass();
if (c.isArray()) {
while (c.isArray()) {
c = c.getComponentType();
}
// If the given string data is stored in byte streams, the last array dimension should be
// treated as a value.
if (byte.class.equals(c)) {
--size;
}
}
}
int[] dimensions = new int[size]; int[] dimensions = new int[size];
fillShape(o, 0, dimensions); fillShape(o, 0, dimensions);
return dimensions; return dimensions;

View File

@ -28,6 +28,9 @@ using tflite::jni::ThrowException;
namespace { namespace {
static const char* kByteArrayClassPath = "[B";
static const char* kStringClassPath = "java/lang/String";
// Convenience handle for obtaining a TfLiteTensor given an interpreter and // Convenience handle for obtaining a TfLiteTensor given an interpreter and
// tensor index. // tensor index.
// //
@ -271,13 +274,24 @@ size_t WriteMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
} }
} }
void AddStringDynamicBuffer(JNIEnv* env, jstring src, void AddStringDynamicBuffer(JNIEnv* env, jobject src,
tflite::DynamicBuffer* dst_buffer) { tflite::DynamicBuffer* dst_buffer) {
const char* chars = env->GetStringUTFChars(src, nullptr); if (env->IsInstanceOf(src, env->FindClass(kStringClassPath))) {
jstring str = static_cast<jstring>(src);
const char* chars = env->GetStringUTFChars(str, nullptr);
// + 1 for terminating character. // + 1 for terminating character.
const int byte_len = env->GetStringUTFLength(src) + 1; const int byte_len = env->GetStringUTFLength(str) + 1;
dst_buffer->AddString(chars, byte_len); dst_buffer->AddString(chars, byte_len);
env->ReleaseStringUTFChars(src, chars); env->ReleaseStringUTFChars(str, chars);
}
if (env->IsInstanceOf(src, env->FindClass(kByteArrayClassPath))) {
jbyteArray byte_array = static_cast<jbyteArray>(src);
jsize byte_array_length = env->GetArrayLength(byte_array);
jbyte* bytes = env->GetByteArrayElements(byte_array, nullptr);
dst_buffer->AddString(reinterpret_cast<const char*>(bytes),
byte_array_length);
env->ReleaseByteArrayElements(byte_array, bytes, JNI_ABORT);
}
} }
void PopulateStringDynamicBuffer(JNIEnv* env, jobject src, void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
@ -290,10 +304,9 @@ void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
// recursively call populateStringDynamicBuffer over sub-dimensions. // recursively call populateStringDynamicBuffer over sub-dimensions.
if (dims_left <= 1) { if (dims_left <= 1) {
for (int i = 0; i < num_elements; ++i) { for (int i = 0; i < num_elements; ++i) {
jstring string_obj = jobject obj = env->GetObjectArrayElement(object_array, i);
static_cast<jstring>(env->GetObjectArrayElement(object_array, i)); AddStringDynamicBuffer(env, obj, dst_buffer);
AddStringDynamicBuffer(env, string_obj, dst_buffer); env->DeleteLocalRef(obj);
env->DeleteLocalRef(string_obj);
} }
} else { } else {
for (int i = 0; i < num_elements; ++i) { for (int i = 0; i < num_elements; ++i) {
@ -358,7 +371,7 @@ void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst,
void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) { void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) {
tflite::DynamicBuffer dst_buffer; tflite::DynamicBuffer dst_buffer;
AddStringDynamicBuffer(env, static_cast<jstring>(src), &dst_buffer); AddStringDynamicBuffer(env, src, &dst_buffer);
if (!env->ExceptionCheck()) { if (!env->ExceptionCheck()) {
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr); dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
} }

View File

@ -46,6 +46,9 @@ public final class TensorTest {
private static final String LONG_MODEL_PATH = private static final String LONG_MODEL_PATH =
"tensorflow/lite/java/src/testdata/int64.bin"; "tensorflow/lite/java/src/testdata/int64.bin";
private static final String STRING_MODEL_PATH =
"tensorflow/lite/java/src/testdata/string.bin";
private static final String QUANTIZED_MODEL_PATH = private static final String QUANTIZED_MODEL_PATH =
"tensorflow/lite/java/src/testdata/quantized.bin"; "tensorflow/lite/java/src/testdata/quantized.bin";
@ -412,30 +415,30 @@ public final class TensorTest {
@Test @Test
public void testDataTypeOf() { public void testDataTypeOf() {
float[] testEmptyArray = {}; float[] testEmptyArray = {};
DataType dataType = Tensor.dataTypeOf(testEmptyArray); DataType dataType = tensor.dataTypeOf(testEmptyArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32); assertThat(dataType).isEqualTo(DataType.FLOAT32);
float[] testFloatArray = {0.783f, 0.251f}; float[] testFloatArray = {0.783f, 0.251f};
dataType = Tensor.dataTypeOf(testFloatArray); dataType = tensor.dataTypeOf(testFloatArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32); assertThat(dataType).isEqualTo(DataType.FLOAT32);
float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
dataType = Tensor.dataTypeOf(testMultiDimArray); dataType = tensor.dataTypeOf(testMultiDimArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32); assertThat(dataType).isEqualTo(DataType.FLOAT32);
FloatBuffer testFloatBuffer = FloatBuffer.allocate(1); FloatBuffer testFloatBuffer = FloatBuffer.allocate(1);
dataType = Tensor.dataTypeOf(testFloatBuffer); dataType = tensor.dataTypeOf(testFloatBuffer);
assertThat(dataType).isEqualTo(DataType.FLOAT32); assertThat(dataType).isEqualTo(DataType.FLOAT32);
float testFloat = 1.0f; float testFloat = 1.0f;
dataType = Tensor.dataTypeOf(testFloat); dataType = tensor.dataTypeOf(testFloat);
assertThat(dataType).isEqualTo(DataType.FLOAT32); assertThat(dataType).isEqualTo(DataType.FLOAT32);
try { try {
double[] testDoubleArray = {0.783, 0.251}; double[] testDoubleArray = {0.783, 0.251};
Tensor.dataTypeOf(testDoubleArray); tensor.dataTypeOf(testDoubleArray);
fail(); fail();
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
} }
try { try {
Float[] testBoxedArray = {0.783f, 0.251f}; Float[] testBoxedArray = {0.783f, 0.251f};
Tensor.dataTypeOf(testBoxedArray); tensor.dataTypeOf(testBoxedArray);
fail(); fail();
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
@ -528,4 +531,15 @@ public final class TensorTest {
assertThat(scale).isWithin(1e-6f).of(0.25f); assertThat(scale).isWithin(1e-6f).of(0.25f);
assertThat(zeroPoint).isEqualTo(127); assertThat(zeroPoint).isEqualTo(127);
} }
@Test
public void testByteArrayStringTensorInput() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(STRING_MODEL_PATH);
wrapper.resizeInput(0, new int[] {1});
Tensor stringTensor = wrapper.getInputTensor(0);
byte[][] byteArray = new byte[][] {new byte[1]};
assertThat(stringTensor.dataTypeOf(byteArray)).isEqualTo(DataType.STRING);
assertThat(stringTensor.shape()).isEqualTo(new int[] {1});
}
} }