Restrict GetDimensionSize HLO op result type to 32 bit integer

XLA implementation has this limitation and always uses 32 bit result for this instruction. This will cause mismatch between the result type in MLIR and XLA at the time of export.

This should be resolved once we have a special dialect mapping directly to HLOInstructionProto. Another option until then could be to introduce a pass to legalize mhlo itself to match XLA semantics.

PiperOrigin-RevId: 324286936
Change-Id: Ice7893f9920bbbc96936b90c8063248b1627e3e9
This commit is contained in:
Smit Hinsu 2020-07-31 14:37:20 -07:00 committed by TensorFlower Gardener
parent 934b4b6a35
commit 8a2c608cf7
3 changed files with 43 additions and 17 deletions

View File

@ -1075,7 +1075,10 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
HLO_Tensor:$operand,
I32Attr:$dimension
);
let results = (outs HLO_IntTensor);
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
// XLA semantics is available. This limitation is because of the current XLA
// implementation.
let results = (outs I32Tensor);
}
def HLO_MapOp: HLO_Op<"map",

View File

@ -3482,8 +3482,8 @@ func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {
// tf.Size legalization
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @size_rank_one_i32
func @size_rank_one_i32(%input: tensor<f32>) -> (tensor<i32>) {
// CHECK-LABEL: @size_scalar_i32
func @size_scalar_i32(%input: tensor<f32>) -> (tensor<i32>) {
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
// CHECK-SAME: tensor<i32>
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<f32>) -> tensor<i32>
@ -3491,8 +3491,8 @@ func @size_rank_one_i32(%input: tensor<f32>) -> (tensor<i32>) {
return %size : tensor<i32>
}
// CHECK-LABEL: @size_rank_one_i64
func @size_rank_one_i64(%input: tensor<f32>) -> (tensor<i64>) {
// CHECK-LABEL: @size_scalar_i64
func @size_scalar_i64(%input: tensor<f32>) -> (tensor<i64>) {
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
// CHECK-SAME: tensor<i64>
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT64"} : (tensor<f32>) -> tensor<i64>
@ -3500,19 +3500,40 @@ func @size_rank_one_i64(%input: tensor<f32>) -> (tensor<i64>) {
return %size : tensor<i64>
}
// CHECK-LABEL: @size_rank_one_i64
// CHECK-SAME: (%[[INPUT:.*]]: tensor<?xf32>)
func @size_rank_one_i64(%input: tensor<?xf32>) -> (tensor<i64>) {
// CHECK: %[[INIT:.*]] = mhlo.constant dense<1>
// CHECK-SAME: tensor<i64>
// CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
// CHECK-SAME: dimension = 0
// CHECK-SAME: tensor<i32>
// CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor<i32>) -> tensor<i64>
// CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply %[[INIT]], %[[CAST_DIM_0]]
%size = "tf.Size"(%input) : (tensor<?xf32>) -> tensor<i64>
// CHECK: return %[[RESULT]]
return %size : tensor<i64>
}
// CHECK-LABEL: @size_ranked
// CHECK-SAME: (%[[INPUT:.*]]: tensor<2x?x8xf32>)
func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor<i32>) {
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
// CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
// CHECK-SAME: dimension = 0
// CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[DIM_0]]
// CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor<i32>) -> tensor<i32>
// CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[CAST_DIM_0]]
// CHECK: %[[DIM_1:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
// CHECK-SAME: dimension = 1
// CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]]
// CHECK: %[[CAST_DIM_1:.*]] = "mhlo.convert"(%[[DIM_1]]) : (tensor<i32>) -> tensor<i32>
// CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[CAST_DIM_1]]
// CHECK: %[[DIM_2:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
// CHECK-SAME: dimension = 2
// CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]]
// CHECK: %[[CAST_DIM_2:.*]] = "mhlo.convert"(%[[DIM_2]]) : (tensor<i32>) -> tensor<i32>
// CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[CAST_DIM_2]]
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor<i32>
// CHECK: return %[[MUL_2]]
return %size : tensor<i32>

View File

@ -2630,19 +2630,21 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
if (!input_ty) return failure();
const int64_t rank = input_ty.getRank();
auto result_type = op.getResult().getType();
Operation *size =
GetScalarConstOfType(result_type.cast<TensorType>().getElementType(),
op.getLoc(), 1, &rewriter);
auto result_ty = op.getResult().getType();
auto element_ty = result_ty.cast<TensorType>().getElementType();
Value size = GetScalarConstOfType(element_ty, op.getLoc(), 1, &rewriter);
for (int64_t i = 0; i < rank; ++i) {
auto dim = rewriter.create<GetDimensionSizeOp>(
op.getLoc(), result_type, input,
rewriter.getIntegerAttr(rewriter.getIntegerType(32), i));
auto i32_ty = rewriter.getIntegerType(32);
auto size_ty = RankedTensorType::get({}, i32_ty);
auto dim_index = rewriter.getIntegerAttr(i32_ty, i);
Value dim = rewriter.create<GetDimensionSizeOp>(op.getLoc(), size_ty,
input, dim_index);
dim = rewriter.create<mhlo::ConvertOp>(op.getLoc(), result_ty, dim);
size = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), size->getResult(0), dim.getResult(),
op.getLoc(), size, dim,
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
}
rewriter.replaceOp(op, size->getResult(0));
rewriter.replaceOp(op, size);
return success();
}