From fdd6db860702eae4cdf15bf052e37ee2f3fb0502 Mon Sep 17 00:00:00 2001
From: Thai Nguyen <thaink@google.com>
Date: Tue, 9 Feb 2021 22:38:30 -0800
Subject: [PATCH] Support complex ops in MLIR converter

PiperOrigin-RevId: 356670767
Change-Id: I26a871b4dbada6e85c8f1a26def313dfa9a5b578
---
 tensorflow/compiler/mlir/lite/ir/tfl_ops.td   | 63 +++++++++++++++++++
 .../compiler/mlir/lite/tests/legalize-tf.mlir | 27 ++++++++
 tensorflow/compiler/mlir/lite/tests/ops.mlir  | 59 +++++++++++++++++
 .../mlir/lite/transforms/legalize_patterns.td |  6 ++
 tensorflow/lite/testing/split.h               |  9 +++
 tensorflow/lite/testing/tflite_driver.cc      | 20 ++++++
 tensorflow/lite/testing/zip_test_utils.py     |  4 ++
 7 files changed, 188 insertions(+)

diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
index 6d254d78f7a..6f9021e45a0 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td
@@ -4683,4 +4683,67 @@ def TFL_Conv3DOp : TFL_Op<"conv_3d", [
   let customOption = "Conv3DOptions";
 }
 
+def TFL_ComplexAbsOp : TFL_Op<"complex_abs", [
+  NoSideEffect,
+  SameOperandsAndResultShape]> {
+  let summary = "Computes the complex absolute value of a tensor.";
+
+  let description = [{
+Given a tensor `x` of complex numbers, this operation returns a tensor of type
+`float` or `double` that is the absolute value of each element in `x`. All
+elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute
+value is computed as \\( \sqrt{a^2 + b^2}\\).
+  }];
+
+  let arguments = (ins
+    TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
+  );
+
+  let results = (outs
+    TFL_TensorOf<[F32, F64]>:$output
+  );
+}
+
+def TFL_RealOp : TFL_Op<"real", [
+  NoSideEffect,
+  SameOperandsAndResultShape]> {
+  let summary = "Returns the real part of a complex number.";
+
+  let description = [{
+Given a tensor `input` of complex numbers, this operation returns a tensor of
+type `float` that is the real part of each element in `input`. All elements in
+`input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
+ part returned by this operation and *b* is the imaginary part.
+  }];
+
+  let arguments = (ins
+    TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
+  );
+
+  let results = (outs
+    TFL_TensorOf<[F32, F64]>:$output
+  );
+}
+
+def TFL_ImagOp : TFL_Op<"imag", [
+  NoSideEffect,
+  SameOperandsAndResultShape]> {
+  let summary = "Returns the imaginary part of a complex number.";
+
+  let description = [{
+Given a tensor `input` of complex numbers, this operation returns a tensor of
+type `float` that is the imaginary part of each element in `input`. All
+elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
+is the real part and *b* is the imaginary part returned by this operation.
+  }];
+
+  let arguments = (ins
+    TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
+  );
+
+  let results = (outs
+    TFL_TensorOf<[F32, F64]>:$output
+  );
+}
+
 #endif // TFL_OPS
diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
index dc85d9a79c0..c0ffa1641e7 100644
--- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
@@ -2015,3 +2015,30 @@ func @conv3d_invalid_strides(%arg0: tensor<?x?x?x?x?xf32>,%arg1:  tensor<?x?x?x?
   // CHECK:  [[BCT:%.*]] = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
   // CHECK:  return [[BCT]] : tensor<?x?x?x?x?xf32>
 }
+
+func @complex_abs(%arg0: tensor<1 x complex<f32>>) -> tensor<1xf32> {
+  %0 = "tf.ComplexAbs"(%arg0) : (tensor<1 x complex<f32>>) -> tensor<1xf32>
+  return %0: tensor<1xf32>
+
+// CHECK-LABEL: complex_abs
+// CHECK:  "tfl.complex_abs"(%arg0) : (tensor<1xcomplex<f32>>) -> tensor<1xf32>
+// CHECK:  return
+}
+
+func @real(%arg0: tensor<1 x complex<f64>>) -> tensor<1xf64> {
+  %0 = "tf.Real"(%arg0) : (tensor<1 x complex<f64>>) -> tensor<1xf64>
+  return %0: tensor<1xf64>
+
+// CHECK-LABEL: real
+// CHECK:  "tfl.real"(%arg0) : (tensor<1xcomplex<f64>>) -> tensor<1xf64>
+// CHECK:  return
+}
+
+func @imag(%arg0: tensor<1 x complex<f64>>) -> tensor<1xf64> {
+  %0 = "tf.Imag"(%arg0) : (tensor<1 x complex<f64>>) -> tensor<1xf64>
+  return %0: tensor<1xf64>
+
+// CHECK-LABEL: imag
+// CHECK:  "tfl.imag"(%arg0) : (tensor<1xcomplex<f64>>) -> tensor<1xf64>
+// CHECK:  return
+}
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index 00b9d67ddc5..eed2bf03010 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -2568,3 +2568,62 @@ func @testConv3dMisMatchBiasType(%arg0: tensor<2x3x4x5x2xf32>,%arg1:  tensor<2x2
   %0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x2x3xf32>, tensor<3xi32>) -> tensor<?x?x?x?x?xf32>
   return %0 : tensor<?x?x?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: testComplexAbs
+func @testComplexAbs(%arg0: tensor<? x complex<f32>>) -> tensor<?xf32> {
+  // CHECK: "tfl.complex_abs"(%arg0)
+  %0 = "tfl.complex_abs"(%arg0): (tensor<? x complex<f32>>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+func @testComplexAbsUnsupportedType(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+  // expected-error @+1 {{operand #0 must be tensor of complex type with 32-bit float elements or complex type with 64-bit float elements values}}
+  %0 = "tfl.complex_abs"(%arg0): (tensor<?xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+func @testComplexAbsWrongShape(%arg0: tensor<2 x complex<f32>>) -> tensor<3xf32> {
+  // expected-error @+1 {{requires the same shape for all operands and results}}
+  %0 = "tfl.complex_abs"(%arg0): (tensor<2 x complex<f32>>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: testReal
+func @testReal(%arg0: tensor<? x complex<f64>>) -> tensor<?xf64> {
+  // CHECK: "tfl.real"(%arg0)
+  %0 = "tfl.real"(%arg0): (tensor<? x complex<f64>>) -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}
+
+// -----
+
+func @testRealWrongShape(%arg0: tensor<3 x complex<f64>>) -> tensor<4xf32> {
+  // expected-error @+1 {{requires the same shape for all operands and results}}
+  %0 = "tfl.real"(%arg0): (tensor<3 x complex<f64>>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: testImag
+func @testImag(%arg0: tensor<? x complex<f64>>) -> tensor<?xf64> {
+  // CHECK: "tfl.imag"(%arg0)
+  %0 = "tfl.imag"(%arg0): (tensor<? x complex<f64>>) -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}
+
+// -----
+
+func @testImagWrongType(%arg0: tensor<3 x complex<f64>>) -> tensor<4xi32> {
+  // expected-error @+1 {{requires the same shape for all operands and results}}
+  %0 = "tfl.imag"(%arg0): (tensor<3 x complex<f64>>) -> tensor<4xi32>
+  return %0 : tensor<4xi32>
+}
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
index b1403870430..80a4b3baac7 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
@@ -495,3 +495,9 @@ def LegalizeStridedSlice : Pat<
 def LegalizeRfft2d : Pat<
   (TF_RFFT2DOp $input, $fft_length),
   (TFL_RFFT2dOp $input, $fft_length)>;
+
+def LegalizeComplexAbs : Pat<(TF_ComplexAbsOp $arg), (TFL_ComplexAbsOp $arg)>;
+
+def LegalizeReal : Pat<(TF_RealOp $arg), (TFL_RealOp $arg)>;
+
+def LegalizeImag : Pat<(TF_ImagOp $arg), (TFL_ImagOp $arg)>;
diff --git a/tensorflow/lite/testing/split.h b/tensorflow/lite/testing/split.h
index 7e41060571c..c23f6f90ce0 100644
--- a/tensorflow/lite/testing/split.h
+++ b/tensorflow/lite/testing/split.h
@@ -87,6 +87,15 @@ inline std::vector<float> Split(const string& s, const string& delimiter) {
   return fields;
 }
 
+template <>
+inline std::vector<double> Split(const string& s, const string& delimiter) {
+  std::vector<double> fields;
+  for (const auto& p : SplitToPos(s, delimiter)) {
+    fields.push_back(strtod(s.data() + p.first, nullptr));
+  }
+  return fields;
+}
+
 template <>
 inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
   std::vector<uint8_t> fields;
diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc
index f2d8ee01d57..c858bf051d1 100644
--- a/tensorflow/lite/testing/tflite_driver.cc
+++ b/tensorflow/lite/testing/tflite_driver.cc
@@ -353,6 +353,8 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
     case kTfLiteComplex128:
       return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
                                                                     tensor);
+    case kTfLiteFloat64:
+      return TypedCheck<double, double>(verbose, tensor);
     default:
       fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
       return false;
@@ -528,6 +530,21 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) {
 
       break;
     }
+    case kTfLiteComplex64: {
+      const auto& values = testing::Split<std::complex<float>>(csv_values, ",");
+      if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size()))
+        return;
+      SetTensorData(values, tensor->data.raw);
+      break;
+    }
+    case kTfLiteComplex128: {
+      const auto& values =
+          testing::Split<std::complex<double>>(csv_values, ",");
+      if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size()))
+        return;
+      SetTensorData(values, tensor->data.raw);
+      break;
+    }
     default:
       Invalidate(absl::StrCat("Unsupported tensor type ",
                               TfLiteTypeGetName(tensor->type),
@@ -590,6 +607,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
     case kTfLiteString:
       expected_output_[id]->SetData<string>(csv_values);
       break;
+    case kTfLiteFloat64:
+      expected_output_[id]->SetData<double>(csv_values);
+      break;
     case kTfLiteComplex64:
       expected_output_[id]->SetData<std::complex<float>>(csv_values);
       break;
diff --git a/tensorflow/lite/testing/zip_test_utils.py b/tensorflow/lite/testing/zip_test_utils.py
index b76f26f4873..4639ce0a515 100644
--- a/tensorflow/lite/testing/zip_test_utils.py
+++ b/tensorflow/lite/testing/zip_test_utils.py
@@ -111,6 +111,10 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
 
   if dtype in (tf.float32, tf.float16, tf.float64):
     value = (max_value - min_value) * np.random.random_sample(shape) + min_value
+  elif dtype in (tf.complex64, tf.complex128):
+    real = (max_value - min_value) * np.random.random_sample(shape) + min_value
+    imag = (max_value - min_value) * np.random.random_sample(shape) + min_value
+    value = real + imag * 1j
   elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
     value = np.random.randint(min_value, max_value + 1, shape)
   elif dtype == tf.bool: