Restrict GetDimensionSize HLO op result type to 32 bit integer
XLA implementation has this limitation and always uses 32 bit result for this instruction. This will cause mismatch between the result type in MLIR and XLA at the time of export. This should be resolved once we have a special dialect mapping directly to HLOInstructionProto. Another option until then could be to introduce a pass to legalize mhlo itself to match XLA semantics. PiperOrigin-RevId: 324286936 Change-Id: Ice7893f9920bbbc96936b90c8063248b1627e3e9
This commit is contained in:
parent
934b4b6a35
commit
8a2c608cf7
@ -1075,7 +1075,10 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
|
||||
HLO_Tensor:$operand,
|
||||
I32Attr:$dimension
|
||||
);
|
||||
let results = (outs HLO_IntTensor);
|
||||
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
|
||||
// XLA semantics is available. This limitation is because of the current XLA
|
||||
// implementation.
|
||||
let results = (outs I32Tensor);
|
||||
}
|
||||
|
||||
def HLO_MapOp: HLO_Op<"map",
|
||||
|
@ -3482,8 +3482,8 @@ func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {
|
||||
// tf.Size legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @size_rank_one_i32
|
||||
func @size_rank_one_i32(%input: tensor<f32>) -> (tensor<i32>) {
|
||||
// CHECK-LABEL: @size_scalar_i32
|
||||
func @size_scalar_i32(%input: tensor<f32>) -> (tensor<i32>) {
|
||||
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
|
||||
// CHECK-SAME: tensor<i32>
|
||||
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<f32>) -> tensor<i32>
|
||||
@ -3491,8 +3491,8 @@ func @size_rank_one_i32(%input: tensor<f32>) -> (tensor<i32>) {
|
||||
return %size : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @size_rank_one_i64
|
||||
func @size_rank_one_i64(%input: tensor<f32>) -> (tensor<i64>) {
|
||||
// CHECK-LABEL: @size_scalar_i64
|
||||
func @size_scalar_i64(%input: tensor<f32>) -> (tensor<i64>) {
|
||||
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
|
||||
// CHECK-SAME: tensor<i64>
|
||||
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT64"} : (tensor<f32>) -> tensor<i64>
|
||||
@ -3500,19 +3500,40 @@ func @size_rank_one_i64(%input: tensor<f32>) -> (tensor<i64>) {
|
||||
return %size : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @size_rank_one_i64
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: tensor<?xf32>)
|
||||
func @size_rank_one_i64(%input: tensor<?xf32>) -> (tensor<i64>) {
|
||||
// CHECK: %[[INIT:.*]] = mhlo.constant dense<1>
|
||||
// CHECK-SAME: tensor<i64>
|
||||
|
||||
// CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 0
|
||||
// CHECK-SAME: tensor<i32>
|
||||
|
||||
// CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply %[[INIT]], %[[CAST_DIM_0]]
|
||||
|
||||
%size = "tf.Size"(%input) : (tensor<?xf32>) -> tensor<i64>
|
||||
// CHECK: return %[[RESULT]]
|
||||
return %size : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @size_ranked
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: tensor<2x?x8xf32>)
|
||||
func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor<i32>) {
|
||||
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
|
||||
// CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 0
|
||||
// CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[DIM_0]]
|
||||
// CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[CAST_DIM_0]]
|
||||
// CHECK: %[[DIM_1:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 1
|
||||
// CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]]
|
||||
// CHECK: %[[CAST_DIM_1:.*]] = "mhlo.convert"(%[[DIM_1]]) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[CAST_DIM_1]]
|
||||
// CHECK: %[[DIM_2:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 2
|
||||
// CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]]
|
||||
// CHECK: %[[CAST_DIM_2:.*]] = "mhlo.convert"(%[[DIM_2]]) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[CAST_DIM_2]]
|
||||
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor<i32>
|
||||
// CHECK: return %[[MUL_2]]
|
||||
return %size : tensor<i32>
|
||||
|
@ -2630,19 +2630,21 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
|
||||
if (!input_ty) return failure();
|
||||
|
||||
const int64_t rank = input_ty.getRank();
|
||||
auto result_type = op.getResult().getType();
|
||||
Operation *size =
|
||||
GetScalarConstOfType(result_type.cast<TensorType>().getElementType(),
|
||||
op.getLoc(), 1, &rewriter);
|
||||
auto result_ty = op.getResult().getType();
|
||||
auto element_ty = result_ty.cast<TensorType>().getElementType();
|
||||
Value size = GetScalarConstOfType(element_ty, op.getLoc(), 1, &rewriter);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
auto dim = rewriter.create<GetDimensionSizeOp>(
|
||||
op.getLoc(), result_type, input,
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(32), i));
|
||||
auto i32_ty = rewriter.getIntegerType(32);
|
||||
auto size_ty = RankedTensorType::get({}, i32_ty);
|
||||
auto dim_index = rewriter.getIntegerAttr(i32_ty, i);
|
||||
Value dim = rewriter.create<GetDimensionSizeOp>(op.getLoc(), size_ty,
|
||||
input, dim_index);
|
||||
dim = rewriter.create<mhlo::ConvertOp>(op.getLoc(), result_ty, dim);
|
||||
size = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), size->getResult(0), dim.getResult(),
|
||||
op.getLoc(), size, dim,
|
||||
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
||||
}
|
||||
rewriter.replaceOp(op, size->getResult(0));
|
||||
rewriter.replaceOp(op, size);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user