Expose microfrontend ops in Lite
PiperOrigin-RevId: 256410602
This commit is contained in:
parent
aacf4909bc
commit
6c366272cf
@ -18,14 +18,16 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "fft",
|
name = "fft",
|
||||||
srcs = [
|
srcs = [
|
||||||
"fft.c",
|
"fft.cc",
|
||||||
"fft_util.c",
|
"fft_util.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"fft.h",
|
"fft.h",
|
||||||
"fft_util.h",
|
"fft_util.h",
|
||||||
],
|
],
|
||||||
deps = ["@kissfft//:kiss_fftr_16"],
|
deps = [
|
||||||
|
"@kissfft//:kiss_fftr_16",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
@ -18,10 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#define FIXED_POINT 16
|
#define FIXED_POINT 16
|
||||||
#include "kiss_fft.h"
|
#include "kiss_fft.h"
|
||||||
// Internal test dependency placeholder1
|
|
||||||
// Internal test dependency placeholder2
|
|
||||||
#include "tools/kiss_fftr.h"
|
#include "tools/kiss_fftr.h"
|
||||||
// Internal test dependency placeholder3
|
|
||||||
|
|
||||||
void FftCompute(struct FftState* state, const int16_t* input,
|
void FftCompute(struct FftState* state, const int16_t* input,
|
||||||
int input_scale_shift) {
|
int input_scale_shift) {
|
||||||
@ -40,8 +37,10 @@ void FftCompute(struct FftState* state, const int16_t* input,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply the FFT.
|
// Apply the FFT.
|
||||||
kiss_fftr((const kiss_fftr_cfg)state->scratch, state->input,
|
kiss_fftr(
|
||||||
(kiss_fft_cpx*)state->output);
|
reinterpret_cast<const kiss_fftr_cfg>(state->scratch),
|
||||||
|
state->input,
|
||||||
|
reinterpret_cast<kiss_fft_cpx*>(state->output));
|
||||||
}
|
}
|
||||||
|
|
||||||
void FftInit(struct FftState* state) {
|
void FftInit(struct FftState* state) {
|
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/experimental/microfrontend/lib/fft.h"
|
#include "tensorflow/lite/experimental/microfrontend/lib/fft.h"
|
||||||
#include "tensorflow/lite/experimental/microfrontend/lib/fft_util.h"
|
|
||||||
|
|
||||||
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
|
||||||
|
#include "tensorflow/lite/experimental/microfrontend/lib/fft_util.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ TF_LITE_MICRO_TEST(FftTest_CheckOutputValues) {
|
|||||||
{0, 25}, {9, -10}, {19, 0}, {9, 9}, {0, 0}};
|
{0, 25}, {9, -10}, {19, 0}, {9, 9}, {0, 0}};
|
||||||
TF_LITE_MICRO_EXPECT_EQ(state.fft_size / 2 + 1,
|
TF_LITE_MICRO_EXPECT_EQ(state.fft_size / 2 + 1,
|
||||||
sizeof(expected) / sizeof(expected[0]));
|
sizeof(expected) / sizeof(expected[0]));
|
||||||
int i;
|
unsigned int i;
|
||||||
for (i = 0; i <= state.fft_size / 2; ++i) {
|
for (i = 0; i <= state.fft_size / 2; ++i) {
|
||||||
TF_LITE_MICRO_EXPECT_EQ(state.output[i].real, expected[i].real);
|
TF_LITE_MICRO_EXPECT_EQ(state.output[i].real, expected[i].real);
|
||||||
TF_LITE_MICRO_EXPECT_EQ(state.output[i].imag, expected[i].imag);
|
TF_LITE_MICRO_EXPECT_EQ(state.output[i].imag, expected[i].imag);
|
||||||
|
@ -27,35 +27,37 @@ int FftPopulateState(struct FftState* state, size_t input_size) {
|
|||||||
state->fft_size <<= 1;
|
state->fft_size <<= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
state->input = malloc(state->fft_size * sizeof(*state->input));
|
state->input = reinterpret_cast<int16_t*>(
|
||||||
if (state->input == NULL) {
|
malloc(state->fft_size * sizeof(*state->input)));
|
||||||
|
if (state->input == nullptr) {
|
||||||
fprintf(stderr, "Failed to alloc fft input buffer\n");
|
fprintf(stderr, "Failed to alloc fft input buffer\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
state->output =
|
state->output = reinterpret_cast<complex_int16_t*>(
|
||||||
malloc((state->fft_size / 2 + 1) * sizeof(*state->output) * 2);
|
malloc((state->fft_size / 2 + 1) * sizeof(*state->output) * 2));
|
||||||
if (state->output == NULL) {
|
if (state->output == nullptr) {
|
||||||
fprintf(stderr, "Failed to alloc fft output buffer\n");
|
fprintf(stderr, "Failed to alloc fft output buffer\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ask kissfft how much memory it wants.
|
// Ask kissfft how much memory it wants.
|
||||||
size_t scratch_size = 0;
|
size_t scratch_size = 0;
|
||||||
kiss_fftr_cfg kfft_cfg =
|
kiss_fftr_cfg kfft_cfg = kiss_fftr_alloc(
|
||||||
kiss_fftr_alloc(state->fft_size, 0, NULL, &scratch_size);
|
state->fft_size, 0, nullptr, &scratch_size);
|
||||||
if (kfft_cfg != NULL) {
|
if (kfft_cfg != nullptr) {
|
||||||
fprintf(stderr, "Kiss memory sizing failed.\n");
|
fprintf(stderr, "Kiss memory sizing failed.\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
state->scratch = malloc(scratch_size);
|
state->scratch = malloc(scratch_size);
|
||||||
if (state->scratch == NULL) {
|
if (state->scratch == nullptr) {
|
||||||
fprintf(stderr, "Failed to alloc fft scratch buffer\n");
|
fprintf(stderr, "Failed to alloc fft scratch buffer\n");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
state->scratch_size = scratch_size;
|
state->scratch_size = scratch_size;
|
||||||
// Let kissfft configure the scratch space we just allocated
|
// Let kissfft configure the scratch space we just allocated
|
||||||
kfft_cfg = kiss_fftr_alloc(state->fft_size, 0, state->scratch, &scratch_size);
|
kfft_cfg = kiss_fftr_alloc(state->fft_size, 0,
|
||||||
|
state->scratch, &scratch_size);
|
||||||
if (kfft_cfg != state->scratch) {
|
if (kfft_cfg != state->scratch) {
|
||||||
fprintf(stderr, "Kiss memory preallocation strategy failed.\n");
|
fprintf(stderr, "Kiss memory preallocation strategy failed.\n");
|
||||||
return 0;
|
return 0;
|
@ -18,14 +18,15 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
from tensorflow.lite.experimental.microfrontend.ops import gen_audio_microfrontend_op
|
from tensorflow.lite.experimental.microfrontend.ops import gen_audio_microfrontend_op
|
||||||
from tensorflow.contrib.util import loader
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import load_library
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
_audio_microfrontend_op = loader.load_op_library(
|
_audio_microfrontend_op = load_library.load_op_library(
|
||||||
resource_loader.get_path_to_datafile("_audio_microfrontend_op.so"))
|
resource_loader.get_path_to_datafile("_audio_microfrontend_op.so"))
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +53,7 @@ def audio_microfrontend(audio,
|
|||||||
frame_stride=1,
|
frame_stride=1,
|
||||||
zero_padding=False,
|
zero_padding=False,
|
||||||
out_scale=1,
|
out_scale=1,
|
||||||
out_type=tf.uint16):
|
out_type=dtypes.uint16):
|
||||||
"""Audio Microfrontend Op.
|
"""Audio Microfrontend Op.
|
||||||
|
|
||||||
This Op converts a sequence of audio data into one or more
|
This Op converts a sequence of audio data into one or more
|
||||||
@ -102,7 +103,7 @@ def audio_microfrontend(audio,
|
|||||||
if audio_shape.ndims is None:
|
if audio_shape.ndims is None:
|
||||||
raise ValueError("Input to `AudioMicrofrontend` should have known rank.")
|
raise ValueError("Input to `AudioMicrofrontend` should have known rank.")
|
||||||
if len(audio_shape) > 1:
|
if len(audio_shape) > 1:
|
||||||
audio = tf.reshape(audio, [-1])
|
audio = array_ops.reshape(audio, [-1])
|
||||||
|
|
||||||
return gen_audio_microfrontend_op.audio_microfrontend(
|
return gen_audio_microfrontend_op.audio_microfrontend(
|
||||||
audio, sample_rate, window_size, window_step, num_channels,
|
audio, sample_rate, window_size, window_step, num_channels,
|
||||||
@ -112,4 +113,4 @@ def audio_microfrontend(audio,
|
|||||||
right_context, frame_stride, zero_padding, out_scale, out_type)
|
right_context, frame_stride, zero_padding, out_scale, out_type)
|
||||||
|
|
||||||
|
|
||||||
tf.NotDifferentiable("AudioMicrofrontend")
|
ops.NotDifferentiable("AudioMicrofrontend")
|
||||||
|
@ -79,6 +79,7 @@ py_library(
|
|||||||
":op_hint",
|
":op_hint",
|
||||||
":util",
|
":util",
|
||||||
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
|
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
|
||||||
|
"//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py",
|
||||||
"//tensorflow/lite/experimental/tensorboard:ops_util",
|
"//tensorflow/lite/experimental/tensorboard:ops_util",
|
||||||
"//tensorflow/lite/python/optimize:calibrator",
|
"//tensorflow/lite/python/optimize:calibrator",
|
||||||
"//tensorflow/python:graph_util",
|
"//tensorflow/python:graph_util",
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
|||||||
from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import
|
from tensorflow.lite.experimental.examples.lstm.rnn import dynamic_rnn # pylint: disable=unused-import
|
||||||
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import
|
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell # pylint: disable=unused-import
|
||||||
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import
|
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TfLiteRNNCell # pylint: disable=unused-import
|
||||||
|
from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op # pylint: disable=unused-import
|
||||||
from tensorflow.lite.experimental.tensorboard.ops_util import get_potentially_supported_ops # pylint: disable=unused-import
|
from tensorflow.lite.experimental.tensorboard.ops_util import get_potentially_supported_ops # pylint: disable=unused-import
|
||||||
from tensorflow.lite.python import lite_constants as constants
|
from tensorflow.lite.python import lite_constants as constants
|
||||||
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
from tensorflow.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
||||||
|
@ -31,6 +31,9 @@ TENSORFLOW_API_INIT_FILES = [
|
|||||||
"linalg/__init__.py",
|
"linalg/__init__.py",
|
||||||
"lite/__init__.py",
|
"lite/__init__.py",
|
||||||
"lite/experimental/__init__.py",
|
"lite/experimental/__init__.py",
|
||||||
|
"lite/experimental/microfrontend/__init__.py",
|
||||||
|
"lite/experimental/microfrontend/python/__init__.py",
|
||||||
|
"lite/experimental/microfrontend/python/ops/__init__.py",
|
||||||
"lookup/__init__.py",
|
"lookup/__init__.py",
|
||||||
"lookup/experimental/__init__.py",
|
"lookup/experimental/__init__.py",
|
||||||
"math/__init__.py",
|
"math/__init__.py",
|
||||||
|
@ -39,6 +39,9 @@ TENSORFLOW_API_INIT_FILES_V1 = [
|
|||||||
"lite/constants/__init__.py",
|
"lite/constants/__init__.py",
|
||||||
"lite/experimental/__init__.py",
|
"lite/experimental/__init__.py",
|
||||||
"lite/experimental/nn/__init__.py",
|
"lite/experimental/nn/__init__.py",
|
||||||
|
"lite/experimental/microfrontend/__init__.py",
|
||||||
|
"lite/experimental/microfrontend/python/__init__.py",
|
||||||
|
"lite/experimental/microfrontend/python/ops/__init__.py",
|
||||||
"logging/__init__.py",
|
"logging/__init__.py",
|
||||||
"lookup/__init__.py",
|
"lookup/__init__.py",
|
||||||
"lookup/experimental/__init__.py",
|
"lookup/experimental/__init__.py",
|
||||||
|
@ -179,6 +179,7 @@ filegroup(
|
|||||||
"@icu//:icu4c/LICENSE",
|
"@icu//:icu4c/LICENSE",
|
||||||
"@jpeg//:LICENSE.md",
|
"@jpeg//:LICENSE.md",
|
||||||
"@keras_applications_archive//:LICENSE",
|
"@keras_applications_archive//:LICENSE",
|
||||||
|
"@kissfft//:LICENSE",
|
||||||
"@lmdb//:LICENSE",
|
"@lmdb//:LICENSE",
|
||||||
"@local_config_sycl//sycl:LICENSE.text",
|
"@local_config_sycl//sycl:LICENSE.text",
|
||||||
"@local_config_tensorrt//:LICENSE",
|
"@local_config_tensorrt//:LICENSE",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user