Enable lowering xla_hlo.convert to linalg.generic op.

Also sort the patterns alphabetically.

PiperOrigin-RevId: 309312612
Change-Id: Ic302751f0ab998a4a04ceee55579776a753c0c9f
This commit is contained in:
Hanhan Wang 2020-04-30 15:34:23 -07:00 committed by TensorFlower Gardener
parent c498e1a0a6
commit 306f371dbc
2 changed files with 62 additions and 1 deletions

View File

@ -444,3 +444,63 @@ func @reshape_multiple_collapse
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// CHECK-LABEL: func @reshape_multiple_collapse
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
// -----
// CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @convert_i16_to_i32
func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16):
// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @convert_i32_to_i16
func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
return %result : tensor<2x2xi16>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16
// -----
// CHECK-LABEL: func @convert_f32_to_f64
func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
return %result : tensor<2x2xf64>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64
// CHECK-NEXT: linalg.yield %[[RESULT]] : f64
// -----
// CHECK-LABEL: func @convert_f64_to_f32
func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64):
// CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32

View File

@ -717,8 +717,9 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
PointwiseToLinalgConverter<xla_hlo::LogOp, false>,