Introduce unsigned integer types for HLO dialect
Also, * Define to/from type and ElementsAttr conversions * Update HLO_AbsOp and HLO_SignOp type constraint to disallow unsigned integers. These two seems to be the only ops disallowing unsigned integers according to HLO shape inference. * Restrict Complex type elements to be of f32 and f64 Similar to TensorFlow dialect, this still uses signless integer for signed integers which should be updated separated. Added a TODO for this. PiperOrigin-RevId: 308360116 Change-Id: Ib4bea31fb5e284a1a7c407e032e75a09da179098
This commit is contained in:
parent
a07ca66517
commit
b72a719b11
@ -131,6 +131,14 @@ StatusOr<mlir::DenseElementsAttr> CreateDenseElementsAttrFromLiteral(
|
|||||||
return CreateDenseAttrFromLiteral<int32>(type, literal);
|
return CreateDenseAttrFromLiteral<int32>(type, literal);
|
||||||
case PrimitiveType::S64:
|
case PrimitiveType::S64:
|
||||||
return CreateDenseAttrFromLiteral<int64>(type, literal);
|
return CreateDenseAttrFromLiteral<int64>(type, literal);
|
||||||
|
case PrimitiveType::U8:
|
||||||
|
return CreateDenseAttrFromLiteral<uint8>(type, literal);
|
||||||
|
case PrimitiveType::U16:
|
||||||
|
return CreateDenseAttrFromLiteral<uint16>(type, literal);
|
||||||
|
case PrimitiveType::U32:
|
||||||
|
return CreateDenseAttrFromLiteral<uint32>(type, literal);
|
||||||
|
case PrimitiveType::U64:
|
||||||
|
return CreateDenseAttrFromLiteral<uint64>(type, literal);
|
||||||
default:
|
default:
|
||||||
return tensorflow::errors::Internal(
|
return tensorflow::errors::Internal(
|
||||||
absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
|
absl::StrCat("Unsupported type: ", PrimitiveType_Name(element_type)));
|
||||||
@ -167,6 +175,14 @@ StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
|
|||||||
return builder.getIntegerType(32);
|
return builder.getIntegerType(32);
|
||||||
case PrimitiveType::S64:
|
case PrimitiveType::S64:
|
||||||
return builder.getIntegerType(64);
|
return builder.getIntegerType(64);
|
||||||
|
case PrimitiveType::U8:
|
||||||
|
return builder.getIntegerType(8, /*isSigned=*/false);
|
||||||
|
case PrimitiveType::U16:
|
||||||
|
return builder.getIntegerType(16, /*isSigned=*/false);
|
||||||
|
case PrimitiveType::U32:
|
||||||
|
return builder.getIntegerType(32, /*isSigned=*/false);
|
||||||
|
case PrimitiveType::U64:
|
||||||
|
return builder.getIntegerType(64, /*isSigned=*/false);
|
||||||
case PrimitiveType::C64:
|
case PrimitiveType::C64:
|
||||||
return mlir::ComplexType::get(builder.getF32Type());
|
return mlir::ComplexType::get(builder.getF32Type());
|
||||||
case PrimitiveType::C128:
|
case PrimitiveType::C128:
|
||||||
|
@ -117,7 +117,8 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
|||||||
|
|
||||||
// Abs supports complex to real, so element type is not guaranteed to match.
|
// Abs supports complex to real, so element type is not guaranteed to match.
|
||||||
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
||||||
[NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_AbsOp {
|
[NoSideEffect, SameOperandsAndResultShape],
|
||||||
|
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp {
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
"Builder *builder, OperationState &result, Value operand"
|
"Builder *builder, OperationState &result, Value operand"
|
||||||
>];
|
>];
|
||||||
@ -194,7 +195,8 @@ def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt",
|
|||||||
BASE_HLO_RsqrtOp;
|
BASE_HLO_RsqrtOp;
|
||||||
|
|
||||||
def HLO_SignOp: HLO_UnaryElementwiseOp<"sign",
|
def HLO_SignOp: HLO_UnaryElementwiseOp<"sign",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>,
|
[NoSideEffect, SameOperandsAndResultType],
|
||||||
|
TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>,
|
||||||
BASE_HLO_SignOp;
|
BASE_HLO_SignOp;
|
||||||
|
|
||||||
def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
|
def HLO_SinOp: HLO_UnaryElementwiseOp<"sine",
|
||||||
|
@ -18,9 +18,16 @@ limitations under the License.
|
|||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
def HLO_Int : SignlessIntOfWidths<[8, 16, 32, 64]>;
|
|
||||||
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;
|
||||||
|
|
||||||
|
// TODO(hinsu): Use signed integers instead of signless integer which is being
|
||||||
|
// used for legacy reasons.
|
||||||
|
def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
|
||||||
|
def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
|
||||||
|
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;
|
||||||
|
|
||||||
|
def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
|
||||||
|
|
||||||
// The broadcasting dimensions correspond to a tuple that describes how a
|
// The broadcasting dimensions correspond to a tuple that describes how a
|
||||||
// smaller rank shape is broadcast into a larger rank shape. For example,
|
// smaller rank shape is broadcast into a larger rank shape. For example,
|
||||||
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
|
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
|
||||||
@ -47,9 +54,9 @@ def HLO_FpTensor : TensorOf<[AnyFloat]>;
|
|||||||
|
|
||||||
def HLO_PredTensor : TensorOf<[HLO_Pred]>;
|
def HLO_PredTensor : TensorOf<[HLO_Pred]>;
|
||||||
|
|
||||||
def HLO_Tensor : TensorOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
|
||||||
|
|
||||||
def HLO_ComplexTensor : TensorOf<[AnyComplex]>;
|
def HLO_ComplexTensor : TensorOf<[HLO_Complex]>;
|
||||||
|
|
||||||
def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
|
def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
|
||||||
|
|
||||||
@ -64,7 +71,7 @@ def HLO_DimensionTensor : ShapedContainerType<
|
|||||||
// In general, static shaped tensor constraints should be avoided unless
|
// In general, static shaped tensor constraints should be avoided unless
|
||||||
// it is for a legacy op which is only correct with static shapes.
|
// it is for a legacy op which is only correct with static shapes.
|
||||||
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
|
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
|
||||||
AnyFloat, AnySignlessInteger, AnyComplex]>;
|
AnyFloat, AnySignlessInteger, HLO_Complex]>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// XLA on tensors combined type definitions.
|
// XLA on tensors combined type definitions.
|
||||||
@ -77,10 +84,10 @@ def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>;
|
|||||||
def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
|
def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
|
||||||
|
|
||||||
// Any floating-point or complex tensor types
|
// Any floating-point or complex tensor types
|
||||||
def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, AnyComplex]>;
|
def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>;
|
||||||
|
|
||||||
// Any int, floating-point or complex tensor types
|
// Any int, floating-point or complex tensor types
|
||||||
def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, AnyComplex]>;
|
def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
|
||||||
|
|
||||||
// Any pred, int or floating-point tensor types
|
// Any pred, int or floating-point tensor types
|
||||||
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
|
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
|
||||||
|
@ -924,7 +924,6 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(ElementsAttr attr) {
|
|||||||
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
|
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
|
||||||
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
|
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
|
||||||
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64)
|
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64)
|
||||||
// TODO(b/130356985): Update once MLIR supports unsigned integers.
|
|
||||||
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
|
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
|
||||||
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
|
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
|
||||||
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
|
ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
|
||||||
|
@ -446,7 +446,7 @@ func @recv_non_token_second_result(%token: !xla_hlo.token) -> tuple<tensor<3x4xi
|
|||||||
|
|
||||||
func @rng_uniform_invalid_type(%mu: tensor<complex<f32>>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
|
func @rng_uniform_invalid_type(%mu: tensor<complex<f32>>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
|
||||||
%shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
|
%shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||||
// expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer or floating-point values, but got 'tensor<complex<f32>>'}}
|
// expected-error@+1 {{but got 'tensor<complex<f32>>'}}
|
||||||
%0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor<complex<f32>>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
|
%0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor<complex<f32>>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
|
||||||
return %0 : tensor<2x3x5xf32>
|
return %0 : tensor<2x3x5xf32>
|
||||||
}
|
}
|
||||||
@ -583,7 +583,7 @@ func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||||
// expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}}
|
// expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
|
||||||
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
%0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||||
return %0 : tensor<1x4xi32>
|
return %0 : tensor<1x4xi32>
|
||||||
}
|
}
|
||||||
@ -599,7 +599,7 @@ func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %sta
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
|
func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> {
|
||||||
// expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer values, but got 'tensor<2xi64>'}}
|
// expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}}
|
||||||
%0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
|
%0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64>
|
||||||
return %0 : tensor<3x4xi64>
|
return %0 : tensor<3x4xi64>
|
||||||
}
|
}
|
||||||
@ -778,7 +778,7 @@ func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) or 8/16/32/64-bit signless integer values, but got 'tensor<4xf32>'}}
|
// expected-error@+1 {{but got 'tensor<4xf32>'}}
|
||||||
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
return %0 : tensor<4xf32>
|
return %0 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
@ -285,11 +285,14 @@ func @main() {
|
|||||||
// CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
|
// CHECK: s32[2,2] constant({ { 3, 2 }, { 1, 4 } })
|
||||||
%cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
|
%cst_5 = constant dense<[[3, 2], [1, 4]]> : tensor<2x2xi32>
|
||||||
|
|
||||||
|
// CHECK: u32[2,2] constant({ { 1, 2 }, { 4, 8 } })
|
||||||
|
%cst_6 = constant dense<[[1, 2], [4, 8]]> : tensor<2x2xui32>
|
||||||
|
|
||||||
// CHECK: bf16[4] constant({1, 2, 3, 4})
|
// CHECK: bf16[4] constant({1, 2, 3, 4})
|
||||||
%cst_6 = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
|
%cst_7 = constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
|
||||||
|
|
||||||
// CHECK: f16[4] constant({1, -4, -65504, 0.015625}
|
// CHECK: f16[4] constant({1, -4, -65504, 0.015625}
|
||||||
%cst_7 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16>
|
%cst_8 = constant dense<[1.0e+00, -4.0e+00, -65504.0e+00, 1.5625e-02]> : tensor<4xf16>
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -1023,3 +1026,15 @@ func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (ten
|
|||||||
// CHECK: %[[ARG1:.*]] = c128[2] parameter(1)
|
// CHECK: %[[ARG1:.*]] = c128[2] parameter(1)
|
||||||
// CHECK: %[[ABS1:.*]] = f64[2] abs(c128[2] %[[ARG1]])
|
// CHECK: %[[ABS1:.*]] = f64[2] abs(c128[2] %[[ARG1]])
|
||||||
// CHECK: ROOT %[[RESULT:.*]] = (f32[2], f64[2]) tuple(f32[2] %[[ABS0]], f64[2] %[[ABS1]])
|
// CHECK: ROOT %[[RESULT:.*]] = (f32[2], f64[2]) tuple(f32[2] %[[ABS0]], f64[2] %[[ABS1]])
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: HloModule
|
||||||
|
func @main(%arg0: tensor<4xui8>) -> (tensor<4xui8>) {
|
||||||
|
%0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8>
|
||||||
|
return %0 : tensor<4xui8>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: ENTRY
|
||||||
|
// CHECK: %[[ARG0:.*]] = u8[4] parameter(0)
|
||||||
|
// ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]])
|
||||||
|
@ -206,11 +206,16 @@ add {
|
|||||||
// CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
|
// CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
|
||||||
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||||
|
|
||||||
|
// CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64>
|
||||||
|
%constant.2 = u64[4] constant({ 1, 2, 4, 8 })
|
||||||
|
|
||||||
// CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
|
// CHECK: dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
|
||||||
%constant.2 = bf16[4] constant({1, 2, 3, 4})
|
%constant.3 = bf16[4] constant({1, 2, 3, 4})
|
||||||
|
|
||||||
// CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
|
// CHECK: dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
|
||||||
ROOT %constant.3 = f16[4] constant({1, -4, -65504, 0.015625})
|
ROOT %constant.4 = f16[4] constant({1, -4, -65504, 0.015625})
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
|
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
|
||||||
@ -1009,3 +1014,12 @@ add {
|
|||||||
// CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
|
// CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
|
||||||
ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4)
|
ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @unsigned_int
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xui16>)
|
||||||
|
%unsigned_int(Arg_0.1: u16[4]) -> u16[4] {
|
||||||
|
%Arg_0.1 = u16[4] parameter(0)
|
||||||
|
|
||||||
|
// CHECK: "xla_hlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16>
|
||||||
|
ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1)
|
||||||
|
}
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -64,17 +65,18 @@ PrimitiveType TypeToPrimitiveType(mlir::Type type) {
|
|||||||
return PrimitiveType::F64;
|
return PrimitiveType::F64;
|
||||||
case mlir::StandardTypes::Integer: {
|
case mlir::StandardTypes::Integer: {
|
||||||
const auto integer = type.cast<IntegerType>();
|
const auto integer = type.cast<IntegerType>();
|
||||||
|
bool is_unsigned = integer.isUnsigned();
|
||||||
switch (integer.getWidth()) {
|
switch (integer.getWidth()) {
|
||||||
case 1:
|
case 1:
|
||||||
return PrimitiveType::PRED;
|
return PrimitiveType::PRED;
|
||||||
case 8:
|
case 8:
|
||||||
return PrimitiveType::S8;
|
return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8;
|
||||||
case 16:
|
case 16:
|
||||||
return PrimitiveType::S16;
|
return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16;
|
||||||
case 32:
|
case 32:
|
||||||
return PrimitiveType::S32;
|
return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32;
|
||||||
case 64:
|
case 64:
|
||||||
return PrimitiveType::S64;
|
return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64;
|
||||||
default:
|
default:
|
||||||
return PrimitiveType::PRIMITIVE_TYPE_INVALID;
|
return PrimitiveType::PRIMITIVE_TYPE_INVALID;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user