Add Mfcc op to TensorFlow for speech feature generation
Change: 153847440
This commit is contained in:
parent
7c8fffaf5d
commit
fa8f9da8f2
@ -3692,11 +3692,130 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mfcc_dct",
|
||||
srcs = ["mfcc_dct.cc"],
|
||||
hdrs = ["mfcc_dct.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "mfcc_dct_test",
|
||||
size = "small",
|
||||
srcs = ["mfcc_dct_test.cc"],
|
||||
deps = [
|
||||
":mfcc_dct",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_test_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mfcc_mel_filterbank",
|
||||
srcs = ["mfcc_mel_filterbank.cc"],
|
||||
hdrs = ["mfcc_mel_filterbank.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "mfcc_mel_filterbank_test",
|
||||
size = "small",
|
||||
srcs = ["mfcc_mel_filterbank_test.cc"],
|
||||
deps = [
|
||||
":mfcc_mel_filterbank",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_test_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mfcc",
|
||||
srcs = ["mfcc.cc"],
|
||||
hdrs = ["mfcc.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":mfcc_dct",
|
||||
":mfcc_mel_filterbank",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "mfcc_test",
|
||||
size = "small",
|
||||
srcs = ["mfcc_test.cc"],
|
||||
deps = [
|
||||
":mfcc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_test_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "mfcc_op",
|
||||
prefix = "mfcc_op",
|
||||
deps = [
|
||||
":mfcc",
|
||||
"//tensorflow/core:audio_ops_op_lib",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "mfcc_op_test",
|
||||
size = "small",
|
||||
srcs = ["mfcc_op_test.cc"],
|
||||
deps = [
|
||||
":mfcc_op",
|
||||
":ops_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:client_session",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "audio",
|
||||
deps = [
|
||||
":decode_wav_op",
|
||||
":encode_wav_op",
|
||||
":mfcc_op",
|
||||
":spectrogram_op",
|
||||
],
|
||||
)
|
||||
|
67
tensorflow/core/kernels/mfcc.cc
Normal file
67
tensorflow/core/kernels/mfcc.cc
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "tensorflow/core/kernels/mfcc.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const double kDefaultUpperFrequencyLimit = 4000;
|
||||
const double kDefaultLowerFrequencyLimit = 20;
|
||||
const double kFilterbankFloor = 1e-12;
|
||||
const int kDefaultFilterbankChannelCount = 40;
|
||||
const int kDefaultDCTCoefficientCount = 13;
|
||||
|
||||
Mfcc::Mfcc() : initialized_(false),
|
||||
lower_frequency_limit_(kDefaultLowerFrequencyLimit),
|
||||
upper_frequency_limit_(kDefaultUpperFrequencyLimit),
|
||||
filterbank_channel_count_(kDefaultFilterbankChannelCount),
|
||||
dct_coefficient_count_(kDefaultDCTCoefficientCount) { }
|
||||
|
||||
bool Mfcc::Initialize(int input_length,
|
||||
double input_sample_rate) {
|
||||
bool initialized = mel_filterbank_.Initialize(input_length,
|
||||
input_sample_rate,
|
||||
filterbank_channel_count_,
|
||||
lower_frequency_limit_,
|
||||
upper_frequency_limit_);
|
||||
initialized &= dct_.Initialize(filterbank_channel_count_,
|
||||
dct_coefficient_count_);
|
||||
initialized_ = initialized;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
void Mfcc::Compute(const std::vector<double>& spectrogram_frame,
|
||||
std::vector<double>* output) const {
|
||||
if (!initialized_) {
|
||||
LOG(ERROR) << "Mfcc not initialized.";
|
||||
return;
|
||||
}
|
||||
std::vector<double> working;
|
||||
mel_filterbank_.Compute(spectrogram_frame, &working);
|
||||
for (int i = 0; i < working.size(); ++i) {
|
||||
double val = working[i];
|
||||
if (val < kFilterbankFloor) {
|
||||
val = kFilterbankFloor;
|
||||
}
|
||||
working[i] = log(val);
|
||||
}
|
||||
dct_.Compute(working, output);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
76
tensorflow/core/kernels/mfcc.h
Normal file
76
tensorflow/core/kernels/mfcc.h
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Basic class for computing MFCCs from spectrogram slices.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/kernels/mfcc_dct.h"
|
||||
#include "tensorflow/core/kernels/mfcc_mel_filterbank.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Mfcc {
|
||||
public:
|
||||
Mfcc();
|
||||
bool Initialize(int input_length,
|
||||
double input_sample_rate);
|
||||
|
||||
// Input is a single magnitude spectrogram frame. The input spectrum
|
||||
// is filtered into bands using a triangular mel filterbank and a
|
||||
// discrete cosine transform (DCT) of the values is taken. Output is
|
||||
// populated with the lowest dct_coefficient_count of these values.
|
||||
void Compute(const std::vector<double>& spectrogram_frame,
|
||||
std::vector<double>* output) const;
|
||||
|
||||
void set_upper_frequency_limit(double upper_frequency_limit) {
|
||||
CHECK(!initialized_) << "Set frequency limits before calling Initialize.";
|
||||
upper_frequency_limit_ = upper_frequency_limit;
|
||||
}
|
||||
|
||||
void set_lower_frequency_limit(double lower_frequency_limit) {
|
||||
CHECK(!initialized_) << "Set frequency limits before calling Initialize.";
|
||||
lower_frequency_limit_ = lower_frequency_limit;
|
||||
}
|
||||
|
||||
void set_filterbank_channel_count(int filterbank_channel_count) {
|
||||
CHECK(!initialized_) << "Set channel count before calling Initialize.";
|
||||
filterbank_channel_count_ = filterbank_channel_count;
|
||||
}
|
||||
|
||||
void set_dct_coefficient_count(int dct_coefficient_count) {
|
||||
CHECK(!initialized_) << "Set coefficient count before calling Initialize.";
|
||||
dct_coefficient_count_ = dct_coefficient_count;
|
||||
}
|
||||
|
||||
private:
|
||||
MfccMelFilterbank mel_filterbank_;
|
||||
MfccDct dct_;
|
||||
bool initialized_;
|
||||
double lower_frequency_limit_;
|
||||
double upper_frequency_limit_;
|
||||
int filterbank_channel_count_;
|
||||
int dct_coefficient_count_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Mfcc);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_
|
82
tensorflow/core/kernels/mfcc_dct.cc
Normal file
82
tensorflow/core/kernels/mfcc_dct.cc
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/mfcc_dct.h"
|
||||
|
||||
#include <math.h>
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
MfccDct::MfccDct() : initialized_(false) {}
|
||||
|
||||
bool MfccDct::Initialize(int input_length, int coefficient_count) {
|
||||
coefficient_count_ = coefficient_count;
|
||||
input_length_ = input_length;
|
||||
|
||||
if (coefficient_count_ < 1) {
|
||||
LOG(ERROR) << "Coefficient count must be positive.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input_length < 1) {
|
||||
LOG(ERROR) << "Input length must be positive.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (coefficient_count_ > input_length_) {
|
||||
LOG(ERROR) << "Coefficient count must be less than or equal to "
|
||||
<< "input length.";
|
||||
return false;
|
||||
}
|
||||
|
||||
cosines_.resize(coefficient_count_);
|
||||
double fnorm = sqrt(2.0 / input_length_);
|
||||
// Some platforms don't have M_PI, so define a local constant here.
|
||||
const double pi = std::atan(1) * 4;
|
||||
double arg = pi / input_length_;
|
||||
for (int i = 0; i < coefficient_count_; ++i) {
|
||||
cosines_[i].resize(input_length_);
|
||||
for (int j = 0; j < input_length_; ++j) {
|
||||
cosines_[i][j] = fnorm * cos(i * arg * (j + 0.5));
|
||||
}
|
||||
}
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
void MfccDct::Compute(const std::vector<double> &input,
|
||||
std::vector<double> *output) const {
|
||||
if (!initialized_) {
|
||||
LOG(ERROR) << "DCT not initialized.";
|
||||
return;
|
||||
}
|
||||
|
||||
output->resize(coefficient_count_);
|
||||
int length = input.size();
|
||||
if (length > input_length_) {
|
||||
length = input_length_;
|
||||
}
|
||||
|
||||
for (int i = 0; i < coefficient_count_; ++i) {
|
||||
double sum = 0.0;
|
||||
for (int j = 0; j < length; ++j) {
|
||||
sum += cosines_[i][j] * input[j];
|
||||
}
|
||||
(*output)[i] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
44
tensorflow/core/kernels/mfcc_dct.h
Normal file
44
tensorflow/core/kernels/mfcc_dct.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Basic minimal DCT class for MFCC speech processing.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class MfccDct {
|
||||
public:
|
||||
MfccDct();
|
||||
bool Initialize(int input_length, int coefficient_count);
|
||||
void Compute(const std::vector<double>& input,
|
||||
std::vector<double>* output) const;
|
||||
|
||||
private:
|
||||
bool initialized_;
|
||||
int coefficient_count_;
|
||||
int input_length_;
|
||||
std::vector<std::vector<double> > cosines_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MfccDct);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
|
55
tensorflow/core/kernels/mfcc_dct_test.cc
Normal file
55
tensorflow/core/kernels/mfcc_dct_test.cc
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/mfcc_dct.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(MfccDctTest, AgreesWithMatlab) {
|
||||
// This test verifies the DCT against MATLAB's dct function.
|
||||
MfccDct dct;
|
||||
std::vector<double> input = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
|
||||
const int kCoefficientCount = 6;
|
||||
ASSERT_TRUE(dct.Initialize(input.size(), kCoefficientCount));
|
||||
std::vector<double> output;
|
||||
dct.Compute(input, &output);
|
||||
// Note, the matlab dct function divides the first coefficient by
|
||||
// sqrt(2), whereas we don't, so we multiply the first element of
|
||||
// the matlab result by sqrt(2) to get the expected values below.
|
||||
std::vector<double> expected = {12.1243556530, -4.1625617959, 0.0,
|
||||
-0.4082482905, 0.0, -0.0800788912};
|
||||
ASSERT_EQ(output.size(), kCoefficientCount);
|
||||
for (int i = 0; i < kCoefficientCount; ++i) {
|
||||
EXPECT_NEAR(output[i], expected[i], 1e-10);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MfccDctTest, InitializeFailsOnInvalidInput) {
|
||||
MfccDct dct1;
|
||||
EXPECT_FALSE(dct1.Initialize(-50, 1));
|
||||
MfccDct dct2;
|
||||
EXPECT_FALSE(dct1.Initialize(10, -4));
|
||||
MfccDct dct3;
|
||||
EXPECT_FALSE(dct1.Initialize(-1, -1));
|
||||
MfccDct dct4;
|
||||
EXPECT_FALSE(dct1.Initialize(20, 21));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
204
tensorflow/core/kernels/mfcc_mel_filterbank.cc
Normal file
204
tensorflow/core/kernels/mfcc_mel_filterbank.cc
Normal file
@ -0,0 +1,204 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This code resamples the FFT bins, and smooths then with triangle-shaped
|
||||
// weights to create a mel-frequency filter bank. For filter i centered at f_i,
|
||||
// there is a triangular weighting of the FFT bins that extends from
|
||||
// filter f_i-1 (with a value of zero at the left edge of the triangle) to f_i
|
||||
// (where the filter value is 1) to f_i+1 (where the filter values returns to
|
||||
// zero).
|
||||
|
||||
// Note: this code fails if you ask for too many channels. The algorithm used
|
||||
// here assumes that each FFT bin contributes to at most two channels: the
|
||||
// right side of a triangle for channel i, and the left side of the triangle
|
||||
// for channel i+1. If you ask for so many channels that some of the
|
||||
// resulting mel triangle filters are smaller than a single FFT bin, these
|
||||
// channels may end up with no contributing FFT bins. The resulting mel
|
||||
// spectrum output will have some channels that are always zero.
|
||||
|
||||
#include "tensorflow/core/kernels/mfcc_mel_filterbank.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
MfccMelFilterbank::MfccMelFilterbank() : initialized_(false) {}
|
||||
|
||||
bool MfccMelFilterbank::Initialize(int input_length,
|
||||
double input_sample_rate,
|
||||
int output_channel_count,
|
||||
double lower_frequency_limit,
|
||||
double upper_frequency_limit) {
|
||||
num_channels_ = output_channel_count;
|
||||
sample_rate_ = input_sample_rate;
|
||||
input_length_ = input_length;
|
||||
|
||||
if (num_channels_ < 1) {
|
||||
LOG(ERROR) << "Number of filterbank channels must be positive.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sample_rate_ <= 0) {
|
||||
LOG(ERROR) << "Sample rate must be positive.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input_length < 2) {
|
||||
LOG(ERROR) << "Input length must greater than 1.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lower_frequency_limit <= 0) {
|
||||
LOG(ERROR) << "Lower frequency limit must be positive.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (upper_frequency_limit <= lower_frequency_limit) {
|
||||
LOG(ERROR) << "Upper frequency limit must be greater than "
|
||||
<< "lower frequency limit.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// An extra center frequency is computed at the top to get the upper
|
||||
// limit on the high side of the final triangular filter.
|
||||
center_frequencies_.resize(num_channels_ + 1);
|
||||
const double mel_low = FreqToMel(lower_frequency_limit);
|
||||
const double mel_hi = FreqToMel(upper_frequency_limit);
|
||||
const double mel_span = mel_hi - mel_low;
|
||||
const double mel_spacing = mel_span / static_cast<double>(num_channels_ + 1);
|
||||
for (int i = 0; i < num_channels_ + 1; ++i) {
|
||||
center_frequencies_[i] = mel_low + (mel_spacing * (i + 1));
|
||||
}
|
||||
|
||||
// Always exclude DC; emulate HTK.
|
||||
const double hz_per_sbin = 0.5 * sample_rate_ /
|
||||
static_cast<double>(input_length_ - 1);
|
||||
start_index_ = static_cast<int>(1.5 + (lower_frequency_limit /
|
||||
hz_per_sbin));
|
||||
end_index_ = static_cast<int>(upper_frequency_limit / hz_per_sbin);
|
||||
|
||||
// Maps the input spectrum bin indices to filter bank channels/indices. For
|
||||
// each FFT bin, band_mapper tells us which channel this bin contributes to
|
||||
// on the right side of the triangle. Thus this bin also contributes to the
|
||||
// left side of the next channel's triangle response.
|
||||
band_mapper_.resize(input_length_);
|
||||
int channel = 0;
|
||||
for (int i = 0; i < input_length_; ++i) {
|
||||
double melf = FreqToMel(i * hz_per_sbin);
|
||||
if ((i < start_index_) || (i > end_index_)) {
|
||||
band_mapper_[i] = -2; // Indicate an unused Fourier coefficient.
|
||||
} else {
|
||||
while ((center_frequencies_[channel] < melf) &&
|
||||
(channel < num_channels_)) {
|
||||
++channel;
|
||||
}
|
||||
band_mapper_[i] = channel - 1; // Can be == -1
|
||||
}
|
||||
}
|
||||
|
||||
// Create the weighting functions to taper the band edges. The contribution
|
||||
// of any one FFT bin is based on its distance along the continuum between two
|
||||
// mel-channel center frequencies. This bin contributes weights_[i] to the
|
||||
// current channel and 1-weights_[i] to the next channel.
|
||||
weights_.resize(input_length_);
|
||||
for (int i = 0; i < input_length_; ++i) {
|
||||
channel = band_mapper_[i];
|
||||
if ((i < start_index_) || (i > end_index_)) {
|
||||
weights_[i] = 0.0;
|
||||
} else {
|
||||
if (channel >= 0) {
|
||||
weights_[i] = (center_frequencies_[channel + 1] -
|
||||
FreqToMel(i * hz_per_sbin)) /
|
||||
(center_frequencies_[channel + 1] - center_frequencies_[channel]);
|
||||
} else {
|
||||
weights_[i] = (center_frequencies_[0] - FreqToMel(i * hz_per_sbin)) /
|
||||
(center_frequencies_[0] - mel_low);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check the sum of FFT bin weights for every mel band to identify
|
||||
// situations where the mel bands are so narrow that they don't get
|
||||
// significant weight on enough (or any) FFT bins -- i.e., too many
|
||||
// mel bands have been requested for the given FFT size.
|
||||
std::vector<int> bad_channels;
|
||||
for (int c = 0; c < num_channels_; ++c) {
|
||||
float band_weights_sum = 0.0;
|
||||
for (int i = 0; i < input_length_; ++i) {
|
||||
if (band_mapper_[i] == c - 1) {
|
||||
band_weights_sum += (1.0 - weights_[i]);
|
||||
} else if (band_mapper_[i] == c) {
|
||||
band_weights_sum += weights_[i];
|
||||
}
|
||||
}
|
||||
// The lowest mel channels have the fewest FFT bins and the lowest
|
||||
// weights sum. But given that the target gain at the center frequency
|
||||
// is 1.0, if the total sum of weights is 0.5, we're in bad shape.
|
||||
if (band_weights_sum < 0.5) {
|
||||
bad_channels.push_back(c);
|
||||
}
|
||||
}
|
||||
if (!bad_channels.empty()) {
|
||||
LOG(ERROR) << "Missing " << bad_channels.size() << " bands " <<
|
||||
" starting at " << bad_channels[0] <<
|
||||
" in mel-frequency design. " <<
|
||||
"Perhaps too many channels or " <<
|
||||
"not enough frequency resolution in spectrum. (" <<
|
||||
"input_length: " << input_length <<
|
||||
" input_sample_rate: " << input_sample_rate <<
|
||||
" output_channel_count: " << output_channel_count <<
|
||||
" lower_frequency_limit: " << lower_frequency_limit <<
|
||||
" upper_frequency_limit: " << upper_frequency_limit;
|
||||
}
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Compute the mel spectrum from the squared-magnitude FFT input by taking the
|
||||
// square root, then summing FFT magnitudes under triangular integration windows
|
||||
// whose widths increase with frequency.
|
||||
void MfccMelFilterbank::Compute(const std::vector<double> &input,
|
||||
std::vector<double> *output) const {
|
||||
if (!initialized_) {
|
||||
LOG(ERROR) << "Mel Filterbank not initialized.";
|
||||
return;
|
||||
}
|
||||
|
||||
if (input.size() <= end_index_) {
|
||||
LOG(ERROR) << "Input too short to compute filterbank";
|
||||
return;
|
||||
}
|
||||
|
||||
// Ensure output is right length and reset all values.
|
||||
output->assign(num_channels_, 0.0);
|
||||
|
||||
for (int i = start_index_; i <= end_index_; i++) { // For each FFT bin
|
||||
double spec_val = sqrt(input[i]);
|
||||
double weighted = spec_val * weights_[i];
|
||||
int channel = band_mapper_[i];
|
||||
if (channel >= 0)
|
||||
(*output)[channel] += weighted; // Right side of triangle, downward slope
|
||||
channel++;
|
||||
if (channel < num_channels_)
|
||||
(*output)[channel] += spec_val - weighted; // Left side of triangle
|
||||
}
|
||||
}
|
||||
|
||||
double MfccMelFilterbank::FreqToMel(double freq) const {
|
||||
return 1127.0 * log(1.0 + (freq / 700.0));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
65
tensorflow/core/kernels/mfcc_mel_filterbank.h
Normal file
65
tensorflow/core/kernels/mfcc_mel_filterbank.h
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Basic class for applying a mel-scale filterbank to an input.
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class MfccMelFilterbank {
|
||||
public:
|
||||
MfccMelFilterbank();
|
||||
bool Initialize(int input_length, // Number of unique FFT bins fftsize/2+1.
|
||||
double input_sample_rate,
|
||||
int output_channel_count,
|
||||
double lower_frequency_limit,
|
||||
double upper_frequency_limit);
|
||||
|
||||
// Takes a magnitude spectrogram slice as input, computes a
|
||||
// traingular mel filterbank and places the result in output.
|
||||
void Compute(const std::vector<double>& input,
|
||||
std::vector<double>* output) const;
|
||||
|
||||
private:
|
||||
double FreqToMel(double freq) const;
|
||||
bool initialized_;
|
||||
int num_channels_;
|
||||
double sample_rate_;
|
||||
int input_length_;
|
||||
std::vector<double> center_frequencies_; // In mel, for each mel channel.
|
||||
|
||||
// Each FFT bin b contributes to two triangular mel channels, with
|
||||
// proportion weights_[b] going into mel channel band_mapper_[b], and
|
||||
// proportion (1 - weights_[b]) going into channel band_mapper_[b] + 1.
|
||||
// Thus, weights_ contains the weighting applied to each FFT bin for the
|
||||
// upper-half of the triangular band.
|
||||
std::vector<double> weights_; // Right-side weight for this fft bin.
|
||||
|
||||
// FFT bin i contributes to the upper side of mel channel band_mapper_[i]
|
||||
std::vector<int> band_mapper_;
|
||||
int start_index_; // Lowest FFT bin used to calculate mel spectrum.
|
||||
int end_index_; // Highest FFT bin used to calculate mel spectrum.
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MfccMelFilterbank);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
|
92
tensorflow/core/kernels/mfcc_mel_filterbank_test.cc
Normal file
92
tensorflow/core/kernels/mfcc_mel_filterbank_test.cc
Normal file
@ -0,0 +1,92 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/mfcc_mel_filterbank.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(MfccMelFilterbankTest, AgreesWithPythonGoldenValues) {
|
||||
// This test verifies the Mel filterbank against "golden values".
|
||||
// Golden values are from an independent Python Mel implementation.
|
||||
MfccMelFilterbank filterbank;
|
||||
|
||||
std::vector<double> input;
|
||||
const int kSampleCount = 513;
|
||||
for (int i = 0; i < kSampleCount; ++i) {
|
||||
input.push_back(i + 1);
|
||||
}
|
||||
const int kChannelCount = 20;
|
||||
filterbank.Initialize(input.size(),
|
||||
22050 /* sample rate */,
|
||||
kChannelCount /* channels */,
|
||||
20.0 /* lower frequency limit */,
|
||||
4000.0 /* upper frequency limit */);
|
||||
|
||||
std::vector<double> output;
|
||||
filterbank.Compute(input, &output);
|
||||
|
||||
std::vector<double> expected = {
|
||||
7.38894574, 10.30330648, 13.72703292, 17.24158686, 21.35253118,
|
||||
25.77781089, 31.30624108, 37.05877236, 43.9436536, 51.80306637,
|
||||
60.79867148, 71.14363376, 82.90910141, 96.50069158, 112.08428368,
|
||||
129.96721968, 150.4277597, 173.74997634, 200.86037462, 231.59802942};
|
||||
|
||||
ASSERT_EQ(output.size(), kChannelCount);
|
||||
|
||||
for (int i = 0; i < kChannelCount; ++i) {
|
||||
EXPECT_NEAR(output[i], expected[i], 1e-04);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MfccMelFilterbankTest, IgnoresExistingContentOfOutputVector) {
|
||||
// Test for bug where the output vector was not cleared before
|
||||
// accumulating next frame's weighted spectral values.
|
||||
MfccMelFilterbank filterbank;
|
||||
|
||||
const int kSampleCount = 513;
|
||||
std::vector<double> input;
|
||||
std::vector<double> output;
|
||||
|
||||
filterbank.Initialize(kSampleCount,
|
||||
22050 /* sample rate */,
|
||||
20 /* channels */,
|
||||
20.0 /* lower frequency limit */,
|
||||
4000.0 /* upper frequency limit */);
|
||||
|
||||
|
||||
// First call with nonzero input value, and an empty output vector,
|
||||
// will resize the output and fill it with the correct, nonzero outputs.
|
||||
input.assign(kSampleCount, 1.0);
|
||||
filterbank.Compute(input, &output);
|
||||
for (const double value : output) {
|
||||
EXPECT_LE(0.0, value);
|
||||
}
|
||||
|
||||
// Second call with zero input should also generate zero output. However,
|
||||
// the output vector now is already the correct size, but full of nonzero
|
||||
// values. Make sure these don't affect the output.
|
||||
input.assign(kSampleCount, 0.0);
|
||||
filterbank.Compute(input, &output);
|
||||
for (const double value : output) {
|
||||
EXPECT_EQ(0.0, value);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
111
tensorflow/core/kernels/mfcc_op.cc
Normal file
111
tensorflow/core/kernels/mfcc_op.cc
Normal file
@ -0,0 +1,111 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/audio_ops.cc
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/mfcc.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Create a speech fingerpring from spectrogram data.
|
||||
class MfccOp : public OpKernel {
|
||||
public:
|
||||
explicit MfccOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("upper_frequency_limit",
|
||||
&upper_frequency_limit_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("lower_frequency_limit",
|
||||
&lower_frequency_limit_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("filterbank_channel_count",
|
||||
&filterbank_channel_count_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dct_coefficient_count",
|
||||
&dct_coefficient_count_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& spectrogram = context->input(0);
|
||||
OP_REQUIRES(context, spectrogram.dims() == 3,
|
||||
errors::InvalidArgument("spectrogram must be 3-dimensional",
|
||||
spectrogram.shape().DebugString()));
|
||||
const Tensor& sample_rate_tensor = context->input(1);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Input sample_rate should be a scalar tensor, got ",
|
||||
sample_rate_tensor.shape().DebugString(), " instead."));
|
||||
const int32 sample_rate = sample_rate_tensor.scalar<int32>()();
|
||||
|
||||
const int spectrogram_channels = spectrogram.dim_size(2);
|
||||
const int spectrogram_samples = spectrogram.dim_size(1);
|
||||
const int audio_channels = spectrogram.dim_size(0);
|
||||
|
||||
Mfcc mfcc;
|
||||
mfcc.set_upper_frequency_limit(upper_frequency_limit_);
|
||||
mfcc.set_lower_frequency_limit(lower_frequency_limit_);
|
||||
mfcc.set_filterbank_channel_count(filterbank_channel_count_);
|
||||
mfcc.set_dct_coefficient_count(dct_coefficient_count_);
|
||||
OP_REQUIRES(context, mfcc.Initialize(spectrogram_channels, sample_rate),
|
||||
errors::InvalidArgument(
|
||||
"Mfcc initialization failed for channel count ",
|
||||
spectrogram_channels, " and sample rate ", sample_rate));
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(
|
||||
0,
|
||||
TensorShape({audio_channels, spectrogram_samples,
|
||||
dct_coefficient_count_}),
|
||||
&output_tensor));
|
||||
|
||||
const float* spectrogram_flat = spectrogram.flat<float>().data();
|
||||
float* output_flat = output_tensor->flat<float>().data();
|
||||
|
||||
for (int audio_channel = 0; audio_channel < audio_channels;
|
||||
++audio_channel) {
|
||||
for (int spectrogram_sample = 0; spectrogram_sample < spectrogram_samples;
|
||||
++spectrogram_sample) {
|
||||
const float* sample_data =
|
||||
spectrogram_flat +
|
||||
(audio_channel * spectrogram_samples * spectrogram_channels) +
|
||||
(spectrogram_sample * spectrogram_channels);
|
||||
std::vector<double> mfcc_input(sample_data,
|
||||
sample_data + spectrogram_channels);
|
||||
std::vector<double> mfcc_output;
|
||||
mfcc.Compute(mfcc_input, &mfcc_output);
|
||||
DCHECK_EQ(dct_coefficient_count_, mfcc_output.size());
|
||||
float* output_data =
|
||||
output_flat +
|
||||
(audio_channel * spectrogram_samples * dct_coefficient_count_) +
|
||||
(spectrogram_sample * dct_coefficient_count_);
|
||||
for (int i = 0; i < dct_coefficient_count_; ++i) {
|
||||
output_data[i] = mfcc_output[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float upper_frequency_limit_;
|
||||
float lower_frequency_limit_;
|
||||
int32 filterbank_channel_count_;
|
||||
int32 dct_coefficient_count_;
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("Mfcc").Device(DEVICE_CPU), MfccOp);
|
||||
|
||||
} // namespace tensorflow
|
77
tensorflow/core/kernels/mfcc_op_test.cc
Normal file
77
tensorflow/core/kernels/mfcc_op_test.cc
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/client/client_session.h"
|
||||
#include "tensorflow/cc/ops/audio_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
TEST(MfccOpTest, SimpleTest) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513}));
|
||||
test::FillIota<float>(&spectrogram_tensor, 1.0f);
|
||||
|
||||
Output spectrogram_const_op = Const(root.WithOpName("spectrogram_const_op"),
|
||||
Input::Initializer(spectrogram_tensor));
|
||||
|
||||
Output sample_rate_const_op =
|
||||
Const(root.WithOpName("sample_rate_const_op"), 22050);
|
||||
|
||||
Mfcc mfcc_op = Mfcc(root.WithOpName("mfcc_op"), spectrogram_const_op,
|
||||
sample_rate_const_op);
|
||||
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
ClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
TF_EXPECT_OK(
|
||||
session.Run(ClientSession::FeedType(), {mfcc_op.output}, &outputs));
|
||||
|
||||
const Tensor& mfcc_tensor = outputs[0];
|
||||
|
||||
EXPECT_EQ(3, mfcc_tensor.dims());
|
||||
EXPECT_EQ(13, mfcc_tensor.dim_size(2));
|
||||
EXPECT_EQ(1, mfcc_tensor.dim_size(1));
|
||||
EXPECT_EQ(1, mfcc_tensor.dim_size(0));
|
||||
|
||||
test::ExpectTensorNear<float>(
|
||||
mfcc_tensor,
|
||||
test::AsTensor<float>(
|
||||
{29.13970072, -6.41568601, -0.61903012, -0.96778652, -0.26819878,
|
||||
-0.40907028, -0.15614748, -0.23203119, -0.10481487, -0.1543029,
|
||||
-0.0769791, -0.10806114, -0.06047613},
|
||||
TensorShape({1, 1, 13})),
|
||||
1e-3);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
92
tensorflow/core/kernels/mfcc_test.cc
Normal file
92
tensorflow/core/kernels/mfcc_test.cc
Normal file
@ -0,0 +1,92 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/mfcc.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(MfccTest, AgreesWithPythonGoldenValues) {
|
||||
Mfcc mfcc;
|
||||
std::vector<double> input;
|
||||
const int kSampleCount = 513;
|
||||
for (int i = 0; i < kSampleCount; ++i) {
|
||||
input.push_back(i + 1);
|
||||
}
|
||||
|
||||
ASSERT_TRUE(mfcc.Initialize(input.size(), 22050 /*sample rate*/));
|
||||
|
||||
std::vector<double> output;
|
||||
mfcc.Compute(input, &output);
|
||||
|
||||
std::vector<double> expected = {29.13970072, -6.41568601, -0.61903012,
|
||||
-0.96778652, -0.26819878, -0.40907028,
|
||||
-0.15614748, -0.23203119, -0.10481487,
|
||||
-0.1543029, -0.0769791, -0.10806114,
|
||||
-0.06047613};
|
||||
|
||||
ASSERT_EQ(expected.size(), output.size());
|
||||
for (int i = 0; i < output.size(); ++i) {
|
||||
EXPECT_NEAR(output[i], expected[i], 1e-04);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MfccTest, AvoidsNansWithZeroInput) {
|
||||
Mfcc mfcc;
|
||||
std::vector<double> input;
|
||||
const int kSampleCount = 513;
|
||||
for (int i = 0; i < kSampleCount; ++i) {
|
||||
input.push_back(0.0);
|
||||
}
|
||||
|
||||
ASSERT_TRUE(mfcc.Initialize(input.size(), 22050 /*sample rate*/));
|
||||
|
||||
std::vector<double> output;
|
||||
mfcc.Compute(input, &output);
|
||||
|
||||
int expected_size = 13;
|
||||
ASSERT_EQ(expected_size, output.size());
|
||||
for (const double value : output) {
|
||||
EXPECT_FALSE(isnan(value));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MfccTest, SimpleInputSaneResult) {
|
||||
Mfcc mfcc;
|
||||
mfcc.set_lower_frequency_limit(125.0);
|
||||
mfcc.set_upper_frequency_limit(3800.0);
|
||||
mfcc.set_filterbank_channel_count(40);
|
||||
mfcc.set_dct_coefficient_count(40);
|
||||
const int kSpectrogramSize = 129;
|
||||
std::vector<double> input(kSpectrogramSize, 0.0);
|
||||
|
||||
// Simulate a low-frequency sinusoid from the spectrogram.
|
||||
const int kHotBin = 10;
|
||||
input[kHotBin] = 1.0;
|
||||
ASSERT_TRUE(mfcc.Initialize(input.size(), 8000));
|
||||
|
||||
std::vector<double> output;
|
||||
mfcc.Compute(input, &output);
|
||||
|
||||
// For a single low-frequency input, output beyond c_0 should look like
|
||||
// a slow cosine, with a slight delay. Largest value will be c_1.
|
||||
EXPECT_EQ(output.begin() + 1, std::max_element(output.begin(), output.end()));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -100,6 +100,26 @@ Status SpectrogramShapeFn(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MfccShapeFn(InferenceContext* c) {
|
||||
ShapeHandle spectrogram;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram));
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
|
||||
|
||||
int32 dct_coefficient_count;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->GetAttr("dct_coefficient_count", &dct_coefficient_count));
|
||||
|
||||
DimensionHandle spectrogram_channels = c->Dim(spectrogram, 0);
|
||||
DimensionHandle spectrogram_length = c->Dim(spectrogram, 1);
|
||||
|
||||
DimensionHandle output_channels = c->MakeDim(dct_coefficient_count);
|
||||
|
||||
c->set_output(0, c->MakeShape({spectrogram_channels, spectrogram_length,
|
||||
output_channels}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_OP("DecodeWav")
|
||||
@ -200,4 +220,34 @@ magnitude_squared: Whether to return the squared magnitude or just the
|
||||
spectrogram: 3D representation of the audio frequencies as an image.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Mfcc")
|
||||
.Input("spectrogram: float")
|
||||
.Input("sample_rate: int32")
|
||||
.Attr("upper_frequency_limit: float = 4000")
|
||||
.Attr("lower_frequency_limit: float = 20")
|
||||
.Attr("filterbank_channel_count: int = 40")
|
||||
.Attr("dct_coefficient_count: int = 13")
|
||||
.Output("output: float")
|
||||
.SetShapeFn(MfccShapeFn)
|
||||
.Doc(R"doc(
|
||||
Transforms a spectrogram into a form that's useful for speech recognition.
|
||||
|
||||
Mel Frequency Cepstral Coefficients are a way of representing audio data that's
|
||||
been effective as an input feature for machine learning. They are created by
|
||||
taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
|
||||
higher frequencies that are less significant to the human ear. They have a long
|
||||
history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
|
||||
is a good resource to learn more.
|
||||
|
||||
spectrogram: Typically produced by the Spectrogram op, with magnitude_squared
|
||||
set to true.
|
||||
sample_rate: How many samples per second the source audio used.
|
||||
upper_frequency_limit: The highest frequency to use when calculating the
|
||||
ceptstrum.
|
||||
lower_frequency_limit: The lowest frequency to use when calculating the
|
||||
ceptstrum.
|
||||
filterbank_channel_count: Resolution of the Mel bank used internally.
|
||||
dct_coefficient_count: How many output channels to produce per time slice.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user