Support complex ops in MLIR converter

PiperOrigin-RevId: 356199595
Change-Id: I6119893bc6a307e2bd3436d64f498dfa4c3a8329
This commit is contained in:
A. Unique TensorFlower 2021-02-07 23:25:07 -08:00 committed by TensorFlower Gardener
parent d8f9b2e88e
commit 318b0fcd3d
7 changed files with 0 additions and 188 deletions

View File

@ -4683,67 +4683,4 @@ def TFL_Conv3DOp : TFL_Op<"conv_3d", [
let customOption = "Conv3DOptions";
}
def TFL_ComplexAbsOp : TFL_Op<"complex_abs", [
NoSideEffect,
SameOperandsAndResultShape]> {
let summary = "Computes the complex absolute value of a tensor.";
let description = [{
Given a tensor `x` of complex numbers, this operation returns a tensor of type
`float` or `double` that is the absolute value of each element in `x`. All
elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute
value is computed as \\( \sqrt{a^2 + b^2}\\).
}];
let arguments = (ins
TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
);
let results = (outs
TFL_TensorOf<[F32, F64]>:$output
);
}
def TFL_RealOp : TFL_Op<"real", [
NoSideEffect,
SameOperandsAndResultShape]> {
let summary = "Returns the real part of a complex number.";
let description = [{
Given a tensor `input` of complex numbers, this operation returns a tensor of
type `float` that is the real part of each element in `input`. All elements in
`input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
part returned by this operation and *b* is the imaginary part.
}];
let arguments = (ins
TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
);
let results = (outs
TFL_TensorOf<[F32, F64]>:$output
);
}
def TFL_ImagOp : TFL_Op<"imag", [
NoSideEffect,
SameOperandsAndResultShape]> {
let summary = "Returns the imaginary part of a complex number.";
let description = [{
Given a tensor `input` of complex numbers, this operation returns a tensor of
type `float` that is the imaginary part of each element in `input`. All
elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
is the real part and *b* is the imaginary part returned by this operation.
}];
let arguments = (ins
TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
);
let results = (outs
TFL_TensorOf<[F32, F64]>:$output
);
}
#endif // TFL_OPS

View File

@ -2015,30 +2015,3 @@ func @conv3d_invalid_strides(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?
// CHECK: [[BCT:%.*]] = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: return [[BCT]] : tensor<?x?x?x?x?xf32>
}
func @complex_abs(%arg0: tensor<1 x complex<f32>>) -> tensor<1xf32> {
%0 = "tf.ComplexAbs"(%arg0) : (tensor<1 x complex<f32>>) -> tensor<1xf32>
return %0: tensor<1xf32>
// CHECK-LABEL: complex_abs
// CHECK: "tfl.complex_abs"(%arg0) : (tensor<1xcomplex<f32>>) -> tensor<1xf32>
// CHECK: return
}
func @real(%arg0: tensor<1 x complex<f64>>) -> tensor<1xf64> {
%0 = "tf.Real"(%arg0) : (tensor<1 x complex<f64>>) -> tensor<1xf64>
return %0: tensor<1xf64>
// CHECK-LABEL: real
// CHECK: "tfl.real"(%arg0) : (tensor<1xcomplex<f64>>) -> tensor<1xf64>
// CHECK: return
}
func @imag(%arg0: tensor<1 x complex<f64>>) -> tensor<1xf64> {
%0 = "tf.Imag"(%arg0) : (tensor<1 x complex<f64>>) -> tensor<1xf64>
return %0: tensor<1xf64>
// CHECK-LABEL: imag
// CHECK: "tfl.imag"(%arg0) : (tensor<1xcomplex<f64>>) -> tensor<1xf64>
// CHECK: return
}

View File

@ -2568,62 +2568,3 @@ func @testConv3dMisMatchBiasType(%arg0: tensor<2x3x4x5x2xf32>,%arg1: tensor<2x2
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x2x3xf32>, tensor<3xi32>) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
}
// -----
// CHECK-LABEL: testComplexAbs
func @testComplexAbs(%arg0: tensor<? x complex<f32>>) -> tensor<?xf32> {
// CHECK: "tfl.complex_abs"(%arg0)
%0 = "tfl.complex_abs"(%arg0): (tensor<? x complex<f32>>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testComplexAbsUnsupportedType(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// expected-error @+1 {{operand #0 must be tensor of complex type with 32-bit float elements or complex type with 64-bit float elements values}}
%0 = "tfl.complex_abs"(%arg0): (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testComplexAbsWrongShape(%arg0: tensor<2 x complex<f32>>) -> tensor<3xf32> {
// expected-error @+1 {{requires the same shape for all operands and results}}
%0 = "tfl.complex_abs"(%arg0): (tensor<2 x complex<f32>>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
// CHECK-LABEL: testReal
func @testReal(%arg0: tensor<? x complex<f64>>) -> tensor<?xf64> {
// CHECK: "tfl.real"(%arg0)
%0 = "tfl.real"(%arg0): (tensor<? x complex<f64>>) -> tensor<?xf64>
return %0 : tensor<?xf64>
}
// -----
func @testRealWrongShape(%arg0: tensor<3 x complex<f64>>) -> tensor<4xf32> {
// expected-error @+1 {{requires the same shape for all operands and results}}
%0 = "tfl.real"(%arg0): (tensor<3 x complex<f64>>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: testImag
func @testImag(%arg0: tensor<? x complex<f64>>) -> tensor<?xf64> {
// CHECK: "tfl.imag"(%arg0)
%0 = "tfl.imag"(%arg0): (tensor<? x complex<f64>>) -> tensor<?xf64>
return %0 : tensor<?xf64>
}
// -----
func @testImagWrongType(%arg0: tensor<3 x complex<f64>>) -> tensor<4xi32> {
// expected-error @+1 {{requires the same shape for all operands and results}}
%0 = "tfl.imag"(%arg0): (tensor<3 x complex<f64>>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}

View File

@ -495,9 +495,3 @@ def LegalizeStridedSlice : Pat<
def LegalizeRfft2d : Pat<
(TF_RFFT2DOp $input, $fft_length),
(TFL_RFFT2dOp $input, $fft_length)>;
def LegalizeComplexAbs : Pat<(TF_ComplexAbsOp $arg), (TFL_ComplexAbsOp $arg)>;
def LegalizeReal : Pat<(TF_RealOp $arg), (TFL_RealOp $arg)>;
def LegalizeImag : Pat<(TF_ImagOp $arg), (TFL_ImagOp $arg)>;

View File

@ -87,15 +87,6 @@ inline std::vector<float> Split(const string& s, const string& delimiter) {
return fields;
}
template <>
inline std::vector<double> Split(const string& s, const string& delimiter) {
std::vector<double> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.push_back(strtod(s.data() + p.first, nullptr));
}
return fields;
}
template <>
inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
std::vector<uint8_t> fields;

View File

@ -353,8 +353,6 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
case kTfLiteComplex128:
return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
tensor);
case kTfLiteFloat64:
return TypedCheck<double, double>(verbose, tensor);
default:
fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
return false;
@ -530,21 +528,6 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) {
break;
}
case kTfLiteComplex64: {
const auto& values = testing::Split<std::complex<float>>(csv_values, ",");
if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size()))
return;
SetTensorData(values, tensor->data.raw);
break;
}
case kTfLiteComplex128: {
const auto& values =
testing::Split<std::complex<double>>(csv_values, ",");
if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size()))
return;
SetTensorData(values, tensor->data.raw);
break;
}
default:
Invalidate(absl::StrCat("Unsupported tensor type ",
TfLiteTypeGetName(tensor->type),
@ -607,9 +590,6 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
case kTfLiteString:
expected_output_[id]->SetData<string>(csv_values);
break;
case kTfLiteFloat64:
expected_output_[id]->SetData<double>(csv_values);
break;
case kTfLiteComplex64:
expected_output_[id]->SetData<std::complex<float>>(csv_values);
break;

View File

@ -111,10 +111,6 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
if dtype in (tf.float32, tf.float16, tf.float64):
value = (max_value - min_value) * np.random.random_sample(shape) + min_value
elif dtype in (tf.complex64, tf.complex128):
real = (max_value - min_value) * np.random.random_sample(shape) + min_value
imag = (max_value - min_value) * np.random.random_sample(shape) + min_value
value = real + imag * 1j
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
value = np.random.randint(min_value, max_value + 1, shape)
elif dtype == tf.bool: