Introduce unsigned integer types for HLO dialect

Also,
* Define to/from type and ElementsAttr conversions
* Update HLO_AbsOp and HLO_SignOp type constraint to disallow unsigned integers. These two
  seems to be the only ops disallowing unsigned integers according to HLO shape inference.
* Restrict Complex type elements to be of f32 and f64

Similar to TensorFlow dialect, this still uses signless integer for signed integers which should be updated separated. Added a TODO for this.

PiperOrigin-RevId: 308360116
Change-Id: Ib4bea31fb5e284a1a7c407e032e75a09da179098
This commit is contained in:
Smit Hinsu 2020-04-24 17:53:18 -07:00 committed by TensorFlower Gardener
parent a07ca66517
commit b72a719b11
8 changed files with 76 additions and 21 deletions

View File

@ -131,6 +131,14 @@ StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
return CreateDenseAttrFromLiteral<int32>(type, literal); return CreateDenseAttrFromLiteral<int32>(type, literal);
case PrimitiveType::S64: case PrimitiveType::S64:
return CreateDenseAttrFromLiteral<int64>(type, literal); return CreateDenseAttrFromLiteral<int64>(type, literal);
case PrimitiveType::U8:
return CreateDenseAttrFromLiteral<uint8>(type, literal);
case PrimitiveType::U16:
return CreateDenseAttrFromLiteral<uint16>(type, literal);
case PrimitiveType::U32:
return CreateDenseAttrFromLiteral<uint32>(type, literal);
case PrimitiveType::U64:
return CreateDenseAttrFromLiteral<uint64>(type, literal);
default: default:
return tensorflow::errors::Internal( return tensorflow::errors::Internal(
absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type))); absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
@ -167,6 +175,14 @@ StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
return builder.getIntegerType(32); return builder.getIntegerType(32);
case PrimitiveType::S64: case PrimitiveType::S64:
return builder.getIntegerType(64); return builder.getIntegerType(64);
case PrimitiveType::U8:
return builder.getIntegerType(8, /*isSigned=*/false);
case PrimitiveType::U16:
return builder.getIntegerType(16, /*isSigned=*/false);
case PrimitiveType::U32:
return builder.getIntegerType(32, /*isSigned=*/false);
case PrimitiveType::U64:
return builder.getIntegerType(64, /*isSigned=*/false);
case PrimitiveType::C64: case PrimitiveType::C64:
return mlir::ComplexType::get(builder.getF32Type()); return mlir::ComplexType::get(builder.getF32Type());
case PrimitiveType::C128: case PrimitiveType::C128:

View File

@ -117,7 +117,8 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
// Abs supports complex to real, so element type is not guaranteed to match. // Abs supports complex to real, so element type is not guaranteed to match.
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
[NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_AbsOp { [NoSideEffect, SameOperandsAndResultShape],
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
let builders = [OpBuilder< let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value operand" "Builder *builder, OperationState &result, Value operand"
>]; >];
@ -194,7 +195,8 @@ def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
BASE_HLO_RsqrtOp; BASE_HLO_RsqrtOp;
def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", def HLO_SignOp: HLO_UnaryElementwiseOp<"sign",
[NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, [NoSideEffect, SameOperandsAndResultType],
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>,
BASE_HLO_SignOp; BASE_HLO_SignOp;
def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",

View File

@ -18,9 +18,16 @@ limitations under the License.
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def HLO_Int : SignlessIntOfWidths<[8, 16, 32, 64]>;
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">; def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
// TODO(hinsu): Use signed integers instead of signless integer which is being
// used for legacy reasons.
def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;
def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
// The broadcasting dimensions correspond to a tuple that describes how a // The broadcasting dimensions correspond to a tuple that describes how a
// smaller rank shape is broadcast into a larger rank shape. For example, // smaller rank shape is broadcast into a larger rank shape. For example,
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means // given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
@ -47,9 +54,9 @@ def HLO_FpTensor : TensorOf<[AnyFloat]>;
def HLO_PredTensor : TensorOf<[HLO_Pred]>; def HLO_PredTensor : TensorOf<[HLO_Pred]>;
def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
def HLO_ComplexTensor : TensorOf<[AnyComplex]>; def HLO_ComplexTensor : TensorOf<[HLO_Complex]>;
def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
@ -64,7 +71,7 @@ def HLO_DimensionTensor : ShapedContainerType<
// In general, static shaped tensor constraints should be avoided unless // In general, static shaped tensor constraints should be avoided unless
// it is for a legacy op which is only correct with static shapes. // it is for a legacy op which is only correct with static shapes.
def HLO_StaticShapeTensor : StaticShapeTensorOf<[ def HLO_StaticShapeTensor : StaticShapeTensorOf<[
AnyFloat, AnySignlessInteger, AnyComplex]>; AnyFloat, AnySignlessInteger, HLO_Complex]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA on tensors combined type definitions. // XLA on tensors combined type definitions.
@ -77,10 +84,10 @@ def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>;
def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
// Any floating-point or complex tensor types // Any floating-point or complex tensor types
def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, AnyComplex]>; def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>;
// Any int, floating-point or complex tensor types // Any int, floating-point or complex tensor types
def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, AnyComplex]>; def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
// Any pred, int or floating-point tensor types // Any pred, int or floating-point tensor types
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;

View File

@ -924,7 +924,6 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64)
// TODO(b/130356985): Update once MLIR supports unsigned integers.
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32) ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)

View File

@ -446,7 +446,7 @@ func @recv_non_token_second_result(%token: !xla_hlo.token) -> tuple<tensor<3x4xi
func @rng_uniform_invalid_type(%mu: tensor<complex<f32>>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> { func @rng_uniform_invalid_type(%mu: tensor<complex<f32>>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
// expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer or floating-point values, but got 'tensor<complex<f32>>'}} // expected-error@+1 {{but got 'tensor<complex<f32>>'}}
%0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor<complex<f32>>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32> %0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor<complex<f32>>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
return %0 : tensor<2x3x5xf32> return %0 : tensor<2x3x5xf32>
} }
@ -583,7 +583,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor
// ----- // -----
func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
// expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}} // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32> return %0 : tensor<1x4xi32>
} }
@ -599,7 +599,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta
// ----- // -----
func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
// expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}} // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
%0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
return %0 : tensor<3x4xi64> return %0 : tensor<3x4xi64>
} }
@ -778,7 +778,7 @@ func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// ----- // -----
func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer values, but got 'tensor<4xf32>'}} // expected-error@+1 {{but got 'tensor<4xf32>'}}
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }

View File

@ -285,11 +285,14 @@ func @main() {
// CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } }) // CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
%cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32> %cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
// CHECK: u32[2,2] constant({ { 1, 2 }, { 4, 8 } })
%cst_6 = constant dense<[[1, 2], [4, 8]]> : tensor<2x2xui32>
// CHECK: bf16[4] constant({1, 2, 3, 4}) // CHECK: bf16[4] constant({1, 2, 3, 4})
%cst_6 = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> %cst_7 = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
// CHECK: f16[4] constant({1, -4, -65504, 0.015625} // CHECK: f16[4] constant({1, -4, -65504, 0.015625}
%cst_7 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16> %cst_8 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16>
return return
} }
@ -1023,3 +1026,15 @@ func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (ten
// CHECK: %[[ARG1:.*]] = c128[2] parameter(1) // CHECK: %[[ARG1:.*]] = c128[2] parameter(1)
// CHECK: %[[ABS1:.*]] = f64[2] abs(c128[2] %[[ARG1]]) // CHECK: %[[ABS1:.*]] = f64[2] abs(c128[2] %[[ARG1]])
// CHECK: ROOT %[[RESULT:.*]] = (f32[2], f64[2]) tuple(f32[2] %[[ABS0]], f64[2] %[[ABS1]]) // CHECK: ROOT %[[RESULT:.*]] = (f32[2], f64[2]) tuple(f32[2] %[[ABS0]], f64[2] %[[ABS1]])
// -----
// CHECK: HloModule
func @main(%arg0: tensor<4xui8>) -> (tensor<4xui8>) {
%0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8>
return %0 : tensor<4xui8>
}
// CHECK: ENTRY
// CHECK: %[[ARG0:.*]] = u8[4] parameter(0)
// ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])

View File

@ -206,11 +206,16 @@ add {
// CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> // CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64>
%constant.2 = u64[4] constant({ 1, 2, 4, 8 })
// CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> // CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
%constant.2 = bf16[4] constant({1, 2, 3, 4}) %constant.3 = bf16[4] constant({1, 2, 3, 4})
// CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16> // CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
ROOT %constant.3 = f16[4] constant({1, -4, -65504, 0.015625}) ROOT %constant.4 = f16[4] constant({1, -4, -65504, 0.015625})
} }
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
@ -1009,3 +1014,12 @@ add {
// CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex<f64>>) -> tensor<2xf64> // CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4) ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4)
} }
// CHECK-LABEL: func @unsigned_int
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xui16>)
%unsigned_int(Arg_0.1: u16[4]) -> u16[4] {
%Arg_0.1 = u16[4] parameter(0)
// CHECK: "xla_hlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16>
ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1)
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -64,17 +65,18 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) {
return PrimitiveType::F64; return PrimitiveType::F64;
case mlir::StandardTypes::Integer: { case mlir::StandardTypes::Integer: {
const auto integer = type.cast<IntegerType>(); const auto integer = type.cast<IntegerType>();
bool is_unsigned = integer.isUnsigned();
switch (integer.getWidth()) { switch (integer.getWidth()) {
case 1: case 1:
return PrimitiveType::PRED; return PrimitiveType::PRED;
case 8: case 8:
return PrimitiveType::S8; return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8;
case 16: case 16:
return PrimitiveType::S16; return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16;
case 32: case 32:
return PrimitiveType::S32; return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32;
case 64: case 64:
return PrimitiveType::S64; return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64;
default: default:
return PrimitiveType::PRIMITIVE_TYPE_INVALID; return PrimitiveType::PRIMITIVE_TYPE_INVALID;
} }