Make string tensor accept byte array inputs for ParseExample op in Java API
PiperOrigin-RevId: 322446486 Change-Id: I1fec0d4f1dd645972fbf04f0647c07731cfc87fa
This commit is contained in:
parent
e48b48ffd2
commit
cc532b863f
@ -325,6 +325,7 @@ java_test(
|
|||||||
"src/testdata/int32.bin",
|
"src/testdata/int32.bin",
|
||||||
"src/testdata/int64.bin",
|
"src/testdata/int64.bin",
|
||||||
"src/testdata/quantized.bin",
|
"src/testdata/quantized.bin",
|
||||||
|
"src/testdata/string.bin",
|
||||||
],
|
],
|
||||||
javacopts = JAVACOPTS,
|
javacopts = JAVACOPTS,
|
||||||
tags = [
|
tags = [
|
||||||
|
@ -302,7 +302,7 @@ public final class Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Returns the type of the data. */
|
/** Returns the type of the data. */
|
||||||
static DataType dataTypeOf(Object o) {
|
DataType dataTypeOf(Object o) {
|
||||||
if (o != null) {
|
if (o != null) {
|
||||||
Class<?> c = o.getClass();
|
Class<?> c = o.getClass();
|
||||||
// For arrays, the data elements must be a *primitive* type, e.g., an
|
// For arrays, the data elements must be a *primitive* type, e.g., an
|
||||||
@ -316,6 +316,10 @@ public final class Tensor {
|
|||||||
} else if (int.class.equals(c)) {
|
} else if (int.class.equals(c)) {
|
||||||
return DataType.INT32;
|
return DataType.INT32;
|
||||||
} else if (byte.class.equals(c)) {
|
} else if (byte.class.equals(c)) {
|
||||||
|
// Byte array can be used for storing string tensors, especially for ParseExample op.
|
||||||
|
if (dtype == DataType.STRING) {
|
||||||
|
return DataType.STRING;
|
||||||
|
}
|
||||||
return DataType.UINT8;
|
return DataType.UINT8;
|
||||||
} else if (long.class.equals(c)) {
|
} else if (long.class.equals(c)) {
|
||||||
return DataType.INT64;
|
return DataType.INT64;
|
||||||
@ -345,8 +349,21 @@ public final class Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Returns the shape of an object as an int array. */
|
/** Returns the shape of an object as an int array. */
|
||||||
static int[] computeShapeOf(Object o) {
|
int[] computeShapeOf(Object o) {
|
||||||
int size = computeNumDimensions(o);
|
int size = computeNumDimensions(o);
|
||||||
|
if (dtype == DataType.STRING) {
|
||||||
|
Class<?> c = o.getClass();
|
||||||
|
if (c.isArray()) {
|
||||||
|
while (c.isArray()) {
|
||||||
|
c = c.getComponentType();
|
||||||
|
}
|
||||||
|
// If the given string data is stored in byte streams, the last array dimension should be
|
||||||
|
// treated as a value.
|
||||||
|
if (byte.class.equals(c)) {
|
||||||
|
--size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
int[] dimensions = new int[size];
|
int[] dimensions = new int[size];
|
||||||
fillShape(o, 0, dimensions);
|
fillShape(o, 0, dimensions);
|
||||||
return dimensions;
|
return dimensions;
|
||||||
|
@ -28,6 +28,9 @@ using tflite::jni::ThrowException;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static const char* kByteArrayClassPath = "[B";
|
||||||
|
static const char* kStringClassPath = "java/lang/String";
|
||||||
|
|
||||||
// Convenience handle for obtaining a TfLiteTensor given an interpreter and
|
// Convenience handle for obtaining a TfLiteTensor given an interpreter and
|
||||||
// tensor index.
|
// tensor index.
|
||||||
//
|
//
|
||||||
@ -271,13 +274,24 @@ size_t WriteMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddStringDynamicBuffer(JNIEnv* env, jstring src,
|
void AddStringDynamicBuffer(JNIEnv* env, jobject src,
|
||||||
tflite::DynamicBuffer* dst_buffer) {
|
tflite::DynamicBuffer* dst_buffer) {
|
||||||
const char* chars = env->GetStringUTFChars(src, nullptr);
|
if (env->IsInstanceOf(src, env->FindClass(kStringClassPath))) {
|
||||||
|
jstring str = static_cast<jstring>(src);
|
||||||
|
const char* chars = env->GetStringUTFChars(str, nullptr);
|
||||||
// + 1 for terminating character.
|
// + 1 for terminating character.
|
||||||
const int byte_len = env->GetStringUTFLength(src) + 1;
|
const int byte_len = env->GetStringUTFLength(str) + 1;
|
||||||
dst_buffer->AddString(chars, byte_len);
|
dst_buffer->AddString(chars, byte_len);
|
||||||
env->ReleaseStringUTFChars(src, chars);
|
env->ReleaseStringUTFChars(str, chars);
|
||||||
|
}
|
||||||
|
if (env->IsInstanceOf(src, env->FindClass(kByteArrayClassPath))) {
|
||||||
|
jbyteArray byte_array = static_cast<jbyteArray>(src);
|
||||||
|
jsize byte_array_length = env->GetArrayLength(byte_array);
|
||||||
|
jbyte* bytes = env->GetByteArrayElements(byte_array, nullptr);
|
||||||
|
dst_buffer->AddString(reinterpret_cast<const char*>(bytes),
|
||||||
|
byte_array_length);
|
||||||
|
env->ReleaseByteArrayElements(byte_array, bytes, JNI_ABORT);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
|
void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
|
||||||
@ -290,10 +304,9 @@ void PopulateStringDynamicBuffer(JNIEnv* env, jobject src,
|
|||||||
// recursively call populateStringDynamicBuffer over sub-dimensions.
|
// recursively call populateStringDynamicBuffer over sub-dimensions.
|
||||||
if (dims_left <= 1) {
|
if (dims_left <= 1) {
|
||||||
for (int i = 0; i < num_elements; ++i) {
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
jstring string_obj =
|
jobject obj = env->GetObjectArrayElement(object_array, i);
|
||||||
static_cast<jstring>(env->GetObjectArrayElement(object_array, i));
|
AddStringDynamicBuffer(env, obj, dst_buffer);
|
||||||
AddStringDynamicBuffer(env, string_obj, dst_buffer);
|
env->DeleteLocalRef(obj);
|
||||||
env->DeleteLocalRef(string_obj);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < num_elements; ++i) {
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
@ -358,7 +371,7 @@ void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst,
|
|||||||
|
|
||||||
void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) {
|
void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) {
|
||||||
tflite::DynamicBuffer dst_buffer;
|
tflite::DynamicBuffer dst_buffer;
|
||||||
AddStringDynamicBuffer(env, static_cast<jstring>(src), &dst_buffer);
|
AddStringDynamicBuffer(env, src, &dst_buffer);
|
||||||
if (!env->ExceptionCheck()) {
|
if (!env->ExceptionCheck()) {
|
||||||
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
|
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
|
||||||
}
|
}
|
||||||
|
@ -46,6 +46,9 @@ public final class TensorTest {
|
|||||||
private static final String LONG_MODEL_PATH =
|
private static final String LONG_MODEL_PATH =
|
||||||
"tensorflow/lite/java/src/testdata/int64.bin";
|
"tensorflow/lite/java/src/testdata/int64.bin";
|
||||||
|
|
||||||
|
private static final String STRING_MODEL_PATH =
|
||||||
|
"tensorflow/lite/java/src/testdata/string.bin";
|
||||||
|
|
||||||
private static final String QUANTIZED_MODEL_PATH =
|
private static final String QUANTIZED_MODEL_PATH =
|
||||||
"tensorflow/lite/java/src/testdata/quantized.bin";
|
"tensorflow/lite/java/src/testdata/quantized.bin";
|
||||||
|
|
||||||
@ -412,30 +415,30 @@ public final class TensorTest {
|
|||||||
@Test
|
@Test
|
||||||
public void testDataTypeOf() {
|
public void testDataTypeOf() {
|
||||||
float[] testEmptyArray = {};
|
float[] testEmptyArray = {};
|
||||||
DataType dataType = Tensor.dataTypeOf(testEmptyArray);
|
DataType dataType = tensor.dataTypeOf(testEmptyArray);
|
||||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||||
float[] testFloatArray = {0.783f, 0.251f};
|
float[] testFloatArray = {0.783f, 0.251f};
|
||||||
dataType = Tensor.dataTypeOf(testFloatArray);
|
dataType = tensor.dataTypeOf(testFloatArray);
|
||||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||||
float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
|
float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
|
||||||
dataType = Tensor.dataTypeOf(testMultiDimArray);
|
dataType = tensor.dataTypeOf(testMultiDimArray);
|
||||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||||
FloatBuffer testFloatBuffer = FloatBuffer.allocate(1);
|
FloatBuffer testFloatBuffer = FloatBuffer.allocate(1);
|
||||||
dataType = Tensor.dataTypeOf(testFloatBuffer);
|
dataType = tensor.dataTypeOf(testFloatBuffer);
|
||||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||||
float testFloat = 1.0f;
|
float testFloat = 1.0f;
|
||||||
dataType = Tensor.dataTypeOf(testFloat);
|
dataType = tensor.dataTypeOf(testFloat);
|
||||||
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
assertThat(dataType).isEqualTo(DataType.FLOAT32);
|
||||||
try {
|
try {
|
||||||
double[] testDoubleArray = {0.783, 0.251};
|
double[] testDoubleArray = {0.783, 0.251};
|
||||||
Tensor.dataTypeOf(testDoubleArray);
|
tensor.dataTypeOf(testDoubleArray);
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalArgumentException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
|
assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
Float[] testBoxedArray = {0.783f, 0.251f};
|
Float[] testBoxedArray = {0.783f, 0.251f};
|
||||||
Tensor.dataTypeOf(testBoxedArray);
|
tensor.dataTypeOf(testBoxedArray);
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalArgumentException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
|
assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
|
||||||
@ -528,4 +531,15 @@ public final class TensorTest {
|
|||||||
assertThat(scale).isWithin(1e-6f).of(0.25f);
|
assertThat(scale).isWithin(1e-6f).of(0.25f);
|
||||||
assertThat(zeroPoint).isEqualTo(127);
|
assertThat(zeroPoint).isEqualTo(127);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testByteArrayStringTensorInput() {
|
||||||
|
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(STRING_MODEL_PATH);
|
||||||
|
wrapper.resizeInput(0, new int[] {1});
|
||||||
|
Tensor stringTensor = wrapper.getInputTensor(0);
|
||||||
|
|
||||||
|
byte[][] byteArray = new byte[][] {new byte[1]};
|
||||||
|
assertThat(stringTensor.dataTypeOf(byteArray)).isEqualTo(DataType.STRING);
|
||||||
|
assertThat(stringTensor.shape()).isEqualTo(new int[] {1});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user