Remove the mlir::TF::UintN types in favor of using the standard IntegerType.

The standard integer type is now able to represent signedness, which makes these custom types unnecessary.

PiperOrigin-RevId: 298708764
Change-Id: If5a642bc522c9c48a860fbda30ad2dd1807d9307
This commit is contained in:
River Riddle 2020-03-03 15:27:55 -08:00 committed by TensorFlower Gardener
parent a8ac8f901a
commit 750e9df721
14 changed files with 70 additions and 56 deletions

View File

@ -179,8 +179,6 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
return tflite::TensorType_FLOAT16;
case mlir::TF::TensorFlowTypes::STRING:
return tflite::TensorType_STRING;
case mlir::TF::TensorFlowTypes::UINT8:
return tflite::TensorType_UINT8;
case mlir::TF::TensorFlowTypes::QUINT8:
return tflite::TensorType_UINT8;
case mlir::StandardTypes::Complex: {
@ -196,7 +194,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
case 1:
return tflite::TensorType_BOOL;
case 8:
return tflite::TensorType_INT8;
return itype.isUnsigned() ? tflite::TensorType_UINT8
: tflite::TensorType_INT8;
case 16:
return tflite::TensorType_INT16;
case 32:

View File

@ -47,13 +47,6 @@ def TFL_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
"TFLite string type">,
BuildableType<"getType<mlir::TF::StringType>()">;
//===----------------------------------------------------------------------===//
// TFLite dialect uint8 type - uses the TF uint8 type as implementation
//===----------------------------------------------------------------------===//
def TFL_Uint8 : Type<CPred<"$_self.isa<mlir::TF::Uint8Type>()">,
"TFLite uint8 type">,
BuildableType<"getType<mlir::TF::Uint8Type>()">;
//===----------------------------------------------------------------------===//
// TFLite dialect quint8 type - uses the TF quint8 type as implementation
//===----------------------------------------------------------------------===//
@ -141,6 +134,7 @@ class TFL_VariadicTensorOf<list<Type> allowedRuntimeTypes,
Variadic<TensorOf<allowedOpTypes>>,
TFL_RuntimeType<Variadic<TensorOf<allowedRuntimeTypes>>>;
def TFL_Uint8 : UI<8>;
def TFL_Int32Or64 : IntOfWidths<[32, 64]>;
def TFL_BoolTensor : TFL_TensorOf<[I1]>;
@ -223,9 +217,9 @@ class TFL_Operand0DOr1ElementTensor<int x> :
class TFL_TFTypesWithSameBits<int i, int j, int num> :
And<[
Or<[CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa<mlir::TF::Quint" # num # "Type>()">,
CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa<mlir::TF::Uint" # num # "Type>()">]>,
CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isUnsignedInteger(" # num # ")">]>,
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Uint" # num # "Type>()">]>]>;
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
class TFL_OperandHasRankLessThan<int n, int m> :
PredOpTrait<"operand " # n # " is maximum " # m # "-D",

View File

@ -61,11 +61,11 @@ func @i64() -> tensor<4xi64> {
// the same sort of opaque round-trip we get for complex64, but it might be good
// to check
func @uint8() -> tensor<4x!tf.uint8> {
func @uint8() -> tensor<4xui8> {
// CHECK-LABEL: @uint8
// CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8>
%0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8> } : () -> tensor<4x!tf.uint8>
return %0 : tensor<4x!tf.uint8>
// CHECK: value = dense<[222, 173, 190, 239]> : tensor<4xui8>
%0 = "tfl.pseudo_const"() {value = dense<[222, 173, 190, 239]> : tensor<4xui8>} : () -> tensor<4xui8>
return %0 : tensor<4xui8>
}
func @qi32_per_axis() -> tensor<3x3x!quant.uniform<i32:f32:1, {1.0, 0.5:1, 0.25:1}>> {

View File

@ -1141,8 +1141,8 @@ func @testStridedSliceWithQUI8(%arg0: tensor<12x2x2x5x!quant.uniform<u8:f32, 0.1
}
// CHECK-LABEL: testStridedSliceTFType
func @testStridedSliceTFType(%arg0: tensor<12x2x2x5x!tf.uint8>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8> {
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.uint8>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8>
func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8> {
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xui8>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.quint8>
return %0 : tensor<1x2x2x5x!tf.quint8>
}

View File

@ -38,7 +38,7 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
case tflite::TensorType_INT32:
return builder.getIntegerType(32);
case tflite::TensorType_UINT8:
return mlir::TF::Uint8Type::get(builder.getContext());
return builder.getIntegerType(8, /*isSigned=*/false);
case tflite::TensorType_INT64:
return builder.getIntegerType(64);
case tflite::TensorType_STRING:

View File

@ -86,6 +86,7 @@ class TF_TensorFlowType <string name, string description> :
// Any tensor element type allowed in TensorFlow ops
def TF_ElementType : Type<Or<[AnyFloat.predicate,
AnySignlessInteger.predicate,
AnyUnsignedInteger.predicate,
AnyComplex.predicate,
TF_TFDialectType.predicate]>,
"tf.dtype">;
@ -100,13 +101,13 @@ def TF_I32Or64 : IntOfWidths<[32, 64]>;
def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
def TF_Uint8 : TF_TensorFlowType<"Uint8", "uint8">;
def TF_Uint16 : TF_TensorFlowType<"Uint16", "uint16">;
def TF_Uint32 : TF_TensorFlowType<"Uint32", "uint32">;
def TF_Uint64 : TF_TensorFlowType<"Uint64", "uint64">;
def TF_Uint8 : UI<8>;
def TF_Uint16 : UI<16>;
def TF_Uint32 : UI<32>;
def TF_Uint64 : UI<64>;
// Any unsigned integer type
def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>;
def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
// Any signed integer type
def TF_SInt : IntOfWidths<[8, 16, 32, 64]>;

View File

@ -77,13 +77,17 @@ TensorFlowType TensorFlowRefType::get(Type type) {
case 1:
return BoolRefType::get(ctx);
case 8:
return Int8RefType::get(ctx);
return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx))
: Int8RefType::get(ctx);
case 16:
return Int16RefType::get(ctx);
return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx))
: Int16RefType::get(ctx);
case 32:
return Int32RefType::get(ctx);
return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx))
: Int32RefType::get(ctx);
case 64:
return Int64RefType::get(ctx);
return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx))
: Int64RefType::get(ctx);
default:
llvm_unreachable("unexpected integer type");
}
@ -121,6 +125,14 @@ Type TensorFlowRefType::RemoveRef() {
return mlir::IntegerType::get(32, ctx);
case TensorFlowTypes::INT64_REF:
return mlir::IntegerType::get(64, ctx);
case TensorFlowTypes::UINT8_REF:
return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx);
case TensorFlowTypes::UINT16_REF:
return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx);
case TensorFlowTypes::UINT32_REF:
return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx);
case TensorFlowTypes::UINT64_REF:
return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx);
case TensorFlowTypes::COMPLEX64_REF:
return mlir::ComplexType::get(mlir::FloatType::getF32(ctx));
case TensorFlowTypes::COMPLEX128_REF:

View File

@ -19,10 +19,6 @@ limitations under the License.
#ifdef HANDLE_TF_TYPE
// class, enumerant, name
HANDLE_TF_TYPE(Uint8, UINT8, "uint8")
HANDLE_TF_TYPE(Uint16, UINT16, "uint16")
HANDLE_TF_TYPE(Uint32, UINT32, "uint32")
HANDLE_TF_TYPE(Uint64, UINT64, "uint64")
HANDLE_TF_TYPE(Qint8, QINT8, "qint8")
HANDLE_TF_TYPE(Qint16, QINT16, "qint16")
HANDLE_TF_TYPE(Qint32, QINT32, "qint32")

View File

@ -91,7 +91,7 @@ class TensorFlowType : public Type {
// Returns true if the specified type is a valid TensorFlow element type.
static inline bool IsValidTFElementType(Type type) {
return type.isa<ComplexType>() || type.isa<FloatType>() ||
type.isSignlessInteger() || type.isa<TensorFlowType>();
type.isa<IntegerType>() || type.isa<TensorFlowType>();
}
// Returns true if this is a valid TensorFlow tensor type.

View File

@ -107,5 +107,5 @@ versions {
# CHECK: "tf.PartitionedCall"()
# CHECK-SAME: Tout = ["tfdtype$DT_UINT8"]
# CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
# CHECK: func @[[FUNCTION]]() -> tensor<!tf.uint8>
# CHECK: return {{.*}} : tensor<!tf.uint8>
# CHECK: func @[[FUNCTION]]() -> tensor<ui8>
# CHECK: return {{.*}} : tensor<ui8>

View File

@ -392,12 +392,12 @@ func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x
// Verify that custom types are lowered and have legal output.
// CHECK-LABEL: func @DynamicStitch_uint8
func @DynamicStitch_uint8(%arg0: tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8> {
func @DynamicStitch_uint8(%arg0: tensor<2x2xui8>) -> tensor<2x2xui8> {
// CHECK-NOT: tf.DynamicStitch
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8>
return %0 : tensor<2x2x!tf.uint8>
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xui8>) -> tensor<2x2xui8>
return %0 : tensor<2x2xui8>
}
// CHECK-LABEL: func @DynamicStitch_scalar_item

View File

@ -66,17 +66,17 @@ func @testIdentity(%arg0: tensor<4x2x!tf.stringref>) -> tensor<4x2x!tf.string> {
// -----
// CHECK-LABEL: func @testBitcast
func @testBitcast(%arg0: tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16> {
%0 = "tf.Bitcast"(%arg0) : (tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16>
func @testBitcast(%arg0: tensor<3x4xui16>) -> tensor<3x4x!tf.quint16> {
%0 = "tf.Bitcast"(%arg0) : (tensor<3x4xui16>) -> tensor<3x4x!tf.quint16>
return %0 : tensor<3x4x!tf.quint16>
}
// -----
// CHECK-LABEL: func @testReverseV2
func @testReverseV2(%arg0: tensor<2x4x3x!tf.uint8>, %arg1: tensor<1xi32>) -> tensor<2x4x3x!tf.uint8> {
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<2x4x3x!tf.uint8>, tensor<1xi32>) -> tensor<2x4x3x!tf.uint8>
return %0 : tensor<2x4x3x!tf.uint8>
func @testReverseV2(%arg0: tensor<2x4x3xui8>, %arg1: tensor<1xi32>) -> tensor<2x4x3xui8> {
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<2x4x3xui8>, tensor<1xi32>) -> tensor<2x4x3xui8>
return %0 : tensor<2x4x3xui8>
}
// -----
@ -210,9 +210,9 @@ func @testLeakyWrongAlphaType(tensor<16xf32>) -> tensor<16xf32> {
// -----
// CHECK-LABEL: func @testMul
func @testMul(%arg0: tensor<2x!tf.uint16>) -> (tensor<2x!tf.uint16>) {
%0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2x!tf.uint16>, tensor<2x!tf.uint16>) -> tensor<2x!tf.uint16>
return %0 : tensor<2x!tf.uint16>
func @testMul(%arg0: tensor<2xui16>) -> (tensor<2xui16>) {
%0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2xui16>, tensor<2xui16>) -> tensor<2xui16>
return %0 : tensor<2xui16>
}
// -----

View File

@ -57,6 +57,18 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
case DT_INT64:
*type = builder.getIntegerType(64);
return Status::OK();
case DT_UINT8:
*type = builder.getIntegerType(8, /*isSigned=*/false);
return Status::OK();
case DT_UINT16:
*type = builder.getIntegerType(16, /*isSigned=*/false);
return Status::OK();
case DT_UINT32:
*type = builder.getIntegerType(32, /*isSigned=*/false);
return Status::OK();
case DT_UINT64:
*type = builder.getIntegerType(64, /*isSigned=*/false);
return Status::OK();
case DT_BFLOAT16:
*type = builder.getBF16Type();
return Status::OK();
@ -99,16 +111,16 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
*dtype = DT_BOOL;
return Status::OK();
case 8:
*dtype = DT_INT8;
*dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8;
return Status::OK();
case 16:
*dtype = DT_INT16;
*dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16;
return Status::OK();
case 32:
*dtype = DT_INT32;
*dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32;
return Status::OK();
case 64:
*dtype = DT_INT64;
*dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64;
return Status::OK();
default:
return errors::Unimplemented(

View File

@ -523,17 +523,17 @@ func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> ten
}
// CHECK-LABEL: func @shift_right_unsigned
func @shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8> {
func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> {
// CHECK: tf.RightShift
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<4x!tf.uint8>) -> tensor<4x!tf.uint8>
return %0 : tensor<4x!tf.uint8>
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8>
return %0 : tensor<4xui8>
}
// CHECK-LABEL: func @broadcast_shift_right_unsigned
func @broadcast_shift_right_unsigned(%arg0: tensor<4x!tf.uint8>, %arg1: tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8> {
func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> {
// CHECK: tf.RightShift
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4x!tf.uint8>, tensor<2x4x!tf.uint8>) -> tensor<2x4x!tf.uint8>
return %0 : tensor<2x4x!tf.uint8>
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8>
return %0 : tensor<2x4xui8>
}
// CHECK-LABEL: func @and