Promote rfft2d to builtin op and add mlir conversion support.

PiperOrigin-RevId: 344732711
Change-Id: I811c45a03d7c204120f2c9bc491e39b32cc5222b
This commit is contained in:
Renjie Liu 2020-11-29 18:36:48 -08:00 committed by TensorFlower Gardener
parent 4d1142b04b
commit bc295ed3bc
19 changed files with 689 additions and 568 deletions

View File

@ -50,6 +50,8 @@
* Added support for saved model's session initializer through
`TFLiteConverter.from_saved_model`.
* Added dynamic range quantization support for the BatchMatMul op.
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
only supports float32 input.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,

View File

@ -172,6 +172,8 @@ def TFL_FpTensor : TFL_TensorOf<[F32]>;
def TFL_I32OrI64Tensor : TFL_TensorOf<[TFL_Int32Or64]>;
def TFL_I32Tensor : TFL_TensorOf<[I32]>;
def TFL_I64Tensor : TFL_TensorOf<[I64]>;
def TFL_Complex64Tensor : TFL_TensorOf<[Complex<F<32>>]>;
// TODO(jpienaar): Expand to all int types.
def TFL_IntTensor : TypeAlias<TFL_I32Tensor, "tensor of any integer type">;
@ -4481,4 +4483,31 @@ subsequent operation and then be optimized away, however.)
);
}
def TFL_RFFT2dOp : TFL_Op<"RFFT2D", [NoSideEffect, NoQuantizableResult]> {
let summary = "2D real-valued fast Fourier transform.";
let description = [{
Computes the 2-dimensional discrete Fourier transform of a real-valued signal
over the inner-most 2 dimensions of `input`.
Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
of `output`: the zero-frequency term, followed by the `fft_length / 2`
positive-frequency terms.
Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
corresponding dimension of `input`, the dimension is cropped. If it is larger,
the dimension is padded with zeros.
}];
let arguments = (ins
TFL_FpTensor:$input,
TFL_I32Tensor:$fft_length
);
let results = (outs
TFL_Complex64Tensor:$output
);
}
#endif // TFL_OPS

View File

@ -1966,3 +1966,17 @@ func @segmentsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<*xf32
// CHECK: "tfl.cast"
// CHECK: "tfl.segment_sum"
}
func @rfft2d(%arg0: tensor<10x20x10x30xf32>, %arg1: tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>> {
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
return %0 : tensor<10x20x10x30xcomplex<f32>>
// CHECK-LABEL: rfft2d
// CHECK: "tfl.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
}
func @rfft2d_invalid(%arg0: tensor<10x20x10x30xf64>, %arg1: tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f64>> {
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf64>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f64>>
return %0 : tensor<10x20x10x30xcomplex<f64>>
// CHECK-LABEL: rfft2d_invalid
// CHECK-NOT: "tfl.RFFT2D"
}

View File

@ -491,3 +491,7 @@ def LegalizeStridedSlice : Pat<
(convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
(convertIntAttrTo32Bit $new_axis_mask),
(convertIntAttrTo32Bit $shrink_axis_mask))>;
def LegalizeRfft2d : Pat<
(TF_RFFT2DOp $input, $fft_length),
(TFL_RFFT2dOp $input, $fft_length)>;

View File

@ -359,7 +359,6 @@ def generated_test_models():
"resolve_constant_strided_slice",
"reverse_sequence",
"reverse_v2",
"rfft2d",
"round",
"rsqrt",
"scatter_nd",

View File

@ -158,6 +158,7 @@ typedef enum {
kTfLiteBuiltinCumsum = 128,
kTfLiteBuiltinCallOnce = 129,
kTfLiteBuiltinBroadcastTo = 130,
kTfLiteBuiltinRfft2d = 131,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -823,6 +823,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_DENSIFY:
case BuiltinOperator_SEGMENT_SUM:
case BuiltinOperator_BROADCAST_TO:
case BuiltinOperator_RFFT2D:
return kTfLiteOk;
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
return kTfLiteError;

View File

@ -621,6 +621,7 @@ BUILTIN_KERNEL_SRCS = [
"where.cc",
"while.cc",
"zeros_like.cc",
"rfft2d.cc",
]
BUILTIN_KERNEL_DEPS = [
@ -669,10 +670,12 @@ cc_library(
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
visibility = ["//visibility:private"],
deps = BUILTIN_KERNEL_DEPS + [
"@fft2d",
"@ruy//ruy/profiler:instrumentation",
"//tensorflow/lite/kernels/internal:cppmath",
"//tensorflow/lite:string",
"@farmhash_archive//:farmhash",
"//third_party/fft2d:fft2d_headers",
],
)
@ -713,7 +716,6 @@ cc_library(
"complex_support.cc",
"multinomial.cc",
"random_standard_normal.cc",
"rfft2d.cc",
],
hdrs = ["custom_ops_register.h"],
copts = tflite_copts(),
@ -722,8 +724,6 @@ cc_library(
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:types",
"//third_party/fft2d:fft2d_headers",
"@fft2d",
"@ruy//ruy/profiler:instrumentation",
],
)
@ -2187,13 +2187,9 @@ cc_test(
size = "small",
srcs = ["rfft2d_test.cc"],
deps = [
":custom_ops",
":test_main",
":test_util",
"//tensorflow/lite:framework",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)

View File

@ -118,6 +118,7 @@ TfLiteRegistration* Register_RESIZE_BILINEAR();
TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR();
TfLiteRegistration* Register_REVERSE_SEQUENCE();
TfLiteRegistration* Register_REVERSE_V2();
TfLiteRegistration* Register_RFFT2D();
TfLiteRegistration* Register_RNN();
TfLiteRegistration* Register_ROUND();
TfLiteRegistration* Register_RSQRT();

View File

@ -29,7 +29,6 @@ TfLiteRegistration* Register_IMAG();
TfLiteRegistration* Register_MULTINOMIAL();
TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL();
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_RFFT2D();
} // namespace custom
} // namespace ops

View File

@ -309,6 +309,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_CALL_ONCE,
tflite::ops::builtin::Register_CALL_ONCE());
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -31,7 +31,7 @@ limitations under the License.
namespace tflite {
namespace ops {
namespace custom {
namespace builtin {
namespace rfft2d {
using std::complex;
@ -467,6 +467,6 @@ TfLiteRegistration* Register_RFFT2D() {
return &r;
}
} // namespace custom
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -19,17 +19,12 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_RFFT2D();
namespace builtin {
namespace {
@ -44,8 +39,8 @@ class Rfft2dOpModel : public SingleOpModel {
TensorType output_type = TensorType_COMPLEX64;
output_ = AddOutput({output_type, {}});
const std::vector<uint8_t> custom_option;
SetCustomOp("Rfft2d", custom_option, Register_RFFT2D);
SetBuiltinOp(BuiltinOperator_RFFT2D, BuiltinOptions_Rfft2dOptions,
CreateRfft2dOptions(builder_).Union());
BuildInterpreter({GetShape(input_)});
}
@ -147,6 +142,6 @@ TEST(Rfft2dOpTest, InputDimsGreaterThan2) {
}
} // namespace
} // namespace custom
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -355,7 +355,8 @@ enum BuiltinOperator : int32 {
PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
CUMSUM = 128,
CALL_ONCE = 129,
BROADCAST_TO = 130
BROADCAST_TO = 130,
RFFT2D = 131,
}
@ -464,7 +465,8 @@ union BuiltinOptions {
BatchMatMulOptions,
CumsumOptions,
CallOnceOptions,
BroadcastToOptions
BroadcastToOptions,
Rfft2dOptions,
}
enum Padding : byte { SAME, VALID }
@ -1004,6 +1006,9 @@ table CumsumOptions {
table BroadcastToOptions {
}
table Rfft2dOptions {
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

File diff suppressed because it is too large Load Diff

View File

@ -124,7 +124,6 @@ from tensorflow.lite.testing.op_tests.resize_nearest_neighbor import make_resize
from tensorflow.lite.testing.op_tests.resolve_constant_strided_slice import make_resolve_constant_strided_slice_tests
from tensorflow.lite.testing.op_tests.reverse_sequence import make_reverse_sequence_tests
from tensorflow.lite.testing.op_tests.reverse_v2 import make_reverse_v2_tests
from tensorflow.lite.testing.op_tests.rfft2d import make_rfft2d_tests
from tensorflow.lite.testing.op_tests.round import make_round_tests
from tensorflow.lite.testing.op_tests.scatter_nd import make_scatter_nd_tests
from tensorflow.lite.testing.op_tests.shape import make_shape_tests

View File

@ -1,57 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test configs for rfft2d."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
from tensorflow.lite.testing.zip_test_utils import ExtraTocoOptions
from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
from tensorflow.lite.testing.zip_test_utils import register_make_test_function
@register_make_test_function()
def make_rfft2d_tests(options):
"""Make a set of tests to do rfft2d."""
test_parameters = [{
"input_dtype": [tf.float32],
"input_shape": [[8, 8], [3, 8, 8], [3, 1, 16]],
"fft_length": [
None, [4, 4], [4, 8], [8, 4], [8, 8], [8, 16], [16, 8], [16, 16],
[1, 8], [1, 16]
]
}]
def build_graph(parameters):
input_value = tf.compat.v1.placeholder(
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
outs = tf.signal.rfft2d(input_value, fft_length=parameters["fft_length"])
return [input_value], [outs]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
parameters["input_shape"])
return [input_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value])))
extra_toco_options = ExtraTocoOptions()
extra_toco_options.allow_custom_ops = True
make_zip_of_tests(options, test_parameters, build_graph, build_inputs,
extra_toco_options)

View File

@ -369,8 +369,6 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
new ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
buildinop_resolver_->AddCustom("RFFT2D",
tflite::ops::custom::Register_RFFT2D());
tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
}

View File

@ -334,6 +334,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_WHILE, 1}, "1.15.0"},
{{BuiltinOperator_CUMSUM, 1}, "2.4.0"},
{{BuiltinOperator_CALL_ONCE, 1}, kPendingReleaseVersion},
{{BuiltinOperator_RFFT2D, 1}, kPendingReleaseVersion},
});
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};