Remove the mlir::TF::UintN types in favor of using the standard IntegerType.
The standard integer type is now able to represent signedness, which makes these custom types unnecessary. PiperOrigin-RevId: 298708764 Change-Id: If5a642bc522c9c48a860fbda30ad2dd1807d9307
This commit is contained in:
parent
a8ac8f901a
commit
750e9df721
tensorflow/compiler/mlir
lite
tensorflow
ir
tests
utils
xla/tests
@ -179,8 +179,6 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
return tflite::TensorType_FLOAT16;
|
||||
case mlir::TF::TensorFlowTypes::STRING:
|
||||
return tflite::TensorType_STRING;
|
||||
case mlir::TF::TensorFlowTypes::UINT8:
|
||||
return tflite::TensorType_UINT8;
|
||||
case mlir::TF::TensorFlowTypes::QUINT8:
|
||||
return tflite::TensorType_UINT8;
|
||||
case mlir::StandardTypes::Complex: {
|
||||
@ -196,7 +194,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
case 1:
|
||||
return tflite::TensorType_BOOL;
|
||||
case 8:
|
||||
return tflite::TensorType_INT8;
|
||||
return itype.isUnsigned() ? tflite::TensorType_UINT8
|
||||
: tflite::TensorType_INT8;
|
||||
case 16:
|
||||
return tflite::TensorType_INT16;
|
||||
case 32:
|
||||
|
@ -47,13 +47,6 @@ def TFL_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
|
||||
"TFLite string type">,
|
||||
BuildableType<"getType<mlir::TF::StringType>()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFLite dialect uint8 type - uses the TF uint8 type as implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
def TFL_Uint8 : Type<CPred<"$_self.isa<mlir::TF::Uint8Type>()">,
|
||||
"TFLite uint8 type">,
|
||||
BuildableType<"getType<mlir::TF::Uint8Type>()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFLite dialect quint8 type - uses the TF quint8 type as implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -141,6 +134,7 @@ class TFL_VariadicTensorOf<list<Type> allowedRuntimeTypes,
|
||||
Variadic<TensorOf<allowedOpTypes>>,
|
||||
TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;
|
||||
|
||||
def TFL_Uint8 : UI<8>;
|
||||
def TFL_Int32Or64 : IntOfWidths<[32, 64]>;
|
||||
|
||||
def TFL_BoolTensor : TFL_TensorOf<[I1]>;
|
||||
@ -223,9 +217,9 @@ class TFL_Operand0DOr1ElementTensor<int x> :
|
||||
class TFL_TFTypesWithSameBits<int i, int j, int num> :
|
||||
And<[
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa<mlir::TF::Uint" # num # "Type>()">]>,
|
||||
CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isUnsignedInteger(" # num # ")">]>,
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Uint" # num # "Type>()">]>]>;
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||
|
||||
class TFL_OperandHasRankLessThan<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
|
||||
|
@ -61,11 +61,11 @@ func @i64() -> tensor<4xi64> {
|
||||
// the same sort of opaque round-trip we get for complex64, but it might be good
|
||||
// to check
|
||||
|
||||
func @uint8() -> tensor<4x!tf.uint8> {
|
||||
func @uint8() -> tensor<4xui8> {
|
||||
// CHECK-LABEL: @uint8
|
||||
// CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8>
|
||||
%0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8> } : () -> tensor<4x!tf.uint8>
|
||||
return %0 : tensor<4x!tf.uint8>
|
||||
// CHECK: value = dense<[222, 173, 190, 239]> : tensor<4xui8>
|
||||
%0 = "tfl.pseudo_const"() {value = dense<[222, 173, 190, 239]> : tensor<4xui8>} : () -> tensor<4xui8>
|
||||
return %0 : tensor<4xui8>
|
||||
}
|
||||
|
||||
func @qi32_per_axis() -> tensor<3x3x!quant.uniform<i32:f32:1, {1.0, 0.5:1, 0.25:1}>> {
|
||||
|
@ -1141,8 +1141,8 @@ func @testStridedSliceWithQUI8(%arg0: tensor<12x2x2x5x!quant.uniform<u8:f32, 0.1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testStridedSliceTFType
|
||||
func @testStridedSliceTFType(%arg0: tensor<12x2x2x5x!tf.uint8>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8> {
|
||||
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.uint8>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8>
|
||||
func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8> {
|
||||
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xui8>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8>
|
||||
return %0 : tensor<1x2x2x5x!tf.quint8>
|
||||
}
|
||||
|
||||
|
@ -38,7 +38,7 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
|
||||
case tflite::TensorType_INT32:
|
||||
return builder.getIntegerType(32);
|
||||
case tflite::TensorType_UINT8:
|
||||
return mlir::TF::Uint8Type::get(builder.getContext());
|
||||
return builder.getIntegerType(8, /*isSigned=*/false);
|
||||
case tflite::TensorType_INT64:
|
||||
return builder.getIntegerType(64);
|
||||
case tflite::TensorType_STRING:
|
||||
|
@ -86,6 +86,7 @@ class TF_TensorFlowType <string name, string description> :
|
||||
// Any tensor element type allowed in TensorFlow ops
|
||||
def TF_ElementType : Type<Or<[AnyFloat.predicate,
|
||||
AnySignlessInteger.predicate,
|
||||
AnyUnsignedInteger.predicate,
|
||||
AnyComplex.predicate,
|
||||
TF_TFDialectType.predicate]>,
|
||||
"tf.dtype">;
|
||||
@ -100,13 +101,13 @@ def TF_I32Or64 : IntOfWidths<[32, 64]>;
|
||||
|
||||
def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
|
||||
|
||||
def TF_Uint8 : TF_TensorFlowType<"Uint8", "uint8">;
|
||||
def TF_Uint16 : TF_TensorFlowType<"Uint16", "uint16">;
|
||||
def TF_Uint32 : TF_TensorFlowType<"Uint32", "uint32">;
|
||||
def TF_Uint64 : TF_TensorFlowType<"Uint64", "uint64">;
|
||||
def TF_Uint8 : UI<8>;
|
||||
def TF_Uint16 : UI<16>;
|
||||
def TF_Uint32 : UI<32>;
|
||||
def TF_Uint64 : UI<64>;
|
||||
|
||||
// Any unsigned integer type
|
||||
def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>;
|
||||
def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
|
||||
|
||||
// Any signed integer type
|
||||
def TF_SInt : IntOfWidths<[8, 16, 32, 64]>;
|
||||
|
@ -77,13 +77,17 @@ TensorFlowType TensorFlowRefType::get(Type type) {
|
||||
case 1:
|
||||
return BoolRefType::get(ctx);
|
||||
case 8:
|
||||
return Int8RefType::get(ctx);
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
|
||||
: Int8RefType::get(ctx);
|
||||
case 16:
|
||||
return Int16RefType::get(ctx);
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
|
||||
: Int16RefType::get(ctx);
|
||||
case 32:
|
||||
return Int32RefType::get(ctx);
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
|
||||
: Int32RefType::get(ctx);
|
||||
case 64:
|
||||
return Int64RefType::get(ctx);
|
||||
return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
|
||||
: Int64RefType::get(ctx);
|
||||
default:
|
||||
llvm_unreachable("unexpected integer type");
|
||||
}
|
||||
@ -121,6 +125,14 @@ Type TensorFlowRefType::RemoveRef() {
|
||||
return mlir::IntegerType::get(32, ctx);
|
||||
case TensorFlowTypes::INT64_REF:
|
||||
return mlir::IntegerType::get(64, ctx);
|
||||
case TensorFlowTypes::UINT8_REF:
|
||||
return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
|
||||
case TensorFlowTypes::UINT16_REF:
|
||||
return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
|
||||
case TensorFlowTypes::UINT32_REF:
|
||||
return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
|
||||
case TensorFlowTypes::UINT64_REF:
|
||||
return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
|
||||
case TensorFlowTypes::COMPLEX64_REF:
|
||||
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
|
||||
case TensorFlowTypes::COMPLEX128_REF:
|
||||
|
@ -19,10 +19,6 @@ limitations under the License.
|
||||
#ifdef HANDLE_TF_TYPE
|
||||
|
||||
// class, enumerant, name
|
||||
HANDLE_TF_TYPE(Uint8, UINT8, "uint8")
|
||||
HANDLE_TF_TYPE(Uint16, UINT16, "uint16")
|
||||
HANDLE_TF_TYPE(Uint32, UINT32, "uint32")
|
||||
HANDLE_TF_TYPE(Uint64, UINT64, "uint64")
|
||||
HANDLE_TF_TYPE(Qint8, QINT8, "qint8")
|
||||
HANDLE_TF_TYPE(Qint16, QINT16, "qint16")
|
||||
HANDLE_TF_TYPE(Qint32, QINT32, "qint32")
|
||||
|
@ -91,7 +91,7 @@ class TensorFlowType : public Type {
|
||||
// Returns true if the specified type is a valid TensorFlow element type.
|
||||
static inline bool IsValidTFElementType(Type type) {
|
||||
return type.isa<ComplexType>() || type.isa<FloatType>() ||
|
||||
type.isSignlessInteger() || type.isa<TensorFlowType>();
|
||||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
|
||||
}
|
||||
|
||||
// Returns true if this is a valid TensorFlow tensor type.
|
||||
|
@ -107,5 +107,5 @@ versions {
|
||||
# CHECK: "tf.PartitionedCall"()
|
||||
# CHECK-SAME: Tout = ["tfdtype$DT_UINT8"]
|
||||
# CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
|
||||
# CHECK: func @[[FUNCTION]]() -> tensor<!tf.uint8>
|
||||
# CHECK: return {{.*}} : tensor<!tf.uint8>
|
||||
# CHECK: func @[[FUNCTION]]() -> tensor<ui8>
|
||||
# CHECK: return {{.*}} : tensor<ui8>
|
||||
|
@ -392,12 +392,12 @@ func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x
|
||||
|
||||
// Verify that custom types are lowered and have legal output.
|
||||
// CHECK-LABEL: func @DynamicStitch_uint8
|
||||
func @DynamicStitch_uint8(%arg0: tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8> {
|
||||
func @DynamicStitch_uint8(%arg0: tensor<2x2xui8>) -> tensor<2x2xui8> {
|
||||
// CHECK-NOT: tf.DynamicStitch
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8>
|
||||
return %0 : tensor<2x2x!tf.uint8>
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xui8>) -> tensor<2x2xui8>
|
||||
return %0 : tensor<2x2xui8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @DynamicStitch_scalar_item
|
||||
|
@ -66,17 +66,17 @@ func @testIdentity(%arg0: tensor<4x2x!tf.stringref>) -> tensor<4x2x!tf.string> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @testBitcast
|
||||
func @testBitcast(%arg0: tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16> {
|
||||
%0 = "tf.Bitcast"(%arg0) : (tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16>
|
||||
func @testBitcast(%arg0: tensor<3x4xui16>) -> tensor<3x4x!tf.quint16> {
|
||||
%0 = "tf.Bitcast"(%arg0) : (tensor<3x4xui16>) -> tensor<3x4x!tf.quint16>
|
||||
return %0 : tensor<3x4x!tf.quint16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @testReverseV2
|
||||
func @testReverseV2(%arg0: tensor<2x4x3x!tf.uint8>, %arg1: tensor<1xi32>) -> tensor<2x4x3x!tf.uint8> {
|
||||
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<2x4x3x!tf.uint8>, tensor<1xi32>) -> tensor<2x4x3x!tf.uint8>
|
||||
return %0 : tensor<2x4x3x!tf.uint8>
|
||||
func @testReverseV2(%arg0: tensor<2x4x3xui8>, %arg1: tensor<1xi32>) -> tensor<2x4x3xui8> {
|
||||
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<2x4x3xui8>, tensor<1xi32>) -> tensor<2x4x3xui8>
|
||||
return %0 : tensor<2x4x3xui8>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -210,9 +210,9 @@ func @testLeakyWrongAlphaType(tensor<16xf32>) -> tensor<16xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @testMul
|
||||
func @testMul(%arg0: tensor<2x!tf.uint16>) -> (tensor<2x!tf.uint16>) {
|
||||
%0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2x!tf.uint16>, tensor<2x!tf.uint16>) -> tensor<2x!tf.uint16>
|
||||
return %0 : tensor<2x!tf.uint16>
|
||||
func @testMul(%arg0: tensor<2xui16>) -> (tensor<2xui16>) {
|
||||
%0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2xui16>, tensor<2xui16>) -> tensor<2xui16>
|
||||
return %0 : tensor<2xui16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -57,6 +57,18 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
|
||||
case DT_INT64:
|
||||
*type = builder.getIntegerType(64);
|
||||
return Status::OK();
|
||||
case DT_UINT8:
|
||||
*type = builder.getIntegerType(8, /*isSigned=*/false);
|
||||
return Status::OK();
|
||||
case DT_UINT16:
|
||||
*type = builder.getIntegerType(16, /*isSigned=*/false);
|
||||
return Status::OK();
|
||||
case DT_UINT32:
|
||||
*type = builder.getIntegerType(32, /*isSigned=*/false);
|
||||
return Status::OK();
|
||||
case DT_UINT64:
|
||||
*type = builder.getIntegerType(64, /*isSigned=*/false);
|
||||
return Status::OK();
|
||||
case DT_BFLOAT16:
|
||||
*type = builder.getBF16Type();
|
||||
return Status::OK();
|
||||
@ -99,16 +111,16 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
|
||||
*dtype = DT_BOOL;
|
||||
return Status::OK();
|
||||
case 8:
|
||||
*dtype = DT_INT8;
|
||||
*dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8;
|
||||
return Status::OK();
|
||||
case 16:
|
||||
*dtype = DT_INT16;
|
||||
*dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16;
|
||||
return Status::OK();
|
||||
case 32:
|
||||
*dtype = DT_INT32;
|
||||
*dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32;
|
||||
return Status::OK();
|
||||
case 64:
|
||||
*dtype = DT_INT64;
|
||||
*dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64;
|
||||
return Status::OK();
|
||||
default:
|
||||
return errors::Unimplemented(
|
||||
|
@ -523,17 +523,17 @@ func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> ten
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shift_right_unsigned
|
||||
func @shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8> {
|
||||
func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> {
|
||||
// CHECK: tf.RightShift
|
||||
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8>
|
||||
return %0 : tensor<4x!tf.uint8>
|
||||
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8>
|
||||
return %0 : tensor<4xui8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_shift_right_unsigned
|
||||
func @broadcast_shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8> {
|
||||
func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> {
|
||||
// CHECK: tf.RightShift
|
||||
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8>
|
||||
return %0 : tensor<2x4x!tf.uint8>
|
||||
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8>
|
||||
return %0 : tensor<2x4xui8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @and
|
||||
|
Loading…
Reference in New Issue
Block a user