Promote rfft2d to builtin op and add mlir conversion support.
PiperOrigin-RevId: 344732711 Change-Id: I811c45a03d7c204120f2c9bc491e39b32cc5222b
This commit is contained in:
parent
4d1142b04b
commit
bc295ed3bc
@ -50,6 +50,8 @@
|
||||
* Added support for saved model's session initializer through
|
||||
`TFLiteConverter.from_saved_model`.
|
||||
* Added dynamic range quantization support for the BatchMatMul op.
|
||||
* Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently
|
||||
only supports float32 input.
|
||||
|
||||
* TF Core:
|
||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||
|
@ -172,6 +172,8 @@ def TFL_FpTensor : TFL_TensorOf<[F32]>;
|
||||
def TFL_I32OrI64Tensor : TFL_TensorOf<[TFL_Int32Or64]>;
|
||||
def TFL_I32Tensor : TFL_TensorOf<[I32]>;
|
||||
def TFL_I64Tensor : TFL_TensorOf<[I64]>;
|
||||
def TFL_Complex64Tensor : TFL_TensorOf<[Complex<F<32>>]>;
|
||||
|
||||
// TODO(jpienaar): Expand to all int types.
|
||||
def TFL_IntTensor : TypeAlias<TFL_I32Tensor, "tensor of any integer type">;
|
||||
|
||||
@ -4481,4 +4483,31 @@ subsequent operation and then be optimized away, however.)
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_RFFT2dOp : TFL_Op<"RFFT2D", [NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "2D real-valued fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the 2-dimensional discrete Fourier transform of a real-valued signal
|
||||
over the inner-most 2 dimensions of `input`.
|
||||
|
||||
Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
|
||||
`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
|
||||
of `output`: the zero-frequency term, followed by the `fft_length / 2`
|
||||
positive-frequency terms.
|
||||
|
||||
Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
|
||||
corresponding dimension of `input`, the dimension is cropped. If it is larger,
|
||||
the dimension is padded with zeros.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_FpTensor:$input,
|
||||
TFL_I32Tensor:$fft_length
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_Complex64Tensor:$output
|
||||
);
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -1966,3 +1966,17 @@ func @segmentsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<*xf32
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.segment_sum"
|
||||
}
|
||||
|
||||
func @rfft2d(%arg0: tensor<10x20x10x30xf32>, %arg1: tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>> {
|
||||
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
|
||||
return %0 : tensor<10x20x10x30xcomplex<f32>>
|
||||
// CHECK-LABEL: rfft2d
|
||||
// CHECK: "tfl.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
|
||||
}
|
||||
|
||||
func @rfft2d_invalid(%arg0: tensor<10x20x10x30xf64>, %arg1: tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f64>> {
|
||||
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf64>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f64>>
|
||||
return %0 : tensor<10x20x10x30xcomplex<f64>>
|
||||
// CHECK-LABEL: rfft2d_invalid
|
||||
// CHECK-NOT: "tfl.RFFT2D"
|
||||
}
|
||||
|
@ -491,3 +491,7 @@ def LegalizeStridedSlice : Pat<
|
||||
(convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
|
||||
(convertIntAttrTo32Bit $new_axis_mask),
|
||||
(convertIntAttrTo32Bit $shrink_axis_mask))>;
|
||||
|
||||
def LegalizeRfft2d : Pat<
|
||||
(TF_RFFT2DOp $input, $fft_length),
|
||||
(TFL_RFFT2dOp $input, $fft_length)>;
|
||||
|
@ -359,7 +359,6 @@ def generated_test_models():
|
||||
"resolve_constant_strided_slice",
|
||||
"reverse_sequence",
|
||||
"reverse_v2",
|
||||
"rfft2d",
|
||||
"round",
|
||||
"rsqrt",
|
||||
"scatter_nd",
|
||||
|
@ -158,6 +158,7 @@ typedef enum {
|
||||
kTfLiteBuiltinCumsum = 128,
|
||||
kTfLiteBuiltinCallOnce = 129,
|
||||
kTfLiteBuiltinBroadcastTo = 130,
|
||||
kTfLiteBuiltinRfft2d = 131,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -823,6 +823,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_DENSIFY:
|
||||
case BuiltinOperator_SEGMENT_SUM:
|
||||
case BuiltinOperator_BROADCAST_TO:
|
||||
case BuiltinOperator_RFFT2D:
|
||||
return kTfLiteOk;
|
||||
case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES:
|
||||
return kTfLiteError;
|
||||
|
@ -621,6 +621,7 @@ BUILTIN_KERNEL_SRCS = [
|
||||
"where.cc",
|
||||
"while.cc",
|
||||
"zeros_like.cc",
|
||||
"rfft2d.cc",
|
||||
]
|
||||
|
||||
BUILTIN_KERNEL_DEPS = [
|
||||
@ -669,10 +670,12 @@ cc_library(
|
||||
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
|
||||
visibility = ["//visibility:private"],
|
||||
deps = BUILTIN_KERNEL_DEPS + [
|
||||
"@fft2d",
|
||||
"@ruy//ruy/profiler:instrumentation",
|
||||
"//tensorflow/lite/kernels/internal:cppmath",
|
||||
"//tensorflow/lite:string",
|
||||
"@farmhash_archive//:farmhash",
|
||||
"//third_party/fft2d:fft2d_headers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -713,7 +716,6 @@ cc_library(
|
||||
"complex_support.cc",
|
||||
"multinomial.cc",
|
||||
"random_standard_normal.cc",
|
||||
"rfft2d.cc",
|
||||
],
|
||||
hdrs = ["custom_ops_register.h"],
|
||||
copts = tflite_copts(),
|
||||
@ -722,8 +724,6 @@ cc_library(
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"//third_party/fft2d:fft2d_headers",
|
||||
"@fft2d",
|
||||
"@ruy//ruy/profiler:instrumentation",
|
||||
],
|
||||
)
|
||||
@ -2187,13 +2187,9 @@ cc_test(
|
||||
size = "small",
|
||||
srcs = ["rfft2d_test.cc"],
|
||||
deps = [
|
||||
":custom_ops",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
@ -118,6 +118,7 @@ TfLiteRegistration* Register_RESIZE_BILINEAR();
|
||||
TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR();
|
||||
TfLiteRegistration* Register_REVERSE_SEQUENCE();
|
||||
TfLiteRegistration* Register_REVERSE_V2();
|
||||
TfLiteRegistration* Register_RFFT2D();
|
||||
TfLiteRegistration* Register_RNN();
|
||||
TfLiteRegistration* Register_ROUND();
|
||||
TfLiteRegistration* Register_RSQRT();
|
||||
|
@ -29,7 +29,6 @@ TfLiteRegistration* Register_IMAG();
|
||||
TfLiteRegistration* Register_MULTINOMIAL();
|
||||
TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL();
|
||||
TfLiteRegistration* Register_REAL();
|
||||
TfLiteRegistration* Register_RFFT2D();
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
|
@ -309,6 +309,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version = */ 2);
|
||||
AddBuiltin(BuiltinOperator_CALL_ONCE,
|
||||
tflite::ops::builtin::Register_CALL_ONCE());
|
||||
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());
|
||||
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace builtin {
|
||||
namespace rfft2d {
|
||||
|
||||
using std::complex;
|
||||
@ -467,6 +467,6 @@ TfLiteRegistration* Register_RFFT2D() {
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -19,17 +19,12 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_RFFT2D();
|
||||
namespace builtin {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -44,8 +39,8 @@ class Rfft2dOpModel : public SingleOpModel {
|
||||
TensorType output_type = TensorType_COMPLEX64;
|
||||
output_ = AddOutput({output_type, {}});
|
||||
|
||||
const std::vector<uint8_t> custom_option;
|
||||
SetCustomOp("Rfft2d", custom_option, Register_RFFT2D);
|
||||
SetBuiltinOp(BuiltinOperator_RFFT2D, BuiltinOptions_Rfft2dOptions,
|
||||
CreateRfft2dOptions(builder_).Union());
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
@ -147,6 +142,6 @@ TEST(Rfft2dOpTest, InputDimsGreaterThan2) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace custom
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
@ -355,7 +355,8 @@ enum BuiltinOperator : int32 {
|
||||
PLACEHOLDER_FOR_GREATER_OP_CODES = 127,
|
||||
CUMSUM = 128,
|
||||
CALL_ONCE = 129,
|
||||
BROADCAST_TO = 130
|
||||
BROADCAST_TO = 130,
|
||||
RFFT2D = 131,
|
||||
}
|
||||
|
||||
|
||||
@ -464,7 +465,8 @@ union BuiltinOptions {
|
||||
BatchMatMulOptions,
|
||||
CumsumOptions,
|
||||
CallOnceOptions,
|
||||
BroadcastToOptions
|
||||
BroadcastToOptions,
|
||||
Rfft2dOptions,
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -1004,6 +1006,9 @@ table CumsumOptions {
|
||||
table BroadcastToOptions {
|
||||
}
|
||||
|
||||
table Rfft2dOptions {
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -124,7 +124,6 @@ from tensorflow.lite.testing.op_tests.resize_nearest_neighbor import make_resize
|
||||
from tensorflow.lite.testing.op_tests.resolve_constant_strided_slice import make_resolve_constant_strided_slice_tests
|
||||
from tensorflow.lite.testing.op_tests.reverse_sequence import make_reverse_sequence_tests
|
||||
from tensorflow.lite.testing.op_tests.reverse_v2 import make_reverse_v2_tests
|
||||
from tensorflow.lite.testing.op_tests.rfft2d import make_rfft2d_tests
|
||||
from tensorflow.lite.testing.op_tests.round import make_round_tests
|
||||
from tensorflow.lite.testing.op_tests.scatter_nd import make_scatter_nd_tests
|
||||
from tensorflow.lite.testing.op_tests.shape import make_shape_tests
|
||||
|
@ -1,57 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test configs for rfft2d."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
|
||||
from tensorflow.lite.testing.zip_test_utils import ExtraTocoOptions
|
||||
from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
|
||||
from tensorflow.lite.testing.zip_test_utils import register_make_test_function
|
||||
|
||||
|
||||
@register_make_test_function()
|
||||
def make_rfft2d_tests(options):
|
||||
"""Make a set of tests to do rfft2d."""
|
||||
|
||||
test_parameters = [{
|
||||
"input_dtype": [tf.float32],
|
||||
"input_shape": [[8, 8], [3, 8, 8], [3, 1, 16]],
|
||||
"fft_length": [
|
||||
None, [4, 4], [4, 8], [8, 4], [8, 8], [8, 16], [16, 8], [16, 16],
|
||||
[1, 8], [1, 16]
|
||||
]
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
input_value = tf.compat.v1.placeholder(
|
||||
dtype=parameters["input_dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
outs = tf.signal.rfft2d(input_value, fft_length=parameters["fft_length"])
|
||||
return [input_value], [outs]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value = create_tensor_data(parameters["input_dtype"],
|
||||
parameters["input_shape"])
|
||||
return [input_value], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value])))
|
||||
|
||||
extra_toco_options = ExtraTocoOptions()
|
||||
extra_toco_options.allow_custom_ops = True
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs,
|
||||
extra_toco_options)
|
@ -369,8 +369,6 @@ TfLiteDriver::TfLiteDriver(DelegateType delegate_type, bool reference_kernel)
|
||||
new ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
ops::builtin::BuiltinOpResolver* buildinop_resolver_ =
|
||||
reinterpret_cast<ops::builtin::BuiltinOpResolver*>(resolver_.get());
|
||||
buildinop_resolver_->AddCustom("RFFT2D",
|
||||
tflite::ops::custom::Register_RFFT2D());
|
||||
tflite::ops::custom::AddHashtableOps(buildinop_resolver_);
|
||||
}
|
||||
|
||||
|
@ -334,6 +334,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_WHILE, 1}, "1.15.0"},
|
||||
{{BuiltinOperator_CUMSUM, 1}, "2.4.0"},
|
||||
{{BuiltinOperator_CALL_ONCE, 1}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_RFFT2D, 1}, kPendingReleaseVersion},
|
||||
});
|
||||
|
||||
std::pair<BuiltinOperator, int> version_key = {op_code, op_version};
|
||||
|
Loading…
Reference in New Issue
Block a user