From 7f6fd8b48bef935c5adcb1bf09dce45893fab5f2 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Wed, 3 Apr 2019 11:54:41 -0300 Subject: [PATCH 1/2] Embed more metadata in exported model and read it in native client --- DeepSpeech.py | 15 +++++++++-- native_client/deepspeech.cc | 51 +++++++++++++++++-------------------- util/config.py | 6 +++++ util/feeding.py | 5 +++- util/flags.py | 5 ++++ 5 files changed, 52 insertions(+), 30 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 3b8e9860..bd2e6866 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -548,9 +548,9 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): batch_size = batch_size if batch_size > 0 else None # Create feature computation graph - input_samples = tf.placeholder(tf.float32, [512], 'input_samples') + input_samples = tf.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples') samples = tf.expand_dims(input_samples, -1) - mfccs, _ = samples_to_mfccs(samples, 16000) + mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate) mfccs = tf.identity(mfccs, name='mfccs') # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input] @@ -724,6 +724,17 @@ def export(): if not FLAGS.export_tflite: frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h') frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip()) + + # Add a no-op node to the graph with metadata information to be loaded by the native client + metadata = frozen_graph.node.add() + metadata.name = 'model_metadata' + metadata.op = 'NoOp' + metadata.attr['sample_rate'].i = FLAGS.audio_sample_rate + metadata.attr['feature_win_len'].i = FLAGS.feature_win_len + metadata.attr['feature_win_step'].i = FLAGS.feature_win_step + if FLAGS.export_model_language: + metadata.attr['language'].s = FLAGS.export_model_language.encode('ascii') + with open(output_graph_path, 'wb') as fout: fout.write(frozen_graph.SerializeToString()) else: diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index a4466cb1..56acf89a 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -39,25 +39,9 @@ //TODO: infer batch size from model/use dynamic batch size constexpr unsigned int BATCH_SIZE = 1; -//TODO: use dynamic sample rate -constexpr unsigned int SAMPLE_RATE = 16000; - -constexpr float AUDIO_WIN_LEN = 0.032f; -constexpr float AUDIO_WIN_STEP = 0.02f; -constexpr unsigned int AUDIO_WIN_LEN_SAMPLES = (unsigned int)(AUDIO_WIN_LEN * SAMPLE_RATE); -constexpr unsigned int AUDIO_WIN_STEP_SAMPLES = (unsigned int)(AUDIO_WIN_STEP * SAMPLE_RATE); - -constexpr size_t WINDOW_SIZE = AUDIO_WIN_LEN * SAMPLE_RATE; - -std::array calc_hamming_window() { - std::array a{0}; - for (int i = 0; i < WINDOW_SIZE; ++i) { - a[i] = 0.54 - 0.46 * std::cos(2*M_PI*i/(WINDOW_SIZE-1)); - } - return a; -} - -std::array hamming_window = calc_hamming_window(); +constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000; +constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032; +constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02; #ifndef USE_TFLITE using namespace tensorflow; @@ -134,6 +118,9 @@ struct ModelState { unsigned int n_context; unsigned int n_features; unsigned int mfcc_feats_per_timestep; + unsigned int sample_rate; + unsigned int audio_win_len; + unsigned int audio_win_step; #ifdef USE_TFLITE size_t previous_state_size; @@ -220,6 +207,9 @@ ModelState::ModelState() , n_context(-1) , n_features(-1) , mfcc_feats_per_timestep(-1) + , sample_rate(DEFAULT_SAMPLE_RATE) + , audio_win_len(DEFAULT_WINDOW_LENGTH) + , audio_win_step(DEFAULT_WINDOW_STEP) #ifdef USE_TFLITE , previous_state_size(0) , previous_state_c_(nullptr) @@ -258,7 +248,7 @@ StreamingState::feedAudioContent(const short* buffer, { // Consume all the data that was passed in, processing full buffers if needed while (buffer_size > 0) { - while (buffer_size > 0 && audio_buffer.size() < AUDIO_WIN_LEN_SAMPLES) { + while (buffer_size > 0 && audio_buffer.size() < model->audio_win_len) { // Convert i16 sample into f32 float multiplier = 1.0f / (1 << 15); audio_buffer.push_back((float)(*buffer) * multiplier); @@ -267,10 +257,10 @@ StreamingState::feedAudioContent(const short* buffer, } // If the buffer is full, process and shift it - if (audio_buffer.size() == AUDIO_WIN_LEN_SAMPLES) { + if (audio_buffer.size() == model->audio_win_len) { processAudioWindow(audio_buffer); // Shift data by one step - shift_buffer_left(audio_buffer, AUDIO_WIN_STEP_SAMPLES); + shift_buffer_left(audio_buffer, model->audio_win_step); } // Repeat until buffer empty @@ -461,13 +451,13 @@ void ModelState::compute_mfcc(const vector& samples, vector& mfcc_output) { #ifndef USE_TFLITE - Tensor input(DT_FLOAT, TensorShape({AUDIO_WIN_LEN_SAMPLES})); + Tensor input(DT_FLOAT, TensorShape({audio_win_len})); auto input_mapped = input.flat(); int i; for (i = 0; i < samples.size(); ++i) { input_mapped(i) = samples[i]; } - for (; i < AUDIO_WIN_LEN_SAMPLES; ++i) { + for (; i < audio_win_len; ++i) { input_mapped(i) = 0.f; } @@ -556,8 +546,8 @@ ModelState::decode_metadata(const vector& logits) for (int i = 0; i < out[0].tokens.size(); ++i) { metadata->items[i].character = (char*)alphabet->StringFromLabel(out[0].tokens[i]).c_str(); metadata->items[i].timestep = out[0].timesteps[i]; - metadata->items[i].start_time = static_cast(out[0].timesteps[i] * AUDIO_WIN_STEP); - + metadata->items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step / sample_rate); + if (metadata->items[i].start_time < 0) { metadata->items[i].start_time = 0; } @@ -700,6 +690,13 @@ DS_CreateModel(const char* aModelPath, << std::endl; return DS_ERR_INVALID_ALPHABET; } + } else if (node.name() == "model_metadata") { + int sample_rate = node.attr().at("sample_rate").i(); + model->sample_rate = sample_rate; + int win_len_ms = node.attr().at("feature_win_len").i(); + int win_step_ms = node.attr().at("feature_win_step").i(); + model->audio_win_len = sample_rate * (win_len_ms / 1000.0); + model->audio_win_step = sample_rate * (win_step_ms / 1000.0); } } @@ -833,7 +830,7 @@ DS_SetupStream(ModelState* aCtx, ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes); - ctx->audio_buffer.reserve(AUDIO_WIN_LEN_SAMPLES); + ctx->audio_buffer.reserve(aCtx->audio_win_len); ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep); ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f); ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep); diff --git a/util/config.py b/util/config.py index b7765451..14077f6c 100644 --- a/util/config.py +++ b/util/config.py @@ -92,6 +92,12 @@ def initialize_globals(): # Units in the sixth layer = number of characters in the target language plus one c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label + # Size of audio window in samples + c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000) + + # Stride for feature computations in samples + c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000) + if FLAGS.one_shot_infer: if not os.path.exists(FLAGS.one_shot_infer): log_error('Path specified in --one_shot_infer is not a valid file.') diff --git a/util/feeding.py b/util/feeding.py index 2feb5bbc..9c91518b 100644 --- a/util/feeding.py +++ b/util/feeding.py @@ -27,7 +27,10 @@ def read_csvs(csv_files): def samples_to_mfccs(samples, sample_rate): - spectrogram = contrib_audio.audio_spectrogram(samples, window_size=512, stride=320, magnitude_squared=True) + spectrogram = contrib_audio.audio_spectrogram(samples, + window_size=Config.audio_window_samples, + stride=Config.audio_step_samples, + magnitude_squared=True) mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input) mfccs = tf.reshape(mfccs, [-1, Config.n_input]) diff --git a/util/flags.py b/util/flags.py index 1683e32b..efdedae6 100644 --- a/util/flags.py +++ b/util/flags.py @@ -19,6 +19,10 @@ def create_flags(): tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged') tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged') + tf.app.flags.DEFINE_integer ('feature_win_len', 32, 'feature extraction audio window length in milliseconds') + tf.app.flags.DEFINE_integer ('feature_win_step', 20, 'feature extraction window step length in milliseconds') + tf.app.flags.DEFINE_integer ('audio_sample_rate',16000, 'sample rate value expected by model') + # Global Constants # ================ @@ -73,6 +77,7 @@ def create_flags(): tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine') tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)') tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') + tf.app.flags.DEFINE_string ('export_model_language', '', 'language the model was trained on. Gets embedded into exported model.') # Reporting From 5b80f216684f0f8c0889d0754496161f26337fed Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Fri, 5 Apr 2019 11:54:02 -0300 Subject: [PATCH 2/2] Rename language flag --- util/flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/flags.py b/util/flags.py index efdedae6..f373c0e5 100644 --- a/util/flags.py +++ b/util/flags.py @@ -77,7 +77,7 @@ def create_flags(): tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine') tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)') tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency') - tf.app.flags.DEFINE_string ('export_model_language', '', 'language the model was trained on. Gets embedded into exported model.') + tf.app.flags.DEFINE_string ('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.') # Reporting