From a70db4d1d13f454529521371b464a7766c4e72dc Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 8 Apr 2020 17:54:45 -0700 Subject: [PATCH] [hlo2tf] Add complex -> complex PiperOrigin-RevId: 305591936 Change-Id: I1c5e8ded4ab8257003e6accd6c62a105ca5452ae --- .../compiler/mlir/tensorflow/tests/legalize_hlo.mlir | 12 ++++++++++++ .../tensorflow/transforms/legalize_hlo_patterns.td | 1 + 2 files changed, 13 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 4b38465257d..dca015f87ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -677,6 +677,11 @@ func @size_rank_one_i64(%arg0: tensor) -> tensor { return %0 : tensor } +func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { + %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> + return %0 : tensor<3xcomplex> +} + // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK-LABEL: func @biasAdd_NHWC( @@ -1481,3 +1486,10 @@ func @size_rank_one_i64(%arg0: tensor) -> tensor { // CHECK: [[VAL_366:%.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor // CHECK: return [[VAL_366]] : tensor // CHECK: } + +// CHECK-LABEL: func @complex( +// CHECK-SAME: [[VAL_367:%.*]]: tensor<3xf32>, [[VAL_368:%.*]]: tensor<3xf32>) -> tensor<3xcomplex> { +// CHECK: [[VAL_369:%.*]] = "tf.Complex"([[VAL_367]], [[VAL_368]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> +// CHECK: return [[VAL_369]] : tensor<3xcomplex> +// CHECK: } + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index 8a71005bf70..853fd806c5f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -64,6 +64,7 @@ def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; +def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>; //===----------------------------------------------------------------------===// // Unary op patterns. //===----------------------------------------------------------------------===//