Legalize soup of mhlo ops to tf.Round
np.round and tf.round uses banker rounding (round x.5 to even) what can only be represented as a soup of HLO ops. This change adds a new pattern to the mhlo->TF legaliser to pattern match this soup and rewrite it to tf.round. PiperOrigin-RevId: 345432796 Change-Id: Idadf8bfea9552d5624fb9bd4579bddc073375ad9
This commit is contained in:
parent
8d5103318f
commit
74ecc3ec25
@ -1827,3 +1827,27 @@ func @convert_pad(%arg0: tensor<8x128xf32>, %arg1: tensor<f32>) -> tensor<11x131
|
||||
return %0 : tensor<11x131xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @convert_round(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x128xbf16>) -> tensor<8x128xbf16>
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.Round"(%[[VAL_0]]) : (tensor<8x128xbf16>) -> tensor<8x128xbf16>
|
||||
// CHECK: return %[[VAL_1]]
|
||||
// CHECK: }
|
||||
func @convert_round(%arg0: tensor<8x128xbf16>) -> tensor<8x128xbf16> {
|
||||
%0 = mhlo.constant dense<2.000000e+00> : tensor<8x128xbf16>
|
||||
%1 = mhlo.constant dense<5.000000e-01> : tensor<8x128xbf16>
|
||||
%2 = mhlo.constant dense<1.000000e+00> : tensor<8x128xbf16>
|
||||
%3 = "mhlo.floor"(%arg0) : (tensor<8x128xbf16>) -> tensor<8x128xbf16>
|
||||
%4 = mhlo.subtract %arg0, %3 : tensor<8x128xbf16>
|
||||
%5 = "mhlo.compare"(%4, %1) {comparison_direction = "GT"} : (tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xi1>
|
||||
%6 = "mhlo.compare"(%4, %1) {comparison_direction = "EQ"} : (tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xi1>
|
||||
%7 = mhlo.multiply %arg0, %1 : tensor<8x128xbf16>
|
||||
%8 = "mhlo.floor"(%7) : (tensor<8x128xbf16>) -> tensor<8x128xbf16>
|
||||
%9 = mhlo.multiply %8, %0 : tensor<8x128xbf16>
|
||||
%10 = mhlo.subtract %3, %9 : tensor<8x128xbf16>
|
||||
%11 = "mhlo.compare"(%10, %2) {comparison_direction = "EQ"} : (tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xi1>
|
||||
%12 = mhlo.and %6, %11 : tensor<8x128xi1>
|
||||
%13 = mhlo.or %5, %12 : tensor<8x128xi1>
|
||||
%14 = mhlo.add %3, %2 : tensor<8x128xbf16>
|
||||
%15 = "mhlo.select"(%13, %14, %3) : (tensor<8x128xi1>, tensor<8x128xbf16>, tensor<8x128xbf16>) -> tensor<8x128xbf16>
|
||||
return %15 : tensor<8x128xbf16>
|
||||
}
|
||||
|
@ -225,3 +225,59 @@ def : Pat<(HLO_PadOp:$old_value $input, $pad_value, $pad_low, $pad_high,
|
||||
$pad_interior),
|
||||
(ConvertPadOp $old_value),
|
||||
[(IsZero $pad_interior)]>;
|
||||
|
||||
class FloatValueEquals<string val> : Constraint<CPred<
|
||||
"$0.isa<SplatElementsAttr>() && "
|
||||
"$0.cast<SplatElementsAttr>().getSplatValue<APFloat>().isExactlyValue(" # val # ")">>;
|
||||
def SameValue : Constraint<CPred<"$0 == $1">>;
|
||||
def FloatOrDefaultCompare : Constraint<CPred<
|
||||
"!$0 || $0.getValue() == \"FLOAT\"">>;
|
||||
|
||||
// Converts a soup of HLOs representing banker rounding (round x.5 to nearest
|
||||
// even) to tf.round.
|
||||
// The pattern matched executes the following computation:
|
||||
// frac = x - floor(x)
|
||||
// to_even = (floor(x) - 2 * floor(0.5 * x)) == 1
|
||||
// if frac > 0.5 || (frac == 0.5 && to_even)
|
||||
// return floor + 1
|
||||
// else
|
||||
// return floor
|
||||
def : Pat<(HLO_SelectOp
|
||||
(HLO_OrOp
|
||||
(HLO_CompareOp (HLO_SubOp:$frac
|
||||
$input,
|
||||
(HLO_FloorOp:$floor $input)),
|
||||
(HLO_ConstOp $half),
|
||||
HLO_COMPARISON_DIRECTION_GT,
|
||||
$compare_type0),
|
||||
(HLO_AndOp
|
||||
(HLO_CompareOp
|
||||
$frac1,
|
||||
(HLO_ConstOp $half1),
|
||||
HLO_COMPARISON_DIRECTION_EQ,
|
||||
$compare_type1),
|
||||
(HLO_CompareOp
|
||||
(HLO_SubOp
|
||||
$floor1,
|
||||
(HLO_MulOp
|
||||
(HLO_FloorOp (HLO_MulOp $input, (HLO_ConstOp $half2))),
|
||||
(HLO_ConstOp $two))),
|
||||
(HLO_ConstOp $one1),
|
||||
HLO_COMPARISON_DIRECTION_EQ,
|
||||
$compare_type2))),
|
||||
(HLO_AddOp $floor2, (HLO_ConstOp $one)),
|
||||
$floor3),
|
||||
(TF_RoundOp $input),
|
||||
[(FloatValueEquals<"1.0"> $one),
|
||||
(FloatValueEquals<"1.0"> $one1),
|
||||
(FloatValueEquals<"2.0"> $two),
|
||||
(FloatValueEquals<"0.5"> $half),
|
||||
(FloatValueEquals<"0.5"> $half1),
|
||||
(FloatValueEquals<"0.5"> $half2),
|
||||
(SameValue $floor, $floor1),
|
||||
(SameValue $floor, $floor2),
|
||||
(SameValue $floor, $floor3),
|
||||
(SameValue $frac, $frac1),
|
||||
(FloatOrDefaultCompare $compare_type0),
|
||||
(FloatOrDefaultCompare $compare_type1),
|
||||
(FloatOrDefaultCompare $compare_type2)]>;
|
||||
|
Loading…
Reference in New Issue
Block a user