Add AudioSpectrogram op to TensorFlow for audio feature generation
Change: 152872386
This commit is contained in:
parent
b6d47b5e56
commit
7c9d2a458e
@ -277,6 +277,7 @@ filegroup(
|
|||||||
"//tensorflow/examples/tutorials/estimators:all_files",
|
"//tensorflow/examples/tutorials/estimators:all_files",
|
||||||
"//tensorflow/examples/tutorials/mnist:all_files",
|
"//tensorflow/examples/tutorials/mnist:all_files",
|
||||||
"//tensorflow/examples/tutorials/word2vec:all_files",
|
"//tensorflow/examples/tutorials/word2vec:all_files",
|
||||||
|
"//tensorflow/examples/wav_to_spectrogram:all_files",
|
||||||
"//tensorflow/go:all_files",
|
"//tensorflow/go:all_files",
|
||||||
"//tensorflow/java:all_files",
|
"//tensorflow/java:all_files",
|
||||||
"//tensorflow/java/src/main/java/org/tensorflow/examples:all_files",
|
"//tensorflow/java/src/main/java/org/tensorflow/examples:all_files",
|
||||||
|
@ -108,6 +108,7 @@ include(eigen)
|
|||||||
include(gemmlowp)
|
include(gemmlowp)
|
||||||
include(jsoncpp)
|
include(jsoncpp)
|
||||||
include(farmhash)
|
include(farmhash)
|
||||||
|
include(fft2d)
|
||||||
include(highwayhash)
|
include(highwayhash)
|
||||||
include(protobuf)
|
include(protobuf)
|
||||||
if (tensorflow_BUILD_CC_TESTS)
|
if (tensorflow_BUILD_CC_TESTS)
|
||||||
@ -121,6 +122,7 @@ set(tensorflow_EXTERNAL_LIBRARIES
|
|||||||
${jpeg_STATIC_LIBRARIES}
|
${jpeg_STATIC_LIBRARIES}
|
||||||
${jsoncpp_STATIC_LIBRARIES}
|
${jsoncpp_STATIC_LIBRARIES}
|
||||||
${farmhash_STATIC_LIBRARIES}
|
${farmhash_STATIC_LIBRARIES}
|
||||||
|
${fft2d_STATIC_LIBRARIES}
|
||||||
${highwayhash_STATIC_LIBRARIES}
|
${highwayhash_STATIC_LIBRARIES}
|
||||||
${protobuf_STATIC_LIBRARIES}
|
${protobuf_STATIC_LIBRARIES}
|
||||||
)
|
)
|
||||||
@ -135,6 +137,7 @@ set(tensorflow_EXTERNAL_DEPENDENCIES
|
|||||||
protobuf
|
protobuf
|
||||||
eigen
|
eigen
|
||||||
gemmlowp
|
gemmlowp
|
||||||
|
fft2d
|
||||||
)
|
)
|
||||||
|
|
||||||
include_directories(
|
include_directories(
|
||||||
|
52
tensorflow/contrib/cmake/external/fft2d.cmake
vendored
Normal file
52
tensorflow/contrib/cmake/external/fft2d.cmake
vendored
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
include (ExternalProject)
|
||||||
|
|
||||||
|
set(fft2d_URL http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz)
|
||||||
|
set(fft2d_HASH SHA256=52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296)
|
||||||
|
set(fft2d_BUILD ${CMAKE_CURRENT_BINARY_DIR}/fft2d/)
|
||||||
|
set(fft2d_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/fft2d/src)
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
set(fft2d_STATIC_LIBRARIES ${fft2d_BUILD}/src/lib/fft2d.lib)
|
||||||
|
|
||||||
|
ExternalProject_Add(fft2d
|
||||||
|
PREFIX fft2d
|
||||||
|
URL ${fft2d_URL}
|
||||||
|
URL_HASH ${fft2d_HASH}
|
||||||
|
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||||
|
BUILD_IN_SOURCE 1
|
||||||
|
PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/fft2d/CMakeLists.txt ${fft2d_BUILD}/src/fft2d/CMakeLists.txt
|
||||||
|
INSTALL_DIR ${fft2d_INSTALL}
|
||||||
|
CMAKE_CACHE_ARGS
|
||||||
|
-DCMAKE_BUILD_TYPE:STRING=Release
|
||||||
|
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
|
||||||
|
-DCMAKE_INSTALL_PREFIX:STRING=${fft2d_INSTALL})
|
||||||
|
else()
|
||||||
|
set(fft2d_STATIC_LIBRARIES ${fft2d_BUILD}/src/fft2d/libfft2d.a)
|
||||||
|
|
||||||
|
ExternalProject_Add(fft2d
|
||||||
|
PREFIX fft2d
|
||||||
|
URL ${fft2d_URL}
|
||||||
|
URL_HASH ${fft2d_HASH}
|
||||||
|
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||||
|
BUILD_IN_SOURCE 1
|
||||||
|
PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/fft2d/CMakeLists.txt ${fft2d_BUILD}/src/fft2d/CMakeLists.txt
|
||||||
|
INSTALL_DIR $(fft2d_INSTALL)
|
||||||
|
INSTALL_COMMAND echo
|
||||||
|
BUILD_COMMAND $(MAKE))
|
||||||
|
|
||||||
|
endif()
|
17
tensorflow/contrib/cmake/patches/fft2d/CMakeLists.txt
Normal file
17
tensorflow/contrib/cmake/patches/fft2d/CMakeLists.txt
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
cmake_minimum_required(VERSION 2.8.3)
|
||||||
|
|
||||||
|
project(fft2d)
|
||||||
|
|
||||||
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
|
set(FFT2D_SRCS
|
||||||
|
"fftsg.c"
|
||||||
|
)
|
||||||
|
|
||||||
|
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
|
|
||||||
|
add_library(fft2d ${FFT2D_SRCS})
|
||||||
|
|
||||||
|
install(TARGETS fft2d
|
||||||
|
LIBRARY DESTINATION lib COMPONENT RuntimeLibraries
|
||||||
|
ARCHIVE DESTINATION lib COMPONENT Development)
|
@ -496,7 +496,6 @@ cc_library(
|
|||||||
tf_gen_op_libs(
|
tf_gen_op_libs(
|
||||||
op_lib_names = [
|
op_lib_names = [
|
||||||
"array_ops",
|
"array_ops",
|
||||||
"audio_ops",
|
|
||||||
"candidate_sampling_ops",
|
"candidate_sampling_ops",
|
||||||
"control_flow_ops",
|
"control_flow_ops",
|
||||||
"ctc_ops",
|
"ctc_ops",
|
||||||
@ -526,6 +525,13 @@ tf_gen_op_libs(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_gen_op_libs(
|
||||||
|
op_lib_names = [
|
||||||
|
"audio_ops",
|
||||||
|
],
|
||||||
|
deps = [":lib"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "debug_ops_op_lib",
|
name = "debug_ops_op_lib",
|
||||||
srcs = ["ops/debug_ops.cc"],
|
srcs = ["ops/debug_ops.cc"],
|
||||||
@ -688,6 +694,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core/kernels:array",
|
"//tensorflow/core/kernels:array",
|
||||||
|
"//tensorflow/core/kernels:audio",
|
||||||
"//tensorflow/core/kernels:bincount_op",
|
"//tensorflow/core/kernels:bincount_op",
|
||||||
"//tensorflow/core/kernels:candidate_sampler_ops",
|
"//tensorflow/core/kernels:candidate_sampler_ops",
|
||||||
"//tensorflow/core/kernels:control_flow_ops",
|
"//tensorflow/core/kernels:control_flow_ops",
|
||||||
|
@ -3559,6 +3559,117 @@ tf_kernel_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "spectrogram_test_data",
|
||||||
|
srcs = [
|
||||||
|
"spectrogram_test_data/short_test_segment.wav",
|
||||||
|
"spectrogram_test_data/short_test_segment_spectrogram.csv.bin",
|
||||||
|
"spectrogram_test_data/short_test_segment_spectrogram_400_200.csv.bin",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "spectrogram",
|
||||||
|
srcs = ["spectrogram.cc"],
|
||||||
|
hdrs = ["spectrogram.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//third_party/fft2d:fft2d_headers",
|
||||||
|
"@fft2d//:fft2d",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "spectrogram_test_utils",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["spectrogram_test_utils.cc"],
|
||||||
|
hdrs = ["spectrogram_test_utils.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "spectrogram_convert_test_data",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["spectrogram_convert_test_data.cc"],
|
||||||
|
deps = [
|
||||||
|
":spectrogram_test_utils",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "spectrogram_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["spectrogram_test.cc"],
|
||||||
|
data = [":spectrogram_test_data"],
|
||||||
|
deps = [
|
||||||
|
":spectrogram",
|
||||||
|
":spectrogram_test_utils",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:lib_test_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "spectrogram_op",
|
||||||
|
prefix = "spectrogram_op",
|
||||||
|
deps = [
|
||||||
|
":spectrogram",
|
||||||
|
"//tensorflow/core:audio_ops_op_lib",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "spectrogram_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["spectrogram_op_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":ops_util",
|
||||||
|
":spectrogram_op",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:client_session",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "audio",
|
||||||
|
deps = [
|
||||||
|
":decode_wav_op",
|
||||||
|
":encode_wav_op",
|
||||||
|
":spectrogram_op",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Android libraries -----------------------------------------------------------
|
# Android libraries -----------------------------------------------------------
|
||||||
|
|
||||||
# Changes to the Android srcs here should be replicated in
|
# Changes to the Android srcs here should be replicated in
|
||||||
@ -3962,6 +4073,7 @@ filegroup(
|
|||||||
"whole_file_read_ops.*",
|
"whole_file_read_ops.*",
|
||||||
"sample_distorted_bounding_box_op.*",
|
"sample_distorted_bounding_box_op.*",
|
||||||
"ctc_loss_op.*",
|
"ctc_loss_op.*",
|
||||||
|
"spectrogram_convert_test_data.cc",
|
||||||
# Excluded due to experimental status:
|
# Excluded due to experimental status:
|
||||||
"debug_ops.*",
|
"debug_ops.*",
|
||||||
"scatter_nd_op*",
|
"scatter_nd_op*",
|
||||||
|
212
tensorflow/core/kernels/spectrogram.cc
Normal file
212
tensorflow/core/kernels/spectrogram.cc
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/spectrogram.h"
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
#include "third_party/fft2d/fft.h"
|
||||||
|
#include "tensorflow/core/lib/core/bits.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using std::complex;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Returns the default Hann window function for the spectrogram.
|
||||||
|
void GetPeriodicHann(int window_length, std::vector<double>* window) {
|
||||||
|
// Some platforms don't have M_PI, so define a local constant here.
|
||||||
|
const double pi = std::atan(1) * 4;
|
||||||
|
window->resize(window_length);
|
||||||
|
for (int i = 0; i < window_length; ++i) {
|
||||||
|
(*window)[i] = 0.5 - 0.5 * cos((2 * pi * i) / window_length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool Spectrogram::Initialize(int window_length, int step_length) {
|
||||||
|
std::vector<double> window;
|
||||||
|
GetPeriodicHann(window_length, &window);
|
||||||
|
return Initialize(window, step_length);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Spectrogram::Initialize(const std::vector<double>& window,
|
||||||
|
int step_length) {
|
||||||
|
window_length_ = window.size();
|
||||||
|
window_ = window; // Copy window.
|
||||||
|
if (window_length_ < 2) {
|
||||||
|
LOG(ERROR) << "Window length too short.";
|
||||||
|
initialized_ = false;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
step_length_ = step_length;
|
||||||
|
if (step_length_ < 1) {
|
||||||
|
LOG(ERROR) << "Step length must be positive.";
|
||||||
|
initialized_ = false;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
fft_length_ = NextPowerOfTwo(window_length_);
|
||||||
|
CHECK(fft_length_ >= window_length_);
|
||||||
|
output_frequency_channels_ = 1 + fft_length_ / 2;
|
||||||
|
|
||||||
|
// Allocate 2 more than what rdft needs, so we can rationalize the layout.
|
||||||
|
fft_input_output_.assign(fft_length_ + 2, 0.0);
|
||||||
|
|
||||||
|
int half_fft_length = fft_length_ / 2;
|
||||||
|
fft_double_working_area_.assign(half_fft_length, 0.0);
|
||||||
|
fft_integer_working_area_.assign(2 + static_cast<int>(sqrt(half_fft_length)),
|
||||||
|
0);
|
||||||
|
// Set flag element to ensure that the working areas are initialized
|
||||||
|
// on the first call to cdft. It's redundant given the assign above,
|
||||||
|
// but keep it as a reminder.
|
||||||
|
fft_integer_working_area_[0] = 0;
|
||||||
|
input_queue_.clear();
|
||||||
|
samples_to_next_step_ = window_length_;
|
||||||
|
initialized_ = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class InputSample, class OutputSample>
|
||||||
|
bool Spectrogram::ComputeComplexSpectrogram(
|
||||||
|
const std::vector<InputSample>& input,
|
||||||
|
std::vector<std::vector<complex<OutputSample>>>* output) {
|
||||||
|
if (!initialized_) {
|
||||||
|
LOG(ERROR) << "ComputeComplexSpectrogram() called before successful call "
|
||||||
|
<< "to Initialize().";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
CHECK(output);
|
||||||
|
output->clear();
|
||||||
|
int input_start = 0;
|
||||||
|
while (GetNextWindowOfSamples(input, &input_start)) {
|
||||||
|
DCHECK_EQ(input_queue_.size(), window_length_);
|
||||||
|
ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_.
|
||||||
|
// Add a new slice vector onto the output, to save new result to.
|
||||||
|
output->resize(output->size() + 1);
|
||||||
|
// Get a reference to the newly added slice to fill in.
|
||||||
|
auto& spectrogram_slice = output->back();
|
||||||
|
spectrogram_slice.resize(output_frequency_channels_);
|
||||||
|
for (int i = 0; i < output_frequency_channels_; ++i) {
|
||||||
|
// This will convert double to float if it needs to.
|
||||||
|
spectrogram_slice[i] = complex<OutputSample>(
|
||||||
|
fft_input_output_[2 * i], fft_input_output_[2 * i + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Instantiate it four ways:
|
||||||
|
template bool Spectrogram::ComputeComplexSpectrogram(
|
||||||
|
const std::vector<float>& input, std::vector<std::vector<complex<float>>>*);
|
||||||
|
template bool Spectrogram::ComputeComplexSpectrogram(
|
||||||
|
const std::vector<double>& input,
|
||||||
|
std::vector<std::vector<complex<float>>>*);
|
||||||
|
template bool Spectrogram::ComputeComplexSpectrogram(
|
||||||
|
const std::vector<float>& input,
|
||||||
|
std::vector<std::vector<complex<double>>>*);
|
||||||
|
template bool Spectrogram::ComputeComplexSpectrogram(
|
||||||
|
const std::vector<double>& input,
|
||||||
|
std::vector<std::vector<complex<double>>>*);
|
||||||
|
|
||||||
|
template <class InputSample, class OutputSample>
|
||||||
|
bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
|
||||||
|
const std::vector<InputSample>& input,
|
||||||
|
std::vector<std::vector<OutputSample>>* output) {
|
||||||
|
if (!initialized_) {
|
||||||
|
LOG(ERROR) << "ComputeSquaredMagnitudeSpectrogram() called before "
|
||||||
|
<< "successful call to Initialize().";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
CHECK(output);
|
||||||
|
output->clear();
|
||||||
|
int input_start = 0;
|
||||||
|
while (GetNextWindowOfSamples(input, &input_start)) {
|
||||||
|
DCHECK_EQ(input_queue_.size(), window_length_);
|
||||||
|
ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_.
|
||||||
|
// Add a new slice vector onto the output, to save new result to.
|
||||||
|
output->resize(output->size() + 1);
|
||||||
|
// Get a reference to the newly added slice to fill in.
|
||||||
|
auto& spectrogram_slice = output->back();
|
||||||
|
spectrogram_slice.resize(output_frequency_channels_);
|
||||||
|
for (int i = 0; i < output_frequency_channels_; ++i) {
|
||||||
|
// Similar to the Complex case, except storing the norm.
|
||||||
|
// But the norm function is known to be a performance killer,
|
||||||
|
// so do it this way with explicit real and imagninary temps.
|
||||||
|
const double re = fft_input_output_[2 * i];
|
||||||
|
const double im = fft_input_output_[2 * i + 1];
|
||||||
|
// Which finally converts double to float if it needs to.
|
||||||
|
spectrogram_slice[i] = re * re + im * im;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Instantiate it four ways:
|
||||||
|
template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
|
||||||
|
const std::vector<float>& input, std::vector<std::vector<float>>*);
|
||||||
|
template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
|
||||||
|
const std::vector<double>& input, std::vector<std::vector<float>>*);
|
||||||
|
template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
|
||||||
|
const std::vector<float>& input, std::vector<std::vector<double>>*);
|
||||||
|
template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
|
||||||
|
const std::vector<double>& input, std::vector<std::vector<double>>*);
|
||||||
|
|
||||||
|
// Return true if a full window of samples is prepared; manage the queue.
|
||||||
|
template <class InputSample>
|
||||||
|
bool Spectrogram::GetNextWindowOfSamples(const std::vector<InputSample>& input,
|
||||||
|
int* input_start) {
|
||||||
|
auto input_it = input.begin() + *input_start;
|
||||||
|
int input_remaining = input.end() - input_it;
|
||||||
|
if (samples_to_next_step_ > input_remaining) {
|
||||||
|
// Copy in as many samples are left and return false, no full window.
|
||||||
|
input_queue_.insert(input_queue_.end(), input_it, input.end());
|
||||||
|
*input_start += input_remaining; // Increases it to input.size().
|
||||||
|
samples_to_next_step_ -= input_remaining;
|
||||||
|
return false; // Not enough for a full window.
|
||||||
|
} else {
|
||||||
|
// Copy just enough into queue to make a new window, then trim the
|
||||||
|
// front off the queue to make it window-sized.
|
||||||
|
input_queue_.insert(input_queue_.end(), input_it,
|
||||||
|
input_it + samples_to_next_step_);
|
||||||
|
*input_start += samples_to_next_step_;
|
||||||
|
input_queue_.erase(
|
||||||
|
input_queue_.begin(),
|
||||||
|
input_queue_.begin() + input_queue_.size() - window_length_);
|
||||||
|
DCHECK_EQ(window_length_, input_queue_.size());
|
||||||
|
samples_to_next_step_ = step_length_; // Be ready for next time.
|
||||||
|
return true; // Yes, input_queue_ now contains exactly a window-full.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Spectrogram::ProcessCoreFFT() {
|
||||||
|
for (int j = 0; j < window_length_; ++j) {
|
||||||
|
fft_input_output_[j] = input_queue_[j] * window_[j];
|
||||||
|
}
|
||||||
|
// Zero-pad the rest of the input buffer.
|
||||||
|
for (int j = window_length_; j < fft_length_; ++j) {
|
||||||
|
fft_input_output_[j] = 0.0;
|
||||||
|
}
|
||||||
|
const int kForwardFFT = 1; // 1 means forward; -1 reverse.
|
||||||
|
// This real FFT is a fair amount faster than using cdft here.
|
||||||
|
rdft(fft_length_, kForwardFFT, &fft_input_output_[0],
|
||||||
|
&fft_integer_working_area_[0], &fft_double_working_area_[0]);
|
||||||
|
// Make rdft result look like cdft result;
|
||||||
|
// unpack the last real value from the first position's imag slot.
|
||||||
|
fft_input_output_[fft_length_] = fft_input_output_[1];
|
||||||
|
fft_input_output_[fft_length_ + 1] = 0;
|
||||||
|
fft_input_output_[1] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
112
tensorflow/core/kernels/spectrogram.h
Normal file
112
tensorflow/core/kernels/spectrogram.h
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Class for generating spectrogram slices from a waveform.
|
||||||
|
// Initialize() should be called before calls to other functions. Once
|
||||||
|
// Initialize() has been called and returned true, The Compute*() functions can
|
||||||
|
// be called repeatedly with sequential input data (ie. the first element of the
|
||||||
|
// next input vector directly follows the last element of the previous input
|
||||||
|
// vector). Whenever enough audio samples are buffered to produce a
|
||||||
|
// new frame, it will be placed in output. Output is cleared on each
|
||||||
|
// call to Compute*(). This class is thread-unsafe, and should only be
|
||||||
|
// called from one thread at a time.
|
||||||
|
// With the default parameters, the output of this class should be very
|
||||||
|
// close to the results of the following MATLAB code:
|
||||||
|
// overlap_samples = window_length_samples - step_samples;
|
||||||
|
// window = hann(window_length_samples, 'periodic');
|
||||||
|
// S = abs(spectrogram(audio, window, overlap_samples)).^2;
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
#include <deque>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "third_party/fft2d/fft.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class Spectrogram {
|
||||||
|
public:
|
||||||
|
Spectrogram() : initialized_(false) {}
|
||||||
|
~Spectrogram() {}
|
||||||
|
|
||||||
|
// Initializes the class with a given window length and step length
|
||||||
|
// (both in samples). Internally a Hann window is used as the window
|
||||||
|
// function. Returns true on success, after which calls to Process()
|
||||||
|
// are possible. window_length must be greater than 1 and step
|
||||||
|
// length must be greater than 0.
|
||||||
|
bool Initialize(int window_length, int step_length);
|
||||||
|
|
||||||
|
// Initialize with an explicit window instead of a length.
|
||||||
|
bool Initialize(const std::vector<double>& window, int step_length);
|
||||||
|
|
||||||
|
// Processes an arbitrary amount of audio data (contained in input)
|
||||||
|
// to yield complex spectrogram frames. After a successful call to
|
||||||
|
// Initialize(), Process() may be called repeatedly with new input data
|
||||||
|
// each time. The audio input is buffered internally, and the output
|
||||||
|
// vector is populated with as many temporally-ordered spectral slices
|
||||||
|
// as it is possible to generate from the input. The output is cleared
|
||||||
|
// on each call before the new frames (if any) are added.
|
||||||
|
//
|
||||||
|
// The template parameters can be float or double.
|
||||||
|
template <class InputSample, class OutputSample>
|
||||||
|
bool ComputeComplexSpectrogram(
|
||||||
|
const std::vector<InputSample>& input,
|
||||||
|
std::vector<std::vector<std::complex<OutputSample>>>* output);
|
||||||
|
|
||||||
|
// This function works as the one above, but returns the power
|
||||||
|
// (the L2 norm, or the squared magnitude) of each complex value.
|
||||||
|
template <class InputSample, class OutputSample>
|
||||||
|
bool ComputeSquaredMagnitudeSpectrogram(
|
||||||
|
const std::vector<InputSample>& input,
|
||||||
|
std::vector<std::vector<OutputSample>>* output);
|
||||||
|
|
||||||
|
// Return reference to the window function used internally.
|
||||||
|
const std::vector<double>& GetWindow() const { return window_; }
|
||||||
|
|
||||||
|
// Return the number of frequency channels in the spectrogram.
|
||||||
|
int output_frequency_channels() const { return output_frequency_channels_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <class InputSample>
|
||||||
|
bool GetNextWindowOfSamples(const std::vector<InputSample>& input,
|
||||||
|
int* input_start);
|
||||||
|
void ProcessCoreFFT();
|
||||||
|
|
||||||
|
int fft_length_;
|
||||||
|
int output_frequency_channels_;
|
||||||
|
int window_length_;
|
||||||
|
int step_length_;
|
||||||
|
bool initialized_;
|
||||||
|
int samples_to_next_step_;
|
||||||
|
|
||||||
|
std::vector<double> window_;
|
||||||
|
std::vector<double> fft_input_output_;
|
||||||
|
std::deque<double> input_queue_;
|
||||||
|
|
||||||
|
// Working data areas for the FFT routines.
|
||||||
|
std::vector<int> fft_integer_working_area_;
|
||||||
|
std::vector<double> fft_double_working_area_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(Spectrogram);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
|
56
tensorflow/core/kernels/spectrogram_convert_test_data.cc
Normal file
56
tensorflow/core/kernels/spectrogram_convert_test_data.cc
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/spectrogram_test_utils.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace wav {
|
||||||
|
|
||||||
|
// This takes a CSV file representing an array of complex numbers, and saves out
|
||||||
|
// a version using a binary format to save space in the repository.
|
||||||
|
Status ConvertCsvToRaw(const string& input_filename) {
|
||||||
|
std::vector<std::vector<std::complex<double>>> input_data;
|
||||||
|
ReadCSVFileToComplexVectorOrDie(input_filename, &input_data);
|
||||||
|
const string output_filename = input_filename + ".bin";
|
||||||
|
if (!WriteComplexVectorToRawFloatFile(output_filename, input_data)) {
|
||||||
|
return errors::InvalidArgument("Failed to write raw float file ",
|
||||||
|
input_filename);
|
||||||
|
}
|
||||||
|
LOG(INFO) << "Wrote raw file to " << output_filename;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace wav
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||||
|
if (argc < 2) {
|
||||||
|
LOG(ERROR) << "You must supply a CSV file as the first argument";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
tensorflow::string filename(argv[1]);
|
||||||
|
tensorflow::Status status = tensorflow::wav::ConvertCsvToRaw(filename);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Error processing '" << filename << "':" << status;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
120
tensorflow/core/kernels/spectrogram_op.cc
Normal file
120
tensorflow/core/kernels/spectrogram_op.cc
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// See docs in ../ops/audio_ops.cc
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/spectrogram.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Create a spectrogram frequency visualization from audio data.
|
||||||
|
class SpectrogramOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit SpectrogramOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("window_size", &window_size_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("stride", &stride_));
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->GetAttr("magnitude_squared", &magnitude_squared_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor& input = context->input(0);
|
||||||
|
OP_REQUIRES(context, input.dims() == 2,
|
||||||
|
errors::InvalidArgument("input must be 2-dimensional",
|
||||||
|
input.shape().DebugString()));
|
||||||
|
Spectrogram spectrogram;
|
||||||
|
OP_REQUIRES(context, spectrogram.Initialize(window_size_, stride_),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Spectrogram initialization failed for window size ",
|
||||||
|
window_size_, " and stride ", stride_));
|
||||||
|
|
||||||
|
const auto input_as_matrix = input.matrix<float>();
|
||||||
|
|
||||||
|
const int64 sample_count = input.dim_size(0);
|
||||||
|
const int64 channel_count = input.dim_size(1);
|
||||||
|
|
||||||
|
const int64 output_width = spectrogram.output_frequency_channels();
|
||||||
|
const int64 length_minus_window = (sample_count - window_size_);
|
||||||
|
int64 output_height;
|
||||||
|
if (length_minus_window < 0) {
|
||||||
|
output_height = 0;
|
||||||
|
} else {
|
||||||
|
output_height = 1 + (length_minus_window / stride_);
|
||||||
|
}
|
||||||
|
const int64 output_slices = channel_count;
|
||||||
|
|
||||||
|
Tensor* output_tensor = nullptr;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context,
|
||||||
|
context->allocate_output(
|
||||||
|
0, TensorShape({output_slices, output_height, output_width}),
|
||||||
|
&output_tensor));
|
||||||
|
auto output_flat = output_tensor->flat<float>().data();
|
||||||
|
|
||||||
|
std::vector<float> input_for_channel(sample_count);
|
||||||
|
for (int64 channel = 0; channel < channel_count; ++channel) {
|
||||||
|
float* output_slice =
|
||||||
|
output_flat + (channel * output_height * output_width);
|
||||||
|
for (int i = 0; i < sample_count; ++i) {
|
||||||
|
input_for_channel[i] = input_as_matrix(i, channel);
|
||||||
|
}
|
||||||
|
std::vector<std::vector<float>> spectrogram_output;
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
spectrogram.ComputeSquaredMagnitudeSpectrogram(
|
||||||
|
input_for_channel, &spectrogram_output),
|
||||||
|
errors::InvalidArgument("Spectrogram compute failed"));
|
||||||
|
OP_REQUIRES(context, (spectrogram_output.size() == output_height),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Spectrogram size calculation failed: Expected height ",
|
||||||
|
output_height, " but got ", spectrogram_output.size()));
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
spectrogram_output.empty() ||
|
||||||
|
(spectrogram_output[0].size() == output_width),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Spectrogram size calculation failed: Expected width ",
|
||||||
|
output_width, " but got ", spectrogram_output[0].size()));
|
||||||
|
for (int row_index = 0; row_index < output_height; ++row_index) {
|
||||||
|
const std::vector<float>& spectrogram_row =
|
||||||
|
spectrogram_output[row_index];
|
||||||
|
DCHECK_EQ(spectrogram_row.size(), output_width);
|
||||||
|
float* output_row = output_slice + (row_index * output_width);
|
||||||
|
if (magnitude_squared_) {
|
||||||
|
for (int i = 0; i < output_width; ++i) {
|
||||||
|
output_row[i] = spectrogram_row[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < output_width; ++i) {
|
||||||
|
output_row[i] = sqrtf(spectrogram_row[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32 window_size_;
|
||||||
|
int32 stride_;
|
||||||
|
bool magnitude_squared_;
|
||||||
|
};
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("AudioSpectrogram").Device(DEVICE_CPU),
|
||||||
|
SpectrogramOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
104
tensorflow/core/kernels/spectrogram_op_test.cc
Normal file
104
tensorflow/core/kernels/spectrogram_op_test.cc
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/client/client_session.h"
|
||||||
|
#include "tensorflow/cc/ops/audio_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/math_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using namespace ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
TEST(SpectrogramOpTest, SimpleTest) {
|
||||||
|
Scope root = Scope::NewRootScope();
|
||||||
|
|
||||||
|
Tensor audio_tensor(DT_FLOAT, TensorShape({8, 1}));
|
||||||
|
test::FillValues<float>(&audio_tensor,
|
||||||
|
{-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f});
|
||||||
|
|
||||||
|
Output audio_const_op = Const(root.WithOpName("audio_const_op"),
|
||||||
|
Input::Initializer(audio_tensor));
|
||||||
|
|
||||||
|
AudioSpectrogram spectrogram_op =
|
||||||
|
AudioSpectrogram(root.WithOpName("spectrogram_op"), audio_const_op, 8, 1);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
|
||||||
|
ClientSession session(root);
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
|
||||||
|
TF_EXPECT_OK(session.Run(ClientSession::FeedType(),
|
||||||
|
{spectrogram_op.spectrogram}, &outputs));
|
||||||
|
|
||||||
|
const Tensor& spectrogram_tensor = outputs[0];
|
||||||
|
|
||||||
|
EXPECT_EQ(3, spectrogram_tensor.dims());
|
||||||
|
EXPECT_EQ(5, spectrogram_tensor.dim_size(2));
|
||||||
|
EXPECT_EQ(1, spectrogram_tensor.dim_size(1));
|
||||||
|
EXPECT_EQ(1, spectrogram_tensor.dim_size(0));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(
|
||||||
|
spectrogram_tensor,
|
||||||
|
test::AsTensor<float>({0, 1, 2, 1, 0}, TensorShape({1, 1, 5})), 1e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramOpTest, SquaredTest) {
|
||||||
|
Scope root = Scope::NewRootScope();
|
||||||
|
|
||||||
|
Tensor audio_tensor(DT_FLOAT, TensorShape({8, 1}));
|
||||||
|
test::FillValues<float>(&audio_tensor,
|
||||||
|
{-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f});
|
||||||
|
|
||||||
|
Output audio_const_op = Const(root.WithOpName("audio_const_op"),
|
||||||
|
Input::Initializer(audio_tensor));
|
||||||
|
|
||||||
|
AudioSpectrogram spectrogram_op =
|
||||||
|
AudioSpectrogram(root.WithOpName("spectrogram_op"), audio_const_op, 8, 1,
|
||||||
|
AudioSpectrogram::Attrs().MagnitudeSquared(true));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(root.status());
|
||||||
|
|
||||||
|
ClientSession session(root);
|
||||||
|
std::vector<Tensor> outputs;
|
||||||
|
|
||||||
|
TF_EXPECT_OK(session.Run(ClientSession::FeedType(),
|
||||||
|
{spectrogram_op.spectrogram}, &outputs));
|
||||||
|
|
||||||
|
const Tensor& spectrogram_tensor = outputs[0];
|
||||||
|
|
||||||
|
EXPECT_EQ(3, spectrogram_tensor.dims());
|
||||||
|
EXPECT_EQ(5, spectrogram_tensor.dim_size(2));
|
||||||
|
EXPECT_EQ(1, spectrogram_tensor.dim_size(1));
|
||||||
|
EXPECT_EQ(1, spectrogram_tensor.dim_size(0));
|
||||||
|
|
||||||
|
test::ExpectTensorNear<float>(
|
||||||
|
spectrogram_tensor,
|
||||||
|
test::AsTensor<float>({0, 1, 4, 1, 0}, TensorShape({1, 1, 5})), 1e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
340
tensorflow/core/kernels/spectrogram_test.cc
Normal file
340
tensorflow/core/kernels/spectrogram_test.cc
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// The MATLAB test data were generated using GenerateTestData.m.
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/spectrogram.h"
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/spectrogram_test_utils.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using ::std::complex;
|
||||||
|
|
||||||
|
const char kInputFilename[] =
|
||||||
|
"core/kernels/spectrogram_test_data/short_test_segment.wav";
|
||||||
|
|
||||||
|
const char kExpectedFilename[] =
|
||||||
|
"core/kernels/spectrogram_test_data/short_test_segment_spectrogram.csv.bin";
|
||||||
|
const int kDataVectorLength = 257;
|
||||||
|
const int kNumberOfFramesInTestData = 178;
|
||||||
|
|
||||||
|
const char kExpectedNonPowerOfTwoFilename[] =
|
||||||
|
"core/kernels/spectrogram_test_data/"
|
||||||
|
"short_test_segment_spectrogram_400_200.csv.bin";
|
||||||
|
const int kNonPowerOfTwoDataVectorLength = 257;
|
||||||
|
const int kNumberOfFramesInNonPowerOfTwoTestData = 228;
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, TooLittleDataYieldsNoFrames) {
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(400, 200);
|
||||||
|
std::vector<double> input;
|
||||||
|
// Generate 44 samples of audio.
|
||||||
|
SineWave(44100, 1000.0, 0.001, &input);
|
||||||
|
EXPECT_EQ(44, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> output;
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
EXPECT_EQ(0, output.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, StepSizeSmallerThanWindow) {
|
||||||
|
Spectrogram sgram;
|
||||||
|
EXPECT_TRUE(sgram.Initialize(400, 200));
|
||||||
|
std::vector<double> input;
|
||||||
|
// Generate 661 samples of audio.
|
||||||
|
SineWave(44100, 1000.0, 0.015, &input);
|
||||||
|
EXPECT_EQ(661, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> output;
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
EXPECT_EQ(2, output.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, StepSizeBiggerThanWindow) {
|
||||||
|
Spectrogram sgram;
|
||||||
|
EXPECT_TRUE(sgram.Initialize(200, 400));
|
||||||
|
std::vector<double> input;
|
||||||
|
// Generate 882 samples of audio.
|
||||||
|
SineWave(44100, 1000.0, 0.02, &input);
|
||||||
|
EXPECT_EQ(882, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> output;
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
EXPECT_EQ(2, output.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, StepSizeBiggerThanWindow2) {
|
||||||
|
Spectrogram sgram;
|
||||||
|
EXPECT_TRUE(sgram.Initialize(200, 400));
|
||||||
|
std::vector<double> input;
|
||||||
|
// Generate more than 600 but fewer than 800 samples of audio.
|
||||||
|
SineWave(44100, 1000.0, 0.016, &input);
|
||||||
|
EXPECT_GT(input.size(), 600);
|
||||||
|
EXPECT_LT(input.size(), 800);
|
||||||
|
std::vector<std::vector<complex<double>>> output;
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
EXPECT_EQ(2, output.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest,
|
||||||
|
MultipleCallsToComputeComplexSpectrogramMayYieldDifferentNumbersOfFrames) {
|
||||||
|
// Repeatedly pass inputs with "extra" samples beyond complete windows
|
||||||
|
// and check that the excess points cumulate to eventually cause an
|
||||||
|
// extra output frame.
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(200, 400);
|
||||||
|
std::vector<double> input;
|
||||||
|
// Generate 882 samples of audio.
|
||||||
|
SineWave(44100, 1000.0, 0.02, &input);
|
||||||
|
EXPECT_EQ(882, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> output;
|
||||||
|
const std::vector<int> expected_output_sizes = {
|
||||||
|
2, // One pass of input leaves 82 samples buffered after two steps of
|
||||||
|
// 400.
|
||||||
|
2, // Passing in 882 samples again will now leave 164 samples buffered.
|
||||||
|
3, // Third time gives 246 extra samples, triggering an extra output
|
||||||
|
// frame.
|
||||||
|
};
|
||||||
|
for (int expected_output_size : expected_output_sizes) {
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
EXPECT_EQ(expected_output_size, output.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, CumulatingExcessInputsForOverlappingFrames) {
|
||||||
|
// Input frames that don't fit into whole windows are cumulated even when
|
||||||
|
// the windows have overlap (similar to
|
||||||
|
// MultipleCallsToComputeComplexSpectrogramMayYieldDifferentNumbersOfFrames
|
||||||
|
// but with window size/hop size swapped).
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(400, 200);
|
||||||
|
std::vector<double> input;
|
||||||
|
// Generate 882 samples of audio.
|
||||||
|
SineWave(44100, 1000.0, 0.02, &input);
|
||||||
|
EXPECT_EQ(882, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> output;
|
||||||
|
const std::vector<int> expected_output_sizes = {
|
||||||
|
3, // Windows 0..400, 200..600, 400..800 with 82 samples buffered.
|
||||||
|
4, // 1764 frames input; outputs from 600, 800, 1000, 1200..1600.
|
||||||
|
5, // 2646 frames in; outputs from 1400, 1600, 1800, 2000, 2200..2600.
|
||||||
|
};
|
||||||
|
for (int expected_output_size : expected_output_sizes) {
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
EXPECT_EQ(expected_output_size, output.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, StepSizeEqualToWindowWorks) {
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(200, 200);
|
||||||
|
std::vector<double> input;
|
||||||
|
// Generate 2205 samples of audio.
|
||||||
|
SineWave(44100, 1000.0, 0.05, &input);
|
||||||
|
EXPECT_EQ(2205, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> output;
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
EXPECT_EQ(11, output.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ExpectedSample, class ActualSample>
|
||||||
|
void CompareComplexData(
|
||||||
|
const std::vector<std::vector<complex<ExpectedSample>>>& expected,
|
||||||
|
const std::vector<std::vector<complex<ActualSample>>>& actual,
|
||||||
|
double tolerance) {
|
||||||
|
ASSERT_EQ(actual.size(), expected.size());
|
||||||
|
for (int i = 0; i < expected.size(); ++i) {
|
||||||
|
ASSERT_EQ(expected[i].size(), actual[i].size());
|
||||||
|
for (int j = 0; j < expected[i].size(); ++j) {
|
||||||
|
ASSERT_NEAR(real(expected[i][j]), real(actual[i][j]), tolerance)
|
||||||
|
<< ": where i=" << i << " and j=" << j << ".";
|
||||||
|
ASSERT_NEAR(imag(expected[i][j]), imag(actual[i][j]), tolerance)
|
||||||
|
<< ": where i=" << i << " and j=" << j << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Sample>
|
||||||
|
double GetMaximumAbsolute(const std::vector<std::vector<Sample>>& spectrogram) {
|
||||||
|
double max_absolute = 0.0;
|
||||||
|
for (int i = 0; i < spectrogram.size(); ++i) {
|
||||||
|
for (int j = 0; j < spectrogram[i].size(); ++j) {
|
||||||
|
double absolute_value = std::abs(spectrogram[i][j]);
|
||||||
|
if (absolute_value > max_absolute) {
|
||||||
|
max_absolute = absolute_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max_absolute;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ExpectedSample, class ActualSample>
|
||||||
|
void CompareMagnitudeData(
|
||||||
|
const std::vector<std::vector<complex<ExpectedSample>>>&
|
||||||
|
expected_complex_output,
|
||||||
|
const std::vector<std::vector<ActualSample>>& actual_squared_magnitude,
|
||||||
|
double tolerance) {
|
||||||
|
ASSERT_EQ(actual_squared_magnitude.size(), expected_complex_output.size());
|
||||||
|
for (int i = 0; i < expected_complex_output.size(); ++i) {
|
||||||
|
ASSERT_EQ(expected_complex_output[i].size(),
|
||||||
|
actual_squared_magnitude[i].size());
|
||||||
|
for (int j = 0; j < expected_complex_output[i].size(); ++j) {
|
||||||
|
ASSERT_NEAR(norm(expected_complex_output[i][j]),
|
||||||
|
actual_squared_magnitude[i][j], tolerance)
|
||||||
|
<< ": where i=" << i << " and j=" << j << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, ReInitializationWorks) {
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(512, 256);
|
||||||
|
std::vector<double> input;
|
||||||
|
CHECK(ReadWaveFileToVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
|
||||||
|
&input));
|
||||||
|
std::vector<std::vector<complex<double>>> first_output;
|
||||||
|
std::vector<std::vector<complex<double>>> second_output;
|
||||||
|
sgram.Initialize(512, 256);
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &first_output);
|
||||||
|
// Re-Initialize it.
|
||||||
|
sgram.Initialize(512, 256);
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &second_output);
|
||||||
|
// Verify identical outputs.
|
||||||
|
ASSERT_EQ(first_output.size(), second_output.size());
|
||||||
|
int slice_size = first_output[0].size();
|
||||||
|
for (int i = 0; i < first_output.size(); ++i) {
|
||||||
|
ASSERT_EQ(slice_size, first_output[i].size());
|
||||||
|
ASSERT_EQ(slice_size, second_output[i].size());
|
||||||
|
for (int j = 0; j < slice_size; ++j) {
|
||||||
|
ASSERT_EQ(first_output[i][j], second_output[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, ComputedComplexDataAgreeWithMatlab) {
|
||||||
|
const int kInputDataLength = 45870;
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(512, 256);
|
||||||
|
std::vector<double> input;
|
||||||
|
CHECK(ReadWaveFileToVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
|
||||||
|
&input));
|
||||||
|
EXPECT_EQ(kInputDataLength, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> expected_output;
|
||||||
|
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename),
|
||||||
|
kDataVectorLength, &expected_output));
|
||||||
|
EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size());
|
||||||
|
EXPECT_EQ(kDataVectorLength, expected_output[0].size());
|
||||||
|
std::vector<std::vector<complex<double>>> output;
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
CompareComplexData(expected_output, output, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, ComputedFloatComplexDataAgreeWithMatlab) {
|
||||||
|
const int kInputDataLength = 45870;
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(512, 256);
|
||||||
|
std::vector<double> double_input;
|
||||||
|
CHECK(ReadWaveFileToVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
|
||||||
|
&double_input));
|
||||||
|
std::vector<float> input;
|
||||||
|
input.assign(double_input.begin(), double_input.end());
|
||||||
|
EXPECT_EQ(kInputDataLength, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> expected_output;
|
||||||
|
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename),
|
||||||
|
kDataVectorLength, &expected_output));
|
||||||
|
EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size());
|
||||||
|
EXPECT_EQ(kDataVectorLength, expected_output[0].size());
|
||||||
|
std::vector<std::vector<complex<float>>> output;
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
CompareComplexData(expected_output, output, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, ComputedSquaredMagnitudeDataAgreeWithMatlab) {
|
||||||
|
const int kInputDataLength = 45870;
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(512, 256);
|
||||||
|
std::vector<double> input;
|
||||||
|
CHECK(ReadWaveFileToVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
|
||||||
|
&input));
|
||||||
|
EXPECT_EQ(kInputDataLength, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> expected_output;
|
||||||
|
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename),
|
||||||
|
kDataVectorLength, &expected_output));
|
||||||
|
EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size());
|
||||||
|
EXPECT_EQ(kDataVectorLength, expected_output[0].size());
|
||||||
|
std::vector<std::vector<double>> output;
|
||||||
|
sgram.ComputeSquaredMagnitudeSpectrogram(input, &output);
|
||||||
|
CompareMagnitudeData(expected_output, output, 1e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, ComputedFloatSquaredMagnitudeDataAgreeWithMatlab) {
|
||||||
|
const int kInputDataLength = 45870;
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(512, 256);
|
||||||
|
std::vector<double> double_input;
|
||||||
|
CHECK(ReadWaveFileToVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
|
||||||
|
&double_input));
|
||||||
|
EXPECT_EQ(kInputDataLength, double_input.size());
|
||||||
|
std::vector<float> input;
|
||||||
|
input.assign(double_input.begin(), double_input.end());
|
||||||
|
std::vector<std::vector<complex<double>>> expected_output;
|
||||||
|
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename),
|
||||||
|
kDataVectorLength, &expected_output));
|
||||||
|
EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size());
|
||||||
|
EXPECT_EQ(kDataVectorLength, expected_output[0].size());
|
||||||
|
std::vector<std::vector<float>> output;
|
||||||
|
sgram.ComputeSquaredMagnitudeSpectrogram(input, &output);
|
||||||
|
double max_absolute = GetMaximumAbsolute(output);
|
||||||
|
EXPECT_GT(max_absolute, 2300.0); // Verify that we have some big numbers.
|
||||||
|
// Squaring increases dynamic range; max square is about 2300,
|
||||||
|
// so 2e-4 is about 7 decimal digits; not bad for a float.
|
||||||
|
CompareMagnitudeData(expected_output, output, 2e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpectrogramTest, ComputedNonPowerOfTwoComplexDataAgreeWithMatlab) {
|
||||||
|
const int kInputDataLength = 45870;
|
||||||
|
Spectrogram sgram;
|
||||||
|
sgram.Initialize(400, 200);
|
||||||
|
std::vector<double> input;
|
||||||
|
CHECK(ReadWaveFileToVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
|
||||||
|
&input));
|
||||||
|
EXPECT_EQ(kInputDataLength, input.size());
|
||||||
|
std::vector<std::vector<complex<double>>> expected_output;
|
||||||
|
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
|
||||||
|
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||||
|
kExpectedNonPowerOfTwoFilename),
|
||||||
|
kNonPowerOfTwoDataVectorLength, &expected_output));
|
||||||
|
EXPECT_EQ(kNumberOfFramesInNonPowerOfTwoTestData, expected_output.size());
|
||||||
|
EXPECT_EQ(kNonPowerOfTwoDataVectorLength, expected_output[0].size());
|
||||||
|
std::vector<std::vector<complex<double>>> output;
|
||||||
|
sgram.ComputeComplexSpectrogram(input, &output);
|
||||||
|
CompareComplexData(expected_output, output, 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
8
tensorflow/core/kernels/spectrogram_test_data/README
Normal file
8
tensorflow/core/kernels/spectrogram_test_data/README
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
The CSV spectrogram files in this directory are generated from the
|
||||||
|
matlab code in ./matlab/GenerateTestData.m
|
||||||
|
To save space in the repo, you'll then need to convert them into a binary packed
|
||||||
|
format using the convert_test_data.cc command line tool.
|
||||||
|
|
||||||
|
|
||||||
|
short_test_segment.wav is approximately 1s of music audio.
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
288
tensorflow/core/kernels/spectrogram_test_utils.cc
Normal file
288
tensorflow/core/kernels/spectrogram_test_utils.cc
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/spectrogram_test_utils.h"
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/lib/wav/wav_io.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
bool ReadWaveFileToVector(const string& file_name, std::vector<double>* data) {
|
||||||
|
string wav_data;
|
||||||
|
if (!ReadFileToString(Env::Default(), file_name, &wav_data).ok()) {
|
||||||
|
LOG(ERROR) << "Wave file read failed for " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<float> decoded_data;
|
||||||
|
uint32 decoded_sample_count;
|
||||||
|
uint16 decoded_channel_count;
|
||||||
|
uint32 decoded_sample_rate;
|
||||||
|
if (!wav::DecodeLin16WaveAsFloatVector(
|
||||||
|
wav_data, &decoded_data, &decoded_sample_count,
|
||||||
|
&decoded_channel_count, &decoded_sample_rate)
|
||||||
|
.ok()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Convert from float to double for the output value.
|
||||||
|
data->resize(decoded_data.size());
|
||||||
|
for (int i = 0; i < decoded_data.size(); ++i) {
|
||||||
|
(*data)[i] = decoded_data[i];
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ReadRawFloatFileToComplexVector(
|
||||||
|
const string& file_name, int row_length,
|
||||||
|
std::vector<std::vector<std::complex<double> > >* data) {
|
||||||
|
data->clear();
|
||||||
|
string data_string;
|
||||||
|
if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
|
||||||
|
LOG(ERROR) << "Failed to open file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
float real_out;
|
||||||
|
float imag_out;
|
||||||
|
const int kBytesPerValue = 4;
|
||||||
|
CHECK_EQ(sizeof(real_out), kBytesPerValue);
|
||||||
|
std::vector<std::complex<double> > data_row;
|
||||||
|
int row_counter = 0;
|
||||||
|
int offset = 0;
|
||||||
|
const int end = data_string.size();
|
||||||
|
while (offset < end) {
|
||||||
|
memcpy(&real_out, data_string.data() + offset, kBytesPerValue);
|
||||||
|
offset += kBytesPerValue;
|
||||||
|
memcpy(&imag_out, data_string.data() + offset, kBytesPerValue);
|
||||||
|
offset += kBytesPerValue;
|
||||||
|
if (row_counter >= row_length) {
|
||||||
|
data->push_back(data_row);
|
||||||
|
data_row.clear();
|
||||||
|
row_counter = 0;
|
||||||
|
}
|
||||||
|
data_row.push_back(std::complex<double>(real_out, imag_out));
|
||||||
|
++row_counter;
|
||||||
|
}
|
||||||
|
if (row_counter >= row_length) {
|
||||||
|
data->push_back(data_row);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadCSVFileToComplexVectorOrDie(
|
||||||
|
const string& file_name,
|
||||||
|
std::vector<std::vector<std::complex<double> > >* data) {
|
||||||
|
data->clear();
|
||||||
|
string data_string;
|
||||||
|
if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
|
||||||
|
LOG(FATAL) << "Failed to open file " << file_name;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<string> lines = str_util::Split(data_string, '\n');
|
||||||
|
for (const string& line : lines) {
|
||||||
|
if (line == "") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<std::complex<double> > data_line;
|
||||||
|
std::vector<string> values = str_util::Split(line, ',');
|
||||||
|
for (std::vector<string>::const_iterator i = values.begin();
|
||||||
|
i != values.end(); ++i) {
|
||||||
|
// each element of values may be in the form:
|
||||||
|
// 0.001+0.002i, 0.001, 0.001i, -1.2i, -1.2-3.2i, 1.5, 1.5e-03+21.0i
|
||||||
|
std::vector<string> parts;
|
||||||
|
// Find the first instance of + or - after the second character
|
||||||
|
// in the string, that does not immediately follow an 'e'.
|
||||||
|
size_t operator_index = i->find_first_of("+-", 2);
|
||||||
|
if (operator_index < i->size() &&
|
||||||
|
i->substr(operator_index - 1, 1) == "e") {
|
||||||
|
operator_index = i->find_first_of("+-", operator_index + 1);
|
||||||
|
}
|
||||||
|
parts.push_back(i->substr(0, operator_index));
|
||||||
|
if (operator_index < i->size()) {
|
||||||
|
parts.push_back(i->substr(operator_index, string::npos));
|
||||||
|
}
|
||||||
|
|
||||||
|
double real_part = 0.0;
|
||||||
|
double imaginary_part = 0.0;
|
||||||
|
for (std::vector<string>::const_iterator j = parts.begin();
|
||||||
|
j != parts.end(); ++j) {
|
||||||
|
if (j->find_first_of("ij") != string::npos) {
|
||||||
|
strings::safe_strtod((*j).c_str(), &imaginary_part);
|
||||||
|
} else {
|
||||||
|
strings::safe_strtod((*j).c_str(), &real_part);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data_line.push_back(std::complex<double>(real_part, imaginary_part));
|
||||||
|
}
|
||||||
|
data->push_back(data_line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadCSVFileToArrayOrDie(const string& filename,
|
||||||
|
std::vector<std::vector<float> >* array) {
|
||||||
|
string contents;
|
||||||
|
TF_CHECK_OK(ReadFileToString(Env::Default(), filename, &contents));
|
||||||
|
std::vector<string> lines = str_util::Split(contents, '\n');
|
||||||
|
contents.clear();
|
||||||
|
|
||||||
|
array->clear();
|
||||||
|
std::vector<float> values;
|
||||||
|
for (int l = 0; l < lines.size(); ++l) {
|
||||||
|
values.clear();
|
||||||
|
CHECK(str_util::SplitAndParseAsFloats(lines[l], ',', &values));
|
||||||
|
array->push_back(values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WriteDoubleVectorToFile(const string& file_name,
|
||||||
|
const std::vector<double>& data) {
|
||||||
|
std::unique_ptr<WritableFile> file;
|
||||||
|
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
|
||||||
|
LOG(ERROR) << "Failed to open file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < data.size(); ++i) {
|
||||||
|
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
|
||||||
|
sizeof(data[i])))
|
||||||
|
.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to append to file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!file->Close().ok()) {
|
||||||
|
LOG(ERROR) << "Failed to close file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WriteFloatVectorToFile(const string& file_name,
|
||||||
|
const std::vector<float>& data) {
|
||||||
|
std::unique_ptr<WritableFile> file;
|
||||||
|
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
|
||||||
|
LOG(ERROR) << "Failed to open file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < data.size(); ++i) {
|
||||||
|
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
|
||||||
|
sizeof(data[i])))
|
||||||
|
.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to append to file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!file->Close().ok()) {
|
||||||
|
LOG(ERROR) << "Failed to close file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WriteDoubleArrayToFile(const string& file_name, int size,
|
||||||
|
const double* data) {
|
||||||
|
std::unique_ptr<WritableFile> file;
|
||||||
|
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
|
||||||
|
LOG(ERROR) << "Failed to open file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
|
||||||
|
sizeof(data[i])))
|
||||||
|
.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to append to file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!file->Close().ok()) {
|
||||||
|
LOG(ERROR) << "Failed to close file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WriteFloatArrayToFile(const string& file_name, int size,
|
||||||
|
const float* data) {
|
||||||
|
std::unique_ptr<WritableFile> file;
|
||||||
|
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
|
||||||
|
LOG(ERROR) << "Failed to open file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
|
||||||
|
sizeof(data[i])))
|
||||||
|
.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to append to file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!file->Close().ok()) {
|
||||||
|
LOG(ERROR) << "Failed to close file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WriteComplexVectorToRawFloatFile(
|
||||||
|
const string& file_name,
|
||||||
|
const std::vector<std::vector<std::complex<double> > >& data) {
|
||||||
|
std::unique_ptr<WritableFile> file;
|
||||||
|
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
|
||||||
|
LOG(ERROR) << "Failed to open file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < data.size(); ++i) {
|
||||||
|
for (int j = 0; j < data[i].size(); ++j) {
|
||||||
|
const float real_part(real(data[i][j]));
|
||||||
|
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&real_part),
|
||||||
|
sizeof(real_part)))
|
||||||
|
.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to append to file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float imag_part(imag(data[i][j]));
|
||||||
|
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&imag_part),
|
||||||
|
sizeof(imag_part)))
|
||||||
|
.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to append to file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!file->Close().ok()) {
|
||||||
|
LOG(ERROR) << "Failed to close file " << file_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SineWave(int sample_rate, float frequency, float duration_seconds,
|
||||||
|
std::vector<double>* data) {
|
||||||
|
data->clear();
|
||||||
|
for (int i = 0; i < static_cast<int>(sample_rate * duration_seconds); ++i) {
|
||||||
|
data->push_back(
|
||||||
|
sin(2.0 * M_PI * i * frequency / static_cast<double>(sample_rate)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
81
tensorflow/core/kernels/spectrogram_test_utils.h
Normal file
81
tensorflow/core/kernels/spectrogram_test_utils.h
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Reads a wav format file into a vector of floating-point values with range
|
||||||
|
// -1.0 to 1.0.
|
||||||
|
bool ReadWaveFileToVector(const string& file_name, std::vector<double>* data);
|
||||||
|
|
||||||
|
// Reads a binary file containing 32-bit floating point values in the
|
||||||
|
// form [real_1, imag_1, real_2, imag_2, ...] into a rectangular array
|
||||||
|
// of complex values where row_length is the length of each inner vector.
|
||||||
|
bool ReadRawFloatFileToComplexVector(
|
||||||
|
const string& file_name, int row_length,
|
||||||
|
std::vector<std::vector<std::complex<double> > >* data);
|
||||||
|
|
||||||
|
// Reads a CSV file of numbers in the format 1.1+2.2i,1.1,2.2i,3.3j into data.
|
||||||
|
void ReadCSVFileToComplexVectorOrDie(
|
||||||
|
const string& file_name,
|
||||||
|
std::vector<std::vector<std::complex<double> > >* data);
|
||||||
|
|
||||||
|
// Reads a 2D array of floats from an ASCII text file, where each line is a row
|
||||||
|
// of the array, and elements are separated by commas.
|
||||||
|
void ReadCSVFileToArrayOrDie(const string& filename,
|
||||||
|
std::vector<std::vector<float> >* array);
|
||||||
|
|
||||||
|
// Write a binary file containing 64-bit floating-point values for
|
||||||
|
// reading by, for example, MATLAB.
|
||||||
|
bool WriteDoubleVectorToFile(const string& file_name,
|
||||||
|
const std::vector<double>& data);
|
||||||
|
|
||||||
|
// Write a binary file containing 32-bit floating-point values for
|
||||||
|
// reading by, for example, MATLAB.
|
||||||
|
bool WriteFloatVectorToFile(const string& file_name,
|
||||||
|
const std::vector<float>& data);
|
||||||
|
|
||||||
|
// Write a binary file containing 64-bit floating-point values for
|
||||||
|
// reading by, for example, MATLAB.
|
||||||
|
bool WriteDoubleArrayToFile(const string& file_name, int size,
|
||||||
|
const double* data);
|
||||||
|
|
||||||
|
// Write a binary file containing 32-bit floating-point values for
|
||||||
|
// reading by, for example, MATLAB.
|
||||||
|
bool WriteFloatArrayToFile(const string& file_name, int size,
|
||||||
|
const float* data);
|
||||||
|
|
||||||
|
// Write a binary file in the format read by
|
||||||
|
// ReadRawDoubleFileToComplexVector above.
|
||||||
|
bool WriteComplexVectorToRawFloatFile(
|
||||||
|
const string& file_name,
|
||||||
|
const std::vector<std::vector<std::complex<double> > >& data);
|
||||||
|
|
||||||
|
// Generate a sine wave with the provided parameters, and populate
|
||||||
|
// data with the samples.
|
||||||
|
void SineWave(int sample_rate, float frequency, float duration_seconds,
|
||||||
|
std::vector<double>* data);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_LIB_CORE_BITS_H_
|
#ifndef TENSORFLOW_LIB_CORE_BITS_H_
|
||||||
#define TENSORFLOW_LIB_CORE_BITS_H_
|
#define TENSORFLOW_LIB_CORE_BITS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -91,6 +92,18 @@ inline int Log2Ceiling64(uint64 n) {
|
|||||||
return floor + 1;
|
return floor + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline uint32 NextPowerOfTwo(uint32 value) {
|
||||||
|
int exponent = Log2Ceiling(value);
|
||||||
|
DCHECK_LT(exponent, std::numeric_limits<uint32>::digits);
|
||||||
|
return 1 << exponent;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint64 NextPowerOfTwo64(uint64 value) {
|
||||||
|
int exponent = Log2Ceiling(value);
|
||||||
|
DCHECK_LT(exponent, std::numeric_limits<uint64>::digits);
|
||||||
|
return 1LL << exponent;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_LIB_CORE_BITS_H_
|
#endif // TENSORFLOW_LIB_CORE_BITS_H_
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/lib/core/bits.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -66,6 +67,39 @@ Status EncodeWavShapeFn(InferenceContext* c) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status SpectrogramShapeFn(InferenceContext* c) {
|
||||||
|
ShapeHandle input;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
|
||||||
|
int32 window_size;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("window_size", &window_size));
|
||||||
|
int32 stride;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("stride", &stride));
|
||||||
|
|
||||||
|
DimensionHandle input_channels = c->Dim(input, 0);
|
||||||
|
DimensionHandle input_length = c->Dim(input, 1);
|
||||||
|
|
||||||
|
DimensionHandle output_length;
|
||||||
|
if (!c->ValueKnown(input_length)) {
|
||||||
|
output_length = c->UnknownDim();
|
||||||
|
} else {
|
||||||
|
const int64 input_length_value = c->Value(input_length);
|
||||||
|
const int64 length_minus_window = (input_length_value - window_size);
|
||||||
|
int64 output_length_value;
|
||||||
|
if (length_minus_window < 0) {
|
||||||
|
output_length_value = 0;
|
||||||
|
} else {
|
||||||
|
output_length_value = 1 + (length_minus_window / stride);
|
||||||
|
}
|
||||||
|
output_length = c->MakeDim(output_length_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
DimensionHandle output_channels =
|
||||||
|
c->MakeDim(1 + NextPowerOfTwo(window_size) / 2);
|
||||||
|
c->set_output(0,
|
||||||
|
c->MakeShape({input_channels, output_length, output_channels}));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
REGISTER_OP("DecodeWav")
|
REGISTER_OP("DecodeWav")
|
||||||
@ -121,4 +155,49 @@ sample_rate: Scalar containing the sample frequency.
|
|||||||
contents: 0-D. WAV-encoded file contents.
|
contents: 0-D. WAV-encoded file contents.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("AudioSpectrogram")
|
||||||
|
.Input("input: float")
|
||||||
|
.Attr("window_size: int")
|
||||||
|
.Attr("stride: int")
|
||||||
|
.Attr("magnitude_squared: bool = false")
|
||||||
|
.Output("spectrogram: float")
|
||||||
|
.SetShapeFn(SpectrogramShapeFn)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Produces a visualization of audio data over time.
|
||||||
|
|
||||||
|
Spectrograms are a standard way of representing audio information as a series of
|
||||||
|
slices of frequency information, one slice for each window of time. By joining
|
||||||
|
these together into a sequence, they form a distinctive fingerprint of the sound
|
||||||
|
over time.
|
||||||
|
|
||||||
|
This op expects to receive audio data as an input, stored as floats in the range
|
||||||
|
-1 to 1, together with a window width in samples, and a stride specifying how
|
||||||
|
far to move the window between slices. From this it generates a three
|
||||||
|
dimensional output. The lowest dimension has an amplitude value for each
|
||||||
|
frequency during that time slice. The next dimension is time, with successive
|
||||||
|
frequency slices. The final dimension is for the channels in the input, so a
|
||||||
|
stereo audio input would have two here for example.
|
||||||
|
|
||||||
|
This means the layout when converted and saved as an image is rotated 90 degrees
|
||||||
|
clockwise from a typical spectrogram. Time is descending down the Y axis, and
|
||||||
|
the frequency decreases from left to right.
|
||||||
|
|
||||||
|
Each value in the result represents the square root of the sum of the real and
|
||||||
|
imaginary parts of an FFT on the current window of samples. In this way, the
|
||||||
|
lowest dimension represents the power of each frequency in the current window,
|
||||||
|
and adjacent windows are concatenated in the next dimension.
|
||||||
|
|
||||||
|
To get a more intuitive and visual look at what this operation does, you can run
|
||||||
|
tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the
|
||||||
|
resulting spectrogram as a PNG image.
|
||||||
|
|
||||||
|
input: Float representation of audio data.
|
||||||
|
window_size: How wide the input window is in samples. For the highest efficiency
|
||||||
|
this should be a power of two, but other values are accepted.
|
||||||
|
stride: How widely apart the center of adjacent sample windows should be.
|
||||||
|
magnitude_squared: Whether to return the squared magnitude or just the
|
||||||
|
magnitude. Using squared magnitude can avoid extra calculations.
|
||||||
|
spectrogram: 3D representation of the audio frequencies as an image.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -92,6 +92,7 @@ cc_library(
|
|||||||
"//tensorflow/core:protos_cc",
|
"//tensorflow/core:protos_cc",
|
||||||
"@com_googlesource_code_re2//:re2",
|
"@com_googlesource_code_re2//:re2",
|
||||||
"@farmhash_archive//:farmhash",
|
"@farmhash_archive//:farmhash",
|
||||||
|
"@fft2d//:fft2d",
|
||||||
"@highwayhash//:sip_hash",
|
"@highwayhash//:sip_hash",
|
||||||
"@png_archive//:png",
|
"@png_archive//:png",
|
||||||
],
|
],
|
||||||
|
@ -93,6 +93,22 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
|
||||||
|
float* dst, bool* value_parsing_ok) {
|
||||||
|
*value_parsing_ok = true;
|
||||||
|
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
|
||||||
|
char extra;
|
||||||
|
if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) {
|
||||||
|
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
|
||||||
|
<< ".";
|
||||||
|
*value_parsing_ok = false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
|
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
|
||||||
@ -116,6 +132,12 @@ Flag::Flag(const char* name, string* dst, const string& usage_text)
|
|||||||
string_value_(dst),
|
string_value_(dst),
|
||||||
usage_text_(usage_text) {}
|
usage_text_(usage_text) {}
|
||||||
|
|
||||||
|
Flag::Flag(const char* name, float* dst, const string& usage_text)
|
||||||
|
: name_(name),
|
||||||
|
type_(TYPE_FLOAT),
|
||||||
|
float_value_(dst),
|
||||||
|
usage_text_(usage_text) {}
|
||||||
|
|
||||||
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
|
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
|
||||||
bool result = false;
|
bool result = false;
|
||||||
if (type_ == TYPE_INT) {
|
if (type_ == TYPE_INT) {
|
||||||
@ -126,6 +148,8 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
|
|||||||
result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok);
|
result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok);
|
||||||
} else if (type_ == TYPE_STRING) {
|
} else if (type_ == TYPE_STRING) {
|
||||||
result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok);
|
result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok);
|
||||||
|
} else if (type_ == TYPE_FLOAT) {
|
||||||
|
result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -195,6 +219,10 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
|
|||||||
type_name = "string";
|
type_name = "string";
|
||||||
flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
|
flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
|
||||||
flag.string_value_->c_str());
|
flag.string_value_->c_str());
|
||||||
|
} else if (flag.type_ == Flag::TYPE_FLOAT) {
|
||||||
|
type_name = "float";
|
||||||
|
flag_string =
|
||||||
|
strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_);
|
||||||
}
|
}
|
||||||
strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
|
strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
|
||||||
type_name, flag.usage_text_.c_str());
|
type_name, flag.usage_text_.c_str());
|
||||||
|
@ -65,6 +65,7 @@ class Flag {
|
|||||||
Flag(const char* name, int64* dst1, const string& usage_text);
|
Flag(const char* name, int64* dst1, const string& usage_text);
|
||||||
Flag(const char* name, bool* dst, const string& usage_text);
|
Flag(const char* name, bool* dst, const string& usage_text);
|
||||||
Flag(const char* name, string* dst, const string& usage_text);
|
Flag(const char* name, string* dst, const string& usage_text);
|
||||||
|
Flag(const char* name, float* dst, const string& usage_text);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Flags;
|
friend class Flags;
|
||||||
@ -72,11 +73,12 @@ class Flag {
|
|||||||
bool Parse(string arg, bool* value_parsing_ok) const;
|
bool Parse(string arg, bool* value_parsing_ok) const;
|
||||||
|
|
||||||
string name_;
|
string name_;
|
||||||
enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING } type_;
|
enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_;
|
||||||
int* int_value_;
|
int* int_value_;
|
||||||
int64* int64_value_;
|
int64* int64_value_;
|
||||||
bool* bool_value_;
|
bool* bool_value_;
|
||||||
string* string_value_;
|
string* string_value_;
|
||||||
|
float* float_value_;
|
||||||
string usage_text_;
|
string usage_text_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -32,29 +32,35 @@ std::vector<char *> CharPointerVectorFromStrings(
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
TEST(CommandLineFlagsTest, BasicUsage) {
|
TEST(CommandLineFlagsTest, BasicUsage) {
|
||||||
int some_int = 10;
|
int some_int = 10;
|
||||||
int64 some_int64 = 21474836470; // max int32 is 2147483647
|
int64 some_int64 = 21474836470; // max int32 is 2147483647
|
||||||
bool some_switch = false;
|
bool some_switch = false;
|
||||||
string some_name = "something";
|
string some_name = "something";
|
||||||
int argc = 5;
|
float some_float = -23.23f;
|
||||||
std::vector<string> argv_strings = {
|
int argc = 6;
|
||||||
"program_name", "--some_int=20", "--some_int64=214748364700",
|
std::vector<string> argv_strings = {"program_name",
|
||||||
"--some_switch", "--some_name=somethingelse"};
|
"--some_int=20",
|
||||||
|
"--some_int64=214748364700",
|
||||||
|
"--some_switch",
|
||||||
|
"--some_name=somethingelse",
|
||||||
|
"--some_float=42.0"};
|
||||||
std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
|
std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
|
||||||
bool parsed_ok =
|
bool parsed_ok =
|
||||||
Flags::Parse(&argc, argv_array.data(),
|
Flags::Parse(&argc, argv_array.data(),
|
||||||
{Flag("some_int", &some_int, "some int"),
|
{Flag("some_int", &some_int, "some int"),
|
||||||
Flag("some_int64", &some_int64, "some int64"),
|
Flag("some_int64", &some_int64, "some int64"),
|
||||||
Flag("some_switch", &some_switch, "some switch"),
|
Flag("some_switch", &some_switch, "some switch"),
|
||||||
Flag("some_name", &some_name, "some name")});
|
Flag("some_name", &some_name, "some name"),
|
||||||
|
Flag("some_float", &some_float, "some float")});
|
||||||
EXPECT_EQ(true, parsed_ok);
|
EXPECT_EQ(true, parsed_ok);
|
||||||
EXPECT_EQ(20, some_int);
|
EXPECT_EQ(20, some_int);
|
||||||
EXPECT_EQ(214748364700, some_int64);
|
EXPECT_EQ(214748364700, some_int64);
|
||||||
EXPECT_EQ(true, some_switch);
|
EXPECT_EQ(true, some_switch);
|
||||||
EXPECT_EQ("somethingelse", some_name);
|
EXPECT_EQ("somethingelse", some_name);
|
||||||
|
EXPECT_NEAR(42.0f, some_float, 1e-5f);
|
||||||
EXPECT_EQ(argc, 1);
|
EXPECT_EQ(argc, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,6 +91,21 @@ TEST(CommandLineFlagsTest, BadBoolValue) {
|
|||||||
EXPECT_EQ(argc, 1);
|
EXPECT_EQ(argc, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CommandLineFlagsTest, BadFloatValue) {
|
||||||
|
float some_float = -23.23f;
|
||||||
|
int argc = 2;
|
||||||
|
std::vector<string> argv_strings = {"program_name",
|
||||||
|
"--some_float=notanumber"};
|
||||||
|
std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
|
||||||
|
bool parsed_ok =
|
||||||
|
Flags::Parse(&argc, argv_array.data(),
|
||||||
|
{Flag("some_float", &some_float, "some float")});
|
||||||
|
|
||||||
|
EXPECT_EQ(false, parsed_ok);
|
||||||
|
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
|
||||||
|
EXPECT_EQ(argc, 1);
|
||||||
|
}
|
||||||
|
|
||||||
// Return whether str==pat, but allowing any whitespace in pat
|
// Return whether str==pat, but allowing any whitespace in pat
|
||||||
// to match zero or more whitespace characters in str.
|
// to match zero or more whitespace characters in str.
|
||||||
static bool MatchWithAnyWhitespace(const string &str, const string &pat) {
|
static bool MatchWithAnyWhitespace(const string &str, const string &pat) {
|
||||||
@ -111,6 +132,8 @@ TEST(CommandLineFlagsTest, UsageString) {
|
|||||||
int64 some_int64 = 21474836470; // max int32 is 2147483647
|
int64 some_int64 = 21474836470; // max int32 is 2147483647
|
||||||
bool some_switch = false;
|
bool some_switch = false;
|
||||||
string some_name = "something";
|
string some_name = "something";
|
||||||
|
// Don't test float in this case, because precision is hard to predict and
|
||||||
|
// match against, and we don't want a flakey test.
|
||||||
const string tool_name = "some_tool_name";
|
const string tool_name = "some_tool_name";
|
||||||
string usage = Flags::Usage(tool_name + "<flags>",
|
string usage = Flags::Usage(tool_name + "<flags>",
|
||||||
{Flag("some_int", &some_int, "some int"),
|
{Flag("some_int", &some_int, "some int"),
|
||||||
|
68
tensorflow/examples/wav_to_spectrogram/BUILD
Normal file
68
tensorflow/examples/wav_to_spectrogram/BUILD
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# Description:
|
||||||
|
# TensorFlow C++ inference example for labeling images.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//tensorflow:internal"],
|
||||||
|
features = [
|
||||||
|
"-layering_check",
|
||||||
|
"-parse_headers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "wav_to_spectrogram_lib",
|
||||||
|
srcs = [
|
||||||
|
"wav_to_spectrogram.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"wav_to_spectrogram.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "wav_to_spectrogram",
|
||||||
|
srcs = [
|
||||||
|
"main.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":wav_to_spectrogram_lib",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "wav_to_spectrogram_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["wav_to_spectrogram_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":wav_to_spectrogram_lib",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
"bin/**",
|
||||||
|
"gen/**",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
49
tensorflow/examples/wav_to_spectrogram/README.md
Normal file
49
tensorflow/examples/wav_to_spectrogram/README.md
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# TensorFlow Spectrogram Example
|
||||||
|
|
||||||
|
This example shows how you can load audio from a .wav file, convert it to a
|
||||||
|
spectrogram, and then save it out as a PNG image. A spectrogram is a
|
||||||
|
visualization of the frequencies in sound over time, and can be useful as a
|
||||||
|
feature for neural network recognition on noise or speech.
|
||||||
|
|
||||||
|
## Building
|
||||||
|
|
||||||
|
To build it, run this command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel build tensorflow/examples/wav_to_spectrogram/...
|
||||||
|
```
|
||||||
|
|
||||||
|
That should build a binary executable that you can then run like this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel-bin/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram
|
||||||
|
```
|
||||||
|
|
||||||
|
This uses a default test audio file that's part of the TensorFlow source code,
|
||||||
|
and writes out the image to the current directory as spectrogram.png.
|
||||||
|
|
||||||
|
## Options
|
||||||
|
|
||||||
|
To load your own audio, you need to supply a .wav file in LIN16 format, and use
|
||||||
|
the `--input_audio` flag to pass in the path.
|
||||||
|
|
||||||
|
To control how the spectrogram is created, you can specify the `--window_size`
|
||||||
|
and `--stride` arguments, which control how wide the window used to estimate
|
||||||
|
frequencies is, and how widely adjacent windows are spaced.
|
||||||
|
|
||||||
|
The `--output_image` flag sets the path to save the image file to. This is
|
||||||
|
always written out in PNG format, even if you specify a different file
|
||||||
|
extension.
|
||||||
|
|
||||||
|
If your result seems too dark, try using the `--brightness` flag to make the
|
||||||
|
output image easier to see.
|
||||||
|
|
||||||
|
Here's an example of how to use all of them together:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bazel-bin/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram \
|
||||||
|
--input_wav=/tmp/my_audio.wav \
|
||||||
|
--window=1024 \
|
||||||
|
--stride=512 \
|
||||||
|
--output_image=/tmp/my_spectrogram.png
|
||||||
|
```
|
66
tensorflow/examples/wav_to_spectrogram/main.cc
Normal file
66
tensorflow/examples/wav_to_spectrogram/main.cc
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
// These are the command-line flags the program can understand.
|
||||||
|
// They define where the graph and input data is located, and what kind of
|
||||||
|
// input the model expects. If you train your own model, or use something
|
||||||
|
// other than inception_v3, then you'll need to update these.
|
||||||
|
tensorflow::string input_wav =
|
||||||
|
"tensorflow/core/kernels/spectrogram_test_data/short_test_segment.wav";
|
||||||
|
tensorflow::int32 window_size = 256;
|
||||||
|
tensorflow::int32 stride = 128;
|
||||||
|
float brightness = 64.0f;
|
||||||
|
tensorflow::string output_image = "spectrogram.png";
|
||||||
|
std::vector<tensorflow::Flag> flag_list = {
|
||||||
|
tensorflow::Flag("input_wav", &input_wav, "audio file to load"),
|
||||||
|
tensorflow::Flag("window_size", &window_size,
|
||||||
|
"frequency sample window width"),
|
||||||
|
tensorflow::Flag("stride", &stride,
|
||||||
|
"how far apart to place frequency windows"),
|
||||||
|
tensorflow::Flag("brightness", &brightness,
|
||||||
|
"controls how bright the output image is"),
|
||||||
|
tensorflow::Flag("output_image", &output_image,
|
||||||
|
"where to save the spectrogram image to"),
|
||||||
|
};
|
||||||
|
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||||
|
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||||
|
if (!parse_result) {
|
||||||
|
LOG(ERROR) << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to call this to set up global state for TensorFlow.
|
||||||
|
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||||
|
if (argc > 1) {
|
||||||
|
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::Status wav_status = WavToSpectrogram(
|
||||||
|
input_wav, window_size, stride, brightness, output_image);
|
||||||
|
if (!wav_status.ok()) {
|
||||||
|
LOG(ERROR) << "WavToSpectrogram failed with " << wav_status;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
97
tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.cc
Normal file
97
tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.cc
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/audio_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/image_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/graph/default_device.h"
|
||||||
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
|
||||||
|
using tensorflow::DT_FLOAT;
|
||||||
|
using tensorflow::DT_UINT8;
|
||||||
|
using tensorflow::Output;
|
||||||
|
using tensorflow::TensorShape;
|
||||||
|
|
||||||
|
// Runs a TensorFlow graph to convert an audio file into a visualization.
|
||||||
|
tensorflow::Status WavToSpectrogram(const tensorflow::string& input_wav,
|
||||||
|
tensorflow::int32 window_size,
|
||||||
|
tensorflow::int32 stride, float brightness,
|
||||||
|
const tensorflow::string& output_image) {
|
||||||
|
auto root = tensorflow::Scope::NewRootScope();
|
||||||
|
using namespace tensorflow::ops; // NOLINT(build/namespaces)
|
||||||
|
// The following block creates a TensorFlow graph that:
|
||||||
|
// - Reads and decodes the audio file into a tensor of float samples.
|
||||||
|
// - Creates a float spectrogram from those samples.
|
||||||
|
// - Scales, clamps, and converts that spectrogram to 0 to 255 uint8's.
|
||||||
|
// - Reshapes the tensor so that it's [height, width, 1] for imaging.
|
||||||
|
// - Encodes it as a PNG stream and saves it out to a file.
|
||||||
|
Output file_reader = ReadFile(root.WithOpName("input_wav"), input_wav);
|
||||||
|
DecodeWav wav_decoder =
|
||||||
|
DecodeWav(root.WithOpName("wav_decoder"), file_reader);
|
||||||
|
Output spectrogram = AudioSpectrogram(root.WithOpName("spectrogram"),
|
||||||
|
wav_decoder.audio, window_size, stride);
|
||||||
|
Output brightness_placeholder =
|
||||||
|
Placeholder(root.WithOpName("brightness_placeholder"), DT_FLOAT,
|
||||||
|
Placeholder::Attrs().Shape(TensorShape({})));
|
||||||
|
Output mul = Mul(root.WithOpName("mul"), spectrogram, brightness_placeholder);
|
||||||
|
Output min_const = Const(root.WithOpName("min_const"), 255.0f);
|
||||||
|
Output min = Minimum(root.WithOpName("min"), mul, min_const);
|
||||||
|
Output cast = Cast(root.WithOpName("cast"), min, DT_UINT8);
|
||||||
|
Output expand_dims_const = Const(root.WithOpName("expand_dims_const"), -1);
|
||||||
|
Output expand_dims =
|
||||||
|
ExpandDims(root.WithOpName("expand_dims"), cast, expand_dims_const);
|
||||||
|
Output squeeze = Squeeze(root.WithOpName("squeeze"), expand_dims,
|
||||||
|
Squeeze::Attrs().SqueezeDims({0}));
|
||||||
|
Output png_encoder = EncodePng(root.WithOpName("png_encoder"), squeeze);
|
||||||
|
WriteFile file_writer =
|
||||||
|
WriteFile(root.WithOpName("output_image"), output_image, png_encoder);
|
||||||
|
tensorflow::GraphDef graph;
|
||||||
|
TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
|
||||||
|
|
||||||
|
// Build a session object from this graph definition. The power of TensorFlow
|
||||||
|
// is that you can reuse complex computations like this, so usually we'd run a
|
||||||
|
// lot of different inputs through it. In this example, we're just doing a
|
||||||
|
// one-off run, so we'll create it and then use it immediately.
|
||||||
|
std::unique_ptr<tensorflow::Session> session(
|
||||||
|
tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||||
|
TF_RETURN_IF_ERROR(session->Create(graph));
|
||||||
|
|
||||||
|
// We're passing in the brightness as an input, so create a tensor to hold the
|
||||||
|
// value.
|
||||||
|
tensorflow::Tensor brightness_tensor(DT_FLOAT, TensorShape({}));
|
||||||
|
brightness_tensor.scalar<float>()() = brightness;
|
||||||
|
|
||||||
|
// Run the session to analyze the audio and write out the file.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
session->Run({{"brightness_placeholder", brightness_tensor}}, {},
|
||||||
|
{"output_image"}, nullptr));
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
31
tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h
Normal file
31
tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
// Runs a TensorFlow graph to convert an audio file into a visualization. Takes
|
||||||
|
// in the path to the audio file, the window size and stride parameters
|
||||||
|
// controlling the spectrogram creation, the brightness scaling to use, and a
|
||||||
|
// path to save the output PNG file to.
|
||||||
|
tensorflow::Status WavToSpectrogram(const tensorflow::string& input_wav,
|
||||||
|
tensorflow::int32 window_size,
|
||||||
|
tensorflow::int32 stride, float brightness,
|
||||||
|
const tensorflow::string& output_image);
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
|
@ -0,0 +1,37 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/lib/wav/wav_io.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
TEST(WavToSpectrogramTest, WavToSpectrogramTest) {
|
||||||
|
const tensorflow::string input_wav =
|
||||||
|
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "input_wav.wav");
|
||||||
|
const tensorflow::string output_image = tensorflow::io::JoinPath(
|
||||||
|
tensorflow::testing::TmpDir(), "output_image.png");
|
||||||
|
float audio[8] = {-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f};
|
||||||
|
tensorflow::string wav_string;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
tensorflow::wav::EncodeAudioAsS16LEWav(audio, 44100, 1, 8, &wav_string));
|
||||||
|
TF_ASSERT_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
|
||||||
|
input_wav, wav_string));
|
||||||
|
TF_ASSERT_OK(WavToSpectrogram(input_wav, 4, 4, 64.0f, output_image));
|
||||||
|
TF_EXPECT_OK(tensorflow::Env::Default()->FileExists(output_image));
|
||||||
|
}
|
@ -79,11 +79,13 @@ genrule(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"//third_party/hadoop:LICENSE.txt",
|
"//third_party/hadoop:LICENSE.txt",
|
||||||
"//third_party/eigen3:LICENSE",
|
"//third_party/eigen3:LICENSE",
|
||||||
|
"//third_party/fft2d:LICENSE",
|
||||||
"@boringssl//:LICENSE",
|
"@boringssl//:LICENSE",
|
||||||
"@com_googlesource_code_re2//:LICENSE",
|
"@com_googlesource_code_re2//:LICENSE",
|
||||||
"@curl//:COPYING",
|
"@curl//:COPYING",
|
||||||
"@eigen_archive//:COPYING.MPL2",
|
"@eigen_archive//:COPYING.MPL2",
|
||||||
"@farmhash_archive//:COPYING",
|
"@farmhash_archive//:COPYING",
|
||||||
|
"@fft2d//:fft/readme.txt",
|
||||||
"@gemmlowp//:LICENSE",
|
"@gemmlowp//:LICENSE",
|
||||||
"@gif_archive//:COPYING",
|
"@gif_archive//:COPYING",
|
||||||
"@highwayhash//:LICENSE",
|
"@highwayhash//:LICENSE",
|
||||||
@ -106,11 +108,13 @@ genrule(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"//third_party/hadoop:LICENSE.txt",
|
"//third_party/hadoop:LICENSE.txt",
|
||||||
"//third_party/eigen3:LICENSE",
|
"//third_party/eigen3:LICENSE",
|
||||||
|
"//third_party/fft2d:LICENSE",
|
||||||
"@boringssl//:LICENSE",
|
"@boringssl//:LICENSE",
|
||||||
"@com_googlesource_code_re2//:LICENSE",
|
"@com_googlesource_code_re2//:LICENSE",
|
||||||
"@curl//:COPYING",
|
"@curl//:COPYING",
|
||||||
"@eigen_archive//:COPYING.MPL2",
|
"@eigen_archive//:COPYING.MPL2",
|
||||||
"@farmhash_archive//:COPYING",
|
"@farmhash_archive//:COPYING",
|
||||||
|
"@fft2d//:fft/readme.txt",
|
||||||
"@gemmlowp//:LICENSE",
|
"@gemmlowp//:LICENSE",
|
||||||
"@gif_archive//:COPYING",
|
"@gif_archive//:COPYING",
|
||||||
"@highwayhash//:LICENSE",
|
"@highwayhash//:LICENSE",
|
||||||
|
@ -91,12 +91,14 @@ filegroup(
|
|||||||
name = "licenses",
|
name = "licenses",
|
||||||
data = [
|
data = [
|
||||||
"//third_party/eigen3:LICENSE",
|
"//third_party/eigen3:LICENSE",
|
||||||
|
"//third_party/fft2d:LICENSE",
|
||||||
"//third_party/hadoop:LICENSE.txt",
|
"//third_party/hadoop:LICENSE.txt",
|
||||||
"@boringssl//:LICENSE",
|
"@boringssl//:LICENSE",
|
||||||
"@com_googlesource_code_re2//:LICENSE",
|
"@com_googlesource_code_re2//:LICENSE",
|
||||||
"@curl//:COPYING",
|
"@curl//:COPYING",
|
||||||
"@eigen_archive//:COPYING.MPL2",
|
"@eigen_archive//:COPYING.MPL2",
|
||||||
"@farmhash_archive//:COPYING",
|
"@farmhash_archive//:COPYING",
|
||||||
|
"@fft2d//:fft/readme.txt",
|
||||||
"@gemmlowp//:LICENSE",
|
"@gemmlowp//:LICENSE",
|
||||||
"@gif_archive//:COPYING",
|
"@gif_archive//:COPYING",
|
||||||
"@grpc//:LICENSE",
|
"@grpc//:LICENSE",
|
||||||
|
@ -500,6 +500,16 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
|||||||
name="zlib",
|
name="zlib",
|
||||||
actual="@zlib_archive//:zlib",)
|
actual="@zlib_archive//:zlib",)
|
||||||
|
|
||||||
|
native.new_http_archive(
|
||||||
|
name = "fft2d",
|
||||||
|
urls = [
|
||||||
|
"http://bazel-mirror.storage.googleapis.com/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
|
||||||
|
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
|
||||||
|
],
|
||||||
|
sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
|
||||||
|
build_file = str(Label("//third_party/fft2d:fft2d.BUILD")),
|
||||||
|
)
|
||||||
|
|
||||||
temp_workaround_http_archive(
|
temp_workaround_http_archive(
|
||||||
name="snappy",
|
name="snappy",
|
||||||
urls=[
|
urls=[
|
||||||
|
Loading…
Reference in New Issue
Block a user