Allow lowering Unranked Binary CWise TF->HLO
PiperOrigin-RevId: 346426752 Change-Id: Ie75fdf35c67ddfa5fe120d2429e850d27601ac6e
This commit is contained in:
parent
17f251ee93
commit
98a9cbb5f3
@ -835,7 +835,14 @@ func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @floordiv_unranked
|
||||
func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NOT: tf.FloorDiv
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @floordiv_int
|
||||
func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
// CHECK: tf.FloorDiv
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0: tensor<*xi32>
|
||||
@ -894,7 +901,7 @@ func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
|
||||
|
||||
// CHECK-LABEL: func @floormod_unranked
|
||||
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
// CHECK: tf.FloorMod
|
||||
// CHECK-NOT: tf.FloorMod
|
||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0: tensor<*xi32>
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
|
||||
// Performs a substitution of FloorDiv, pseudo code below:
|
||||
//
|
||||
// return floor(div(x, y))
|
||||
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
|
||||
(HLO_FloorOp
|
||||
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
|
||||
[(IEEEFloatTensor $l)]>;
|
||||
@ -166,7 +166,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
|
||||
// Requires static shaped inputs to create constant splats and computation of
|
||||
// broadcast attributes.
|
||||
def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
|
||||
(HLO_SelectOp
|
||||
(HLOClient_BroadcastAndOp
|
||||
(HLOClient_BroadcastCompareOp
|
||||
|
Loading…
Reference in New Issue
Block a user