From d768d147870b202559878c610c366a0ac536a748 Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Sat, 1 Feb 2020 12:06:42 +0800 Subject: [PATCH 1/4] [tflite] enable INT8 for Java binding some models created by full-integer post training quantization, e.g., the mobilenet v3 edgetpu one [1], have INT8 input and output tensors. [1] https://storage.cloud.google.com/mobilenet_edgetpu/checkpoints/mobilenet_edgetpu_224_1.0.tgz --- .../java/src/main/java/org/tensorflow/lite/DataType.java | 7 ++++++- .../java/src/main/java/org/tensorflow/lite/Tensor.java | 5 +++++ tensorflow/lite/java/src/main/native/tensor_jni.cc | 2 ++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java index 407e01c6e17..527346c3c9b 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -30,7 +30,10 @@ public enum DataType { INT64(4), /** Strings. */ - STRING(5); + STRING(5), + + /** 8-bit signed integer. */ + INT8(9); private final int value; @@ -45,6 +48,7 @@ public enum DataType { return 4; case INT32: return 4; + case INT8: case UINT8: return 1; case INT64: @@ -83,6 +87,7 @@ public enum DataType { return "float"; case INT32: return "int"; + case INT8: case UINT8: return "byte"; case INT64: diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 68952ff6e49..cdca8d48f9c 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -311,6 +311,11 @@ public final class Tensor { return; } DataType oType = dataTypeOf(o); + + // INT8 and UINT8 have the same string name, "byte" + if (oType.toStringName() == dtype.toStringName()) { + return; + } if (oType != dtype) { throw new IllegalArgumentException( String.format( diff --git a/tensorflow/lite/java/src/main/native/tensor_jni.cc b/tensorflow/lite/java/src/main/native/tensor_jni.cc index f2cb1f81ab8..8beafa0c48e 100644 --- a/tensorflow/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/lite/java/src/main/native/tensor_jni.cc @@ -126,6 +126,7 @@ size_t WriteOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, env->GetLongArrayRegion(long_array, 0, num_elements, long_dst); return to_copy; } + case kTfLiteInt8: case kTfLiteUInt8: { jbyteArray byte_array = static_cast(array); jbyte* byte_dst = static_cast(dst); @@ -174,6 +175,7 @@ size_t ReadOneDimensionalArray(JNIEnv* env, TfLiteType data_type, static_cast(src)); return size; } + case kTfLiteInt8: case kTfLiteUInt8: { jbyteArray byte_array = static_cast(dst); env->SetByteArrayRegion(byte_array, 0, len, From 937ffc157f16677357699569d604bb3c76b9dce1 Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Tue, 11 Feb 2020 05:51:09 +0800 Subject: [PATCH 2/4] don't use == for string comparison --- .../lite/java/src/main/java/org/tensorflow/lite/Tensor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index cdca8d48f9c..ae415b9e948 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -313,7 +313,7 @@ public final class Tensor { DataType oType = dataTypeOf(o); // INT8 and UINT8 have the same string name, "byte" - if (oType.toStringName() == dtype.toStringName()) { + if (oType.toStringName().equals(dtype.toStringName())) { return; } if (oType != dtype) { From c78d9c6d5ce497130d752d5ac36dc2c25fd940eb Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Tue, 11 Feb 2020 08:50:58 +0800 Subject: [PATCH 3/4] avoid unnecessary int8 check --- .../java/src/main/java/org/tensorflow/lite/Tensor.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index ae415b9e948..8ed019dc3f1 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -312,11 +312,12 @@ public final class Tensor { } DataType oType = dataTypeOf(o); - // INT8 and UINT8 have the same string name, "byte" - if (oType.toStringName().equals(dtype.toStringName())) { - return; - } if (oType != dtype) { + // INT8 and UINT8 have the same string name, "byte" + if (oType.toStringName().equals(dtype.toStringName())) { + return; + } + throw new IllegalArgumentException( String.format( "Cannot convert between a TensorFlowLite tensor with type %s and a Java " From 81f764e5d173fdfcf93898cf3416a71b97519fc9 Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Wed, 12 Feb 2020 09:56:38 +0800 Subject: [PATCH 4/4] add unit test for equality of INT8 and UINT8 --- .../src/test/java/org/tensorflow/lite/DataTypeTest.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java index 8412ec0e9da..d1e9c03ddd6 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java @@ -39,4 +39,11 @@ public final class DataTypeTest { assertThat(DataType.fromC(dataType.c())).isEqualTo(dataType); } } + + @Test + public void testINT8AndUINT8() { + assertThat(DataType.INT8.toStringName()).isEqualTo("byte"); + assertThat(DataType.UINT8.toStringName()).isEqualTo("byte"); + assertThat(DataType.INT8.toStringName()).isEqualTo(DataType.UINT8.toStringName()); + } }