Remove previous state model variable, track by hand in StreamingState instead

This commit is contained in:
Reuben Morais 2019-06-06 16:40:19 -03:00
parent 6e78bac799
commit e51b9d987d
10 changed files with 213 additions and 202 deletions

View File

@ -574,12 +574,8 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# no state management since n_step is expected to be dynamic too (see below)
previous_state = previous_state_c = previous_state_h = None
else:
if tflite:
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
else:
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
@ -605,7 +601,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
logits = tf.squeeze(logits, [1])
# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits)
logits = tf.nn.softmax(logits, name='logits')
if batch_size <= 0:
if tflite:
@ -618,51 +614,31 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
'input_lengths': seq_length,
},
{
'outputs': tf.identity(logits, name='logits'),
'outputs': logits,
},
layers
)
new_state_c, new_state_h = layers['rnn_output_state']
if tflite:
logits = tf.identity(logits, name='logits')
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}
if FLAGS.use_seq_length:
inputs.update({'input_lengths': seq_length})
if FLAGS.use_seq_length:
inputs.update({'input_lengths': seq_length})
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
else:
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')
inputs = {
'input': input_tensor,
'input_lengths': seq_length,
'input_samples': input_samples,
}
outputs = {
'outputs': logits,
'initialize_state': initialize_state,
'mfccs': mfccs,
}
outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
return inputs, outputs, layers
@ -682,10 +658,12 @@ def export():
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops)
if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
else:
mapping = None
if FLAGS.export_tflite:
# Create a saver using variables from the above newly created graph
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses
# a static RNN with a normal cell, so we need to rewrite the names to
# match the training weights when restoring.
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
@ -710,7 +688,7 @@ def export():
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
frozen = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(),
@ -731,7 +709,7 @@ def export():
placeholder_type_enum=tf.float32.as_datatype_enum)
if not FLAGS.export_tflite:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
frozen_graph = do_graph_freeze(output_node_names=output_names)
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
# Add a no-op node to the graph with metadata information to be loaded by the native client
@ -747,7 +725,7 @@ def export():
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
frozen_graph = do_graph_freeze(output_node_names=output_names)
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
@ -771,8 +749,7 @@ def do_single_file_inference(input_file_path):
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Create a saver using variables from the above newly created graph
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
saver = tf.train.Saver(mapping)
saver = tf.train.Saver()
# Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counteract
@ -784,9 +761,10 @@ def do_single_file_inference(input_file_path):
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
session.run(outputs['initialize_state'])
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])
# Add batch dimension
features = tf.expand_dims(features, 0)
@ -799,6 +777,8 @@ def do_single_file_inference(input_file_path):
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)
logits = np.squeeze(logits)

View File

@ -1 +1 @@
1
2

View File

@ -114,34 +114,26 @@ tf_cc_shared_object(
### => Trying to be more fine-grained
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
### CPU only build, libdeepspeech.so file size reduced by ~50%
"//tensorflow/core/kernels:dense_update_ops", # Assign
"//tensorflow/core/kernels:constant_op", # Const
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst
"//tensorflow/core/kernels:dense_update_ops", # Assign (remove once prod model no longer depends on it)
"//tensorflow/core/kernels:constant_op", # Placeholder
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
"//tensorflow/core/kernels:identity_op", # Identity
"//tensorflow/core/kernels:softmax_op", # Softmax
"//tensorflow/core/kernels:transpose_op", # Transpose
"//tensorflow/core/kernels:reshape_op", # Reshape
"//tensorflow/core/kernels:shape_ops", # Shape
"//tensorflow/core/kernels:concat_op", # ConcatV2
"//tensorflow/core/kernels:split_op", # Split
"//tensorflow/core/kernels:variable_ops", # VariableV2
"//tensorflow/core/kernels:relu_op", # Relu
"//tensorflow/core/kernels:bias_op", # BiasAdd
"//tensorflow/core/kernels:math", # Range, MatMul
"//tensorflow/core/kernels:control_flow_ops", # Enter
"//tensorflow/core/kernels:tile_ops", # Tile
"//tensorflow/core/kernels:gather_op", # Gather
"//tensorflow/core/kernels:mfcc_op", # Mfcc
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
"//tensorflow/core/kernels:strided_slice_op", # StridedSlice
"//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
"//tensorflow/core/kernels:random_ops", # RandomGammaGrad
"//tensorflow/core/kernels:pack_op", # Pack
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
#### Needed by production model produced without "--use_seq_length False"
#"//tensorflow/core/kernels:logging_ops", # Assert
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
],
}) + if_cuda([
"//tensorflow/core:core",

View File

@ -67,6 +67,9 @@ struct StreamingState {
vector<float> audio_buffer_;
vector<float> mfcc_buffer_;
vector<float> batch_buffer_;
vector<float> previous_state_c_;
vector<float> previous_state_h_;
ModelState* model_;
std::unique_ptr<DecoderState> decoder_state_;
@ -233,7 +236,13 @@ void
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
{
vector<float> logits;
model_->infer(buf.data(), n_steps, logits);
model_->infer(buf,
n_steps,
previous_state_c_,
previous_state_h_,
logits,
previous_state_c_,
previous_state_h_);
const int cutoff_top_n = 40;
const double cutoff_prob = 1.0;
@ -326,11 +335,6 @@ DS_SetupStream(ModelState* aCtx,
{
*retval = nullptr;
int err = aCtx->initialize_state();
if (err != DS_ERR_OK) {
return err;
}
std::unique_ptr<StreamingState> ctx(new StreamingState());
if (!ctx) {
std::cerr << "Could not allocate streaming state." << std::endl;
@ -348,6 +352,8 @@ DS_SetupStream(ModelState* aCtx,
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_);
ctx->previous_state_c_.resize(aCtx->state_size_, 0.f);
ctx->previous_state_h_.resize(aCtx->state_size_, 0.f);
ctx->model_ = aCtx;
ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_));

View File

@ -17,6 +17,7 @@ ModelState::ModelState()
, sample_rate_(DEFAULT_SAMPLE_RATE)
, audio_win_len_(DEFAULT_WINDOW_LENGTH)
, audio_win_step_(DEFAULT_WINDOW_STEP)
, state_size_(-1)
{
}

View File

@ -28,6 +28,7 @@ struct ModelState {
unsigned int sample_rate_;
unsigned int audio_win_len_;
unsigned int audio_win_step_;
unsigned int state_size_;
ModelState();
virtual ~ModelState();
@ -38,8 +39,6 @@ struct ModelState {
const char* alphabet_path,
unsigned int beam_width);
virtual int initialize_state() = 0;
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
/**
@ -52,7 +51,13 @@ struct ModelState {
*
* @param[out] output_logits Where to store computed logits.
*/
virtual void infer(const float* mfcc, unsigned int n_frames, std::vector<float>& logits_output) = 0;
virtual void infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) = 0;
/**
* @brief Perform decoding of the logits, using basic CTC decoder or

View File

@ -4,14 +4,13 @@ using namespace tflite;
using std::vector;
int
tflite_get_tensor_by_name(const Interpreter* interpreter,
const vector<int>& list,
const char* name)
TFLiteModelState::get_tensor_by_name(const vector<int>& list,
const char* name)
{
int rv = -1;
for (int i = 0; i < list.size(); ++i) {
const string& node_name = interpreter->tensor(list[i])->name;
const string& node_name = interpreter_->tensor(list[i])->name;
if (node_name.compare(string(name)) == 0) {
rv = i;
}
@ -22,17 +21,17 @@ tflite_get_tensor_by_name(const Interpreter* interpreter,
}
int
tflite_get_input_tensor_by_name(const Interpreter* interpreter, const char* name)
TFLiteModelState::get_input_tensor_by_name(const char* name)
{
int idx = tflite_get_tensor_by_name(interpreter, interpreter->inputs(), name);
return interpreter->inputs()[idx];
int idx = get_tensor_by_name(interpreter_->inputs(), name);
return interpreter_->inputs()[idx];
}
int
tflite_get_output_tensor_by_name(const Interpreter* interpreter, const char* name)
TFLiteModelState::get_output_tensor_by_name(const char* name)
{
int idx = tflite_get_tensor_by_name(interpreter, interpreter->outputs(), name);
return interpreter->outputs()[idx];
int idx = get_tensor_by_name(interpreter_->outputs(), name);
return interpreter_->outputs()[idx];
}
void
@ -48,8 +47,8 @@ push_back_if_not_present(std::deque<int>& list, int value)
// output, add it to the parent list, and add its input tensors to the frontier
// list. Because we start from the final tensor and work backwards to the inputs,
// the parents list is constructed in reverse, adding elements to its front.
std::vector<int>
tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
vector<int>
TFLiteModelState::find_parent_node_ids(int tensor_id)
{
std::deque<int> parents;
std::deque<int> frontier;
@ -58,8 +57,8 @@ tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
int next_tensor_id = frontier.front();
frontier.pop_front();
// Find all nodes that have next_tensor_id as an output
for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) {
TfLiteNode node = interpreter->node_and_registration(node_id)->first;
for (int node_id = 0; node_id < interpreter_->nodes_size(); ++node_id) {
TfLiteNode node = interpreter_->node_and_registration(node_id)->first;
// Search node outputs for the tensor we're looking for
for (int i = 0; i < node.outputs->size; ++i) {
if (node.outputs->data[i] == next_tensor_id) {
@ -74,16 +73,13 @@ tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
}
}
return std::vector<int>(parents.begin(), parents.end());
return vector<int>(parents.begin(), parents.end());
}
TFLiteModelState::TFLiteModelState()
: ModelState()
, interpreter_(nullptr)
, fbmodel_(nullptr)
, previous_state_size_(0)
, previous_state_c_(nullptr)
, previous_state_h_(nullptr)
{
}
@ -120,21 +116,21 @@ TFLiteModelState::init(const char* model_path,
interpreter_->SetNumThreads(4);
// Query all the index once
input_node_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_node");
previous_state_c_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_c");
previous_state_h_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_h");
input_samples_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_samples");
logits_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "logits");
new_state_c_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_c");
new_state_h_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_h");
mfccs_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "mfccs");
input_node_idx_ = get_input_tensor_by_name("input_node");
previous_state_c_idx_ = get_input_tensor_by_name("previous_state_c");
previous_state_h_idx_ = get_input_tensor_by_name("previous_state_h");
input_samples_idx_ = get_input_tensor_by_name("input_samples");
logits_idx_ = get_output_tensor_by_name("logits");
new_state_c_idx_ = get_output_tensor_by_name("new_state_c");
new_state_h_idx_ = get_output_tensor_by_name("new_state_h");
mfccs_idx_ = get_output_tensor_by_name("mfccs");
// When we call Interpreter::Invoke, the whole graph is executed by default,
// which means every time compute_mfcc is called the entire acoustic model is
// also executed. To workaround that problem, we walk up the dependency DAG
// from the mfccs output tensor to find all the relevant nodes required for
// feature computation, building an execution plan that runs just those nodes.
auto mfcc_plan = tflite_find_parent_node_ids(interpreter_.get(), mfccs_idx_);
auto mfcc_plan = find_parent_node_ids(mfccs_idx_);
auto orig_plan = interpreter_->execution_plan();
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
@ -168,50 +164,57 @@ TFLiteModelState::init(const char* model_path,
TfLiteIntArray* dims_c = interpreter_->tensor(previous_state_c_idx_)->dims;
TfLiteIntArray* dims_h = interpreter_->tensor(previous_state_h_idx_)->dims;
assert(dims_c->data[1] == dims_h->data[1]);
previous_state_size_ = dims_c->data[1];
previous_state_c_.reset(new float[previous_state_size_]());
previous_state_h_.reset(new float[previous_state_size_]());
// Set initial values for previous_state_c and previous_state_h
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
return DS_ERR_OK;
}
int
TFLiteModelState::initialize_state()
{
/* Ensure previous_state_{c,h} are not holding previous stream value */
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
assert(state_size_ > 0);
state_size_ = dims_c->data[1];
return DS_ERR_OK;
}
void
TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
TFLiteModelState::copy_vector_to_tensor(const vector<float>& vec,
int tensor_idx,
int num_elements)
{
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
int i;
for (i = 0; i < vec.size(); ++i) {
tensor[i] = vec[i];
}
for (; i < num_elements; ++i) {
tensor[i] = 0.f;
}
}
void
TFLiteModelState::copy_tensor_to_vector(int tensor_idx,
int num_elements,
vector<float>& vec)
{
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
for (int i = 0; i < num_elements; ++i) {
vec.push_back(tensor[i]);
}
}
void
TFLiteModelState::infer(const vector<float>& mfcc,
unsigned int n_frames,
const vector<float>& previous_state_c,
const vector<float>& previous_state_h,
vector<float>& logits_output,
vector<float>& state_c_output,
vector<float>& state_h_output)
{
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
// Feeding input_node
float* input_node = interpreter_->typed_tensor<float>(input_node_idx_);
{
int i;
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
input_node[i] = aMfcc[i];
}
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
input_node[i] = 0;
}
}
assert(previous_state_size_ > 0);
copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_);
// Feeding previous_state_c, previous_state_h
memcpy(interpreter_->typed_tensor<float>(previous_state_c_idx_), previous_state_c_.get(), sizeof(float) * previous_state_size_);
memcpy(interpreter_->typed_tensor<float>(previous_state_h_idx_), previous_state_h_.get(), sizeof(float) * previous_state_size_);
assert(previous_state_c.size() == state_size_);
copy_vector_to_tensor(previous_state_c, previous_state_c_idx_, state_size_);
assert(previous_state_h.size() == state_size_);
copy_vector_to_tensor(previous_state_h, previous_state_h_idx_, state_size_);
interpreter_->SetExecutionPlan(acoustic_exec_plan_);
TfLiteStatus status = interpreter_->Invoke();
@ -220,25 +223,23 @@ TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>
return;
}
float* outputs = interpreter_->typed_tensor<float>(logits_idx_);
copy_tensor_to_vector(logits_idx_, n_frames * BATCH_SIZE * num_classes, logits_output);
// The CTCDecoder works with log-probs.
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
logits_output.push_back(outputs[t]);
}
state_c_output.clear();
state_c_output.reserve(state_size_);
copy_tensor_to_vector(new_state_c_idx_, state_size_, state_c_output);
memcpy(previous_state_c_.get(), interpreter_->typed_tensor<float>(new_state_c_idx_), sizeof(float) * previous_state_size_);
memcpy(previous_state_h_.get(), interpreter_->typed_tensor<float>(new_state_h_idx_), sizeof(float) * previous_state_size_);
state_h_output.clear();
state_h_output.reserve(state_size_);
copy_tensor_to_vector(new_state_h_idx_, state_size_, state_h_output);
}
void
TFLiteModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
TFLiteModelState::compute_mfcc(const vector<float>& samples,
vector<float>& mfcc_output)
{
// Feeding input_node
float* input_samples = interpreter_->typed_tensor<float>(input_samples_idx_);
for (int i = 0; i < samples.size(); ++i) {
input_samples[i] = samples[i];
}
copy_vector_to_tensor(samples, input_samples_idx_, samples.size());
TfLiteStatus status = interpreter_->SetExecutionPlan(mfcc_exec_plan_);
if (status != kTfLiteOk) {
@ -261,8 +262,5 @@ TFLiteModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc
}
assert(num_elements / n_features_ == n_windows);
float* outputs = interpreter_->typed_tensor<float>(mfccs_idx_);
for (int i = 0; i < n_windows * n_features_; ++i) {
mfcc_output.push_back(outputs[i]);
}
copy_tensor_to_vector(mfccs_idx_, n_windows * n_features_, mfcc_output);
}

View File

@ -14,10 +14,6 @@ struct TFLiteModelState : public ModelState
std::unique_ptr<tflite::Interpreter> interpreter_;
std::unique_ptr<tflite::FlatBufferModel> fbmodel_;
size_t previous_state_size_;
std::unique_ptr<float[]> previous_state_c_;
std::unique_ptr<float[]> previous_state_h_;
int input_node_idx_;
int previous_state_c_idx_;
int previous_state_h_idx_;
@ -40,13 +36,28 @@ struct TFLiteModelState : public ModelState
const char* alphabet_path,
unsigned int beam_width) override;
virtual int initialize_state() override;
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
std::vector<float>& mfcc_output) override;
virtual void infer(const float* mfcc, unsigned int n_frames,
std::vector<float>& logits_output) override;
virtual void infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) override;
private:
int get_tensor_by_name(const std::vector<int>& list, const char* name);
int get_input_tensor_by_name(const char* name);
int get_output_tensor_by_name(const char* name);
std::vector<int> find_parent_node_ids(int tensor_id);
void copy_vector_to_tensor(const std::vector<float>& vec,
int tensor_idx,
int num_elements);
void copy_tensor_to_vector(int tensor_idx,
int num_elements,
std::vector<float>& vec);
};
#endif // TFLITEMODELSTATE_H

View File

@ -98,6 +98,9 @@ TFModelState::init(const char* model_path,
n_context_ = (shape.dim(2).size()-1)/2;
n_features_ = shape.dim(3).size();
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
} else if (node.name() == "previous_state_c") {
const auto& shape = node.attr().at("shape").shape();
state_size_ = shape.dim(1).size();
} else if (node.name() == "logits_shape") {
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
@ -134,66 +137,83 @@ TFModelState::init(const char* model_path,
return DS_ERR_OK;
}
int
TFModelState::initialize_state()
Tensor
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
{
Status status = session_->Run({}, {}, {"initialize_state"}, nullptr);
if (!status.ok()) {
std::cerr << "Error running session: " << status << std::endl;
return DS_ERR_FAIL_RUN_SESS;
Tensor ret(DT_FLOAT, shape);
auto ret_mapped = ret.flat<float>();
int i;
for (i = 0; i < vec.size(); ++i) {
ret_mapped(i) = vec[i];
}
return DS_ERR_OK;
for (; i < shape.num_elements(); ++i) {
ret_mapped(i) = 0.f;
}
return ret;
}
void
TFModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
{
auto tensor_mapped = tensor.flat<float>();
if (num_elements == -1) {
num_elements = tensor.shape().num_elements();
}
for (int i = 0; i < num_elements; ++i) {
vec.push_back(tensor_mapped(i));
}
}
void
TFModelState::infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
vector<float>& logits_output,
vector<float>& state_c_output,
vector<float>& state_h_output)
{
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
auto input_mapped = input.flat<float>();
int i;
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
input_mapped(i) = aMfcc[i];
}
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
input_mapped(i) = 0.;
}
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
Tensor input_lengths(DT_INT32, TensorShape({1}));
input_lengths.scalar<int>()() = n_frames;
vector<Tensor> outputs;
Status status = session_->Run(
{{"input_node", input}, {"input_lengths", input_lengths}},
{"logits"}, {}, &outputs);
{
{"input_node", input},
{"input_lengths", input_lengths},
{"previous_state_c", previous_state_c_t},
{"previous_state_h", previous_state_h_t}
},
{"logits", "new_state_c", "new_state_h"},
{},
&outputs);
if (!status.ok()) {
std::cerr << "Error running session: " << status << "\n";
return;
}
auto logits_mapped = outputs[0].flat<float>();
// The CTCDecoder works with log-probs.
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
logits_output.push_back(logits_mapped(t));
}
copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
state_c_output.clear();
state_c_output.reserve(state_size_);
copy_tensor_to_vector(outputs[1], state_c_output);
state_h_output.clear();
state_h_output.reserve(state_size_);
copy_tensor_to_vector(outputs[2], state_h_output);
}
void
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
{
Tensor input(DT_FLOAT, TensorShape({audio_win_len_}));
auto input_mapped = input.flat<float>();
int i;
for (i = 0; i < samples.size(); ++i) {
input_mapped(i) = samples[i];
}
for (; i < audio_win_len_; ++i) {
input_mapped(i) = 0.f;
}
Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
vector<Tensor> outputs;
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
@ -206,9 +226,5 @@ TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_out
// The feature computation graph is hardcoded to one audio length for now
const int n_windows = 1;
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
auto mfcc_mapped = outputs[0].flat<float>();
for (int i = 0; i < n_windows * n_features_; ++i) {
mfcc_output.push_back(mfcc_mapped(i));
}
copy_tensor_to_vector(outputs[0], mfcc_output);
}

View File

@ -24,11 +24,13 @@ struct TFModelState : public ModelState
const char* alphabet_path,
unsigned int beam_width) override;
virtual int initialize_state() override;
virtual void infer(const float* mfcc,
virtual void infer(const std::vector<float>& mfcc,
unsigned int n_frames,
std::vector<float>& logits_output) override;
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) override;
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
std::vector<float>& mfcc_output) override;