Legalization to CHLO for unranked binary ops

The unspecified broadcast attribute is left padded, just generalize the previous rule to match this case too. For use by KernelGen.

PiperOrigin-RevId: 332325673
Change-Id: I7b6b72e73412d69cfbe5aa53d8a6691e8befbad2
This commit is contained in:
Jacques Pienaar 2020-09-17 15:09:20 -07:00 committed by TensorFlower Gardener
parent 97d2a8469b
commit 5e007c2770
2 changed files with 11 additions and 1 deletions

View File

@ -2,6 +2,7 @@
// (unlike the rest), since this is the primary use case for such ops and
// verification of shapes and broadcasts is desired.
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -canonicalize %s | FileCheck %s
// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FileCheck --check-prefix CHLO %s
//===----------------------------------------------------------------------===//
// Binary op legalizations.
@ -58,6 +59,15 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
return %0: tensor<?x?xi32>
}
// CHECK-LABEL: func @broadcast_add_unranked
// CHLO-LABEL: func @broadcast_add_unranked
func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: tf.Add
// CHLO: chlo.broadcast_add %arg0, %arg1
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
}
// CHECK-LABEL: func @div
func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>

View File

@ -86,7 +86,7 @@ def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
"types must be broadcastable">;
class DirectBinaryPat<Op FromOp, Op ToOp>
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r),
: Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
(ToOp $l, $r, (BinBroadcastDimensions $l, $r))>;
foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp],