fix support of multichaannel data in audio_spectrogram
PiperOrigin-RevId: 301185141 Change-Id: Iad0e70a6624c6865cd9244fd1d34e48586f9a205
This commit is contained in:
parent
2af901507b
commit
f32e0b9b40
@ -64,19 +64,35 @@ bool Spectrogram::Initialize(const std::vector<double>& window,
|
||||
output_frequency_channels_ = 1 + fft_length_ / 2;
|
||||
|
||||
// Allocate 2 more than what rdft needs, so we can rationalize the layout.
|
||||
fft_input_output_.assign(fft_length_ + 2, 0.0);
|
||||
fft_input_output_.resize(fft_length_ + 2);
|
||||
|
||||
int half_fft_length = fft_length_ / 2;
|
||||
fft_double_working_area_.assign(half_fft_length, 0.0);
|
||||
fft_integer_working_area_.assign(2 + static_cast<int>(sqrt(half_fft_length)),
|
||||
0);
|
||||
fft_double_working_area_.resize(half_fft_length);
|
||||
fft_integer_working_area_.resize(2 + static_cast<int>(sqrt(half_fft_length)));
|
||||
initialized_ = true;
|
||||
if (!Reset()) {
|
||||
LOG(ERROR) << "Failed to Reset()";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Spectrogram::Reset() {
|
||||
if (!initialized_) {
|
||||
LOG(ERROR) << "Initialize() has to be called, before Reset().";
|
||||
return false;
|
||||
}
|
||||
std::fill(fft_double_working_area_.begin(), fft_double_working_area_.end(),
|
||||
0.0);
|
||||
std::fill(fft_integer_working_area_.begin(), fft_integer_working_area_.end(),
|
||||
0);
|
||||
|
||||
// Set flag element to ensure that the working areas are initialized
|
||||
// on the first call to cdft. It's redundant given the assign above,
|
||||
// but keep it as a reminder.
|
||||
fft_integer_working_area_[0] = 0;
|
||||
input_queue_.clear();
|
||||
samples_to_next_step_ = window_length_;
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -56,6 +56,19 @@ class Spectrogram {
|
||||
// Initialize with an explicit window instead of a length.
|
||||
bool Initialize(const std::vector<double>& window, int step_length);
|
||||
|
||||
// Reset internal variables.
|
||||
// Spectrogram keeps internal state: remaining input data from previous call.
|
||||
// As a result it can produce different number of frames when you call
|
||||
// ComputeComplexSpectrogram multiple times (even though input data
|
||||
// has the same size). As it is shown in
|
||||
// MultipleCallsToComputeComplexSpectrogramMayYieldDifferentNumbersOfFrames
|
||||
// in tensorflow/core/kernels/spectrogram_test.cc.
|
||||
// But if you need to compute Spectrogram on input data without keeping
|
||||
// internal state (and clear remaining input data from the previous call)
|
||||
// you have to call Reset() before computing Spectrogram.
|
||||
// For example in tensorflow/core/kernels/spectrogram_op.cc
|
||||
bool Reset();
|
||||
|
||||
// Processes an arbitrary amount of audio data (contained in input)
|
||||
// to yield complex spectrogram frames. After a successful call to
|
||||
// Initialize(), Process() may be called repeatedly with new input data
|
||||
|
@ -72,6 +72,9 @@ class AudioSpectrogramOp : public OpKernel {
|
||||
|
||||
std::vector<float> input_for_channel(sample_count);
|
||||
for (int64 channel = 0; channel < channel_count; ++channel) {
|
||||
OP_REQUIRES(context, spectrogram.Reset(),
|
||||
errors::InvalidArgument("Failed to Reset()"));
|
||||
|
||||
float* output_slice =
|
||||
output_flat + (channel * output_height * output_width);
|
||||
for (int i = 0; i < sample_count; ++i) {
|
||||
|
@ -101,6 +101,45 @@ TEST(SpectrogramOpTest, SquaredTest) {
|
||||
test::AsTensor<float>({0, 1, 4, 1, 0}, TensorShape({1, 1, 5})), 1e-3);
|
||||
}
|
||||
|
||||
TEST(SpectrogramOpTest, MultichannelTest) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
const int audio_size = 8;
|
||||
const int channel_size = 2;
|
||||
Tensor audio_tensor(DT_FLOAT, TensorShape({audio_size, channel_size}));
|
||||
test::FillValues<float>(
|
||||
&audio_tensor, {-1.0f, -1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, -1.0f,
|
||||
-1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f});
|
||||
|
||||
Output audio_const_op = Const(root.WithOpName("audio_const_op"),
|
||||
Input::Initializer(audio_tensor));
|
||||
|
||||
AudioSpectrogram spectrogram_op =
|
||||
AudioSpectrogram(root.WithOpName("spectrogram_op"), audio_const_op,
|
||||
audio_size, channel_size);
|
||||
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
ClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
TF_EXPECT_OK(session.Run(ClientSession::FeedType(),
|
||||
{spectrogram_op.spectrogram}, &outputs));
|
||||
|
||||
const Tensor& spectrogram_tensor = outputs[0];
|
||||
|
||||
EXPECT_EQ(3, spectrogram_tensor.dims());
|
||||
EXPECT_EQ(5, spectrogram_tensor.dim_size(2));
|
||||
EXPECT_EQ(1, spectrogram_tensor.dim_size(1));
|
||||
EXPECT_EQ(channel_size, spectrogram_tensor.dim_size(0));
|
||||
|
||||
for (int channel = 0; channel < channel_size; channel++) {
|
||||
test::ExpectTensorNear<float>(
|
||||
spectrogram_tensor.SubSlice(channel),
|
||||
test::AsTensor<float>({0, 1, 2, 1, 0}, TensorShape({1, 5})), 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user