Allow lowering Unranked Binary CWise TF->HLO

PiperOrigin-RevId: 346426752
Change-Id: Ie75fdf35c67ddfa5fe120d2429e850d27601ac6e
This commit is contained in:
Tres Popp 2020-12-08 15:44:34 -08:00 committed by TensorFlower Gardener
parent 17f251ee93
commit 98a9cbb5f3
2 changed files with 11 additions and 4 deletions

View File

@ -835,7 +835,14 @@ func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
}
// CHECK-LABEL: func @floordiv_unranked
func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NOT: tf.FloorDiv
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0: tensor<*xf32>
}
// CHECK-LABEL: func @floordiv_int
func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.FloorDiv
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
@ -894,7 +901,7 @@ func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?
// CHECK-LABEL: func @floormod_unranked
func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.FloorMod
// CHECK-NOT: tf.FloorMod
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
}

View File

@ -114,7 +114,7 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>;
// Performs a substitution of FloorDiv, pseudo code below:
//
// return floor(div(x, y))
def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r),
(HLO_FloorOp
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))),
[(IEEEFloatTensor $l)]>;
@ -166,7 +166,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y
// Requires static shaped inputs to create constant splats and computation of
// broadcast attributes.
def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
(HLO_SelectOp
(HLOClient_BroadcastAndOp
(HLOClient_BroadcastCompareOp