From 5e007c2770d3a1630f8837aed829238538302d51 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 17 Sep 2020 15:09:20 -0700 Subject: [PATCH] Legalization to CHLO for unranked binary ops The unspecified broadcast attribute is left padded, just generalize the previous rule to match this case too. For use by KernelGen. PiperOrigin-RevId: 332325673 Change-Id: I7b6b72e73412d69cfbe5aa53d8a6691e8befbad2 --- .../mlir/xla/tests/legalize-tf-binary-elementwise.mlir | 10 ++++++++++ .../mlir/xla/transforms/legalize_tf_patterns.td | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index 1a6e0e1229f..887fdea5a21 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -2,6 +2,7 @@ // (unlike the rest), since this is the primary use case for such ops and // verification of shapes and broadcasts is desired. // RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -canonicalize %s | FileCheck %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FileCheck --check-prefix CHLO %s //===----------------------------------------------------------------------===// // Binary op legalizations. @@ -58,6 +59,15 @@ func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor } +// CHECK-LABEL: func @broadcast_add_unranked +// CHLO-LABEL: func @broadcast_add_unranked +func @broadcast_add_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: tf.Add + // CHLO: chlo.broadcast_add %arg0, %arg1 + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + // CHECK-LABEL: func @div func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32> diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 7b4e7aea5ec..b1460421f16 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -86,7 +86,7 @@ def AreBroadcastCompatible : Constraint, "types must be broadcastable">; class DirectBinaryPat - : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp],