diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 87e9cd28d4c..fbcf2281fa6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1827,3 +1827,27 @@ func @convert_pad(%arg0: tensor<8x128xf32>, %arg1: tensor<f32>) -> tensor<11x131 return %0 : tensor<11x131xf32> } +// CHECK-LABEL: func @convert_round( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x128xbf16>) -> tensor<8x128xbf16> +// CHECK: %[[VAL_1:.*]] = "tf.Round"(%[[VAL_0]]) : (tensor<8x128xbf16>) -> tensor<8x128xbf16> +// CHECK: return %[[VAL_1]] +// CHECK: } +func @convert_round(%arg0: tensor<8x128xbf16>) -> tensor<8x128xbf16> { + %0 = mhlo.constant dense<2.000000e+00> : tensor<8x128xbf16> + %1 = mhlo.constant dense<5.000000e-01> : tensor<8x128xbf16> + %2 = mhlo.constant dense<1.000000e+00> : tensor<8x128xbf16> + %3 = "mhlo.floor"(%arg0) : (tensor<8x128xbf16>) -> tensor<8x128xbf16> + %4 = mhlo.subtract %arg0, %3 : tensor<8x128xbf16> + %5 = "mhlo.compare"(%4, %1) {comparison_direction = "GT"} : (tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xi1> + %6 = "mhlo.compare"(%4, %1) {comparison_direction = "EQ"} : (tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xi1> + %7 = mhlo.multiply %arg0, %1 : tensor<8x128xbf16> + %8 = "mhlo.floor"(%7) : (tensor<8x128xbf16>) -> tensor<8x128xbf16> + %9 = mhlo.multiply %8, %0 : tensor<8x128xbf16> + %10 = mhlo.subtract %3, %9 : tensor<8x128xbf16> + %11 = "mhlo.compare"(%10, %2) {comparison_direction = "EQ"} : (tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xi1> + %12 = mhlo.and %6, %11 : tensor<8x128xi1> + %13 = mhlo.or %5, %12 : tensor<8x128xi1> + %14 = mhlo.add %3, %2 : tensor<8x128xbf16> + %15 = "mhlo.select"(%13, %14, %3) : (tensor<8x128xi1>, tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xbf16> + return %15 : tensor<8x128xbf16> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index 8543e271fed..fe019ef957c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -225,3 +225,59 @@ def : Pat<(HLO_PadOp:$old_value $input, $pad_value, $pad_low, $pad_high, $pad_interior), (ConvertPadOp $old_value), [(IsZero $pad_interior)]>; + +class FloatValueEquals<string val> : Constraint<CPred< + "$0.isa<SplatElementsAttr>() && " + "$0.cast<SplatElementsAttr>().getSplatValue<APFloat>().isExactlyValue(" # val # ")">>; +def SameValue : Constraint<CPred<"$0 == $1">>; +def FloatOrDefaultCompare : Constraint<CPred< + "!$0 || $0.getValue() == \"FLOAT\"">>; + +// Converts a soup of HLOs representing banker rounding (round x.5 to nearest +// even) to tf.round. +// The pattern matched executes the following computation: +// frac = x - floor(x) +// to_even = (floor(x) - 2 * floor(0.5 * x)) == 1 +// if frac > 0.5 || (frac == 0.5 && to_even) +// return floor + 1 +// else +// return floor +def : Pat<(HLO_SelectOp + (HLO_OrOp + (HLO_CompareOp (HLO_SubOp:$frac + $input, + (HLO_FloorOp:$floor $input)), + (HLO_ConstOp $half), + HLO_COMPARISON_DIRECTION_GT, + $compare_type0), + (HLO_AndOp + (HLO_CompareOp + $frac1, + (HLO_ConstOp $half1), + HLO_COMPARISON_DIRECTION_EQ, + $compare_type1), + (HLO_CompareOp + (HLO_SubOp + $floor1, + (HLO_MulOp + (HLO_FloorOp (HLO_MulOp $input, (HLO_ConstOp $half2))), + (HLO_ConstOp $two))), + (HLO_ConstOp $one1), + HLO_COMPARISON_DIRECTION_EQ, + $compare_type2))), + (HLO_AddOp $floor2, (HLO_ConstOp $one)), + $floor3), + (TF_RoundOp $input), + [(FloatValueEquals<"1.0"> $one), + (FloatValueEquals<"1.0"> $one1), + (FloatValueEquals<"2.0"> $two), + (FloatValueEquals<"0.5"> $half), + (FloatValueEquals<"0.5"> $half1), + (FloatValueEquals<"0.5"> $half2), + (SameValue $floor, $floor1), + (SameValue $floor, $floor2), + (SameValue $floor, $floor3), + (SameValue $frac, $frac1), + (FloatOrDefaultCompare $compare_type0), + (FloatOrDefaultCompare $compare_type1), + (FloatOrDefaultCompare $compare_type2)]>;