[hlo2tf] Add complex -> complex
PiperOrigin-RevId: 305591936 Change-Id: I1c5e8ded4ab8257003e6accd6c62a105ca5452ae
This commit is contained in:
parent
72600af909
commit
a70db4d1d1
|
@ -677,6 +677,11 @@ func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
|
|||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
|
||||
return %0 : tensor<3xcomplex<f32>>
|
||||
}
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_NHWC(
|
||||
|
@ -1481,3 +1486,10 @@ func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
|
|||
// CHECK: [[VAL_366:%.*]] = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: return [[VAL_366]] : tensor<i64>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @complex(
|
||||
// CHECK-SAME: [[VAL_367:%.*]]: tensor<3xf32>, [[VAL_368:%.*]]: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
|
||||
// CHECK: [[VAL_369:%.*]] = "tf.Complex"([[VAL_367]], [[VAL_368]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
|
||||
// CHECK: return [[VAL_369]] : tensor<3xcomplex<f32>>
|
||||
// CHECK: }
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
|||
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue