Remove previous state model variable, track by hand in StreamingState instead

This commit is contained in:
Reuben Morais 2019-06-06 16:40:19 -03:00
parent 6e78bac799
commit e51b9d987d
10 changed files with 213 additions and 202 deletions

View File

@ -574,12 +574,8 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# no state management since n_step is expected to be dynamic too (see below) # no state management since n_step is expected to be dynamic too (see below)
previous_state = previous_state_c = previous_state_h = None previous_state = previous_state_c = previous_state_h = None
else: else:
if tflite:
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c') previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h') previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
else:
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h) previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
@ -605,7 +601,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
logits = tf.squeeze(logits, [1]) logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder # Apply softmax for CTC decoder
logits = tf.nn.softmax(logits) logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0: if batch_size <= 0:
if tflite: if tflite:
@ -618,14 +614,12 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
'input_lengths': seq_length, 'input_lengths': seq_length,
}, },
{ {
'outputs': tf.identity(logits, name='logits'), 'outputs': logits,
}, },
layers layers
) )
new_state_c, new_state_h = layers['rnn_output_state'] new_state_c, new_state_h = layers['rnn_output_state']
if tflite:
logits = tf.identity(logits, name='logits')
new_state_c = tf.identity(new_state_c, name='new_state_c') new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h') new_state_h = tf.identity(new_state_h, name='new_state_h')
@ -645,24 +639,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
'new_state_h': new_state_h, 'new_state_h': new_state_h,
'mfccs': mfccs, 'mfccs': mfccs,
} }
else:
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')
inputs = {
'input': input_tensor,
'input_lengths': seq_length,
'input_samples': input_samples,
}
outputs = {
'outputs': logits,
'initialize_state': initialize_state,
'mfccs': mfccs,
}
return inputs, outputs, layers return inputs, outputs, layers
@ -682,10 +658,12 @@ def export():
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)] output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops) output_names = ",".join(output_names_tensors + output_names_ops)
if not FLAGS.export_tflite: mapping = None
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} if FLAGS.export_tflite:
else:
# Create a saver using variables from the above newly created graph # Create a saver using variables from the above newly created graph
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses
# a static RNN with a normal cell, so we need to rewrite the names to
# match the training weights when restoring.
def fixup(name): def fixup(name):
if name.startswith('rnn/lstm_cell/'): if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/') return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
@ -710,7 +688,7 @@ def export():
if not os.path.isdir(FLAGS.export_dir): if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir) os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None): def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
frozen = freeze_graph.freeze_graph_with_def_protos( frozen = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(), input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(), input_saver_def=saver.as_saver_def(),
@ -731,7 +709,7 @@ def export():
placeholder_type_enum=tf.float32.as_datatype_enum) placeholder_type_enum=tf.float32.as_datatype_enum)
if not FLAGS.export_tflite: if not FLAGS.export_tflite:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h') frozen_graph = do_graph_freeze(output_node_names=output_names)
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip()) frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
# Add a no-op node to the graph with metadata information to be loaded by the native client # Add a no-op node to the graph with metadata information to be loaded by the native client
@ -747,7 +725,7 @@ def export():
with open(output_graph_path, 'wb') as fout: with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString()) fout.write(frozen_graph.SerializeToString())
else: else:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='') frozen_graph = do_graph_freeze(output_node_names=output_names)
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
@ -771,8 +749,7 @@ def do_single_file_inference(input_file_path):
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Create a saver using variables from the above newly created graph # Create a saver using variables from the above newly created graph
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')} saver = tf.train.Saver()
saver = tf.train.Saver(mapping)
# Restore variables from training checkpoint # Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counteract # TODO: This restores the most recent checkpoint, but if we use validation to counteract
@ -784,9 +761,10 @@ def do_single_file_inference(input_file_path):
checkpoint_path = checkpoint.model_checkpoint_path checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path) saver.restore(session, checkpoint_path)
session.run(outputs['initialize_state'])
features, features_len = audiofile_to_features(input_file_path) features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension # Add batch dimension
features = tf.expand_dims(features, 0) features = tf.expand_dims(features, 0)
@ -799,6 +777,8 @@ def do_single_file_inference(input_file_path):
logits = outputs['outputs'].eval(feed_dict={ logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features, inputs['input']: features,
inputs['input_lengths']: features_len, inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session) }, session=session)
logits = np.squeeze(logits) logits = np.squeeze(logits)

View File

@ -1 +1 @@
1 2

View File

@ -114,34 +114,26 @@ tf_cc_shared_object(
### => Trying to be more fine-grained ### => Trying to be more fine-grained
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph. ### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
### CPU only build, libdeepspeech.so file size reduced by ~50% ### CPU only build, libdeepspeech.so file size reduced by ~50%
"//tensorflow/core/kernels:dense_update_ops", # Assign "//tensorflow/core/kernels:dense_update_ops", # Assign (remove once prod model no longer depends on it)
"//tensorflow/core/kernels:constant_op", # Const "//tensorflow/core/kernels:constant_op", # Placeholder
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst "//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
"//tensorflow/core/kernels:identity_op", # Identity "//tensorflow/core/kernels:identity_op", # Identity
"//tensorflow/core/kernels:softmax_op", # Softmax "//tensorflow/core/kernels:softmax_op", # Softmax
"//tensorflow/core/kernels:transpose_op", # Transpose "//tensorflow/core/kernels:transpose_op", # Transpose
"//tensorflow/core/kernels:reshape_op", # Reshape "//tensorflow/core/kernels:reshape_op", # Reshape
"//tensorflow/core/kernels:shape_ops", # Shape "//tensorflow/core/kernels:shape_ops", # Shape
"//tensorflow/core/kernels:concat_op", # ConcatV2 "//tensorflow/core/kernels:concat_op", # ConcatV2
"//tensorflow/core/kernels:split_op", # Split
"//tensorflow/core/kernels:variable_ops", # VariableV2
"//tensorflow/core/kernels:relu_op", # Relu "//tensorflow/core/kernels:relu_op", # Relu
"//tensorflow/core/kernels:bias_op", # BiasAdd "//tensorflow/core/kernels:bias_op", # BiasAdd
"//tensorflow/core/kernels:math", # Range, MatMul "//tensorflow/core/kernels:math", # Range, MatMul
"//tensorflow/core/kernels:control_flow_ops", # Enter
"//tensorflow/core/kernels:tile_ops", # Tile "//tensorflow/core/kernels:tile_ops", # Tile
"//tensorflow/core/kernels:gather_op", # Gather
"//tensorflow/core/kernels:mfcc_op", # Mfcc "//tensorflow/core/kernels:mfcc_op", # Mfcc
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram "//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
"//tensorflow/core/kernels:strided_slice_op", # StridedSlice "//tensorflow/core/kernels:strided_slice_op", # StridedSlice
"//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice "//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM "//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
"//tensorflow/core/kernels:random_ops", # RandomGammaGrad
"//tensorflow/core/kernels:pack_op", # Pack "//tensorflow/core/kernels:pack_op", # Pack
"//tensorflow/core/kernels:gather_nd_op", # GatherNd "//tensorflow/core/kernels:gather_nd_op", # GatherNd
#### Needed by production model produced without "--use_seq_length False"
#"//tensorflow/core/kernels:logging_ops", # Assert
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
], ],
}) + if_cuda([ }) + if_cuda([
"//tensorflow/core:core", "//tensorflow/core:core",

View File

@ -67,6 +67,9 @@ struct StreamingState {
vector<float> audio_buffer_; vector<float> audio_buffer_;
vector<float> mfcc_buffer_; vector<float> mfcc_buffer_;
vector<float> batch_buffer_; vector<float> batch_buffer_;
vector<float> previous_state_c_;
vector<float> previous_state_h_;
ModelState* model_; ModelState* model_;
std::unique_ptr<DecoderState> decoder_state_; std::unique_ptr<DecoderState> decoder_state_;
@ -233,7 +236,13 @@ void
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps) StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
{ {
vector<float> logits; vector<float> logits;
model_->infer(buf.data(), n_steps, logits); model_->infer(buf,
n_steps,
previous_state_c_,
previous_state_h_,
logits,
previous_state_c_,
previous_state_h_);
const int cutoff_top_n = 40; const int cutoff_top_n = 40;
const double cutoff_prob = 1.0; const double cutoff_prob = 1.0;
@ -326,11 +335,6 @@ DS_SetupStream(ModelState* aCtx,
{ {
*retval = nullptr; *retval = nullptr;
int err = aCtx->initialize_state();
if (err != DS_ERR_OK) {
return err;
}
std::unique_ptr<StreamingState> ctx(new StreamingState()); std::unique_ptr<StreamingState> ctx(new StreamingState());
if (!ctx) { if (!ctx) {
std::cerr << "Could not allocate streaming state." << std::endl; std::cerr << "Could not allocate streaming state." << std::endl;
@ -348,6 +352,8 @@ DS_SetupStream(ModelState* aCtx,
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_); ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f); ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_); ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_);
ctx->previous_state_c_.resize(aCtx->state_size_, 0.f);
ctx->previous_state_h_.resize(aCtx->state_size_, 0.f);
ctx->model_ = aCtx; ctx->model_ = aCtx;
ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_)); ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_));

View File

@ -17,6 +17,7 @@ ModelState::ModelState()
, sample_rate_(DEFAULT_SAMPLE_RATE) , sample_rate_(DEFAULT_SAMPLE_RATE)
, audio_win_len_(DEFAULT_WINDOW_LENGTH) , audio_win_len_(DEFAULT_WINDOW_LENGTH)
, audio_win_step_(DEFAULT_WINDOW_STEP) , audio_win_step_(DEFAULT_WINDOW_STEP)
, state_size_(-1)
{ {
} }

View File

@ -28,6 +28,7 @@ struct ModelState {
unsigned int sample_rate_; unsigned int sample_rate_;
unsigned int audio_win_len_; unsigned int audio_win_len_;
unsigned int audio_win_step_; unsigned int audio_win_step_;
unsigned int state_size_;
ModelState(); ModelState();
virtual ~ModelState(); virtual ~ModelState();
@ -38,8 +39,6 @@ struct ModelState {
const char* alphabet_path, const char* alphabet_path,
unsigned int beam_width); unsigned int beam_width);
virtual int initialize_state() = 0;
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0; virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
/** /**
@ -52,7 +51,13 @@ struct ModelState {
* *
* @param[out] output_logits Where to store computed logits. * @param[out] output_logits Where to store computed logits.
*/ */
virtual void infer(const float* mfcc, unsigned int n_frames, std::vector<float>& logits_output) = 0; virtual void infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) = 0;
/** /**
* @brief Perform decoding of the logits, using basic CTC decoder or * @brief Perform decoding of the logits, using basic CTC decoder or

View File

@ -4,14 +4,13 @@ using namespace tflite;
using std::vector; using std::vector;
int int
tflite_get_tensor_by_name(const Interpreter* interpreter, TFLiteModelState::get_tensor_by_name(const vector<int>& list,
const vector<int>& list,
const char* name) const char* name)
{ {
int rv = -1; int rv = -1;
for (int i = 0; i < list.size(); ++i) { for (int i = 0; i < list.size(); ++i) {
const string& node_name = interpreter->tensor(list[i])->name; const string& node_name = interpreter_->tensor(list[i])->name;
if (node_name.compare(string(name)) == 0) { if (node_name.compare(string(name)) == 0) {
rv = i; rv = i;
} }
@ -22,17 +21,17 @@ tflite_get_tensor_by_name(const Interpreter* interpreter,
} }
int int
tflite_get_input_tensor_by_name(const Interpreter* interpreter, const char* name) TFLiteModelState::get_input_tensor_by_name(const char* name)
{ {
int idx = tflite_get_tensor_by_name(interpreter, interpreter->inputs(), name); int idx = get_tensor_by_name(interpreter_->inputs(), name);
return interpreter->inputs()[idx]; return interpreter_->inputs()[idx];
} }
int int
tflite_get_output_tensor_by_name(const Interpreter* interpreter, const char* name) TFLiteModelState::get_output_tensor_by_name(const char* name)
{ {
int idx = tflite_get_tensor_by_name(interpreter, interpreter->outputs(), name); int idx = get_tensor_by_name(interpreter_->outputs(), name);
return interpreter->outputs()[idx]; return interpreter_->outputs()[idx];
} }
void void
@ -48,8 +47,8 @@ push_back_if_not_present(std::deque<int>& list, int value)
// output, add it to the parent list, and add its input tensors to the frontier // output, add it to the parent list, and add its input tensors to the frontier
// list. Because we start from the final tensor and work backwards to the inputs, // list. Because we start from the final tensor and work backwards to the inputs,
// the parents list is constructed in reverse, adding elements to its front. // the parents list is constructed in reverse, adding elements to its front.
std::vector<int> vector<int>
tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id) TFLiteModelState::find_parent_node_ids(int tensor_id)
{ {
std::deque<int> parents; std::deque<int> parents;
std::deque<int> frontier; std::deque<int> frontier;
@ -58,8 +57,8 @@ tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
int next_tensor_id = frontier.front(); int next_tensor_id = frontier.front();
frontier.pop_front(); frontier.pop_front();
// Find all nodes that have next_tensor_id as an output // Find all nodes that have next_tensor_id as an output
for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) { for (int node_id = 0; node_id < interpreter_->nodes_size(); ++node_id) {
TfLiteNode node = interpreter->node_and_registration(node_id)->first; TfLiteNode node = interpreter_->node_and_registration(node_id)->first;
// Search node outputs for the tensor we're looking for // Search node outputs for the tensor we're looking for
for (int i = 0; i < node.outputs->size; ++i) { for (int i = 0; i < node.outputs->size; ++i) {
if (node.outputs->data[i] == next_tensor_id) { if (node.outputs->data[i] == next_tensor_id) {
@ -74,16 +73,13 @@ tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
} }
} }
return std::vector<int>(parents.begin(), parents.end()); return vector<int>(parents.begin(), parents.end());
} }
TFLiteModelState::TFLiteModelState() TFLiteModelState::TFLiteModelState()
: ModelState() : ModelState()
, interpreter_(nullptr) , interpreter_(nullptr)
, fbmodel_(nullptr) , fbmodel_(nullptr)
, previous_state_size_(0)
, previous_state_c_(nullptr)
, previous_state_h_(nullptr)
{ {
} }
@ -120,21 +116,21 @@ TFLiteModelState::init(const char* model_path,
interpreter_->SetNumThreads(4); interpreter_->SetNumThreads(4);
// Query all the index once // Query all the index once
input_node_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_node"); input_node_idx_ = get_input_tensor_by_name("input_node");
previous_state_c_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_c"); previous_state_c_idx_ = get_input_tensor_by_name("previous_state_c");
previous_state_h_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_h"); previous_state_h_idx_ = get_input_tensor_by_name("previous_state_h");
input_samples_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_samples"); input_samples_idx_ = get_input_tensor_by_name("input_samples");
logits_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "logits"); logits_idx_ = get_output_tensor_by_name("logits");
new_state_c_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_c"); new_state_c_idx_ = get_output_tensor_by_name("new_state_c");
new_state_h_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_h"); new_state_h_idx_ = get_output_tensor_by_name("new_state_h");
mfccs_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "mfccs"); mfccs_idx_ = get_output_tensor_by_name("mfccs");
// When we call Interpreter::Invoke, the whole graph is executed by default, // When we call Interpreter::Invoke, the whole graph is executed by default,
// which means every time compute_mfcc is called the entire acoustic model is // which means every time compute_mfcc is called the entire acoustic model is
// also executed. To workaround that problem, we walk up the dependency DAG // also executed. To workaround that problem, we walk up the dependency DAG
// from the mfccs output tensor to find all the relevant nodes required for // from the mfccs output tensor to find all the relevant nodes required for
// feature computation, building an execution plan that runs just those nodes. // feature computation, building an execution plan that runs just those nodes.
auto mfcc_plan = tflite_find_parent_node_ids(interpreter_.get(), mfccs_idx_); auto mfcc_plan = find_parent_node_ids(mfccs_idx_);
auto orig_plan = interpreter_->execution_plan(); auto orig_plan = interpreter_->execution_plan();
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan // Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
@ -168,50 +164,57 @@ TFLiteModelState::init(const char* model_path,
TfLiteIntArray* dims_c = interpreter_->tensor(previous_state_c_idx_)->dims; TfLiteIntArray* dims_c = interpreter_->tensor(previous_state_c_idx_)->dims;
TfLiteIntArray* dims_h = interpreter_->tensor(previous_state_h_idx_)->dims; TfLiteIntArray* dims_h = interpreter_->tensor(previous_state_h_idx_)->dims;
assert(dims_c->data[1] == dims_h->data[1]); assert(dims_c->data[1] == dims_h->data[1]);
assert(state_size_ > 0);
previous_state_size_ = dims_c->data[1]; state_size_ = dims_c->data[1];
previous_state_c_.reset(new float[previous_state_size_]());
previous_state_h_.reset(new float[previous_state_size_]());
// Set initial values for previous_state_c and previous_state_h
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
return DS_ERR_OK;
}
int
TFLiteModelState::initialize_state()
{
/* Ensure previous_state_{c,h} are not holding previous stream value */
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
return DS_ERR_OK; return DS_ERR_OK;
} }
void void
TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output) TFLiteModelState::copy_vector_to_tensor(const vector<float>& vec,
int tensor_idx,
int num_elements)
{
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
int i;
for (i = 0; i < vec.size(); ++i) {
tensor[i] = vec[i];
}
for (; i < num_elements; ++i) {
tensor[i] = 0.f;
}
}
void
TFLiteModelState::copy_tensor_to_vector(int tensor_idx,
int num_elements,
vector<float>& vec)
{
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
for (int i = 0; i < num_elements; ++i) {
vec.push_back(tensor[i]);
}
}
void
TFLiteModelState::infer(const vector<float>& mfcc,
unsigned int n_frames,
const vector<float>& previous_state_c,
const vector<float>& previous_state_h,
vector<float>& logits_output,
vector<float>& state_c_output,
vector<float>& state_h_output)
{ {
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
// Feeding input_node // Feeding input_node
float* input_node = interpreter_->typed_tensor<float>(input_node_idx_); copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_);
{
int i;
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
input_node[i] = aMfcc[i];
}
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
input_node[i] = 0;
}
}
assert(previous_state_size_ > 0);
// Feeding previous_state_c, previous_state_h // Feeding previous_state_c, previous_state_h
memcpy(interpreter_->typed_tensor<float>(previous_state_c_idx_), previous_state_c_.get(), sizeof(float) * previous_state_size_); assert(previous_state_c.size() == state_size_);
memcpy(interpreter_->typed_tensor<float>(previous_state_h_idx_), previous_state_h_.get(), sizeof(float) * previous_state_size_); copy_vector_to_tensor(previous_state_c, previous_state_c_idx_, state_size_);
assert(previous_state_h.size() == state_size_);
copy_vector_to_tensor(previous_state_h, previous_state_h_idx_, state_size_);
interpreter_->SetExecutionPlan(acoustic_exec_plan_); interpreter_->SetExecutionPlan(acoustic_exec_plan_);
TfLiteStatus status = interpreter_->Invoke(); TfLiteStatus status = interpreter_->Invoke();
@ -220,25 +223,23 @@ TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>
return; return;
} }
float* outputs = interpreter_->typed_tensor<float>(logits_idx_); copy_tensor_to_vector(logits_idx_, n_frames * BATCH_SIZE * num_classes, logits_output);
// The CTCDecoder works with log-probs. state_c_output.clear();
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) { state_c_output.reserve(state_size_);
logits_output.push_back(outputs[t]); copy_tensor_to_vector(new_state_c_idx_, state_size_, state_c_output);
}
memcpy(previous_state_c_.get(), interpreter_->typed_tensor<float>(new_state_c_idx_), sizeof(float) * previous_state_size_); state_h_output.clear();
memcpy(previous_state_h_.get(), interpreter_->typed_tensor<float>(new_state_h_idx_), sizeof(float) * previous_state_size_); state_h_output.reserve(state_size_);
copy_tensor_to_vector(new_state_h_idx_, state_size_, state_h_output);
} }
void void
TFLiteModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output) TFLiteModelState::compute_mfcc(const vector<float>& samples,
vector<float>& mfcc_output)
{ {
// Feeding input_node // Feeding input_node
float* input_samples = interpreter_->typed_tensor<float>(input_samples_idx_); copy_vector_to_tensor(samples, input_samples_idx_, samples.size());
for (int i = 0; i < samples.size(); ++i) {
input_samples[i] = samples[i];
}
TfLiteStatus status = interpreter_->SetExecutionPlan(mfcc_exec_plan_); TfLiteStatus status = interpreter_->SetExecutionPlan(mfcc_exec_plan_);
if (status != kTfLiteOk) { if (status != kTfLiteOk) {
@ -261,8 +262,5 @@ TFLiteModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc
} }
assert(num_elements / n_features_ == n_windows); assert(num_elements / n_features_ == n_windows);
float* outputs = interpreter_->typed_tensor<float>(mfccs_idx_); copy_tensor_to_vector(mfccs_idx_, n_windows * n_features_, mfcc_output);
for (int i = 0; i < n_windows * n_features_; ++i) {
mfcc_output.push_back(outputs[i]);
}
} }

View File

@ -14,10 +14,6 @@ struct TFLiteModelState : public ModelState
std::unique_ptr<tflite::Interpreter> interpreter_; std::unique_ptr<tflite::Interpreter> interpreter_;
std::unique_ptr<tflite::FlatBufferModel> fbmodel_; std::unique_ptr<tflite::FlatBufferModel> fbmodel_;
size_t previous_state_size_;
std::unique_ptr<float[]> previous_state_c_;
std::unique_ptr<float[]> previous_state_h_;
int input_node_idx_; int input_node_idx_;
int previous_state_c_idx_; int previous_state_c_idx_;
int previous_state_h_idx_; int previous_state_h_idx_;
@ -40,13 +36,28 @@ struct TFLiteModelState : public ModelState
const char* alphabet_path, const char* alphabet_path,
unsigned int beam_width) override; unsigned int beam_width) override;
virtual int initialize_state() override;
virtual void compute_mfcc(const std::vector<float>& audio_buffer, virtual void compute_mfcc(const std::vector<float>& audio_buffer,
std::vector<float>& mfcc_output) override; std::vector<float>& mfcc_output) override;
virtual void infer(const float* mfcc, unsigned int n_frames, virtual void infer(const std::vector<float>& mfcc,
std::vector<float>& logits_output) override; unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) override;
private:
int get_tensor_by_name(const std::vector<int>& list, const char* name);
int get_input_tensor_by_name(const char* name);
int get_output_tensor_by_name(const char* name);
std::vector<int> find_parent_node_ids(int tensor_id);
void copy_vector_to_tensor(const std::vector<float>& vec,
int tensor_idx,
int num_elements);
void copy_tensor_to_vector(int tensor_idx,
int num_elements,
std::vector<float>& vec);
}; };
#endif // TFLITEMODELSTATE_H #endif // TFLITEMODELSTATE_H

View File

@ -98,6 +98,9 @@ TFModelState::init(const char* model_path,
n_context_ = (shape.dim(2).size()-1)/2; n_context_ = (shape.dim(2).size()-1)/2;
n_features_ = shape.dim(3).size(); n_features_ = shape.dim(3).size();
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size(); mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
} else if (node.name() == "previous_state_c") {
const auto& shape = node.attr().at("shape").shape();
state_size_ = shape.dim(1).size();
} else if (node.name() == "logits_shape") { } else if (node.name() == "logits_shape") {
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3})); Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
if (!logits_shape.FromProto(node.attr().at("value").tensor())) { if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
@ -134,66 +137,83 @@ TFModelState::init(const char* model_path,
return DS_ERR_OK; return DS_ERR_OK;
} }
int Tensor
TFModelState::initialize_state() tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
{ {
Status status = session_->Run({}, {}, {"initialize_state"}, nullptr); Tensor ret(DT_FLOAT, shape);
if (!status.ok()) { auto ret_mapped = ret.flat<float>();
std::cerr << "Error running session: " << status << std::endl; int i;
return DS_ERR_FAIL_RUN_SESS; for (i = 0; i < vec.size(); ++i) {
ret_mapped(i) = vec[i];
} }
for (; i < shape.num_elements(); ++i) {
return DS_ERR_OK; ret_mapped(i) = 0.f;
}
return ret;
} }
void void
TFModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output) copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
{
auto tensor_mapped = tensor.flat<float>();
if (num_elements == -1) {
num_elements = tensor.shape().num_elements();
}
for (int i = 0; i < num_elements; ++i) {
vec.push_back(tensor_mapped(i));
}
}
void
TFModelState::infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
vector<float>& logits_output,
vector<float>& state_c_output,
vector<float>& state_h_output)
{ {
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_})); Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
auto input_mapped = input.flat<float>(); Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
int i;
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
input_mapped(i) = aMfcc[i];
}
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
input_mapped(i) = 0.;
}
Tensor input_lengths(DT_INT32, TensorShape({1})); Tensor input_lengths(DT_INT32, TensorShape({1}));
input_lengths.scalar<int>()() = n_frames; input_lengths.scalar<int>()() = n_frames;
vector<Tensor> outputs; vector<Tensor> outputs;
Status status = session_->Run( Status status = session_->Run(
{{"input_node", input}, {"input_lengths", input_lengths}}, {
{"logits"}, {}, &outputs); {"input_node", input},
{"input_lengths", input_lengths},
{"previous_state_c", previous_state_c_t},
{"previous_state_h", previous_state_h_t}
},
{"logits", "new_state_c", "new_state_h"},
{},
&outputs);
if (!status.ok()) { if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n"; std::cerr << "Error running session: " << status << "\n";
return; return;
} }
auto logits_mapped = outputs[0].flat<float>(); copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
// The CTCDecoder works with log-probs.
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) { state_c_output.clear();
logits_output.push_back(logits_mapped(t)); state_c_output.reserve(state_size_);
} copy_tensor_to_vector(outputs[1], state_c_output);
state_h_output.clear();
state_h_output.reserve(state_size_);
copy_tensor_to_vector(outputs[2], state_h_output);
} }
void void
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output) TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
{ {
Tensor input(DT_FLOAT, TensorShape({audio_win_len_})); Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
auto input_mapped = input.flat<float>();
int i;
for (i = 0; i < samples.size(); ++i) {
input_mapped(i) = samples[i];
}
for (; i < audio_win_len_; ++i) {
input_mapped(i) = 0.f;
}
vector<Tensor> outputs; vector<Tensor> outputs;
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs); Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
@ -206,9 +226,5 @@ TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_out
// The feature computation graph is hardcoded to one audio length for now // The feature computation graph is hardcoded to one audio length for now
const int n_windows = 1; const int n_windows = 1;
assert(outputs[0].shape().num_elements() / n_features_ == n_windows); assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
copy_tensor_to_vector(outputs[0], mfcc_output);
auto mfcc_mapped = outputs[0].flat<float>();
for (int i = 0; i < n_windows * n_features_; ++i) {
mfcc_output.push_back(mfcc_mapped(i));
}
} }

View File

@ -24,11 +24,13 @@ struct TFModelState : public ModelState
const char* alphabet_path, const char* alphabet_path,
unsigned int beam_width) override; unsigned int beam_width) override;
virtual int initialize_state() override; virtual void infer(const std::vector<float>& mfcc,
virtual void infer(const float* mfcc,
unsigned int n_frames, unsigned int n_frames,
std::vector<float>& logits_output) override; const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) override;
virtual void compute_mfcc(const std::vector<float>& audio_buffer, virtual void compute_mfcc(const std::vector<float>& audio_buffer,
std::vector<float>& mfcc_output) override; std::vector<float>& mfcc_output) override;