Merge pull request #36397 from freedomtan:enable_int8_for_tflite_java_binding
PiperOrigin-RevId: 294933944 Change-Id: I2e852573e689b081de50f558e66eb914727a5524
This commit is contained in:
commit
9533ea1ca9
tensorflow/lite/java/src
main
test/java/org/tensorflow/lite
@ -30,7 +30,10 @@ public enum DataType {
|
||||
INT64(4),
|
||||
|
||||
/** Strings. */
|
||||
STRING(5);
|
||||
STRING(5),
|
||||
|
||||
/** 8-bit signed integer. */
|
||||
INT8(9);
|
||||
|
||||
private final int value;
|
||||
|
||||
@ -45,6 +48,7 @@ public enum DataType {
|
||||
return 4;
|
||||
case INT32:
|
||||
return 4;
|
||||
case INT8:
|
||||
case UINT8:
|
||||
return 1;
|
||||
case INT64:
|
||||
@ -83,6 +87,7 @@ public enum DataType {
|
||||
return "float";
|
||||
case INT32:
|
||||
return "int";
|
||||
case INT8:
|
||||
case UINT8:
|
||||
return "byte";
|
||||
case INT64:
|
||||
|
@ -311,7 +311,13 @@ public final class Tensor {
|
||||
return;
|
||||
}
|
||||
DataType oType = dataTypeOf(o);
|
||||
|
||||
if (oType != dtype) {
|
||||
// INT8 and UINT8 have the same string name, "byte"
|
||||
if (oType.toStringName().equals(dtype.toStringName())) {
|
||||
return;
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Cannot convert between a TensorFlowLite tensor with type %s and a Java "
|
||||
|
@ -126,6 +126,7 @@ size_t WriteOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
|
||||
env->GetLongArrayRegion(long_array, 0, num_elements, long_dst);
|
||||
return to_copy;
|
||||
}
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteUInt8: {
|
||||
jbyteArray byte_array = static_cast<jbyteArray>(array);
|
||||
jbyte* byte_dst = static_cast<jbyte*>(dst);
|
||||
@ -174,6 +175,7 @@ size_t ReadOneDimensionalArray(JNIEnv* env, TfLiteType data_type,
|
||||
static_cast<const jlong*>(src));
|
||||
return size;
|
||||
}
|
||||
case kTfLiteInt8:
|
||||
case kTfLiteUInt8: {
|
||||
jbyteArray byte_array = static_cast<jbyteArray>(dst);
|
||||
env->SetByteArrayRegion(byte_array, 0, len,
|
||||
|
@ -39,4 +39,11 @@ public final class DataTypeTest {
|
||||
assertThat(DataType.fromC(dataType.c())).isEqualTo(dataType);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testINT8AndUINT8() {
|
||||
assertThat(DataType.INT8.toStringName()).isEqualTo("byte");
|
||||
assertThat(DataType.UINT8.toStringName()).isEqualTo("byte");
|
||||
assertThat(DataType.INT8.toStringName()).isEqualTo(DataType.UINT8.toStringName());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user