Merge pull request from freedomtan:enable_int8_for_tflite_java_binding

PiperOrigin-RevId: 294933944
Change-Id: I2e852573e689b081de50f558e66eb914727a5524
This commit is contained in:
TensorFlower Gardener 2020-02-13 09:41:50 -08:00
commit 9533ea1ca9
4 changed files with 21 additions and 1 deletions
tensorflow/lite/java/src
main
java/org/tensorflow/lite
native
test/java/org/tensorflow/lite

View File

@ -30,7 +30,10 @@ public enum DataType {
INT64(4),
/** Strings. */
STRING(5);
STRING(5),
/** 8-bit signed integer. */
INT8(9);
private final int value;
@ -45,6 +48,7 @@ public enum DataType {
return 4;
case INT32:
return 4;
case INT8:
case UINT8:
return 1;
case INT64:
@ -83,6 +87,7 @@ public enum DataType {
return "float";
case INT32:
return "int";
case INT8:
case UINT8:
return "byte";
case INT64:

View File

@ -311,7 +311,13 @@ public final class Tensor {
return;
}
DataType oType = dataTypeOf(o);
if (oType != dtype) {
// INT8 and UINT8 have the same string name, "byte"
if (oType.toStringName().equals(dtype.toStringName())) {
return;
}
throw new IllegalArgumentException(
String.format(
"Cannot convert between a TensorFlowLite tensor with type %s and a Java "

View File

@ -126,6 +126,7 @@ size_t WriteOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
env->GetLongArrayRegion(long_array, 0, num_elements, long_dst);
return to_copy;
}
case kTfLiteInt8:
case kTfLiteUInt8: {
jbyteArray byte_array = static_cast<jbyteArray>(array);
jbyte* byte_dst = static_cast<jbyte*>(dst);
@ -174,6 +175,7 @@ size_t ReadOneDimensionalArray(JNIEnv* env, TfLiteType data_type,
static_cast<const jlong*>(src));
return size;
}
case kTfLiteInt8:
case kTfLiteUInt8: {
jbyteArray byte_array = static_cast<jbyteArray>(dst);
env->SetByteArrayRegion(byte_array, 0, len,

View File

@ -39,4 +39,11 @@ public final class DataTypeTest {
assertThat(DataType.fromC(dataType.c())).isEqualTo(dataType);
}
}
@Test
public void testINT8AndUINT8() {
assertThat(DataType.INT8.toStringName()).isEqualTo("byte");
assertThat(DataType.UINT8.toStringName()).isEqualTo("byte");
assertThat(DataType.INT8.toStringName()).isEqualTo(DataType.UINT8.toStringName());
}
}