From fdd6db860702eae4cdf15bf052e37ee2f3fb0502 Mon Sep 17 00:00:00 2001 From: Thai Nguyen <thaink@google.com> Date: Tue, 9 Feb 2021 22:38:30 -0800 Subject: [PATCH] Support complex ops in MLIR converter PiperOrigin-RevId: 356670767 Change-Id: I26a871b4dbada6e85c8f1a26def313dfa9a5b578 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 63 +++++++++++++++++++ .../compiler/mlir/lite/tests/legalize-tf.mlir | 27 ++++++++ tensorflow/compiler/mlir/lite/tests/ops.mlir | 59 +++++++++++++++++ .../mlir/lite/transforms/legalize_patterns.td | 6 ++ tensorflow/lite/testing/split.h | 9 +++ tensorflow/lite/testing/tflite_driver.cc | 20 ++++++ tensorflow/lite/testing/zip_test_utils.py | 4 ++ 7 files changed, 188 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 6d254d78f7a..6f9021e45a0 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -4683,4 +4683,67 @@ def TFL_Conv3DOp : TFL_Op<"conv_3d", [ let customOption = "Conv3DOptions"; } +def TFL_ComplexAbsOp : TFL_Op<"complex_abs", [ + NoSideEffect, + SameOperandsAndResultShape]> { + let summary = "Computes the complex absolute value of a tensor."; + + let description = [{ +Given a tensor `x` of complex numbers, this operation returns a tensor of type +`float` or `double` that is the absolute value of each element in `x`. All +elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute +value is computed as \\( \sqrt{a^2 + b^2}\\). + }]; + + let arguments = (ins + TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input + ); + + let results = (outs + TFL_TensorOf<[F32, F64]>:$output + ); +} + +def TFL_RealOp : TFL_Op<"real", [ + NoSideEffect, + SameOperandsAndResultShape]> { + let summary = "Returns the real part of a complex number."; + + let description = [{ +Given a tensor `input` of complex numbers, this operation returns a tensor of +type `float` that is the real part of each element in `input`. All elements in +`input` must be complex numbers of the form \\(a + bj\\), where *a* is the real + part returned by this operation and *b* is the imaginary part. + }]; + + let arguments = (ins + TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input + ); + + let results = (outs + TFL_TensorOf<[F32, F64]>:$output + ); +} + +def TFL_ImagOp : TFL_Op<"imag", [ + NoSideEffect, + SameOperandsAndResultShape]> { + let summary = "Returns the imaginary part of a complex number."; + + let description = [{ +Given a tensor `input` of complex numbers, this operation returns a tensor of +type `float` that is the imaginary part of each element in `input`. All +elements in `input` must be complex numbers of the form \\(a + bj\\), where *a* +is the real part and *b* is the imaginary part returned by this operation. + }]; + + let arguments = (ins + TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input + ); + + let results = (outs + TFL_TensorOf<[F32, F64]>:$output + ); +} + #endif // TFL_OPS diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index dc85d9a79c0..c0ffa1641e7 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -2015,3 +2015,30 @@ func @conv3d_invalid_strides(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x? // CHECK: [[BCT:%.*]] = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> // CHECK: return [[BCT]] : tensor<?x?x?x?x?xf32> } + +func @complex_abs(%arg0: tensor<1 x complex<f32>>) -> tensor<1xf32> { + %0 = "tf.ComplexAbs"(%arg0) : (tensor<1 x complex<f32>>) -> tensor<1xf32> + return %0: tensor<1xf32> + +// CHECK-LABEL: complex_abs +// CHECK: "tfl.complex_abs"(%arg0) : (tensor<1xcomplex<f32>>) -> tensor<1xf32> +// CHECK: return +} + +func @real(%arg0: tensor<1 x complex<f64>>) -> tensor<1xf64> { + %0 = "tf.Real"(%arg0) : (tensor<1 x complex<f64>>) -> tensor<1xf64> + return %0: tensor<1xf64> + +// CHECK-LABEL: real +// CHECK: "tfl.real"(%arg0) : (tensor<1xcomplex<f64>>) -> tensor<1xf64> +// CHECK: return +} + +func @imag(%arg0: tensor<1 x complex<f64>>) -> tensor<1xf64> { + %0 = "tf.Imag"(%arg0) : (tensor<1 x complex<f64>>) -> tensor<1xf64> + return %0: tensor<1xf64> + +// CHECK-LABEL: imag +// CHECK: "tfl.imag"(%arg0) : (tensor<1xcomplex<f64>>) -> tensor<1xf64> +// CHECK: return +} diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 00b9d67ddc5..eed2bf03010 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -2568,3 +2568,62 @@ func @testConv3dMisMatchBiasType(%arg0: tensor<2x3x4x5x2xf32>,%arg1: tensor<2x2 %0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x2x3xf32>, tensor<3xi32>) -> tensor<?x?x?x?x?xf32> return %0 : tensor<?x?x?x?x?xf32> } + +// ----- + +// CHECK-LABEL: testComplexAbs +func @testComplexAbs(%arg0: tensor<? x complex<f32>>) -> tensor<?xf32> { + // CHECK: "tfl.complex_abs"(%arg0) + %0 = "tfl.complex_abs"(%arg0): (tensor<? x complex<f32>>) -> tensor<?xf32> + return %0 : tensor<?xf32> +} + +// ----- + +func @testComplexAbsUnsupportedType(%arg0: tensor<?xf32>) -> tensor<?xf32> { + // expected-error @+1 {{operand #0 must be tensor of complex type with 32-bit float elements or complex type with 64-bit float elements values}} + %0 = "tfl.complex_abs"(%arg0): (tensor<?xf32>) -> tensor<?xf32> + return %0 : tensor<?xf32> +} + +// ----- + +func @testComplexAbsWrongShape(%arg0: tensor<2 x complex<f32>>) -> tensor<3xf32> { + // expected-error @+1 {{requires the same shape for all operands and results}} + %0 = "tfl.complex_abs"(%arg0): (tensor<2 x complex<f32>>) -> tensor<3xf32> + return %0 : tensor<3xf32> +} + +// ----- + +// CHECK-LABEL: testReal +func @testReal(%arg0: tensor<? x complex<f64>>) -> tensor<?xf64> { + // CHECK: "tfl.real"(%arg0) + %0 = "tfl.real"(%arg0): (tensor<? x complex<f64>>) -> tensor<?xf64> + return %0 : tensor<?xf64> +} + +// ----- + +func @testRealWrongShape(%arg0: tensor<3 x complex<f64>>) -> tensor<4xf32> { + // expected-error @+1 {{requires the same shape for all operands and results}} + %0 = "tfl.real"(%arg0): (tensor<3 x complex<f64>>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: testImag +func @testImag(%arg0: tensor<? x complex<f64>>) -> tensor<?xf64> { + // CHECK: "tfl.imag"(%arg0) + %0 = "tfl.imag"(%arg0): (tensor<? x complex<f64>>) -> tensor<?xf64> + return %0 : tensor<?xf64> +} + +// ----- + +func @testImagWrongType(%arg0: tensor<3 x complex<f64>>) -> tensor<4xi32> { + // expected-error @+1 {{requires the same shape for all operands and results}} + %0 = "tfl.imag"(%arg0): (tensor<3 x complex<f64>>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index b1403870430..80a4b3baac7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -495,3 +495,9 @@ def LegalizeStridedSlice : Pat< def LegalizeRfft2d : Pat< (TF_RFFT2DOp $input, $fft_length), (TFL_RFFT2dOp $input, $fft_length)>; + +def LegalizeComplexAbs : Pat<(TF_ComplexAbsOp $arg), (TFL_ComplexAbsOp $arg)>; + +def LegalizeReal : Pat<(TF_RealOp $arg), (TFL_RealOp $arg)>; + +def LegalizeImag : Pat<(TF_ImagOp $arg), (TFL_ImagOp $arg)>; diff --git a/tensorflow/lite/testing/split.h b/tensorflow/lite/testing/split.h index 7e41060571c..c23f6f90ce0 100644 --- a/tensorflow/lite/testing/split.h +++ b/tensorflow/lite/testing/split.h @@ -87,6 +87,15 @@ inline std::vector<float> Split(const string& s, const string& delimiter) { return fields; } +template <> +inline std::vector<double> Split(const string& s, const string& delimiter) { + std::vector<double> fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtod(s.data() + p.first, nullptr)); + } + return fields; +} + template <> inline std::vector<uint8_t> Split(const string& s, const string& delimiter) { std::vector<uint8_t> fields; diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc index f2d8ee01d57..c858bf051d1 100644 --- a/tensorflow/lite/testing/tflite_driver.cc +++ b/tensorflow/lite/testing/tflite_driver.cc @@ -353,6 +353,8 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose, case kTfLiteComplex128: return TypedCheck<std::complex<double>, std::complex<double>>(verbose, tensor); + case kTfLiteFloat64: + return TypedCheck<double, double>(verbose, tensor); default: fprintf(stderr, "Unsupported type %d in Check\n", tensor.type); return false; @@ -528,6 +530,21 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) { break; } + case kTfLiteComplex64: { + const auto& values = testing::Split<std::complex<float>>(csv_values, ","); + if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size())) + return; + SetTensorData(values, tensor->data.raw); + break; + } + case kTfLiteComplex128: { + const auto& values = + testing::Split<std::complex<double>>(csv_values, ","); + if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size())) + return; + SetTensorData(values, tensor->data.raw); + break; + } default: Invalidate(absl::StrCat("Unsupported tensor type ", TfLiteTypeGetName(tensor->type), @@ -590,6 +607,9 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) { case kTfLiteString: expected_output_[id]->SetData<string>(csv_values); break; + case kTfLiteFloat64: + expected_output_[id]->SetData<double>(csv_values); + break; case kTfLiteComplex64: expected_output_[id]->SetData<std::complex<float>>(csv_values); break; diff --git a/tensorflow/lite/testing/zip_test_utils.py b/tensorflow/lite/testing/zip_test_utils.py index b76f26f4873..4639ce0a515 100644 --- a/tensorflow/lite/testing/zip_test_utils.py +++ b/tensorflow/lite/testing/zip_test_utils.py @@ -111,6 +111,10 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): if dtype in (tf.float32, tf.float16, tf.float64): value = (max_value - min_value) * np.random.random_sample(shape) + min_value + elif dtype in (tf.complex64, tf.complex128): + real = (max_value - min_value) * np.random.random_sample(shape) + min_value + imag = (max_value - min_value) * np.random.random_sample(shape) + min_value + value = real + imag * 1j elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value + 1, shape) elif dtype == tf.bool: