From e51b9d987d162bd4cbad0a0c94295ae7809ea086 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 6 Jun 2019 16:40:19 -0300 Subject: [PATCH] Remove previous state model variable, track by hand in StreamingState instead --- DeepSpeech.py | 86 +++++++---------- GRAPH_VERSION | 2 +- native_client/BUILD | 14 +-- native_client/deepspeech.cc | 18 ++-- native_client/modelstate.cc | 1 + native_client/modelstate.h | 11 ++- native_client/tflitemodelstate.cc | 150 +++++++++++++++--------------- native_client/tflitemodelstate.h | 27 ++++-- native_client/tfmodelstate.cc | 96 +++++++++++-------- native_client/tfmodelstate.h | 10 +- 10 files changed, 213 insertions(+), 202 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 1883724d..7e92e202 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -574,12 +574,8 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): # no state management since n_step is expected to be dynamic too (see below) previous_state = previous_state_c = previous_state_h = None else: - if tflite: - previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') - previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') - else: - previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None) - previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None) + previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') + previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) @@ -605,7 +601,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): logits = tf.squeeze(logits, [1]) # Apply softmax for CTC decoder - logits = tf.nn.softmax(logits) + logits = tf.nn.softmax(logits, name='logits') if batch_size <= 0: if tflite: @@ -618,51 +614,31 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): 'input_lengths': seq_length, }, { - 'outputs': tf.identity(logits, name='logits'), + 'outputs': logits, }, layers ) new_state_c, new_state_h = layers['rnn_output_state'] - if tflite: - logits = tf.identity(logits, name='logits') - new_state_c = tf.identity(new_state_c, name='new_state_c') - new_state_h = tf.identity(new_state_h, name='new_state_h') + new_state_c = tf.identity(new_state_c, name='new_state_c') + new_state_h = tf.identity(new_state_h, name='new_state_h') - inputs = { - 'input': input_tensor, - 'previous_state_c': previous_state_c, - 'previous_state_h': previous_state_h, - 'input_samples': input_samples, - } + inputs = { + 'input': input_tensor, + 'previous_state_c': previous_state_c, + 'previous_state_h': previous_state_h, + 'input_samples': input_samples, + } - if FLAGS.use_seq_length: - inputs.update({'input_lengths': seq_length}) + if FLAGS.use_seq_length: + inputs.update({'input_lengths': seq_length}) - outputs = { - 'outputs': logits, - 'new_state_c': new_state_c, - 'new_state_h': new_state_h, - 'mfccs': mfccs, - } - else: - zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32) - initialize_c = tf.assign(previous_state_c, zero_state) - initialize_h = tf.assign(previous_state_h, zero_state) - initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state') - with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]): - logits = tf.identity(logits, name='logits') - - inputs = { - 'input': input_tensor, - 'input_lengths': seq_length, - 'input_samples': input_samples, - } - outputs = { - 'outputs': logits, - 'initialize_state': initialize_state, - 'mfccs': mfccs, - } + outputs = { + 'outputs': logits, + 'new_state_c': new_state_c, + 'new_state_h': new_state_h, + 'mfccs': mfccs, + } return inputs, outputs, layers @@ -682,10 +658,12 @@ def export(): output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)] output_names = ",".join(output_names_tensors + output_names_ops) - if not FLAGS.export_tflite: - mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} - else: + mapping = None + if FLAGS.export_tflite: # Create a saver using variables from the above newly created graph + # Training graph uses LSTMFusedCell, but the TFLite inference graph uses + # a static RNN with a normal cell, so we need to rewrite the names to + # match the training weights when restoring. def fixup(name): if name.startswith('rnn/lstm_cell/'): return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/') @@ -710,7 +688,7 @@ def export(): if not os.path.isdir(FLAGS.export_dir): os.makedirs(FLAGS.export_dir) - def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None): + def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''): frozen = freeze_graph.freeze_graph_with_def_protos( input_graph_def=tf.get_default_graph().as_graph_def(), input_saver_def=saver.as_saver_def(), @@ -731,7 +709,7 @@ def export(): placeholder_type_enum=tf.float32.as_datatype_enum) if not FLAGS.export_tflite: - frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h') + frozen_graph = do_graph_freeze(output_node_names=output_names) frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip()) # Add a no-op node to the graph with metadata information to be loaded by the native client @@ -747,7 +725,7 @@ def export(): with open(output_graph_path, 'wb') as fout: fout.write(frozen_graph.SerializeToString()) else: - frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='') + frozen_graph = do_graph_freeze(output_node_names=output_names) output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) @@ -771,8 +749,7 @@ def do_single_file_inference(input_file_path): inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) # Create a saver using variables from the above newly created graph - mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} - saver = tf.train.Saver(mapping) + saver = tf.train.Saver() # Restore variables from training checkpoint # TODO: This restores the most recent checkpoint, but if we use validation to counteract @@ -784,9 +761,10 @@ def do_single_file_inference(input_file_path): checkpoint_path = checkpoint.model_checkpoint_path saver.restore(session, checkpoint_path) - session.run(outputs['initialize_state']) features, features_len = audiofile_to_features(input_file_path) + previous_state_c = np.zeros([1, Config.n_cell_dim]) + previous_state_h = np.zeros([1, Config.n_cell_dim]) # Add batch dimension features = tf.expand_dims(features, 0) @@ -799,6 +777,8 @@ def do_single_file_inference(input_file_path): logits = outputs['outputs'].eval(feed_dict={ inputs['input']: features, inputs['input_lengths']: features_len, + inputs['previous_state_c']: previous_state_c, + inputs['previous_state_h']: previous_state_h, }, session=session) logits = np.squeeze(logits) diff --git a/GRAPH_VERSION b/GRAPH_VERSION index 56a6051c..d8263ee9 100644 --- a/GRAPH_VERSION +++ b/GRAPH_VERSION @@ -1 +1 @@ -1 \ No newline at end of file +2 \ No newline at end of file diff --git a/native_client/BUILD b/native_client/BUILD index d7813d29..5203eb47 100644 --- a/native_client/BUILD +++ b/native_client/BUILD @@ -114,34 +114,26 @@ tf_cc_shared_object( ### => Trying to be more fine-grained ### Use bin/ops_in_graph.py to list all the ops used by a frozen graph. ### CPU only build, libdeepspeech.so file size reduced by ~50% - "//tensorflow/core/kernels:dense_update_ops", # Assign - "//tensorflow/core/kernels:constant_op", # Const - "//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst + "//tensorflow/core/kernels:dense_update_ops", # Assign (remove once prod model no longer depends on it) + "//tensorflow/core/kernels:constant_op", # Placeholder + "//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models) "//tensorflow/core/kernels:identity_op", # Identity "//tensorflow/core/kernels:softmax_op", # Softmax "//tensorflow/core/kernels:transpose_op", # Transpose "//tensorflow/core/kernels:reshape_op", # Reshape "//tensorflow/core/kernels:shape_ops", # Shape "//tensorflow/core/kernels:concat_op", # ConcatV2 - "//tensorflow/core/kernels:split_op", # Split - "//tensorflow/core/kernels:variable_ops", # VariableV2 "//tensorflow/core/kernels:relu_op", # Relu "//tensorflow/core/kernels:bias_op", # BiasAdd "//tensorflow/core/kernels:math", # Range, MatMul - "//tensorflow/core/kernels:control_flow_ops", # Enter "//tensorflow/core/kernels:tile_ops", # Tile - "//tensorflow/core/kernels:gather_op", # Gather "//tensorflow/core/kernels:mfcc_op", # Mfcc "//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram "//tensorflow/core/kernels:strided_slice_op", # StridedSlice "//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice "//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM - "//tensorflow/core/kernels:random_ops", # RandomGammaGrad "//tensorflow/core/kernels:pack_op", # Pack "//tensorflow/core/kernels:gather_nd_op", # GatherNd - #### Needed by production model produced without "--use_seq_length False" - #"//tensorflow/core/kernels:logging_ops", # Assert - #"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence ], }) + if_cuda([ "//tensorflow/core:core", diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 1ee22d58..7dd96574 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -67,6 +67,9 @@ struct StreamingState { vector audio_buffer_; vector mfcc_buffer_; vector batch_buffer_; + vector previous_state_c_; + vector previous_state_h_; + ModelState* model_; std::unique_ptr decoder_state_; @@ -233,7 +236,13 @@ void StreamingState::processBatch(const vector& buf, unsigned int n_steps) { vector logits; - model_->infer(buf.data(), n_steps, logits); + model_->infer(buf, + n_steps, + previous_state_c_, + previous_state_h_, + logits, + previous_state_c_, + previous_state_h_); const int cutoff_top_n = 40; const double cutoff_prob = 1.0; @@ -326,11 +335,6 @@ DS_SetupStream(ModelState* aCtx, { *retval = nullptr; - int err = aCtx->initialize_state(); - if (err != DS_ERR_OK) { - return err; - } - std::unique_ptr ctx(new StreamingState()); if (!ctx) { std::cerr << "Could not allocate streaming state." << std::endl; @@ -348,6 +352,8 @@ DS_SetupStream(ModelState* aCtx, ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_); ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f); ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_); + ctx->previous_state_c_.resize(aCtx->state_size_, 0.f); + ctx->previous_state_h_.resize(aCtx->state_size_, 0.f); ctx->model_ = aCtx; ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_)); diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index c3fda2b9..7bb7f073 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -17,6 +17,7 @@ ModelState::ModelState() , sample_rate_(DEFAULT_SAMPLE_RATE) , audio_win_len_(DEFAULT_WINDOW_LENGTH) , audio_win_step_(DEFAULT_WINDOW_STEP) + , state_size_(-1) { } diff --git a/native_client/modelstate.h b/native_client/modelstate.h index 7f53c63e..71799421 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -28,6 +28,7 @@ struct ModelState { unsigned int sample_rate_; unsigned int audio_win_len_; unsigned int audio_win_step_; + unsigned int state_size_; ModelState(); virtual ~ModelState(); @@ -38,8 +39,6 @@ struct ModelState { const char* alphabet_path, unsigned int beam_width); - virtual int initialize_state() = 0; - virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) = 0; /** @@ -52,7 +51,13 @@ struct ModelState { * * @param[out] output_logits Where to store computed logits. */ - virtual void infer(const float* mfcc, unsigned int n_frames, std::vector& logits_output) = 0; + virtual void infer(const std::vector& mfcc, + unsigned int n_frames, + const std::vector& previous_state_c, + const std::vector& previous_state_h, + std::vector& logits_output, + std::vector& state_c_output, + std::vector& state_h_output) = 0; /** * @brief Perform decoding of the logits, using basic CTC decoder or diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc index 92d9c014..9af0ae86 100644 --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -4,14 +4,13 @@ using namespace tflite; using std::vector; int -tflite_get_tensor_by_name(const Interpreter* interpreter, - const vector& list, - const char* name) +TFLiteModelState::get_tensor_by_name(const vector& list, + const char* name) { int rv = -1; for (int i = 0; i < list.size(); ++i) { - const string& node_name = interpreter->tensor(list[i])->name; + const string& node_name = interpreter_->tensor(list[i])->name; if (node_name.compare(string(name)) == 0) { rv = i; } @@ -22,17 +21,17 @@ tflite_get_tensor_by_name(const Interpreter* interpreter, } int -tflite_get_input_tensor_by_name(const Interpreter* interpreter, const char* name) +TFLiteModelState::get_input_tensor_by_name(const char* name) { - int idx = tflite_get_tensor_by_name(interpreter, interpreter->inputs(), name); - return interpreter->inputs()[idx]; + int idx = get_tensor_by_name(interpreter_->inputs(), name); + return interpreter_->inputs()[idx]; } int -tflite_get_output_tensor_by_name(const Interpreter* interpreter, const char* name) +TFLiteModelState::get_output_tensor_by_name(const char* name) { - int idx = tflite_get_tensor_by_name(interpreter, interpreter->outputs(), name); - return interpreter->outputs()[idx]; + int idx = get_tensor_by_name(interpreter_->outputs(), name); + return interpreter_->outputs()[idx]; } void @@ -48,8 +47,8 @@ push_back_if_not_present(std::deque& list, int value) // output, add it to the parent list, and add its input tensors to the frontier // list. Because we start from the final tensor and work backwards to the inputs, // the parents list is constructed in reverse, adding elements to its front. -std::vector -tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id) +vector +TFLiteModelState::find_parent_node_ids(int tensor_id) { std::deque parents; std::deque frontier; @@ -58,8 +57,8 @@ tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id) int next_tensor_id = frontier.front(); frontier.pop_front(); // Find all nodes that have next_tensor_id as an output - for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) { - TfLiteNode node = interpreter->node_and_registration(node_id)->first; + for (int node_id = 0; node_id < interpreter_->nodes_size(); ++node_id) { + TfLiteNode node = interpreter_->node_and_registration(node_id)->first; // Search node outputs for the tensor we're looking for for (int i = 0; i < node.outputs->size; ++i) { if (node.outputs->data[i] == next_tensor_id) { @@ -74,16 +73,13 @@ tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id) } } - return std::vector(parents.begin(), parents.end()); + return vector(parents.begin(), parents.end()); } TFLiteModelState::TFLiteModelState() : ModelState() , interpreter_(nullptr) , fbmodel_(nullptr) - , previous_state_size_(0) - , previous_state_c_(nullptr) - , previous_state_h_(nullptr) { } @@ -120,21 +116,21 @@ TFLiteModelState::init(const char* model_path, interpreter_->SetNumThreads(4); // Query all the index once - input_node_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_node"); - previous_state_c_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_c"); - previous_state_h_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_h"); - input_samples_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_samples"); - logits_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "logits"); - new_state_c_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_c"); - new_state_h_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_h"); - mfccs_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "mfccs"); + input_node_idx_ = get_input_tensor_by_name("input_node"); + previous_state_c_idx_ = get_input_tensor_by_name("previous_state_c"); + previous_state_h_idx_ = get_input_tensor_by_name("previous_state_h"); + input_samples_idx_ = get_input_tensor_by_name("input_samples"); + logits_idx_ = get_output_tensor_by_name("logits"); + new_state_c_idx_ = get_output_tensor_by_name("new_state_c"); + new_state_h_idx_ = get_output_tensor_by_name("new_state_h"); + mfccs_idx_ = get_output_tensor_by_name("mfccs"); // When we call Interpreter::Invoke, the whole graph is executed by default, // which means every time compute_mfcc is called the entire acoustic model is // also executed. To workaround that problem, we walk up the dependency DAG // from the mfccs output tensor to find all the relevant nodes required for // feature computation, building an execution plan that runs just those nodes. - auto mfcc_plan = tflite_find_parent_node_ids(interpreter_.get(), mfccs_idx_); + auto mfcc_plan = find_parent_node_ids(mfccs_idx_); auto orig_plan = interpreter_->execution_plan(); // Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan @@ -168,50 +164,57 @@ TFLiteModelState::init(const char* model_path, TfLiteIntArray* dims_c = interpreter_->tensor(previous_state_c_idx_)->dims; TfLiteIntArray* dims_h = interpreter_->tensor(previous_state_h_idx_)->dims; assert(dims_c->data[1] == dims_h->data[1]); - - previous_state_size_ = dims_c->data[1]; - previous_state_c_.reset(new float[previous_state_size_]()); - previous_state_h_.reset(new float[previous_state_size_]()); - - // Set initial values for previous_state_c and previous_state_h - memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_); - memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_); - - return DS_ERR_OK; -} - -int -TFLiteModelState::initialize_state() -{ - /* Ensure previous_state_{c,h} are not holding previous stream value */ - memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_); - memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_); + assert(state_size_ > 0); + state_size_ = dims_c->data[1]; return DS_ERR_OK; } void -TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector& logits_output) +TFLiteModelState::copy_vector_to_tensor(const vector& vec, + int tensor_idx, + int num_elements) +{ + float* tensor = interpreter_->typed_tensor(tensor_idx); + int i; + for (i = 0; i < vec.size(); ++i) { + tensor[i] = vec[i]; + } + for (; i < num_elements; ++i) { + tensor[i] = 0.f; + } +} + +void +TFLiteModelState::copy_tensor_to_vector(int tensor_idx, + int num_elements, + vector& vec) +{ + float* tensor = interpreter_->typed_tensor(tensor_idx); + for (int i = 0; i < num_elements; ++i) { + vec.push_back(tensor[i]); + } +} + +void +TFLiteModelState::infer(const vector& mfcc, + unsigned int n_frames, + const vector& previous_state_c, + const vector& previous_state_h, + vector& logits_output, + vector& state_c_output, + vector& state_h_output) { const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank // Feeding input_node - float* input_node = interpreter_->typed_tensor(input_node_idx_); - { - int i; - for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) { - input_node[i] = aMfcc[i]; - } - for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) { - input_node[i] = 0; - } - } - - assert(previous_state_size_ > 0); + copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_); // Feeding previous_state_c, previous_state_h - memcpy(interpreter_->typed_tensor(previous_state_c_idx_), previous_state_c_.get(), sizeof(float) * previous_state_size_); - memcpy(interpreter_->typed_tensor(previous_state_h_idx_), previous_state_h_.get(), sizeof(float) * previous_state_size_); + assert(previous_state_c.size() == state_size_); + copy_vector_to_tensor(previous_state_c, previous_state_c_idx_, state_size_); + assert(previous_state_h.size() == state_size_); + copy_vector_to_tensor(previous_state_h, previous_state_h_idx_, state_size_); interpreter_->SetExecutionPlan(acoustic_exec_plan_); TfLiteStatus status = interpreter_->Invoke(); @@ -220,25 +223,23 @@ TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector return; } - float* outputs = interpreter_->typed_tensor(logits_idx_); + copy_tensor_to_vector(logits_idx_, n_frames * BATCH_SIZE * num_classes, logits_output); - // The CTCDecoder works with log-probs. - for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) { - logits_output.push_back(outputs[t]); - } + state_c_output.clear(); + state_c_output.reserve(state_size_); + copy_tensor_to_vector(new_state_c_idx_, state_size_, state_c_output); - memcpy(previous_state_c_.get(), interpreter_->typed_tensor(new_state_c_idx_), sizeof(float) * previous_state_size_); - memcpy(previous_state_h_.get(), interpreter_->typed_tensor(new_state_h_idx_), sizeof(float) * previous_state_size_); + state_h_output.clear(); + state_h_output.reserve(state_size_); + copy_tensor_to_vector(new_state_h_idx_, state_size_, state_h_output); } void -TFLiteModelState::compute_mfcc(const vector& samples, vector& mfcc_output) +TFLiteModelState::compute_mfcc(const vector& samples, + vector& mfcc_output) { // Feeding input_node - float* input_samples = interpreter_->typed_tensor(input_samples_idx_); - for (int i = 0; i < samples.size(); ++i) { - input_samples[i] = samples[i]; - } + copy_vector_to_tensor(samples, input_samples_idx_, samples.size()); TfLiteStatus status = interpreter_->SetExecutionPlan(mfcc_exec_plan_); if (status != kTfLiteOk) { @@ -261,8 +262,5 @@ TFLiteModelState::compute_mfcc(const vector& samples, vector& mfcc } assert(num_elements / n_features_ == n_windows); - float* outputs = interpreter_->typed_tensor(mfccs_idx_); - for (int i = 0; i < n_windows * n_features_; ++i) { - mfcc_output.push_back(outputs[i]); - } + copy_tensor_to_vector(mfccs_idx_, n_windows * n_features_, mfcc_output); } diff --git a/native_client/tflitemodelstate.h b/native_client/tflitemodelstate.h index ee5bfb6a..3a6d4971 100644 --- a/native_client/tflitemodelstate.h +++ b/native_client/tflitemodelstate.h @@ -14,10 +14,6 @@ struct TFLiteModelState : public ModelState std::unique_ptr interpreter_; std::unique_ptr fbmodel_; - size_t previous_state_size_; - std::unique_ptr previous_state_c_; - std::unique_ptr previous_state_h_; - int input_node_idx_; int previous_state_c_idx_; int previous_state_h_idx_; @@ -40,13 +36,28 @@ struct TFLiteModelState : public ModelState const char* alphabet_path, unsigned int beam_width) override; - virtual int initialize_state() override; - virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) override; - virtual void infer(const float* mfcc, unsigned int n_frames, - std::vector& logits_output) override; + virtual void infer(const std::vector& mfcc, + unsigned int n_frames, + const std::vector& previous_state_c, + const std::vector& previous_state_h, + std::vector& logits_output, + std::vector& state_c_output, + std::vector& state_h_output) override; + +private: + int get_tensor_by_name(const std::vector& list, const char* name); + int get_input_tensor_by_name(const char* name); + int get_output_tensor_by_name(const char* name); + std::vector find_parent_node_ids(int tensor_id); + void copy_vector_to_tensor(const std::vector& vec, + int tensor_idx, + int num_elements); + void copy_tensor_to_vector(int tensor_idx, + int num_elements, + std::vector& vec); }; #endif // TFLITEMODELSTATE_H diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc index 866775e4..5393ed40 100644 --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -98,6 +98,9 @@ TFModelState::init(const char* model_path, n_context_ = (shape.dim(2).size()-1)/2; n_features_ = shape.dim(3).size(); mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size(); + } else if (node.name() == "previous_state_c") { + const auto& shape = node.attr().at("shape").shape(); + state_size_ = shape.dim(1).size(); } else if (node.name() == "logits_shape") { Tensor logits_shape = Tensor(DT_INT32, TensorShape({3})); if (!logits_shape.FromProto(node.attr().at("value").tensor())) { @@ -134,66 +137,83 @@ TFModelState::init(const char* model_path, return DS_ERR_OK; } -int -TFModelState::initialize_state() +Tensor +tensor_from_vector(const std::vector& vec, const TensorShape& shape) { - Status status = session_->Run({}, {}, {"initialize_state"}, nullptr); - if (!status.ok()) { - std::cerr << "Error running session: " << status << std::endl; - return DS_ERR_FAIL_RUN_SESS; + Tensor ret(DT_FLOAT, shape); + auto ret_mapped = ret.flat(); + int i; + for (i = 0; i < vec.size(); ++i) { + ret_mapped(i) = vec[i]; } - - return DS_ERR_OK; + for (; i < shape.num_elements(); ++i) { + ret_mapped(i) = 0.f; + } + return ret; } void -TFModelState::infer(const float* aMfcc, unsigned int n_frames, vector& logits_output) +copy_tensor_to_vector(const Tensor& tensor, vector& vec, int num_elements = -1) +{ + auto tensor_mapped = tensor.flat(); + if (num_elements == -1) { + num_elements = tensor.shape().num_elements(); + } + for (int i = 0; i < num_elements; ++i) { + vec.push_back(tensor_mapped(i)); + } +} + +void +TFModelState::infer(const std::vector& mfcc, + unsigned int n_frames, + const std::vector& previous_state_c, + const std::vector& previous_state_h, + vector& logits_output, + vector& state_c_output, + vector& state_h_output) { const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank - Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_})); - - auto input_mapped = input.flat(); - int i; - for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) { - input_mapped(i) = aMfcc[i]; - } - for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) { - input_mapped(i) = 0.; - } + Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_})); + Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_})); + Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_})); Tensor input_lengths(DT_INT32, TensorShape({1})); input_lengths.scalar()() = n_frames; vector outputs; Status status = session_->Run( - {{"input_node", input}, {"input_lengths", input_lengths}}, - {"logits"}, {}, &outputs); + { + {"input_node", input}, + {"input_lengths", input_lengths}, + {"previous_state_c", previous_state_c_t}, + {"previous_state_h", previous_state_h_t} + }, + {"logits", "new_state_c", "new_state_h"}, + {}, + &outputs); if (!status.ok()) { std::cerr << "Error running session: " << status << "\n"; return; } - auto logits_mapped = outputs[0].flat(); - // The CTCDecoder works with log-probs. - for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) { - logits_output.push_back(logits_mapped(t)); - } + copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes); + + state_c_output.clear(); + state_c_output.reserve(state_size_); + copy_tensor_to_vector(outputs[1], state_c_output); + + state_h_output.clear(); + state_h_output.reserve(state_size_); + copy_tensor_to_vector(outputs[2], state_h_output); } void TFModelState::compute_mfcc(const vector& samples, vector& mfcc_output) { - Tensor input(DT_FLOAT, TensorShape({audio_win_len_})); - auto input_mapped = input.flat(); - int i; - for (i = 0; i < samples.size(); ++i) { - input_mapped(i) = samples[i]; - } - for (; i < audio_win_len_; ++i) { - input_mapped(i) = 0.f; - } + Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_})); vector outputs; Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs); @@ -206,9 +226,5 @@ TFModelState::compute_mfcc(const vector& samples, vector& mfcc_out // The feature computation graph is hardcoded to one audio length for now const int n_windows = 1; assert(outputs[0].shape().num_elements() / n_features_ == n_windows); - - auto mfcc_mapped = outputs[0].flat(); - for (int i = 0; i < n_windows * n_features_; ++i) { - mfcc_output.push_back(mfcc_mapped(i)); - } + copy_tensor_to_vector(outputs[0], mfcc_output); } diff --git a/native_client/tfmodelstate.h b/native_client/tfmodelstate.h index c3dc7708..0ef7dcfe 100644 --- a/native_client/tfmodelstate.h +++ b/native_client/tfmodelstate.h @@ -24,11 +24,13 @@ struct TFModelState : public ModelState const char* alphabet_path, unsigned int beam_width) override; - virtual int initialize_state() override; - - virtual void infer(const float* mfcc, + virtual void infer(const std::vector& mfcc, unsigned int n_frames, - std::vector& logits_output) override; + const std::vector& previous_state_c, + const std::vector& previous_state_h, + std::vector& logits_output, + std::vector& state_c_output, + std::vector& state_h_output) override; virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) override;