[MLIR:TF] Lower xdivy, xlog1py and xlogy via tf2tf

These are just flavors of x == 0 ? 0 : op(x, y)

PiperOrigin-RevId: 346180716
Change-Id: Ib1a0dcb34157cfd8d2ee28240356d92aa62b2ac7
This commit is contained in:
Benjamin Kramer 2020-12-07 14:28:53 -08:00 committed by TensorFlower Gardener
parent c50af433c1
commit 19fa254274
3 changed files with 54 additions and 3 deletions

View File

@ -908,3 +908,41 @@ func @imag_resize_nearest_full_dyn(%arg0: tensor<1x?x?x1xi32>, %arg1: tensor<2xi
%resize = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x?x?x1xi32> %resize = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x?x?x1xi32>
return %resize: tensor<1x?x?x1xi32> return %resize: tensor<1x?x?x1xi32>
} }
// CHECK-LABEL: func @xdivy
// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
func @xdivy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
// CHECK: %[[MUL:.*]] = "tf.Div"(%[[X]], %[[Y]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Xdivy"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[RESULT]]
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @xlog1py
// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
func @xlog1py(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
// CHECK: %[[LOG:.*]] = "tf.Log1p"(%[[Y]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[X]], %[[LOG]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Xlog1py"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[RESULT]]
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @xlogy
// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
func @xlogy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
// CHECK: %[[LOG:.*]] = "tf.Log"(%[[Y]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[X]], %[[LOG]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Xlogy"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[RESULT]]
return %0 : tensor<*xf32>
}

View File

@ -298,3 +298,19 @@ def LowerScatterNdOp :
(TF_TensorScatterAddOp (TF_TensorScatterAddOp
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))), (TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
$indices, $updates)>; $indices, $updates)>;
//===----------------------------------------------------------------------===//
// Xdivy, Xlog1p and Xlogy op patterns.
//===----------------------------------------------------------------------===//
class BinaryXopyPat<dag From, dag To>
: Pat<From,
(TF_SelectV2Op (TF_EqualOp $x,
(TF_ConstOp:$zero (GetScalarOfType<0> $x)),
/*incompatible_shape_error*/ConstBoolAttrTrue),
$zero, To)>;
foreach fromToPair = [[(TF_XdivyOp $x, $y), (TF_DivOp $x, $y)],
[(TF_Xlog1pyOp $x, $y), (TF_MulOp $x, (TF_Log1pOp $y))],
[(TF_XlogyOp $x, $y), (TF_MulOp $x, (TF_LogOp $y))]] in
def : BinaryXopyPat<fromToPair[0], fromToPair[1]>;

View File

@ -255,7 +255,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::TruncateModOp>(), TypeID::get<TF::TruncateModOp>(),
TypeID::get<TF::UnpackOp>(), TypeID::get<TF::UnpackOp>(),
TypeID::get<TF::UpperBoundOp>(), TypeID::get<TF::UpperBoundOp>(),
TypeID::get<TF::XdivyOp>(),
TypeID::get<TF::XlaBroadcastHelperOp>(), TypeID::get<TF::XlaBroadcastHelperOp>(),
TypeID::get<TF::XlaConvOp>(), TypeID::get<TF::XlaConvOp>(),
TypeID::get<TF::XlaDotOp>(), TypeID::get<TF::XlaDotOp>(),
@ -265,8 +264,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::XlaKeyValueSortOp>(), TypeID::get<TF::XlaKeyValueSortOp>(),
TypeID::get<TF::XlaPadOp>(), TypeID::get<TF::XlaPadOp>(),
TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(), TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(),
TypeID::get<TF::Xlog1pyOp>(),
TypeID::get<TF::XlogyOp>(),
TypeID::get<TF::XlaSortOp>(), TypeID::get<TF::XlaSortOp>(),
TypeID::get<TF::XlaSvdOp>() TypeID::get<TF::XlaSvdOp>()
}; };