Allow lowering Unranked Binary CWise TF->HLO
PiperOrigin-RevId: 346426752 Change-Id: Ie75fdf35c67ddfa5fe120d2429e850d27601ac6e
This commit is contained in:
parent
17f251ee93
commit
98a9cbb5f3
@ -835,7 +835,14 @@ func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @floordiv_unranked
|
// CHECK-LABEL: func @floordiv_unranked
|
||||||
func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK-NOT: tf.FloorDiv
|
||||||
|
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %0: tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @floordiv_int
|
||||||
|
func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||||
// CHECK: tf.FloorDiv
|
// CHECK: tf.FloorDiv
|
||||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||||
return %0: tensor<*xi32>
|
return %0: tensor<*xi32>
|
||||||
@ -894,7 +901,7 @@ func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
|
|||||||
|
|
||||||
// CHECK-LABEL: func @floormod_unranked
|
// CHECK-LABEL: func @floormod_unranked
|
||||||
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||||
// CHECK: tf.FloorMod
|
// CHECK-NOT: tf.FloorMod
|
||||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||||
return %0: tensor<*xi32>
|
return %0: tensor<*xi32>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -114,7 +114,7 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
|
|||||||
// Performs a substitution of FloorDiv, pseudo code below:
|
// Performs a substitution of FloorDiv, pseudo code below:
|
||||||
//
|
//
|
||||||
// return floor(div(x, y))
|
// return floor(div(x, y))
|
||||||
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
|
||||||
(HLO_FloorOp
|
(HLO_FloorOp
|
||||||
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
|
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
|
||||||
[(IEEEFloatTensor $l)]>;
|
[(IEEEFloatTensor $l)]>;
|
||||||
@ -166,7 +166,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
|||||||
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
|
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
|
||||||
// Requires static shaped inputs to create constant splats and computation of
|
// Requires static shaped inputs to create constant splats and computation of
|
||||||
// broadcast attributes.
|
// broadcast attributes.
|
||||||
def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
|
||||||
(HLO_SelectOp
|
(HLO_SelectOp
|
||||||
(HLOClient_BroadcastAndOp
|
(HLOClient_BroadcastAndOp
|
||||||
(HLOClient_BroadcastCompareOp
|
(HLOClient_BroadcastCompareOp
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user