Make string tensor accept byte array inputs for ParseExample op in Java API
PiperOrigin-RevId: 322446486 Change-Id: I1fec0d4f1dd645972fbf04f0647c07731cfc87fa
This commit is contained in:
		
							parent
							
								
									e48b48ffd2
								
							
						
					
					
						commit
						cc532b863f
					
				| @ -325,6 +325,7 @@ java_test( | ||||
|         "src/testdata/int32.bin", | ||||
|         "src/testdata/int64.bin", | ||||
|         "src/testdata/quantized.bin", | ||||
|         "src/testdata/string.bin", | ||||
|     ], | ||||
|     javacopts = JAVACOPTS, | ||||
|     tags = [ | ||||
|  | ||||
| @ -302,7 +302,7 @@ public final class Tensor { | ||||
|   } | ||||
| 
 | ||||
|   /** Returns the type of the data. */ | ||||
|   static DataType dataTypeOf(Object o) { | ||||
|   DataType dataTypeOf(Object o) { | ||||
|     if (o != null) { | ||||
|       Class<?> c = o.getClass(); | ||||
|       // For arrays, the data elements must be a *primitive* type, e.g., an | ||||
| @ -316,6 +316,10 @@ public final class Tensor { | ||||
|         } else if (int.class.equals(c)) { | ||||
|           return DataType.INT32; | ||||
|         } else if (byte.class.equals(c)) { | ||||
|           // Byte array can be used for storing string tensors, especially for ParseExample op. | ||||
|           if (dtype == DataType.STRING) { | ||||
|             return DataType.STRING; | ||||
|           } | ||||
|           return DataType.UINT8; | ||||
|         } else if (long.class.equals(c)) { | ||||
|           return DataType.INT64; | ||||
| @ -345,8 +349,21 @@ public final class Tensor { | ||||
|   } | ||||
| 
 | ||||
|   /** Returns the shape of an object as an int array. */ | ||||
|   static int[] computeShapeOf(Object o) { | ||||
|   int[] computeShapeOf(Object o) { | ||||
|     int size = computeNumDimensions(o); | ||||
|     if (dtype == DataType.STRING) { | ||||
|       Class<?> c = o.getClass(); | ||||
|       if (c.isArray()) { | ||||
|         while (c.isArray()) { | ||||
|           c = c.getComponentType(); | ||||
|         } | ||||
|         // If the given string data is stored in byte streams, the last array dimension should be | ||||
|         // treated as a value. | ||||
|         if (byte.class.equals(c)) { | ||||
|           --size; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     int[] dimensions = new int[size]; | ||||
|     fillShape(o, 0, dimensions); | ||||
|     return dimensions; | ||||
|  | ||||
| @ -28,6 +28,9 @@ using tflite::jni::ThrowException; | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| static const char* kByteArrayClassPath = "[B"; | ||||
| static const char* kStringClassPath = "java/lang/String"; | ||||
| 
 | ||||
| // Convenience handle for obtaining a TfLiteTensor given an interpreter and
 | ||||
| // tensor index.
 | ||||
| //
 | ||||
| @ -271,13 +274,24 @@ size_t WriteMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| void AddStringDynamicBuffer(JNIEnv* env, jstring src, | ||||
| void AddStringDynamicBuffer(JNIEnv* env, jobject src, | ||||
|                             tflite::DynamicBuffer* dst_buffer) { | ||||
|   const char* chars = env->GetStringUTFChars(src, nullptr); | ||||
|   if (env->IsInstanceOf(src, env->FindClass(kStringClassPath))) { | ||||
|     jstring str = static_cast<jstring>(src); | ||||
|     const char* chars = env->GetStringUTFChars(str, nullptr); | ||||
|     // + 1 for terminating character.
 | ||||
|   const int byte_len = env->GetStringUTFLength(src) + 1; | ||||
|     const int byte_len = env->GetStringUTFLength(str) + 1; | ||||
|     dst_buffer->AddString(chars, byte_len); | ||||
|   env->ReleaseStringUTFChars(src, chars); | ||||
|     env->ReleaseStringUTFChars(str, chars); | ||||
|   } | ||||
|   if (env->IsInstanceOf(src, env->FindClass(kByteArrayClassPath))) { | ||||
|     jbyteArray byte_array = static_cast<jbyteArray>(src); | ||||
|     jsize byte_array_length = env->GetArrayLength(byte_array); | ||||
|     jbyte* bytes = env->GetByteArrayElements(byte_array, nullptr); | ||||
|     dst_buffer->AddString(reinterpret_cast<const char*>(bytes), | ||||
|                           byte_array_length); | ||||
|     env->ReleaseByteArrayElements(byte_array, bytes, JNI_ABORT); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| void PopulateStringDynamicBuffer(JNIEnv* env, jobject src, | ||||
| @ -290,10 +304,9 @@ void PopulateStringDynamicBuffer(JNIEnv* env, jobject src, | ||||
|   // recursively call populateStringDynamicBuffer over sub-dimensions.
 | ||||
|   if (dims_left <= 1) { | ||||
|     for (int i = 0; i < num_elements; ++i) { | ||||
|       jstring string_obj = | ||||
|           static_cast<jstring>(env->GetObjectArrayElement(object_array, i)); | ||||
|       AddStringDynamicBuffer(env, string_obj, dst_buffer); | ||||
|       env->DeleteLocalRef(string_obj); | ||||
|       jobject obj = env->GetObjectArrayElement(object_array, i); | ||||
|       AddStringDynamicBuffer(env, obj, dst_buffer); | ||||
|       env->DeleteLocalRef(obj); | ||||
|     } | ||||
|   } else { | ||||
|     for (int i = 0; i < num_elements; ++i) { | ||||
| @ -358,7 +371,7 @@ void WriteScalar(JNIEnv* env, jobject src, TfLiteType type, void* dst, | ||||
| 
 | ||||
| void WriteScalarString(JNIEnv* env, jobject src, TfLiteTensor* tensor) { | ||||
|   tflite::DynamicBuffer dst_buffer; | ||||
|   AddStringDynamicBuffer(env, static_cast<jstring>(src), &dst_buffer); | ||||
|   AddStringDynamicBuffer(env, src, &dst_buffer); | ||||
|   if (!env->ExceptionCheck()) { | ||||
|     dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr); | ||||
|   } | ||||
|  | ||||
| @ -46,6 +46,9 @@ public final class TensorTest { | ||||
|   private static final String LONG_MODEL_PATH = | ||||
|       "tensorflow/lite/java/src/testdata/int64.bin"; | ||||
| 
 | ||||
|   private static final String STRING_MODEL_PATH = | ||||
|       "tensorflow/lite/java/src/testdata/string.bin"; | ||||
| 
 | ||||
|   private static final String QUANTIZED_MODEL_PATH = | ||||
|       "tensorflow/lite/java/src/testdata/quantized.bin"; | ||||
| 
 | ||||
| @ -412,30 +415,30 @@ public final class TensorTest { | ||||
|   @Test | ||||
|   public void testDataTypeOf() { | ||||
|     float[] testEmptyArray = {}; | ||||
|     DataType dataType = Tensor.dataTypeOf(testEmptyArray); | ||||
|     DataType dataType = tensor.dataTypeOf(testEmptyArray); | ||||
|     assertThat(dataType).isEqualTo(DataType.FLOAT32); | ||||
|     float[] testFloatArray = {0.783f, 0.251f}; | ||||
|     dataType = Tensor.dataTypeOf(testFloatArray); | ||||
|     dataType = tensor.dataTypeOf(testFloatArray); | ||||
|     assertThat(dataType).isEqualTo(DataType.FLOAT32); | ||||
|     float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; | ||||
|     dataType = Tensor.dataTypeOf(testMultiDimArray); | ||||
|     dataType = tensor.dataTypeOf(testMultiDimArray); | ||||
|     assertThat(dataType).isEqualTo(DataType.FLOAT32); | ||||
|     FloatBuffer testFloatBuffer = FloatBuffer.allocate(1); | ||||
|     dataType = Tensor.dataTypeOf(testFloatBuffer); | ||||
|     dataType = tensor.dataTypeOf(testFloatBuffer); | ||||
|     assertThat(dataType).isEqualTo(DataType.FLOAT32); | ||||
|     float testFloat = 1.0f; | ||||
|     dataType = Tensor.dataTypeOf(testFloat); | ||||
|     dataType = tensor.dataTypeOf(testFloat); | ||||
|     assertThat(dataType).isEqualTo(DataType.FLOAT32); | ||||
|     try { | ||||
|       double[] testDoubleArray = {0.783, 0.251}; | ||||
|       Tensor.dataTypeOf(testDoubleArray); | ||||
|       tensor.dataTypeOf(testDoubleArray); | ||||
|       fail(); | ||||
|     } catch (IllegalArgumentException e) { | ||||
|       assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); | ||||
|     } | ||||
|     try { | ||||
|       Float[] testBoxedArray = {0.783f, 0.251f}; | ||||
|       Tensor.dataTypeOf(testBoxedArray); | ||||
|       tensor.dataTypeOf(testBoxedArray); | ||||
|       fail(); | ||||
|     } catch (IllegalArgumentException e) { | ||||
|       assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); | ||||
| @ -528,4 +531,15 @@ public final class TensorTest { | ||||
|     assertThat(scale).isWithin(1e-6f).of(0.25f); | ||||
|     assertThat(zeroPoint).isEqualTo(127); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void testByteArrayStringTensorInput() { | ||||
|     NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(STRING_MODEL_PATH); | ||||
|     wrapper.resizeInput(0, new int[] {1}); | ||||
|     Tensor stringTensor = wrapper.getInputTensor(0); | ||||
| 
 | ||||
|     byte[][] byteArray = new byte[][] {new byte[1]}; | ||||
|     assertThat(stringTensor.dataTypeOf(byteArray)).isEqualTo(DataType.STRING); | ||||
|     assertThat(stringTensor.shape()).isEqualTo(new int[] {1}); | ||||
|   } | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user