Make string tensor accept byte array inputs for ParseExample op in Java API
PiperOrigin-RevId: 322446486 Change-Id: I1fec0d4f1dd645972fbf04f0647c07731cfc87fa
This commit is contained in:
parent
e48b48ffd2
commit
cc532b863f
@ -325,6 +325,7 @@ java_test(
|
||||
"src/testdata/int32.bin",
|
||||
"src/testdata/int64.bin",
|
||||
"src/testdata/quantized.bin",
|
||||
"src/testdata/string.bin",
|
||||
],
|
||||
javacopts = JAVACOPTS,
|
||||
tags = [
|
||||
|
@ -302,7 +302,7 @@ public final class Tensor {
|
||||
}
|
||||
|
||||
/** Returns the type of the data. */
|
||||
static DataType dataTypeOf(Object o) {
|
||||
DataType dataTypeOf(Object o) {
|
||||
if (o != null) {
|
||||
Class<?> c = o.getClass();
|
||||
// For arrays, the data elements must be a *primitive* type, e.g., an
|
||||
@ -316,6 +316,10 @@ public final class Tensor {
|
||||
} else if (int.class.equals(c)) {
|
||||
return DataType.INT32;
|
||||
} else if (byte.class.equals(c)) {
|
||||
// Byte array can be used for storing string tensors, especially for ParseExample op.
|
||||
if (dtype == DataType.STRING) {
|
||||
return DataType.STRING;
|
||||
}
|
||||
return DataType.UINT8;
|
||||
} else if (long.class.equals(c)) {
|
||||
return DataType.INT64;
|
||||
@ -345,8 +349,21 @@ public final class Tensor {
|
||||
}
|
||||
|
||||
/** Returns the shape of an object as an int array. */
|
||||
static int[] computeShapeOf(Object o) {
|
||||
int[] computeShapeOf(Object o) {
|
||||
int size = computeNumDimensions(o);
|
||||
if (dtype == DataType.STRING) {
|
||||
Class<?> c = o.getClass();
|
||||
if (c.isArray()) {
|
||||
while (c.isArray()) {
|
||||
c = c.getComponentType();
|
||||
}
|
||||
// If the given string data is stored in byte streams, the last array dimension should be
|
||||
// treated as a value.
|
||||
if (byte.class.equals(c)) {
|
||||
--size;
|
||||
}
|
||||
}
|
||||
}
|
||||
int[] dimensions = new int[size];
|
||||
fillShape(o, 0, dimensions);
|
||||
return dimensions;
|
||||
|
@ -28,6 +28,9 @@ using tflite::jni::ThrowException;
|
||||
|
||||
namespace {
|
||||
|
||||
static const char* kByteArrayClassPath = "[B";
|
||||
static const char* kStringClassPath = "java/lang/String";
|
||||
|
||||
// Convenience handle for obtaining a TfLiteTensor given an interpreter and
|
||||
// tensor index.
|
||||
//
|
||||
@ -271,13 +274,24 @@ size_t WriteMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
|
||||
}
|
||||
}
|
||||
|
||||
void AddStringDynamicBuffer(JNIEnv* env, jstring src,
|
||||
void AddStringDynamicBuffer(JNIEnv* env, jobject src,
|
||||
tflite::DynamicBuffer* dst_buffer) {
|
||||
const char* chars = env->GetStringUTFChars(src, nullptr);
|
||||
// + 1 for terminating character.
|
||||
const int byte_len = env->GetStringUTFLength(src) + 1;
|
||||
dst_buffer->AddString(chars, byte_len);
|
||||
env->ReleaseStringUTFChars(src, chars);
|
||||
if (env->IsInstanceOf(src, env->FindClass(kStringClassPath))) {
|
||||
jstring str = static_cast<jstring>(src);
|
||||
const char* chars = env->GetStringUTFChars(str, nullptr);
|
||||
// + 1 for terminating character.
|
||||
const int byte_len = env->GetStringUTFLength(str) + 1;
|
||||
dst_buffer->AddString(chars, byte_len);
|
||||
env->ReleaseStringUTFChars(str, chars);
|
||||
}
|
||||
if (env->IsInstanceOf(src, env->FindClass(kByteArrayClassPath))) {
|
||||
jbyteArray byte_array = static_cast<jbyteArray>(src);
|
||||
jsize byte_array_length = env->GetArrayLength(byte_array);
|
||||
jbyte* bytes = env->GetByteArrayElements(byte_array, nullptr);
|
||||
dst_buffer->AddString(reinterpret_cast<const char*>(bytes),
|
||||
byte_array_length);
|
||||
env->ReleaseByteArrayElements(byte_array, bytes, JNI_ABORT);
|
||||
}
|
||||
}
|
||||
|
||||
void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
|
||||
@ -290,10 +304,9 @@ void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
|
||||
// recursively call populateStringDynamicBuffer over sub-dimensions.
|
||||
if (dims_left <= 1) {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
jstring string_obj =
|
||||
static_cast<jstring>(env->GetObjectArrayElement(object_array, i));
|
||||
AddStringDynamicBuffer(env, string_obj, dst_buffer);
|
||||
env->DeleteLocalRef(string_obj);
|
||||
jobject obj = env->GetObjectArrayElement(object_array, i);
|
||||
AddStringDynamicBuffer(env, obj, dst_buffer);
|
||||
env->DeleteLocalRef(obj);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
@ -358,7 +371,7 @@ void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst,
|
||||
|
||||
void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) {
|
||||
tflite::DynamicBuffer dst_buffer;
|
||||
AddStringDynamicBuffer(env, static_cast<jstring>(src), &dst_buffer);
|
||||
AddStringDynamicBuffer(env, src, &dst_buffer);
|
||||
if (!env->ExceptionCheck()) {
|
||||
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
|
||||
}
|
||||
|
@ -46,6 +46,9 @@ public final class TensorTest {
|
||||
private static final String LONG_MODEL_PATH =
|
||||
"tensorflow/lite/java/src/testdata/int64.bin";
|
||||
|
||||
private static final String STRING_MODEL_PATH =
|
||||
"tensorflow/lite/java/src/testdata/string.bin";
|
||||
|
||||
private static final String QUANTIZED_MODEL_PATH =
|
||||
"tensorflow/lite/java/src/testdata/quantized.bin";
|
||||
|
||||
@ -412,30 +415,30 @@ public final class TensorTest {
|
||||
@Test
|
||||
public void testDataTypeOf() {
|
||||
float[] testEmptyArray = {};
|
||||
DataType dataType = Tensor.dataTypeOf(testEmptyArray);
|
||||
DataType dataType = tensor.dataTypeOf(testEmptyArray);
|
||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||
float[] testFloatArray = {0.783f, 0.251f};
|
||||
dataType = Tensor.dataTypeOf(testFloatArray);
|
||||
dataType = tensor.dataTypeOf(testFloatArray);
|
||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||
float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
|
||||
dataType = Tensor.dataTypeOf(testMultiDimArray);
|
||||
dataType = tensor.dataTypeOf(testMultiDimArray);
|
||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||
FloatBuffer testFloatBuffer = FloatBuffer.allocate(1);
|
||||
dataType = Tensor.dataTypeOf(testFloatBuffer);
|
||||
dataType = tensor.dataTypeOf(testFloatBuffer);
|
||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||
float testFloat = 1.0f;
|
||||
dataType = Tensor.dataTypeOf(testFloat);
|
||||
dataType = tensor.dataTypeOf(testFloat);
|
||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||
try {
|
||||
double[] testDoubleArray = {0.783, 0.251};
|
||||
Tensor.dataTypeOf(testDoubleArray);
|
||||
tensor.dataTypeOf(testDoubleArray);
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
|
||||
}
|
||||
try {
|
||||
Float[] testBoxedArray = {0.783f, 0.251f};
|
||||
Tensor.dataTypeOf(testBoxedArray);
|
||||
tensor.dataTypeOf(testBoxedArray);
|
||||
fail();
|
||||
} catch (IllegalArgumentException e) {
|
||||
assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
|
||||
@ -528,4 +531,15 @@ public final class TensorTest {
|
||||
assertThat(scale).isWithin(1e-6f).of(0.25f);
|
||||
assertThat(zeroPoint).isEqualTo(127);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testByteArrayStringTensorInput() {
|
||||
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(STRING_MODEL_PATH);
|
||||
wrapper.resizeInput(0, new int[] {1});
|
||||
Tensor stringTensor = wrapper.getInputTensor(0);
|
||||
|
||||
byte[][] byteArray = new byte[][] {new byte[1]};
|
||||
assertThat(stringTensor.dataTypeOf(byteArray)).isEqualTo(DataType.STRING);
|
||||
assertThat(stringTensor.shape()).isEqualTo(new int[] {1});
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user