diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 19272f6e1f7..c685cc296fd 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -131,6 +131,14 @@ StatusOr CreateDenseElementsAttrFromLiteral( return CreateDenseAttrFromLiteral(type, literal); case PrimitiveType::S64: return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::U8: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::U16: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::U32: + return CreateDenseAttrFromLiteral(type, literal); + case PrimitiveType::U64: + return CreateDenseAttrFromLiteral(type, literal); default: return tensorflow::errors::Internal( absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type))); @@ -167,6 +175,14 @@ StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, return builder.getIntegerType(32); case PrimitiveType::S64: return builder.getIntegerType(64); + case PrimitiveType::U8: + return builder.getIntegerType(8, /*isSigned=*/false); + case PrimitiveType::U16: + return builder.getIntegerType(16, /*isSigned=*/false); + case PrimitiveType::U32: + return builder.getIntegerType(32, /*isSigned=*/false); + case PrimitiveType::U64: + return builder.getIntegerType(64, /*isSigned=*/false); case PrimitiveType::C64: return mlir::ComplexType::get(builder.getF32Type()); case PrimitiveType::C128: diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index ff1c74f1037..57f8fe51b18 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -117,7 +117,8 @@ class HLO_UnaryElementwiseOp traits, // Abs supports complex to real, so element type is not guaranteed to match. def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", - [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_AbsOp { + [NoSideEffect, SameOperandsAndResultShape], + TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { let builders = [OpBuilder< "Builder *builder, OperationState &result, Value operand" >]; @@ -194,7 +195,8 @@ def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", BASE_HLO_RsqrtOp; def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", - [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, + [NoSideEffect, SameOperandsAndResultType], + TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_SignOp; def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index 287ad1b4614..44e0abab031 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -18,9 +18,16 @@ limitations under the License. include "mlir/IR/OpBase.td" -def HLO_Int : SignlessIntOfWidths<[8, 16, 32, 64]>; def HLO_Pred : TypeAlias; +// TODO(hinsu): Use signed integers instead of signless integer which is being +// used for legacy reasons. +def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>; +def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; +def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; + +def HLO_Complex : Complex>; + // The broadcasting dimensions correspond to a tuple that describes how a // smaller rank shape is broadcast into a larger rank shape. For example, // given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means @@ -47,9 +54,9 @@ def HLO_FpTensor : TensorOf<[AnyFloat]>; def HLO_PredTensor : TensorOf<[HLO_Pred]>; -def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; +def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; -def HLO_ComplexTensor : TensorOf<[AnyComplex]>; +def HLO_ComplexTensor : TensorOf<[HLO_Complex]>; def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; @@ -64,7 +71,7 @@ def HLO_DimensionTensor : ShapedContainerType< // In general, static shaped tensor constraints should be avoided unless // it is for a legacy op which is only correct with static shapes. def HLO_StaticShapeTensor : StaticShapeTensorOf<[ - AnyFloat, AnySignlessInteger, AnyComplex]>; + AnyFloat, AnySignlessInteger, HLO_Complex]>; //===----------------------------------------------------------------------===// // XLA on tensors combined type definitions. @@ -77,10 +84,10 @@ def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>; def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; // Any floating-point or complex tensor types -def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, AnyComplex]>; +def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>; // Any int, floating-point or complex tensor types -def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, AnyComplex]>; +def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; // Any pred, int or floating-point tensor types def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 51193640235..d92e3d25343 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -924,7 +924,6 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64) - // TODO(b/130356985): Update once MLIR supports unsigned integers. ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32) diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 10ff5fa8f3c..4cdfe1c459d 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -446,7 +446,7 @@ func @recv_non_token_second_result(%token: !xla_hlo.token) -> tuple>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer or floating-point values, but got 'tensor>'}} + // expected-error@+1 {{but got 'tensor>'}} %0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> return %0 : tensor<2x3x5xf32> } @@ -583,7 +583,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor // ----- func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}} + // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> return %0 : tensor<1x4xi32> } @@ -599,7 +599,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta // ----- func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { - // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}} + // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> return %0 : tensor<3x4xi64> } @@ -778,7 +778,7 @@ func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // ----- func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer values, but got 'tensor<4xf32>'}} + // expected-error@+1 {{but got 'tensor<4xf32>'}} %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index ee8340f1c18..3650307ea94 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -285,11 +285,14 @@ func @main() { // CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } }) %cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32> + // CHECK: u32[2,2] constant({ { 1, 2 }, { 4, 8 } }) + %cst_6 = constant dense<[[1, 2], [4, 8]]> : tensor<2x2xui32> + // CHECK: bf16[4] constant({1, 2, 3, 4}) - %cst_6 = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> + %cst_7 = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> // CHECK: f16[4] constant({1, -4, -65504, 0.015625} - %cst_7 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16> + %cst_8 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16> return } @@ -1023,3 +1026,15 @@ func @main(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> (ten // CHECK: %[[ARG1:.*]] = c128[2] parameter(1) // CHECK: %[[ABS1:.*]] = f64[2] abs(c128[2] %[[ARG1]]) // CHECK: ROOT %[[RESULT:.*]] = (f32[2], f64[2]) tuple(f32[2] %[[ABS0]], f64[2] %[[ABS1]]) + +// ----- + +// CHECK: HloModule +func @main(%arg0: tensor<4xui8>) -> (tensor<4xui8>) { + %0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8> + return %0 : tensor<4xui8> +} + +// CHECK: ENTRY +// CHECK: %[[ARG0:.*]] = u8[4] parameter(0) +// ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 0de2826f3e7..d1133057544 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -206,11 +206,16 @@ add { // CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} + // CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64> + %constant.2 = u64[4] constant({ 1, 2, 4, 8 }) + // CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> - %constant.2 = bf16[4] constant({1, 2, 3, 4}) + %constant.3 = bf16[4] constant({1, 2, 3, 4}) // CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16> - ROOT %constant.3 = f16[4] constant({1, -4, -65504, 0.015625}) + ROOT %constant.4 = f16[4] constant({1, -4, -65504, 0.015625}) + + } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -1009,3 +1014,12 @@ add { // CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf64> ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4) } + +// CHECK-LABEL: func @unsigned_int +// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xui16>) +%unsigned_int(Arg_0.1: u16[4]) -> u16[4] { + %Arg_0.1 = u16[4] parameter(0) + + // CHECK: "xla_hlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16> + ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1) +} diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index 3b1ae934c48..9f144bb4a45 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -64,17 +65,18 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) { return PrimitiveType::F64; case mlir::StandardTypes::Integer: { const auto integer = type.cast(); + bool is_unsigned = integer.isUnsigned(); switch (integer.getWidth()) { case 1: return PrimitiveType::PRED; case 8: - return PrimitiveType::S8; + return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8; case 16: - return PrimitiveType::S16; + return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16; case 32: - return PrimitiveType::S32; + return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32; case 64: - return PrimitiveType::S64; + return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64; default: return PrimitiveType::PRIMITIVE_TYPE_INVALID; }