Support complex ops in MLIR converter
PiperOrigin-RevId: 356199595 Change-Id: I6119893bc6a307e2bd3436d64f498dfa4c3a8329
This commit is contained in:
parent
d8f9b2e88e
commit
318b0fcd3d
tensorflow
compiler/mlir/lite
lite/testing
@ -4683,67 +4683,4 @@ def TFL_Conv3DOp : TFL_Op<"conv_3d", [
|
||||
let customOption = "Conv3DOptions";
|
||||
}
|
||||
|
||||
def TFL_ComplexAbsOp : TFL_Op<"complex_abs", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
let summary = "Computes the complex absolute value of a tensor.";
|
||||
|
||||
let description = [{
|
||||
Given a tensor `x` of complex numbers, this operation returns a tensor of type
|
||||
`float` or `double` that is the absolute value of each element in `x`. All
|
||||
elements in `x` must be complex numbers of the form \\(a + bj\\). The absolute
|
||||
value is computed as \\( \sqrt{a^2 + b^2}\\).
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, F64]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_RealOp : TFL_Op<"real", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
let summary = "Returns the real part of a complex number.";
|
||||
|
||||
let description = [{
|
||||
Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||
type `float` that is the real part of each element in `input`. All elements in
|
||||
`input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
|
||||
part returned by this operation and *b* is the imaginary part.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, F64]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_ImagOp : TFL_Op<"imag", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
let summary = "Returns the imaginary part of a complex number.";
|
||||
|
||||
let description = [{
|
||||
Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||
type `float` that is the imaginary part of each element in `input`. All
|
||||
elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
|
||||
is the real part and *b* is the imaginary part returned by this operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[Complex<F<32>>, Complex<F<64>>]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, F64]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -2015,30 +2015,3 @@ func @conv3d_invalid_strides(%arg0: tensor<?x?x?x?x?xf32>,%arg1: tensor<?x?x?x?
|
||||
// CHECK: [[BCT:%.*]] = "tf.Conv3D"(%arg0, %arg1) {padding = "SAME", strides = [2, 1, 1, 1, 1]} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: return [[BCT]] : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
func @complex_abs(%arg0: tensor<1 x complex<f32>>) -> tensor<1xf32> {
|
||||
%0 = "tf.ComplexAbs"(%arg0) : (tensor<1 x complex<f32>>) -> tensor<1xf32>
|
||||
return %0: tensor<1xf32>
|
||||
|
||||
// CHECK-LABEL: complex_abs
|
||||
// CHECK: "tfl.complex_abs"(%arg0) : (tensor<1xcomplex<f32>>) -> tensor<1xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @real(%arg0: tensor<1 x complex<f64>>) -> tensor<1xf64> {
|
||||
%0 = "tf.Real"(%arg0) : (tensor<1 x complex<f64>>) -> tensor<1xf64>
|
||||
return %0: tensor<1xf64>
|
||||
|
||||
// CHECK-LABEL: real
|
||||
// CHECK: "tfl.real"(%arg0) : (tensor<1xcomplex<f64>>) -> tensor<1xf64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @imag(%arg0: tensor<1 x complex<f64>>) -> tensor<1xf64> {
|
||||
%0 = "tf.Imag"(%arg0) : (tensor<1 x complex<f64>>) -> tensor<1xf64>
|
||||
return %0: tensor<1xf64>
|
||||
|
||||
// CHECK-LABEL: imag
|
||||
// CHECK: "tfl.imag"(%arg0) : (tensor<1xcomplex<f64>>) -> tensor<1xf64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
@ -2568,62 +2568,3 @@ func @testConv3dMisMatchBiasType(%arg0: tensor<2x3x4x5x2xf32>,%arg1: tensor<2x2
|
||||
%0 = "tfl.conv_3d"(%arg0, %arg1, %arg2) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"}: (tensor<2x3x4x5x2xf32>, tensor<2x2x2x2x3xf32>, tensor<3xi32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testComplexAbs
|
||||
func @testComplexAbs(%arg0: tensor<? x complex<f32>>) -> tensor<?xf32> {
|
||||
// CHECK: "tfl.complex_abs"(%arg0)
|
||||
%0 = "tfl.complex_abs"(%arg0): (tensor<? x complex<f32>>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testComplexAbsUnsupportedType(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{operand #0 must be tensor of complex type with 32-bit float elements or complex type with 64-bit float elements values}}
|
||||
%0 = "tfl.complex_abs"(%arg0): (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testComplexAbsWrongShape(%arg0: tensor<2 x complex<f32>>) -> tensor<3xf32> {
|
||||
// expected-error @+1 {{requires the same shape for all operands and results}}
|
||||
%0 = "tfl.complex_abs"(%arg0): (tensor<2 x complex<f32>>) -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testReal
|
||||
func @testReal(%arg0: tensor<? x complex<f64>>) -> tensor<?xf64> {
|
||||
// CHECK: "tfl.real"(%arg0)
|
||||
%0 = "tfl.real"(%arg0): (tensor<? x complex<f64>>) -> tensor<?xf64>
|
||||
return %0 : tensor<?xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testRealWrongShape(%arg0: tensor<3 x complex<f64>>) -> tensor<4xf32> {
|
||||
// expected-error @+1 {{requires the same shape for all operands and results}}
|
||||
%0 = "tfl.real"(%arg0): (tensor<3 x complex<f64>>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testImag
|
||||
func @testImag(%arg0: tensor<? x complex<f64>>) -> tensor<?xf64> {
|
||||
// CHECK: "tfl.imag"(%arg0)
|
||||
%0 = "tfl.imag"(%arg0): (tensor<? x complex<f64>>) -> tensor<?xf64>
|
||||
return %0 : tensor<?xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testImagWrongType(%arg0: tensor<3 x complex<f64>>) -> tensor<4xi32> {
|
||||
// expected-error @+1 {{requires the same shape for all operands and results}}
|
||||
%0 = "tfl.imag"(%arg0): (tensor<3 x complex<f64>>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
@ -495,9 +495,3 @@ def LegalizeStridedSlice : Pat<
|
||||
def LegalizeRfft2d : Pat<
|
||||
(TF_RFFT2DOp $input, $fft_length),
|
||||
(TFL_RFFT2dOp $input, $fft_length)>;
|
||||
|
||||
def LegalizeComplexAbs : Pat<(TF_ComplexAbsOp $arg), (TFL_ComplexAbsOp $arg)>;
|
||||
|
||||
def LegalizeReal : Pat<(TF_RealOp $arg), (TFL_RealOp $arg)>;
|
||||
|
||||
def LegalizeImag : Pat<(TF_ImagOp $arg), (TFL_ImagOp $arg)>;
|
||||
|
@ -87,15 +87,6 @@ inline std::vector<float> Split(const string& s, const string& delimiter) {
|
||||
return fields;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<double> Split(const string& s, const string& delimiter) {
|
||||
std::vector<double> fields;
|
||||
for (const auto& p : SplitToPos(s, delimiter)) {
|
||||
fields.push_back(strtod(s.data() + p.first, nullptr));
|
||||
}
|
||||
return fields;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::vector<uint8_t> Split(const string& s, const string& delimiter) {
|
||||
std::vector<uint8_t> fields;
|
||||
|
@ -353,8 +353,6 @@ bool TfLiteDriver::DataExpectation::Check(bool verbose,
|
||||
case kTfLiteComplex128:
|
||||
return TypedCheck<std::complex<double>, std::complex<double>>(verbose,
|
||||
tensor);
|
||||
case kTfLiteFloat64:
|
||||
return TypedCheck<double, double>(verbose, tensor);
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
|
||||
return false;
|
||||
@ -530,21 +528,6 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) {
|
||||
|
||||
break;
|
||||
}
|
||||
case kTfLiteComplex64: {
|
||||
const auto& values = testing::Split<std::complex<float>>(csv_values, ",");
|
||||
if (!CheckSizes<std::complex<float>>(tensor->bytes, values.size()))
|
||||
return;
|
||||
SetTensorData(values, tensor->data.raw);
|
||||
break;
|
||||
}
|
||||
case kTfLiteComplex128: {
|
||||
const auto& values =
|
||||
testing::Split<std::complex<double>>(csv_values, ",");
|
||||
if (!CheckSizes<std::complex<double>>(tensor->bytes, values.size()))
|
||||
return;
|
||||
SetTensorData(values, tensor->data.raw);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
Invalidate(absl::StrCat("Unsupported tensor type ",
|
||||
TfLiteTypeGetName(tensor->type),
|
||||
@ -607,9 +590,6 @@ void TfLiteDriver::SetExpectation(int id, const string& csv_values) {
|
||||
case kTfLiteString:
|
||||
expected_output_[id]->SetData<string>(csv_values);
|
||||
break;
|
||||
case kTfLiteFloat64:
|
||||
expected_output_[id]->SetData<double>(csv_values);
|
||||
break;
|
||||
case kTfLiteComplex64:
|
||||
expected_output_[id]->SetData<std::complex<float>>(csv_values);
|
||||
break;
|
||||
|
@ -111,10 +111,6 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
|
||||
|
||||
if dtype in (tf.float32, tf.float16, tf.float64):
|
||||
value = (max_value - min_value) * np.random.random_sample(shape) + min_value
|
||||
elif dtype in (tf.complex64, tf.complex128):
|
||||
real = (max_value - min_value) * np.random.random_sample(shape) + min_value
|
||||
imag = (max_value - min_value) * np.random.random_sample(shape) + min_value
|
||||
value = real + imag * 1j
|
||||
elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16):
|
||||
value = np.random.randint(min_value, max_value + 1, shape)
|
||||
elif dtype == tf.bool:
|
||||
|
Loading…
Reference in New Issue
Block a user