[MLIR:TF] Lower xdivy, xlog1py and xlogy via tf2tf
These are just flavors of x == 0 ? 0 : op(x, y) PiperOrigin-RevId: 346180716 Change-Id: Ib1a0dcb34157cfd8d2ee28240356d92aa62b2ac7
This commit is contained in:
parent
c50af433c1
commit
19fa254274
@ -908,3 +908,41 @@ func @imag_resize_nearest_full_dyn(%arg0: tensor<1x?x?x1xi32>, %arg1: tensor<2xi
|
|||||||
%resize = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x?x?x1xi32>
|
%resize = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x?x?x1xi32>
|
||||||
return %resize: tensor<1x?x?x1xi32>
|
return %resize: tensor<1x?x?x1xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @xdivy
|
||||||
|
// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
|
||||||
|
func @xdivy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
// CHECK: %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
|
||||||
|
// CHECK: %[[MUL:.*]] = "tf.Div"(%[[X]], %[[Y]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%0 = "tf.Xdivy"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @xlog1py
|
||||||
|
// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
|
||||||
|
func @xlog1py(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
// CHECK: %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
|
||||||
|
// CHECK: %[[LOG:.*]] = "tf.Log1p"(%[[Y]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[X]], %[[LOG]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%0 = "tf.Xlog1py"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @xlogy
|
||||||
|
// CHECK-SAME: (%[[X:.*]]: tensor<*xf32>, %[[Y:.*]]: tensor<*xf32>)
|
||||||
|
func @xlogy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
// CHECK: %[[IS_ZERO:.*]] = "tf.Equal"(%[[X]], %[[ZERO]]) {incompatible_shape_error = true} : (tensor<*xf32>, tensor<f32>) -> tensor<*xi1>
|
||||||
|
// CHECK: %[[LOG:.*]] = "tf.Log"(%[[Y]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[X]], %[[LOG]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "tf.SelectV2"(%[[IS_ZERO]], %[[ZERO]], %[[MUL]]) : (tensor<*xi1>, tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%0 = "tf.Xlogy"(%lhs, %rhs) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
@ -298,3 +298,19 @@ def LowerScatterNdOp :
|
|||||||
(TF_TensorScatterAddOp
|
(TF_TensorScatterAddOp
|
||||||
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
|
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
|
||||||
$indices, $updates)>;
|
$indices, $updates)>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Xdivy, Xlog1p and Xlogy op patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class BinaryXopyPat<dag From, dag To>
|
||||||
|
: Pat<From,
|
||||||
|
(TF_SelectV2Op (TF_EqualOp $x,
|
||||||
|
(TF_ConstOp:$zero (GetScalarOfType<0> $x)),
|
||||||
|
/*incompatible_shape_error*/ConstBoolAttrTrue),
|
||||||
|
$zero, To)>;
|
||||||
|
|
||||||
|
foreach fromToPair = [[(TF_XdivyOp $x, $y), (TF_DivOp $x, $y)],
|
||||||
|
[(TF_Xlog1pyOp $x, $y), (TF_MulOp $x, (TF_Log1pOp $y))],
|
||||||
|
[(TF_XlogyOp $x, $y), (TF_MulOp $x, (TF_LogOp $y))]] in
|
||||||
|
def : BinaryXopyPat<fromToPair[0], fromToPair[1]>;
|
||||||
|
@ -255,7 +255,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::TruncateModOp>(),
|
TypeID::get<TF::TruncateModOp>(),
|
||||||
TypeID::get<TF::UnpackOp>(),
|
TypeID::get<TF::UnpackOp>(),
|
||||||
TypeID::get<TF::UpperBoundOp>(),
|
TypeID::get<TF::UpperBoundOp>(),
|
||||||
TypeID::get<TF::XdivyOp>(),
|
|
||||||
TypeID::get<TF::XlaBroadcastHelperOp>(),
|
TypeID::get<TF::XlaBroadcastHelperOp>(),
|
||||||
TypeID::get<TF::XlaConvOp>(),
|
TypeID::get<TF::XlaConvOp>(),
|
||||||
TypeID::get<TF::XlaDotOp>(),
|
TypeID::get<TF::XlaDotOp>(),
|
||||||
@ -265,8 +264,6 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::XlaKeyValueSortOp>(),
|
TypeID::get<TF::XlaKeyValueSortOp>(),
|
||||||
TypeID::get<TF::XlaPadOp>(),
|
TypeID::get<TF::XlaPadOp>(),
|
||||||
TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(),
|
TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(),
|
||||||
TypeID::get<TF::Xlog1pyOp>(),
|
|
||||||
TypeID::get<TF::XlogyOp>(),
|
|
||||||
TypeID::get<TF::XlaSortOp>(),
|
TypeID::get<TF::XlaSortOp>(),
|
||||||
TypeID::get<TF::XlaSvdOp>()
|
TypeID::get<TF::XlaSvdOp>()
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user