Add AudioSpectrogram op to TensorFlow for audio feature generation

Change: 152872386
This commit is contained in:
Pete Warden 2017-04-11 14:53:41 -08:00 committed by TensorFlower Gardener
parent b6d47b5e56
commit 7c9d2a458e
33 changed files with 2031 additions and 8 deletions

View File

@ -277,6 +277,7 @@ filegroup(
"//tensorflow/examples/tutorials/estimators:all_files",
"//tensorflow/examples/tutorials/mnist:all_files",
"//tensorflow/examples/tutorials/word2vec:all_files",
"//tensorflow/examples/wav_to_spectrogram:all_files",
"//tensorflow/go:all_files",
"//tensorflow/java:all_files",
"//tensorflow/java/src/main/java/org/tensorflow/examples:all_files",

View File

@ -108,6 +108,7 @@ include(eigen)
include(gemmlowp)
include(jsoncpp)
include(farmhash)
include(fft2d)
include(highwayhash)
include(protobuf)
if (tensorflow_BUILD_CC_TESTS)
@ -121,6 +122,7 @@ set(tensorflow_EXTERNAL_LIBRARIES
${jpeg_STATIC_LIBRARIES}
${jsoncpp_STATIC_LIBRARIES}
${farmhash_STATIC_LIBRARIES}
${fft2d_STATIC_LIBRARIES}
${highwayhash_STATIC_LIBRARIES}
${protobuf_STATIC_LIBRARIES}
)
@ -135,6 +137,7 @@ set(tensorflow_EXTERNAL_DEPENDENCIES
protobuf
eigen
gemmlowp
fft2d
)
include_directories(

View File

@ -0,0 +1,52 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
include (ExternalProject)
set(fft2d_URL http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz)
set(fft2d_HASH SHA256=52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296)
set(fft2d_BUILD ${CMAKE_CURRENT_BINARY_DIR}/fft2d/)
set(fft2d_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/fft2d/src)
if(WIN32)
set(fft2d_STATIC_LIBRARIES ${fft2d_BUILD}/src/lib/fft2d.lib)
ExternalProject_Add(fft2d
PREFIX fft2d
URL ${fft2d_URL}
URL_HASH ${fft2d_HASH}
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
BUILD_IN_SOURCE 1
PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/fft2d/CMakeLists.txt ${fft2d_BUILD}/src/fft2d/CMakeLists.txt
INSTALL_DIR ${fft2d_INSTALL}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=Release
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
-DCMAKE_INSTALL_PREFIX:STRING=${fft2d_INSTALL})
else()
set(fft2d_STATIC_LIBRARIES ${fft2d_BUILD}/src/fft2d/libfft2d.a)
ExternalProject_Add(fft2d
PREFIX fft2d
URL ${fft2d_URL}
URL_HASH ${fft2d_HASH}
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
BUILD_IN_SOURCE 1
PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/fft2d/CMakeLists.txt ${fft2d_BUILD}/src/fft2d/CMakeLists.txt
INSTALL_DIR $(fft2d_INSTALL)
INSTALL_COMMAND echo
BUILD_COMMAND $(MAKE))
endif()

View File

@ -0,0 +1,17 @@
cmake_minimum_required(VERSION 2.8.3)
project(fft2d)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(FFT2D_SRCS
"fftsg.c"
)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
add_library(fft2d ${FFT2D_SRCS})
install(TARGETS fft2d
LIBRARY DESTINATION lib COMPONENT RuntimeLibraries
ARCHIVE DESTINATION lib COMPONENT Development)

View File

@ -496,7 +496,6 @@ cc_library(
tf_gen_op_libs(
op_lib_names = [
"array_ops",
"audio_ops",
"candidate_sampling_ops",
"control_flow_ops",
"ctc_ops",
@ -526,6 +525,13 @@ tf_gen_op_libs(
],
)
tf_gen_op_libs(
op_lib_names = [
"audio_ops",
],
deps = [":lib"],
)
cc_library(
name = "debug_ops_op_lib",
srcs = ["ops/debug_ops.cc"],
@ -688,6 +694,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:audio",
"//tensorflow/core/kernels:bincount_op",
"//tensorflow/core/kernels:candidate_sampler_ops",
"//tensorflow/core/kernels:control_flow_ops",

View File

@ -3559,6 +3559,117 @@ tf_kernel_library(
],
)
filegroup(
name = "spectrogram_test_data",
srcs = [
"spectrogram_test_data/short_test_segment.wav",
"spectrogram_test_data/short_test_segment_spectrogram.csv.bin",
"spectrogram_test_data/short_test_segment_spectrogram_400_200.csv.bin",
],
visibility = ["//visibility:public"],
)
cc_library(
name = "spectrogram",
srcs = ["spectrogram.cc"],
hdrs = ["spectrogram.h"],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/fft2d:fft2d_headers",
"@fft2d//:fft2d",
],
)
cc_library(
name = "spectrogram_test_utils",
testonly = 1,
srcs = ["spectrogram_test_utils.cc"],
hdrs = ["spectrogram_test_utils.h"],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
],
)
cc_binary(
name = "spectrogram_convert_test_data",
testonly = 1,
srcs = ["spectrogram_convert_test_data.cc"],
deps = [
":spectrogram_test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "spectrogram_test",
size = "medium",
srcs = ["spectrogram_test.cc"],
data = [":spectrogram_test_data"],
deps = [
":spectrogram",
":spectrogram_test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:lib_test_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "spectrogram_op",
prefix = "spectrogram_op",
deps = [
":spectrogram",
"//tensorflow/core:audio_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
alwayslink = 1,
)
tf_cuda_cc_test(
name = "spectrogram_op_test",
size = "small",
srcs = ["spectrogram_op_test.cc"],
deps = [
":ops_util",
":spectrogram_op",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:client_session",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "audio",
deps = [
":decode_wav_op",
":encode_wav_op",
":spectrogram_op",
],
)
# Android libraries -----------------------------------------------------------
# Changes to the Android srcs here should be replicated in
@ -3962,6 +4073,7 @@ filegroup(
"whole_file_read_ops.*",
"sample_distorted_bounding_box_op.*",
"ctc_loss_op.*",
"spectrogram_convert_test_data.cc",
# Excluded due to experimental status:
"debug_ops.*",
"scatter_nd_op*",

View File

@ -0,0 +1,212 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/spectrogram.h"
#include <math.h>
#include "third_party/fft2d/fft.h"
#include "tensorflow/core/lib/core/bits.h"
namespace tensorflow {
using std::complex;
namespace {
// Returns the default Hann window function for the spectrogram.
void GetPeriodicHann(int window_length, std::vector<double>* window) {
// Some platforms don't have M_PI, so define a local constant here.
const double pi = std::atan(1) * 4;
window->resize(window_length);
for (int i = 0; i < window_length; ++i) {
(*window)[i] = 0.5 - 0.5 * cos((2 * pi * i) / window_length);
}
}
} // namespace
bool Spectrogram::Initialize(int window_length, int step_length) {
std::vector<double> window;
GetPeriodicHann(window_length, &window);
return Initialize(window, step_length);
}
bool Spectrogram::Initialize(const std::vector<double>& window,
int step_length) {
window_length_ = window.size();
window_ = window; // Copy window.
if (window_length_ < 2) {
LOG(ERROR) << "Window length too short.";
initialized_ = false;
return false;
}
step_length_ = step_length;
if (step_length_ < 1) {
LOG(ERROR) << "Step length must be positive.";
initialized_ = false;
return false;
}
fft_length_ = NextPowerOfTwo(window_length_);
CHECK(fft_length_ >= window_length_);
output_frequency_channels_ = 1 + fft_length_ / 2;
// Allocate 2 more than what rdft needs, so we can rationalize the layout.
fft_input_output_.assign(fft_length_ + 2, 0.0);
int half_fft_length = fft_length_ / 2;
fft_double_working_area_.assign(half_fft_length, 0.0);
fft_integer_working_area_.assign(2 + static_cast<int>(sqrt(half_fft_length)),
0);
// Set flag element to ensure that the working areas are initialized
// on the first call to cdft. It's redundant given the assign above,
// but keep it as a reminder.
fft_integer_working_area_[0] = 0;
input_queue_.clear();
samples_to_next_step_ = window_length_;
initialized_ = true;
return true;
}
template <class InputSample, class OutputSample>
bool Spectrogram::ComputeComplexSpectrogram(
const std::vector<InputSample>& input,
std::vector<std::vector<complex<OutputSample>>>* output) {
if (!initialized_) {
LOG(ERROR) << "ComputeComplexSpectrogram() called before successful call "
<< "to Initialize().";
return false;
}
CHECK(output);
output->clear();
int input_start = 0;
while (GetNextWindowOfSamples(input, &input_start)) {
DCHECK_EQ(input_queue_.size(), window_length_);
ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_.
// Add a new slice vector onto the output, to save new result to.
output->resize(output->size() + 1);
// Get a reference to the newly added slice to fill in.
auto& spectrogram_slice = output->back();
spectrogram_slice.resize(output_frequency_channels_);
for (int i = 0; i < output_frequency_channels_; ++i) {
// This will convert double to float if it needs to.
spectrogram_slice[i] = complex<OutputSample>(
fft_input_output_[2 * i], fft_input_output_[2 * i + 1]);
}
}
return true;
}
// Instantiate it four ways:
template bool Spectrogram::ComputeComplexSpectrogram(
const std::vector<float>& input, std::vector<std::vector<complex<float>>>*);
template bool Spectrogram::ComputeComplexSpectrogram(
const std::vector<double>& input,
std::vector<std::vector<complex<float>>>*);
template bool Spectrogram::ComputeComplexSpectrogram(
const std::vector<float>& input,
std::vector<std::vector<complex<double>>>*);
template bool Spectrogram::ComputeComplexSpectrogram(
const std::vector<double>& input,
std::vector<std::vector<complex<double>>>*);
template <class InputSample, class OutputSample>
bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
const std::vector<InputSample>& input,
std::vector<std::vector<OutputSample>>* output) {
if (!initialized_) {
LOG(ERROR) << "ComputeSquaredMagnitudeSpectrogram() called before "
<< "successful call to Initialize().";
return false;
}
CHECK(output);
output->clear();
int input_start = 0;
while (GetNextWindowOfSamples(input, &input_start)) {
DCHECK_EQ(input_queue_.size(), window_length_);
ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_.
// Add a new slice vector onto the output, to save new result to.
output->resize(output->size() + 1);
// Get a reference to the newly added slice to fill in.
auto& spectrogram_slice = output->back();
spectrogram_slice.resize(output_frequency_channels_);
for (int i = 0; i < output_frequency_channels_; ++i) {
// Similar to the Complex case, except storing the norm.
// But the norm function is known to be a performance killer,
// so do it this way with explicit real and imagninary temps.
const double re = fft_input_output_[2 * i];
const double im = fft_input_output_[2 * i + 1];
// Which finally converts double to float if it needs to.
spectrogram_slice[i] = re * re + im * im;
}
}
return true;
}
// Instantiate it four ways:
template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
const std::vector<float>& input, std::vector<std::vector<float>>*);
template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
const std::vector<double>& input, std::vector<std::vector<float>>*);
template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
const std::vector<float>& input, std::vector<std::vector<double>>*);
template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
const std::vector<double>& input, std::vector<std::vector<double>>*);
// Return true if a full window of samples is prepared; manage the queue.
template <class InputSample>
bool Spectrogram::GetNextWindowOfSamples(const std::vector<InputSample>& input,
int* input_start) {
auto input_it = input.begin() + *input_start;
int input_remaining = input.end() - input_it;
if (samples_to_next_step_ > input_remaining) {
// Copy in as many samples are left and return false, no full window.
input_queue_.insert(input_queue_.end(), input_it, input.end());
*input_start += input_remaining; // Increases it to input.size().
samples_to_next_step_ -= input_remaining;
return false; // Not enough for a full window.
} else {
// Copy just enough into queue to make a new window, then trim the
// front off the queue to make it window-sized.
input_queue_.insert(input_queue_.end(), input_it,
input_it + samples_to_next_step_);
*input_start += samples_to_next_step_;
input_queue_.erase(
input_queue_.begin(),
input_queue_.begin() + input_queue_.size() - window_length_);
DCHECK_EQ(window_length_, input_queue_.size());
samples_to_next_step_ = step_length_; // Be ready for next time.
return true; // Yes, input_queue_ now contains exactly a window-full.
}
}
void Spectrogram::ProcessCoreFFT() {
for (int j = 0; j < window_length_; ++j) {
fft_input_output_[j] = input_queue_[j] * window_[j];
}
// Zero-pad the rest of the input buffer.
for (int j = window_length_; j < fft_length_; ++j) {
fft_input_output_[j] = 0.0;
}
const int kForwardFFT = 1; // 1 means forward; -1 reverse.
// This real FFT is a fair amount faster than using cdft here.
rdft(fft_length_, kForwardFFT, &fft_input_output_[0],
&fft_integer_working_area_[0], &fft_double_working_area_[0]);
// Make rdft result look like cdft result;
// unpack the last real value from the first position's imag slot.
fft_input_output_[fft_length_] = fft_input_output_[1];
fft_input_output_[fft_length_ + 1] = 0;
fft_input_output_[1] = 0;
}
} // namespace tensorflow

View File

@ -0,0 +1,112 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Class for generating spectrogram slices from a waveform.
// Initialize() should be called before calls to other functions. Once
// Initialize() has been called and returned true, The Compute*() functions can
// be called repeatedly with sequential input data (ie. the first element of the
// next input vector directly follows the last element of the previous input
// vector). Whenever enough audio samples are buffered to produce a
// new frame, it will be placed in output. Output is cleared on each
// call to Compute*(). This class is thread-unsafe, and should only be
// called from one thread at a time.
// With the default parameters, the output of this class should be very
// close to the results of the following MATLAB code:
// overlap_samples = window_length_samples - step_samples;
// window = hann(window_length_samples, 'periodic');
// S = abs(spectrogram(audio, window, overlap_samples)).^2;
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_
#include <complex>
#include <deque>
#include <vector>
#include "third_party/fft2d/fft.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
class Spectrogram {
public:
Spectrogram() : initialized_(false) {}
~Spectrogram() {}
// Initializes the class with a given window length and step length
// (both in samples). Internally a Hann window is used as the window
// function. Returns true on success, after which calls to Process()
// are possible. window_length must be greater than 1 and step
// length must be greater than 0.
bool Initialize(int window_length, int step_length);
// Initialize with an explicit window instead of a length.
bool Initialize(const std::vector<double>& window, int step_length);
// Processes an arbitrary amount of audio data (contained in input)
// to yield complex spectrogram frames. After a successful call to
// Initialize(), Process() may be called repeatedly with new input data
// each time. The audio input is buffered internally, and the output
// vector is populated with as many temporally-ordered spectral slices
// as it is possible to generate from the input. The output is cleared
// on each call before the new frames (if any) are added.
//
// The template parameters can be float or double.
template <class InputSample, class OutputSample>
bool ComputeComplexSpectrogram(
const std::vector<InputSample>& input,
std::vector<std::vector<std::complex<OutputSample>>>* output);
// This function works as the one above, but returns the power
// (the L2 norm, or the squared magnitude) of each complex value.
template <class InputSample, class OutputSample>
bool ComputeSquaredMagnitudeSpectrogram(
const std::vector<InputSample>& input,
std::vector<std::vector<OutputSample>>* output);
// Return reference to the window function used internally.
const std::vector<double>& GetWindow() const { return window_; }
// Return the number of frequency channels in the spectrogram.
int output_frequency_channels() const { return output_frequency_channels_; }
private:
template <class InputSample>
bool GetNextWindowOfSamples(const std::vector<InputSample>& input,
int* input_start);
void ProcessCoreFFT();
int fft_length_;
int output_frequency_channels_;
int window_length_;
int step_length_;
bool initialized_;
int samples_to_next_step_;
std::vector<double> window_;
std::vector<double> fft_input_output_;
std::deque<double> input_queue_;
// Working data areas for the FFT routines.
std::vector<int> fft_integer_working_area_;
std::vector<double> fft_double_working_area_;
TF_DISALLOW_COPY_AND_ASSIGN(Spectrogram);
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_

View File

@ -0,0 +1,56 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/spectrogram_test_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace wav {
// This takes a CSV file representing an array of complex numbers, and saves out
// a version using a binary format to save space in the repository.
Status ConvertCsvToRaw(const string& input_filename) {
std::vector<std::vector<std::complex<double>>> input_data;
ReadCSVFileToComplexVectorOrDie(input_filename, &input_data);
const string output_filename = input_filename + ".bin";
if (!WriteComplexVectorToRawFloatFile(output_filename, input_data)) {
return errors::InvalidArgument("Failed to write raw float file ",
input_filename);
}
LOG(INFO) << "Wrote raw file to " << output_filename;
return Status::OK();
}
} // namespace wav
} // namespace tensorflow
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc < 2) {
LOG(ERROR) << "You must supply a CSV file as the first argument";
return 1;
}
tensorflow::string filename(argv[1]);
tensorflow::Status status = tensorflow::wav::ConvertCsvToRaw(filename);
if (!status.ok()) {
LOG(ERROR) << "Error processing '" << filename << "':" << status;
return 1;
}
return 0;
}

View File

@ -0,0 +1,120 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/audio_ops.cc
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/spectrogram.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Create a spectrogram frequency visualization from audio data.
class SpectrogramOp : public OpKernel {
public:
explicit SpectrogramOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("window_size", &window_size_));
OP_REQUIRES_OK(context, context->GetAttr("stride", &stride_));
OP_REQUIRES_OK(context,
context->GetAttr("magnitude_squared", &magnitude_squared_));
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
OP_REQUIRES(context, input.dims() == 2,
errors::InvalidArgument("input must be 2-dimensional",
input.shape().DebugString()));
Spectrogram spectrogram;
OP_REQUIRES(context, spectrogram.Initialize(window_size_, stride_),
errors::InvalidArgument(
"Spectrogram initialization failed for window size ",
window_size_, " and stride ", stride_));
const auto input_as_matrix = input.matrix<float>();
const int64 sample_count = input.dim_size(0);
const int64 channel_count = input.dim_size(1);
const int64 output_width = spectrogram.output_frequency_channels();
const int64 length_minus_window = (sample_count - window_size_);
int64 output_height;
if (length_minus_window < 0) {
output_height = 0;
} else {
output_height = 1 + (length_minus_window / stride_);
}
const int64 output_slices = channel_count;
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(
context,
context->allocate_output(
0, TensorShape({output_slices, output_height, output_width}),
&output_tensor));
auto output_flat = output_tensor->flat<float>().data();
std::vector<float> input_for_channel(sample_count);
for (int64 channel = 0; channel < channel_count; ++channel) {
float* output_slice =
output_flat + (channel * output_height * output_width);
for (int i = 0; i < sample_count; ++i) {
input_for_channel[i] = input_as_matrix(i, channel);
}
std::vector<std::vector<float>> spectrogram_output;
OP_REQUIRES(context,
spectrogram.ComputeSquaredMagnitudeSpectrogram(
input_for_channel, &spectrogram_output),
errors::InvalidArgument("Spectrogram compute failed"));
OP_REQUIRES(context, (spectrogram_output.size() == output_height),
errors::InvalidArgument(
"Spectrogram size calculation failed: Expected height ",
output_height, " but got ", spectrogram_output.size()));
OP_REQUIRES(context,
spectrogram_output.empty() ||
(spectrogram_output[0].size() == output_width),
errors::InvalidArgument(
"Spectrogram size calculation failed: Expected width ",
output_width, " but got ", spectrogram_output[0].size()));
for (int row_index = 0; row_index < output_height; ++row_index) {
const std::vector<float>& spectrogram_row =
spectrogram_output[row_index];
DCHECK_EQ(spectrogram_row.size(), output_width);
float* output_row = output_slice + (row_index * output_width);
if (magnitude_squared_) {
for (int i = 0; i < output_width; ++i) {
output_row[i] = spectrogram_row[i];
}
} else {
for (int i = 0; i < output_width; ++i) {
output_row[i] = sqrtf(spectrogram_row[i]);
}
}
}
}
}
private:
int32 window_size_;
int32 stride_;
bool magnitude_squared_;
};
REGISTER_KERNEL_BUILDER(Name("AudioSpectrogram").Device(DEVICE_CPU),
SpectrogramOp);
} // namespace tensorflow

View File

@ -0,0 +1,104 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include <functional>
#include <memory>
#include <vector>
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/audio_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
TEST(SpectrogramOpTest, SimpleTest) {
Scope root = Scope::NewRootScope();
Tensor audio_tensor(DT_FLOAT, TensorShape({8, 1}));
test::FillValues<float>(&audio_tensor,
{-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f});
Output audio_const_op = Const(root.WithOpName("audio_const_op"),
Input::Initializer(audio_tensor));
AudioSpectrogram spectrogram_op =
AudioSpectrogram(root.WithOpName("spectrogram_op"), audio_const_op, 8, 1);
TF_ASSERT_OK(root.status());
ClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(ClientSession::FeedType(),
{spectrogram_op.spectrogram}, &outputs));
const Tensor& spectrogram_tensor = outputs[0];
EXPECT_EQ(3, spectrogram_tensor.dims());
EXPECT_EQ(5, spectrogram_tensor.dim_size(2));
EXPECT_EQ(1, spectrogram_tensor.dim_size(1));
EXPECT_EQ(1, spectrogram_tensor.dim_size(0));
test::ExpectTensorNear<float>(
spectrogram_tensor,
test::AsTensor<float>({0, 1, 2, 1, 0}, TensorShape({1, 1, 5})), 1e-3);
}
TEST(SpectrogramOpTest, SquaredTest) {
Scope root = Scope::NewRootScope();
Tensor audio_tensor(DT_FLOAT, TensorShape({8, 1}));
test::FillValues<float>(&audio_tensor,
{-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f});
Output audio_const_op = Const(root.WithOpName("audio_const_op"),
Input::Initializer(audio_tensor));
AudioSpectrogram spectrogram_op =
AudioSpectrogram(root.WithOpName("spectrogram_op"), audio_const_op, 8, 1,
AudioSpectrogram::Attrs().MagnitudeSquared(true));
TF_ASSERT_OK(root.status());
ClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run(ClientSession::FeedType(),
{spectrogram_op.spectrogram}, &outputs));
const Tensor& spectrogram_tensor = outputs[0];
EXPECT_EQ(3, spectrogram_tensor.dims());
EXPECT_EQ(5, spectrogram_tensor.dim_size(2));
EXPECT_EQ(1, spectrogram_tensor.dim_size(1));
EXPECT_EQ(1, spectrogram_tensor.dim_size(0));
test::ExpectTensorNear<float>(
spectrogram_tensor,
test::AsTensor<float>({0, 1, 4, 1, 0}, TensorShape({1, 1, 5})), 1e-3);
}
} // namespace tensorflow

View File

@ -0,0 +1,340 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// The MATLAB test data were generated using GenerateTestData.m.
#include "tensorflow/core/kernels/spectrogram.h"
#include <complex>
#include <vector>
#include "tensorflow/core/kernels/spectrogram_test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
using ::std::complex;
const char kInputFilename[] =
"core/kernels/spectrogram_test_data/short_test_segment.wav";
const char kExpectedFilename[] =
"core/kernels/spectrogram_test_data/short_test_segment_spectrogram.csv.bin";
const int kDataVectorLength = 257;
const int kNumberOfFramesInTestData = 178;
const char kExpectedNonPowerOfTwoFilename[] =
"core/kernels/spectrogram_test_data/"
"short_test_segment_spectrogram_400_200.csv.bin";
const int kNonPowerOfTwoDataVectorLength = 257;
const int kNumberOfFramesInNonPowerOfTwoTestData = 228;
TEST(SpectrogramTest, TooLittleDataYieldsNoFrames) {
Spectrogram sgram;
sgram.Initialize(400, 200);
std::vector<double> input;
// Generate 44 samples of audio.
SineWave(44100, 1000.0, 0.001, &input);
EXPECT_EQ(44, input.size());
std::vector<std::vector<complex<double>>> output;
sgram.ComputeComplexSpectrogram(input, &output);
EXPECT_EQ(0, output.size());
}
TEST(SpectrogramTest, StepSizeSmallerThanWindow) {
Spectrogram sgram;
EXPECT_TRUE(sgram.Initialize(400, 200));
std::vector<double> input;
// Generate 661 samples of audio.
SineWave(44100, 1000.0, 0.015, &input);
EXPECT_EQ(661, input.size());
std::vector<std::vector<complex<double>>> output;
sgram.ComputeComplexSpectrogram(input, &output);
EXPECT_EQ(2, output.size());
}
TEST(SpectrogramTest, StepSizeBiggerThanWindow) {
Spectrogram sgram;
EXPECT_TRUE(sgram.Initialize(200, 400));
std::vector<double> input;
// Generate 882 samples of audio.
SineWave(44100, 1000.0, 0.02, &input);
EXPECT_EQ(882, input.size());
std::vector<std::vector<complex<double>>> output;
sgram.ComputeComplexSpectrogram(input, &output);
EXPECT_EQ(2, output.size());
}
TEST(SpectrogramTest, StepSizeBiggerThanWindow2) {
Spectrogram sgram;
EXPECT_TRUE(sgram.Initialize(200, 400));
std::vector<double> input;
// Generate more than 600 but fewer than 800 samples of audio.
SineWave(44100, 1000.0, 0.016, &input);
EXPECT_GT(input.size(), 600);
EXPECT_LT(input.size(), 800);
std::vector<std::vector<complex<double>>> output;
sgram.ComputeComplexSpectrogram(input, &output);
EXPECT_EQ(2, output.size());
}
TEST(SpectrogramTest,
MultipleCallsToComputeComplexSpectrogramMayYieldDifferentNumbersOfFrames) {
// Repeatedly pass inputs with "extra" samples beyond complete windows
// and check that the excess points cumulate to eventually cause an
// extra output frame.
Spectrogram sgram;
sgram.Initialize(200, 400);
std::vector<double> input;
// Generate 882 samples of audio.
SineWave(44100, 1000.0, 0.02, &input);
EXPECT_EQ(882, input.size());
std::vector<std::vector<complex<double>>> output;
const std::vector<int> expected_output_sizes = {
2, // One pass of input leaves 82 samples buffered after two steps of
// 400.
2, // Passing in 882 samples again will now leave 164 samples buffered.
3, // Third time gives 246 extra samples, triggering an extra output
// frame.
};
for (int expected_output_size : expected_output_sizes) {
sgram.ComputeComplexSpectrogram(input, &output);
EXPECT_EQ(expected_output_size, output.size());
}
}
TEST(SpectrogramTest, CumulatingExcessInputsForOverlappingFrames) {
// Input frames that don't fit into whole windows are cumulated even when
// the windows have overlap (similar to
// MultipleCallsToComputeComplexSpectrogramMayYieldDifferentNumbersOfFrames
// but with window size/hop size swapped).
Spectrogram sgram;
sgram.Initialize(400, 200);
std::vector<double> input;
// Generate 882 samples of audio.
SineWave(44100, 1000.0, 0.02, &input);
EXPECT_EQ(882, input.size());
std::vector<std::vector<complex<double>>> output;
const std::vector<int> expected_output_sizes = {
3, // Windows 0..400, 200..600, 400..800 with 82 samples buffered.
4, // 1764 frames input; outputs from 600, 800, 1000, 1200..1600.
5, // 2646 frames in; outputs from 1400, 1600, 1800, 2000, 2200..2600.
};
for (int expected_output_size : expected_output_sizes) {
sgram.ComputeComplexSpectrogram(input, &output);
EXPECT_EQ(expected_output_size, output.size());
}
}
TEST(SpectrogramTest, StepSizeEqualToWindowWorks) {
Spectrogram sgram;
sgram.Initialize(200, 200);
std::vector<double> input;
// Generate 2205 samples of audio.
SineWave(44100, 1000.0, 0.05, &input);
EXPECT_EQ(2205, input.size());
std::vector<std::vector<complex<double>>> output;
sgram.ComputeComplexSpectrogram(input, &output);
EXPECT_EQ(11, output.size());
}
template <class ExpectedSample, class ActualSample>
void CompareComplexData(
const std::vector<std::vector<complex<ExpectedSample>>>& expected,
const std::vector<std::vector<complex<ActualSample>>>& actual,
double tolerance) {
ASSERT_EQ(actual.size(), expected.size());
for (int i = 0; i < expected.size(); ++i) {
ASSERT_EQ(expected[i].size(), actual[i].size());
for (int j = 0; j < expected[i].size(); ++j) {
ASSERT_NEAR(real(expected[i][j]), real(actual[i][j]), tolerance)
<< ": where i=" << i << " and j=" << j << ".";
ASSERT_NEAR(imag(expected[i][j]), imag(actual[i][j]), tolerance)
<< ": where i=" << i << " and j=" << j << ".";
}
}
}
template <class Sample>
double GetMaximumAbsolute(const std::vector<std::vector<Sample>>& spectrogram) {
double max_absolute = 0.0;
for (int i = 0; i < spectrogram.size(); ++i) {
for (int j = 0; j < spectrogram[i].size(); ++j) {
double absolute_value = std::abs(spectrogram[i][j]);
if (absolute_value > max_absolute) {
max_absolute = absolute_value;
}
}
}
return max_absolute;
}
template <class ExpectedSample, class ActualSample>
void CompareMagnitudeData(
const std::vector<std::vector<complex<ExpectedSample>>>&
expected_complex_output,
const std::vector<std::vector<ActualSample>>& actual_squared_magnitude,
double tolerance) {
ASSERT_EQ(actual_squared_magnitude.size(), expected_complex_output.size());
for (int i = 0; i < expected_complex_output.size(); ++i) {
ASSERT_EQ(expected_complex_output[i].size(),
actual_squared_magnitude[i].size());
for (int j = 0; j < expected_complex_output[i].size(); ++j) {
ASSERT_NEAR(norm(expected_complex_output[i][j]),
actual_squared_magnitude[i][j], tolerance)
<< ": where i=" << i << " and j=" << j << ".";
}
}
}
TEST(SpectrogramTest, ReInitializationWorks) {
Spectrogram sgram;
sgram.Initialize(512, 256);
std::vector<double> input;
CHECK(ReadWaveFileToVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
&input));
std::vector<std::vector<complex<double>>> first_output;
std::vector<std::vector<complex<double>>> second_output;
sgram.Initialize(512, 256);
sgram.ComputeComplexSpectrogram(input, &first_output);
// Re-Initialize it.
sgram.Initialize(512, 256);
sgram.ComputeComplexSpectrogram(input, &second_output);
// Verify identical outputs.
ASSERT_EQ(first_output.size(), second_output.size());
int slice_size = first_output[0].size();
for (int i = 0; i < first_output.size(); ++i) {
ASSERT_EQ(slice_size, first_output[i].size());
ASSERT_EQ(slice_size, second_output[i].size());
for (int j = 0; j < slice_size; ++j) {
ASSERT_EQ(first_output[i][j], second_output[i][j]);
}
}
}
TEST(SpectrogramTest, ComputedComplexDataAgreeWithMatlab) {
const int kInputDataLength = 45870;
Spectrogram sgram;
sgram.Initialize(512, 256);
std::vector<double> input;
CHECK(ReadWaveFileToVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
&input));
EXPECT_EQ(kInputDataLength, input.size());
std::vector<std::vector<complex<double>>> expected_output;
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename),
kDataVectorLength, &expected_output));
EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size());
EXPECT_EQ(kDataVectorLength, expected_output[0].size());
std::vector<std::vector<complex<double>>> output;
sgram.ComputeComplexSpectrogram(input, &output);
CompareComplexData(expected_output, output, 1e-5);
}
TEST(SpectrogramTest, ComputedFloatComplexDataAgreeWithMatlab) {
const int kInputDataLength = 45870;
Spectrogram sgram;
sgram.Initialize(512, 256);
std::vector<double> double_input;
CHECK(ReadWaveFileToVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
&double_input));
std::vector<float> input;
input.assign(double_input.begin(), double_input.end());
EXPECT_EQ(kInputDataLength, input.size());
std::vector<std::vector<complex<double>>> expected_output;
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename),
kDataVectorLength, &expected_output));
EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size());
EXPECT_EQ(kDataVectorLength, expected_output[0].size());
std::vector<std::vector<complex<float>>> output;
sgram.ComputeComplexSpectrogram(input, &output);
CompareComplexData(expected_output, output, 1e-4);
}
TEST(SpectrogramTest, ComputedSquaredMagnitudeDataAgreeWithMatlab) {
const int kInputDataLength = 45870;
Spectrogram sgram;
sgram.Initialize(512, 256);
std::vector<double> input;
CHECK(ReadWaveFileToVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
&input));
EXPECT_EQ(kInputDataLength, input.size());
std::vector<std::vector<complex<double>>> expected_output;
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename),
kDataVectorLength, &expected_output));
EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size());
EXPECT_EQ(kDataVectorLength, expected_output[0].size());
std::vector<std::vector<double>> output;
sgram.ComputeSquaredMagnitudeSpectrogram(input, &output);
CompareMagnitudeData(expected_output, output, 1e-3);
}
TEST(SpectrogramTest, ComputedFloatSquaredMagnitudeDataAgreeWithMatlab) {
const int kInputDataLength = 45870;
Spectrogram sgram;
sgram.Initialize(512, 256);
std::vector<double> double_input;
CHECK(ReadWaveFileToVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
&double_input));
EXPECT_EQ(kInputDataLength, double_input.size());
std::vector<float> input;
input.assign(double_input.begin(), double_input.end());
std::vector<std::vector<complex<double>>> expected_output;
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename),
kDataVectorLength, &expected_output));
EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size());
EXPECT_EQ(kDataVectorLength, expected_output[0].size());
std::vector<std::vector<float>> output;
sgram.ComputeSquaredMagnitudeSpectrogram(input, &output);
double max_absolute = GetMaximumAbsolute(output);
EXPECT_GT(max_absolute, 2300.0); // Verify that we have some big numbers.
// Squaring increases dynamic range; max square is about 2300,
// so 2e-4 is about 7 decimal digits; not bad for a float.
CompareMagnitudeData(expected_output, output, 2e-4);
}
TEST(SpectrogramTest, ComputedNonPowerOfTwoComplexDataAgreeWithMatlab) {
const int kInputDataLength = 45870;
Spectrogram sgram;
sgram.Initialize(400, 200);
std::vector<double> input;
CHECK(ReadWaveFileToVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename),
&input));
EXPECT_EQ(kInputDataLength, input.size());
std::vector<std::vector<complex<double>>> expected_output;
ASSERT_TRUE(ReadRawFloatFileToComplexVector(
tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(),
kExpectedNonPowerOfTwoFilename),
kNonPowerOfTwoDataVectorLength, &expected_output));
EXPECT_EQ(kNumberOfFramesInNonPowerOfTwoTestData, expected_output.size());
EXPECT_EQ(kNonPowerOfTwoDataVectorLength, expected_output[0].size());
std::vector<std::vector<complex<double>>> output;
sgram.ComputeComplexSpectrogram(input, &output);
CompareComplexData(expected_output, output, 1e-5);
}
} // namespace tensorflow

View File

@ -0,0 +1,8 @@
The CSV spectrogram files in this directory are generated from the
matlab code in ./matlab/GenerateTestData.m
To save space in the repo, you'll then need to convert them into a binary packed
format using the convert_test_data.cc command line tool.
short_test_segment.wav is approximately 1s of music audio.

View File

@ -0,0 +1,288 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/spectrogram_test_utils.h"
#include <math.h>
#include <stddef.h>
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/wav/wav_io.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
bool ReadWaveFileToVector(const string& file_name, std::vector<double>* data) {
string wav_data;
if (!ReadFileToString(Env::Default(), file_name, &wav_data).ok()) {
LOG(ERROR) << "Wave file read failed for " << file_name;
return false;
}
std::vector<float> decoded_data;
uint32 decoded_sample_count;
uint16 decoded_channel_count;
uint32 decoded_sample_rate;
if (!wav::DecodeLin16WaveAsFloatVector(
wav_data, &decoded_data, &decoded_sample_count,
&decoded_channel_count, &decoded_sample_rate)
.ok()) {
return false;
}
// Convert from float to double for the output value.
data->resize(decoded_data.size());
for (int i = 0; i < decoded_data.size(); ++i) {
(*data)[i] = decoded_data[i];
}
return true;
}
bool ReadRawFloatFileToComplexVector(
const string& file_name, int row_length,
std::vector<std::vector<std::complex<double> > >* data) {
data->clear();
string data_string;
if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
LOG(ERROR) << "Failed to open file " << file_name;
return false;
}
float real_out;
float imag_out;
const int kBytesPerValue = 4;
CHECK_EQ(sizeof(real_out), kBytesPerValue);
std::vector<std::complex<double> > data_row;
int row_counter = 0;
int offset = 0;
const int end = data_string.size();
while (offset < end) {
memcpy(&real_out, data_string.data() + offset, kBytesPerValue);
offset += kBytesPerValue;
memcpy(&imag_out, data_string.data() + offset, kBytesPerValue);
offset += kBytesPerValue;
if (row_counter >= row_length) {
data->push_back(data_row);
data_row.clear();
row_counter = 0;
}
data_row.push_back(std::complex<double>(real_out, imag_out));
++row_counter;
}
if (row_counter >= row_length) {
data->push_back(data_row);
}
return true;
}
void ReadCSVFileToComplexVectorOrDie(
const string& file_name,
std::vector<std::vector<std::complex<double> > >* data) {
data->clear();
string data_string;
if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
LOG(FATAL) << "Failed to open file " << file_name;
return;
}
std::vector<string> lines = str_util::Split(data_string, '\n');
for (const string& line : lines) {
if (line == "") {
continue;
}
std::vector<std::complex<double> > data_line;
std::vector<string> values = str_util::Split(line, ',');
for (std::vector<string>::const_iterator i = values.begin();
i != values.end(); ++i) {
// each element of values may be in the form:
// 0.001+0.002i, 0.001, 0.001i, -1.2i, -1.2-3.2i, 1.5, 1.5e-03+21.0i
std::vector<string> parts;
// Find the first instance of + or - after the second character
// in the string, that does not immediately follow an 'e'.
size_t operator_index = i->find_first_of("+-", 2);
if (operator_index < i->size() &&
i->substr(operator_index - 1, 1) == "e") {
operator_index = i->find_first_of("+-", operator_index + 1);
}
parts.push_back(i->substr(0, operator_index));
if (operator_index < i->size()) {
parts.push_back(i->substr(operator_index, string::npos));
}
double real_part = 0.0;
double imaginary_part = 0.0;
for (std::vector<string>::const_iterator j = parts.begin();
j != parts.end(); ++j) {
if (j->find_first_of("ij") != string::npos) {
strings::safe_strtod((*j).c_str(), &imaginary_part);
} else {
strings::safe_strtod((*j).c_str(), &real_part);
}
}
data_line.push_back(std::complex<double>(real_part, imaginary_part));
}
data->push_back(data_line);
}
}
void ReadCSVFileToArrayOrDie(const string& filename,
std::vector<std::vector<float> >* array) {
string contents;
TF_CHECK_OK(ReadFileToString(Env::Default(), filename, &contents));
std::vector<string> lines = str_util::Split(contents, '\n');
contents.clear();
array->clear();
std::vector<float> values;
for (int l = 0; l < lines.size(); ++l) {
values.clear();
CHECK(str_util::SplitAndParseAsFloats(lines[l], ',', &values));
array->push_back(values);
}
}
bool WriteDoubleVectorToFile(const string& file_name,
const std::vector<double>& data) {
std::unique_ptr<WritableFile> file;
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
LOG(ERROR) << "Failed to open file " << file_name;
return false;
}
for (int i = 0; i < data.size(); ++i) {
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
sizeof(data[i])))
.ok()) {
LOG(ERROR) << "Failed to append to file " << file_name;
return false;
}
}
if (!file->Close().ok()) {
LOG(ERROR) << "Failed to close file " << file_name;
return false;
}
return true;
}
bool WriteFloatVectorToFile(const string& file_name,
const std::vector<float>& data) {
std::unique_ptr<WritableFile> file;
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
LOG(ERROR) << "Failed to open file " << file_name;
return false;
}
for (int i = 0; i < data.size(); ++i) {
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
sizeof(data[i])))
.ok()) {
LOG(ERROR) << "Failed to append to file " << file_name;
return false;
}
}
if (!file->Close().ok()) {
LOG(ERROR) << "Failed to close file " << file_name;
return false;
}
return true;
}
bool WriteDoubleArrayToFile(const string& file_name, int size,
const double* data) {
std::unique_ptr<WritableFile> file;
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
LOG(ERROR) << "Failed to open file " << file_name;
return false;
}
for (int i = 0; i < size; ++i) {
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
sizeof(data[i])))
.ok()) {
LOG(ERROR) << "Failed to append to file " << file_name;
return false;
}
}
if (!file->Close().ok()) {
LOG(ERROR) << "Failed to close file " << file_name;
return false;
}
return true;
}
bool WriteFloatArrayToFile(const string& file_name, int size,
const float* data) {
std::unique_ptr<WritableFile> file;
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
LOG(ERROR) << "Failed to open file " << file_name;
return false;
}
for (int i = 0; i < size; ++i) {
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
sizeof(data[i])))
.ok()) {
LOG(ERROR) << "Failed to append to file " << file_name;
return false;
}
}
if (!file->Close().ok()) {
LOG(ERROR) << "Failed to close file " << file_name;
return false;
}
return true;
}
bool WriteComplexVectorToRawFloatFile(
const string& file_name,
const std::vector<std::vector<std::complex<double> > >& data) {
std::unique_ptr<WritableFile> file;
if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
LOG(ERROR) << "Failed to open file " << file_name;
return false;
}
for (int i = 0; i < data.size(); ++i) {
for (int j = 0; j < data[i].size(); ++j) {
const float real_part(real(data[i][j]));
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&real_part),
sizeof(real_part)))
.ok()) {
LOG(ERROR) << "Failed to append to file " << file_name;
return false;
}
const float imag_part(imag(data[i][j]));
if (!file->Append(StringPiece(reinterpret_cast<const char*>(&imag_part),
sizeof(imag_part)))
.ok()) {
LOG(ERROR) << "Failed to append to file " << file_name;
return false;
}
}
}
if (!file->Close().ok()) {
LOG(ERROR) << "Failed to close file " << file_name;
return false;
}
return true;
}
void SineWave(int sample_rate, float frequency, float duration_seconds,
std::vector<double>* data) {
data->clear();
for (int i = 0; i < static_cast<int>(sample_rate * duration_seconds); ++i) {
data->push_back(
sin(2.0 * M_PI * i * frequency / static_cast<double>(sample_rate)));
}
}
} // namespace tensorflow

View File

@ -0,0 +1,81 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_
#include <complex>
#include <string>
#include <vector>
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
// Reads a wav format file into a vector of floating-point values with range
// -1.0 to 1.0.
bool ReadWaveFileToVector(const string& file_name, std::vector<double>* data);
// Reads a binary file containing 32-bit floating point values in the
// form [real_1, imag_1, real_2, imag_2, ...] into a rectangular array
// of complex values where row_length is the length of each inner vector.
bool ReadRawFloatFileToComplexVector(
const string& file_name, int row_length,
std::vector<std::vector<std::complex<double> > >* data);
// Reads a CSV file of numbers in the format 1.1+2.2i,1.1,2.2i,3.3j into data.
void ReadCSVFileToComplexVectorOrDie(
const string& file_name,
std::vector<std::vector<std::complex<double> > >* data);
// Reads a 2D array of floats from an ASCII text file, where each line is a row
// of the array, and elements are separated by commas.
void ReadCSVFileToArrayOrDie(const string& filename,
std::vector<std::vector<float> >* array);
// Write a binary file containing 64-bit floating-point values for
// reading by, for example, MATLAB.
bool WriteDoubleVectorToFile(const string& file_name,
const std::vector<double>& data);
// Write a binary file containing 32-bit floating-point values for
// reading by, for example, MATLAB.
bool WriteFloatVectorToFile(const string& file_name,
const std::vector<float>& data);
// Write a binary file containing 64-bit floating-point values for
// reading by, for example, MATLAB.
bool WriteDoubleArrayToFile(const string& file_name, int size,
const double* data);
// Write a binary file containing 32-bit floating-point values for
// reading by, for example, MATLAB.
bool WriteFloatArrayToFile(const string& file_name, int size,
const float* data);
// Write a binary file in the format read by
// ReadRawDoubleFileToComplexVector above.
bool WriteComplexVectorToRawFloatFile(
const string& file_name,
const std::vector<std::vector<std::complex<double> > >& data);
// Generate a sine wave with the provided parameters, and populate
// data with the samples.
void SineWave(int sample_rate, float frequency, float duration_seconds,
std::vector<double>* data);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_LIB_CORE_BITS_H_
#define TENSORFLOW_LIB_CORE_BITS_H_
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@ -91,6 +92,18 @@ inline int Log2Ceiling64(uint64 n) {
return floor + 1;
}
inline uint32 NextPowerOfTwo(uint32 value) {
int exponent = Log2Ceiling(value);
DCHECK_LT(exponent, std::numeric_limits<uint32>::digits);
return 1 << exponent;
}
inline uint64 NextPowerOfTwo64(uint64 value) {
int exponent = Log2Ceiling(value);
DCHECK_LT(exponent, std::numeric_limits<uint64>::digits);
return 1LL << exponent;
}
} // namespace tensorflow
#endif // TENSORFLOW_LIB_CORE_BITS_H_

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/bits.h"
namespace tensorflow {
@ -66,6 +67,39 @@ Status EncodeWavShapeFn(InferenceContext* c) {
return Status::OK();
}
Status SpectrogramShapeFn(InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
int32 window_size;
TF_RETURN_IF_ERROR(c->GetAttr("window_size", &window_size));
int32 stride;
TF_RETURN_IF_ERROR(c->GetAttr("stride", &stride));
DimensionHandle input_channels = c->Dim(input, 0);
DimensionHandle input_length = c->Dim(input, 1);
DimensionHandle output_length;
if (!c->ValueKnown(input_length)) {
output_length = c->UnknownDim();
} else {
const int64 input_length_value = c->Value(input_length);
const int64 length_minus_window = (input_length_value - window_size);
int64 output_length_value;
if (length_minus_window < 0) {
output_length_value = 0;
} else {
output_length_value = 1 + (length_minus_window / stride);
}
output_length = c->MakeDim(output_length_value);
}
DimensionHandle output_channels =
c->MakeDim(1 + NextPowerOfTwo(window_size) / 2);
c->set_output(0,
c->MakeShape({input_channels, output_length, output_channels}));
return Status::OK();
}
} // namespace
REGISTER_OP("DecodeWav")
@ -121,4 +155,49 @@ sample_rate: Scalar containing the sample frequency.
contents: 0-D. WAV-encoded file contents.
)doc");
REGISTER_OP("AudioSpectrogram")
.Input("input: float")
.Attr("window_size: int")
.Attr("stride: int")
.Attr("magnitude_squared: bool = false")
.Output("spectrogram: float")
.SetShapeFn(SpectrogramShapeFn)
.Doc(R"doc(
Produces a visualization of audio data over time.
Spectrograms are a standard way of representing audio information as a series of
slices of frequency information, one slice for each window of time. By joining
these together into a sequence, they form a distinctive fingerprint of the sound
over time.
This op expects to receive audio data as an input, stored as floats in the range
-1 to 1, together with a window width in samples, and a stride specifying how
far to move the window between slices. From this it generates a three
dimensional output. The lowest dimension has an amplitude value for each
frequency during that time slice. The next dimension is time, with successive
frequency slices. The final dimension is for the channels in the input, so a
stereo audio input would have two here for example.
This means the layout when converted and saved as an image is rotated 90 degrees
clockwise from a typical spectrogram. Time is descending down the Y axis, and
the frequency decreases from left to right.
Each value in the result represents the square root of the sum of the real and
imaginary parts of an FFT on the current window of samples. In this way, the
lowest dimension represents the power of each frequency in the current window,
and adjacent windows are concatenated in the next dimension.
To get a more intuitive and visual look at what this operation does, you can run
tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the
resulting spectrogram as a PNG image.
input: Float representation of audio data.
window_size: How wide the input window is in samples. For the highest efficiency
this should be a power of two, but other values are accepted.
stride: How widely apart the center of adjacent sample windows should be.
magnitude_squared: Whether to return the squared magnitude or just the
magnitude. Using squared magnitude can avoid extra calculations.
spectrogram: 3D representation of the audio frequencies as an image.
)doc");
} // namespace tensorflow

View File

@ -92,6 +92,7 @@ cc_library(
"//tensorflow/core:protos_cc",
"@com_googlesource_code_re2//:re2",
"@farmhash_archive//:farmhash",
"@fft2d//:fft2d",
"@highwayhash//:sip_hash",
"@png_archive//:png",
],

View File

@ -93,6 +93,22 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
return false;
}
bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
float* dst, bool* value_parsing_ok) {
*value_parsing_ok = true;
if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
char extra;
if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
}
return true;
}
return false;
}
} // namespace
Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text)
@ -116,6 +132,12 @@ Flag::Flag(const char* name, string* dst, const string& usage_text)
string_value_(dst),
usage_text_(usage_text) {}
Flag::Flag(const char* name, float* dst, const string& usage_text)
: name_(name),
type_(TYPE_FLOAT),
float_value_(dst),
usage_text_(usage_text) {}
bool Flag::Parse(string arg, bool* value_parsing_ok) const {
bool result = false;
if (type_ == TYPE_INT) {
@ -126,6 +148,8 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok);
} else if (type_ == TYPE_STRING) {
result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok);
} else if (type_ == TYPE_FLOAT) {
result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok);
}
return result;
}
@ -195,6 +219,10 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const {
type_name = "string";
flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
flag.string_value_->c_str());
} else if (flag.type_ == Flag::TYPE_FLOAT) {
type_name = "float";
flag_string =
strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_);
}
strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
type_name, flag.usage_text_.c_str());

View File

@ -65,6 +65,7 @@ class Flag {
Flag(const char* name, int64* dst1, const string& usage_text);
Flag(const char* name, bool* dst, const string& usage_text);
Flag(const char* name, string* dst, const string& usage_text);
Flag(const char* name, float* dst, const string& usage_text);
private:
friend class Flags;
@ -72,11 +73,12 @@ class Flag {
bool Parse(string arg, bool* value_parsing_ok) const;
string name_;
enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING } type_;
enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_;
int* int_value_;
int64* int64_value_;
bool* bool_value_;
string* string_value_;
float* float_value_;
string usage_text_;
};

View File

@ -32,29 +32,35 @@ std::vector<char *> CharPointerVectorFromStrings(
}
return result;
}
}
} // namespace
TEST(CommandLineFlagsTest, BasicUsage) {
int some_int = 10;
int64 some_int64 = 21474836470; // max int32 is 2147483647
bool some_switch = false;
string some_name = "something";
int argc = 5;
std::vector<string> argv_strings = {
"program_name", "--some_int=20", "--some_int64=214748364700",
"--some_switch", "--some_name=somethingelse"};
float some_float = -23.23f;
int argc = 6;
std::vector<string> argv_strings = {"program_name",
"--some_int=20",
"--some_int64=214748364700",
"--some_switch",
"--some_name=somethingelse",
"--some_float=42.0"};
std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
bool parsed_ok =
Flags::Parse(&argc, argv_array.data(),
{Flag("some_int", &some_int, "some int"),
Flag("some_int64", &some_int64, "some int64"),
Flag("some_switch", &some_switch, "some switch"),
Flag("some_name", &some_name, "some name")});
Flag("some_name", &some_name, "some name"),
Flag("some_float", &some_float, "some float")});
EXPECT_EQ(true, parsed_ok);
EXPECT_EQ(20, some_int);
EXPECT_EQ(214748364700, some_int64);
EXPECT_EQ(true, some_switch);
EXPECT_EQ("somethingelse", some_name);
EXPECT_NEAR(42.0f, some_float, 1e-5f);
EXPECT_EQ(argc, 1);
}
@ -85,6 +91,21 @@ TEST(CommandLineFlagsTest, BadBoolValue) {
EXPECT_EQ(argc, 1);
}
TEST(CommandLineFlagsTest, BadFloatValue) {
float some_float = -23.23f;
int argc = 2;
std::vector<string> argv_strings = {"program_name",
"--some_float=notanumber"};
std::vector<char *> argv_array = CharPointerVectorFromStrings(argv_strings);
bool parsed_ok =
Flags::Parse(&argc, argv_array.data(),
{Flag("some_float", &some_float, "some float")});
EXPECT_EQ(false, parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
EXPECT_EQ(argc, 1);
}
// Return whether str==pat, but allowing any whitespace in pat
// to match zero or more whitespace characters in str.
static bool MatchWithAnyWhitespace(const string &str, const string &pat) {
@ -111,6 +132,8 @@ TEST(CommandLineFlagsTest, UsageString) {
int64 some_int64 = 21474836470; // max int32 is 2147483647
bool some_switch = false;
string some_name = "something";
// Don't test float in this case, because precision is hard to predict and
// match against, and we don't want a flakey test.
const string tool_name = "some_tool_name";
string usage = Flags::Usage(tool_name + "<flags>",
{Flag("some_int", &some_int, "some int"),

View File

@ -0,0 +1,68 @@
# Description:
# TensorFlow C++ inference example for labeling images.
package(
default_visibility = ["//tensorflow:internal"],
features = [
"-layering_check",
"-parse_headers",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
cc_library(
name = "wav_to_spectrogram_lib",
srcs = [
"wav_to_spectrogram.cc",
],
hdrs = [
"wav_to_spectrogram.h",
],
deps = [
"//tensorflow/cc:cc_ops",
"//tensorflow/core:framework_internal",
"//tensorflow/core:tensorflow",
],
)
cc_binary(
name = "wav_to_spectrogram",
srcs = [
"main.cc",
],
deps = [
":wav_to_spectrogram_lib",
"//tensorflow/core:framework_internal",
"//tensorflow/core:tensorflow",
],
)
cc_test(
name = "wav_to_spectrogram_test",
size = "medium",
srcs = ["wav_to_spectrogram_test.cc"],
deps = [
":wav_to_spectrogram_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
"bin/**",
"gen/**",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,49 @@
# TensorFlow Spectrogram Example
This example shows how you can load audio from a .wav file, convert it to a
spectrogram, and then save it out as a PNG image. A spectrogram is a
visualization of the frequencies in sound over time, and can be useful as a
feature for neural network recognition on noise or speech.
## Building
To build it, run this command:
```bash
bazel build tensorflow/examples/wav_to_spectrogram/...
```
That should build a binary executable that you can then run like this:
```bash
bazel-bin/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram
```
This uses a default test audio file that's part of the TensorFlow source code,
and writes out the image to the current directory as spectrogram.png.
## Options
To load your own audio, you need to supply a .wav file in LIN16 format, and use
the `--input_audio` flag to pass in the path.
To control how the spectrogram is created, you can specify the `--window_size`
and `--stride` arguments, which control how wide the window used to estimate
frequencies is, and how widely adjacent windows are spaced.
The `--output_image` flag sets the path to save the image file to. This is
always written out in PNG format, even if you specify a different file
extension.
If your result seems too dark, try using the `--brightness` flag to make the
output image easier to see.
Here's an example of how to use all of them together:
```bash
bazel-bin/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram \
--input_wav=/tmp/my_audio.wav \
--window=1024 \
--stride=512 \
--output_image=/tmp/my_spectrogram.png
```

View File

@ -0,0 +1,66 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
int main(int argc, char* argv[]) {
// These are the command-line flags the program can understand.
// They define where the graph and input data is located, and what kind of
// input the model expects. If you train your own model, or use something
// other than inception_v3, then you'll need to update these.
tensorflow::string input_wav =
"tensorflow/core/kernels/spectrogram_test_data/short_test_segment.wav";
tensorflow::int32 window_size = 256;
tensorflow::int32 stride = 128;
float brightness = 64.0f;
tensorflow::string output_image = "spectrogram.png";
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("input_wav", &input_wav, "audio file to load"),
tensorflow::Flag("window_size", &window_size,
"frequency sample window width"),
tensorflow::Flag("stride", &stride,
"how far apart to place frequency windows"),
tensorflow::Flag("brightness", &brightness,
"controls how bright the output image is"),
tensorflow::Flag("output_image", &output_image,
"where to save the spectrogram image to"),
};
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
LOG(ERROR) << usage;
return -1;
}
// We need to call this to set up global state for TensorFlow.
tensorflow::port::InitMain(argv[0], &argc, &argv);
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return -1;
}
tensorflow::Status wav_status = WavToSpectrogram(
input_wav, window_size, stride, brightness, output_image);
if (!wav_status.ok()) {
LOG(ERROR) << "WavToSpectrogram failed with " << wav_status;
return -1;
}
return 0;
}

View File

@ -0,0 +1,97 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h"
#include <vector>
#include "tensorflow/cc/ops/audio_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
using tensorflow::DT_FLOAT;
using tensorflow::DT_UINT8;
using tensorflow::Output;
using tensorflow::TensorShape;
// Runs a TensorFlow graph to convert an audio file into a visualization.
tensorflow::Status WavToSpectrogram(const tensorflow::string& input_wav,
tensorflow::int32 window_size,
tensorflow::int32 stride, float brightness,
const tensorflow::string& output_image) {
auto root = tensorflow::Scope::NewRootScope();
using namespace tensorflow::ops; // NOLINT(build/namespaces)
// The following block creates a TensorFlow graph that:
// - Reads and decodes the audio file into a tensor of float samples.
// - Creates a float spectrogram from those samples.
// - Scales, clamps, and converts that spectrogram to 0 to 255 uint8's.
// - Reshapes the tensor so that it's [height, width, 1] for imaging.
// - Encodes it as a PNG stream and saves it out to a file.
Output file_reader = ReadFile(root.WithOpName("input_wav"), input_wav);
DecodeWav wav_decoder =
DecodeWav(root.WithOpName("wav_decoder"), file_reader);
Output spectrogram = AudioSpectrogram(root.WithOpName("spectrogram"),
wav_decoder.audio, window_size, stride);
Output brightness_placeholder =
Placeholder(root.WithOpName("brightness_placeholder"), DT_FLOAT,
Placeholder::Attrs().Shape(TensorShape({})));
Output mul = Mul(root.WithOpName("mul"), spectrogram, brightness_placeholder);
Output min_const = Const(root.WithOpName("min_const"), 255.0f);
Output min = Minimum(root.WithOpName("min"), mul, min_const);
Output cast = Cast(root.WithOpName("cast"), min, DT_UINT8);
Output expand_dims_const = Const(root.WithOpName("expand_dims_const"), -1);
Output expand_dims =
ExpandDims(root.WithOpName("expand_dims"), cast, expand_dims_const);
Output squeeze = Squeeze(root.WithOpName("squeeze"), expand_dims,
Squeeze::Attrs().SqueezeDims({0}));
Output png_encoder = EncodePng(root.WithOpName("png_encoder"), squeeze);
WriteFile file_writer =
WriteFile(root.WithOpName("output_image"), output_image, png_encoder);
tensorflow::GraphDef graph;
TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
// Build a session object from this graph definition. The power of TensorFlow
// is that you can reuse complex computations like this, so usually we'd run a
// lot of different inputs through it. In this example, we're just doing a
// one-off run, so we'll create it and then use it immediately.
std::unique_ptr<tensorflow::Session> session(
tensorflow::NewSession(tensorflow::SessionOptions()));
TF_RETURN_IF_ERROR(session->Create(graph));
// We're passing in the brightness as an input, so create a tensor to hold the
// value.
tensorflow::Tensor brightness_tensor(DT_FLOAT, TensorShape({}));
brightness_tensor.scalar<float>()() = brightness;
// Run the session to analyze the audio and write out the file.
TF_RETURN_IF_ERROR(
session->Run({{"brightness_placeholder", brightness_tensor}}, {},
{"output_image"}, nullptr));
return tensorflow::Status::OK();
}

View File

@ -0,0 +1,31 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
#define THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
// Runs a TensorFlow graph to convert an audio file into a visualization. Takes
// in the path to the audio file, the window size and stride parameters
// controlling the spectrogram creation, the brightness scaling to use, and a
// path to save the output PNG file to.
tensorflow::Status WavToSpectrogram(const tensorflow::string& input_wav,
tensorflow::int32 window_size,
tensorflow::int32 stride, float brightness,
const tensorflow::string& output_image);
#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/wav/wav_io.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
TEST(WavToSpectrogramTest, WavToSpectrogramTest) {
const tensorflow::string input_wav =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "input_wav.wav");
const tensorflow::string output_image = tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(), "output_image.png");
float audio[8] = {-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f};
tensorflow::string wav_string;
TF_ASSERT_OK(
tensorflow::wav::EncodeAudioAsS16LEWav(audio, 44100, 1, 8, &wav_string));
TF_ASSERT_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
input_wav, wav_string));
TF_ASSERT_OK(WavToSpectrogram(input_wav, 4, 4, 64.0f, output_image));
TF_EXPECT_OK(tensorflow::Env::Default()->FileExists(output_image));
}

View File

@ -79,11 +79,13 @@ genrule(
srcs = [
"//third_party/hadoop:LICENSE.txt",
"//third_party/eigen3:LICENSE",
"//third_party/fft2d:LICENSE",
"@boringssl//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@curl//:COPYING",
"@eigen_archive//:COPYING.MPL2",
"@farmhash_archive//:COPYING",
"@fft2d//:fft/readme.txt",
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
@ -106,11 +108,13 @@ genrule(
srcs = [
"//third_party/hadoop:LICENSE.txt",
"//third_party/eigen3:LICENSE",
"//third_party/fft2d:LICENSE",
"@boringssl//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@curl//:COPYING",
"@eigen_archive//:COPYING.MPL2",
"@farmhash_archive//:COPYING",
"@fft2d//:fft/readme.txt",
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",

View File

@ -91,12 +91,14 @@ filegroup(
name = "licenses",
data = [
"//third_party/eigen3:LICENSE",
"//third_party/fft2d:LICENSE",
"//third_party/hadoop:LICENSE.txt",
"@boringssl//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@curl//:COPYING",
"@eigen_archive//:COPYING.MPL2",
"@farmhash_archive//:COPYING",
"@fft2d//:fft/readme.txt",
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@grpc//:LICENSE",

View File

@ -500,6 +500,16 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
name="zlib",
actual="@zlib_archive//:zlib",)
native.new_http_archive(
name = "fft2d",
urls = [
"http://bazel-mirror.storage.googleapis.com/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
],
sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
build_file = str(Label("//third_party/fft2d:fft2d.BUILD")),
)
temp_workaround_http_archive(
name="snappy",
urls=[