Remove previous state model variable, track by hand in StreamingState instead
This commit is contained in:
parent
6e78bac799
commit
e51b9d987d
@ -574,12 +574,8 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = previous_state_c = previous_state_h = None
|
||||
else:
|
||||
if tflite:
|
||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
else:
|
||||
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
|
||||
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
@ -605,7 +601,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits)
|
||||
logits = tf.nn.softmax(logits, name='logits')
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
@ -618,14 +614,12 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': tf.identity(logits, name='logits'),
|
||||
'outputs': logits,
|
||||
},
|
||||
layers
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
if tflite:
|
||||
logits = tf.identity(logits, name='logits')
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
|
||||
@ -645,24 +639,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
else:
|
||||
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
|
||||
initialize_c = tf.assign(previous_state_c, zero_state)
|
||||
initialize_h = tf.assign(previous_state_h, zero_state)
|
||||
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
|
||||
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
|
||||
logits = tf.identity(logits, name='logits')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'initialize_state': initialize_state,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
@ -682,10 +658,12 @@ def export():
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
|
||||
else:
|
||||
mapping = None
|
||||
if FLAGS.export_tflite:
|
||||
# Create a saver using variables from the above newly created graph
|
||||
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses
|
||||
# a static RNN with a normal cell, so we need to rewrite the names to
|
||||
# match the training weights when restoring.
|
||||
def fixup(name):
|
||||
if name.startswith('rnn/lstm_cell/'):
|
||||
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
|
||||
@ -710,7 +688,7 @@ def export():
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
|
||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
|
||||
frozen = freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=tf.get_default_graph().as_graph_def(),
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
@ -731,7 +709,7 @@ def export():
|
||||
placeholder_type_enum=tf.float32.as_datatype_enum)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
||||
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
|
||||
# Add a no-op node to the graph with metadata information to be loaded by the native client
|
||||
@ -747,7 +725,7 @@ def export():
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
@ -771,8 +749,7 @@ def do_single_file_inference(input_file_path):
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Create a saver using variables from the above newly created graph
|
||||
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
|
||||
saver = tf.train.Saver(mapping)
|
||||
saver = tf.train.Saver()
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
# TODO: This restores the most recent checkpoint, but if we use validation to counteract
|
||||
@ -784,9 +761,10 @@ def do_single_file_inference(input_file_path):
|
||||
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
session.run(outputs['initialize_state'])
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
@ -799,6 +777,8 @@ def do_single_file_inference(input_file_path):
|
||||
logits = outputs['outputs'].eval(feed_dict={
|
||||
inputs['input']: features,
|
||||
inputs['input_lengths']: features_len,
|
||||
inputs['previous_state_c']: previous_state_c,
|
||||
inputs['previous_state_h']: previous_state_h,
|
||||
}, session=session)
|
||||
|
||||
logits = np.squeeze(logits)
|
||||
|
@ -1 +1 @@
|
||||
1
|
||||
2
|
@ -114,34 +114,26 @@ tf_cc_shared_object(
|
||||
### => Trying to be more fine-grained
|
||||
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
|
||||
### CPU only build, libdeepspeech.so file size reduced by ~50%
|
||||
"//tensorflow/core/kernels:dense_update_ops", # Assign
|
||||
"//tensorflow/core/kernels:constant_op", # Const
|
||||
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst
|
||||
"//tensorflow/core/kernels:dense_update_ops", # Assign (remove once prod model no longer depends on it)
|
||||
"//tensorflow/core/kernels:constant_op", # Placeholder
|
||||
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
|
||||
"//tensorflow/core/kernels:identity_op", # Identity
|
||||
"//tensorflow/core/kernels:softmax_op", # Softmax
|
||||
"//tensorflow/core/kernels:transpose_op", # Transpose
|
||||
"//tensorflow/core/kernels:reshape_op", # Reshape
|
||||
"//tensorflow/core/kernels:shape_ops", # Shape
|
||||
"//tensorflow/core/kernels:concat_op", # ConcatV2
|
||||
"//tensorflow/core/kernels:split_op", # Split
|
||||
"//tensorflow/core/kernels:variable_ops", # VariableV2
|
||||
"//tensorflow/core/kernels:relu_op", # Relu
|
||||
"//tensorflow/core/kernels:bias_op", # BiasAdd
|
||||
"//tensorflow/core/kernels:math", # Range, MatMul
|
||||
"//tensorflow/core/kernels:control_flow_ops", # Enter
|
||||
"//tensorflow/core/kernels:tile_ops", # Tile
|
||||
"//tensorflow/core/kernels:gather_op", # Gather
|
||||
"//tensorflow/core/kernels:mfcc_op", # Mfcc
|
||||
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
|
||||
"//tensorflow/core/kernels:strided_slice_op", # StridedSlice
|
||||
"//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice
|
||||
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
|
||||
"//tensorflow/core/kernels:random_ops", # RandomGammaGrad
|
||||
"//tensorflow/core/kernels:pack_op", # Pack
|
||||
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
|
||||
#### Needed by production model produced without "--use_seq_length False"
|
||||
#"//tensorflow/core/kernels:logging_ops", # Assert
|
||||
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
|
||||
],
|
||||
}) + if_cuda([
|
||||
"//tensorflow/core:core",
|
||||
|
@ -67,6 +67,9 @@ struct StreamingState {
|
||||
vector<float> audio_buffer_;
|
||||
vector<float> mfcc_buffer_;
|
||||
vector<float> batch_buffer_;
|
||||
vector<float> previous_state_c_;
|
||||
vector<float> previous_state_h_;
|
||||
|
||||
ModelState* model_;
|
||||
std::unique_ptr<DecoderState> decoder_state_;
|
||||
|
||||
@ -233,7 +236,13 @@ void
|
||||
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
||||
{
|
||||
vector<float> logits;
|
||||
model_->infer(buf.data(), n_steps, logits);
|
||||
model_->infer(buf,
|
||||
n_steps,
|
||||
previous_state_c_,
|
||||
previous_state_h_,
|
||||
logits,
|
||||
previous_state_c_,
|
||||
previous_state_h_);
|
||||
|
||||
const int cutoff_top_n = 40;
|
||||
const double cutoff_prob = 1.0;
|
||||
@ -326,11 +335,6 @@ DS_SetupStream(ModelState* aCtx,
|
||||
{
|
||||
*retval = nullptr;
|
||||
|
||||
int err = aCtx->initialize_state();
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
std::unique_ptr<StreamingState> ctx(new StreamingState());
|
||||
if (!ctx) {
|
||||
std::cerr << "Could not allocate streaming state." << std::endl;
|
||||
@ -348,6 +352,8 @@ DS_SetupStream(ModelState* aCtx,
|
||||
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
|
||||
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
|
||||
ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_);
|
||||
ctx->previous_state_c_.resize(aCtx->state_size_, 0.f);
|
||||
ctx->previous_state_h_.resize(aCtx->state_size_, 0.f);
|
||||
ctx->model_ = aCtx;
|
||||
|
||||
ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_));
|
||||
|
@ -17,6 +17,7 @@ ModelState::ModelState()
|
||||
, sample_rate_(DEFAULT_SAMPLE_RATE)
|
||||
, audio_win_len_(DEFAULT_WINDOW_LENGTH)
|
||||
, audio_win_step_(DEFAULT_WINDOW_STEP)
|
||||
, state_size_(-1)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -28,6 +28,7 @@ struct ModelState {
|
||||
unsigned int sample_rate_;
|
||||
unsigned int audio_win_len_;
|
||||
unsigned int audio_win_step_;
|
||||
unsigned int state_size_;
|
||||
|
||||
ModelState();
|
||||
virtual ~ModelState();
|
||||
@ -38,8 +39,6 @@ struct ModelState {
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width);
|
||||
|
||||
virtual int initialize_state() = 0;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
||||
|
||||
/**
|
||||
@ -52,7 +51,13 @@ struct ModelState {
|
||||
*
|
||||
* @param[out] output_logits Where to store computed logits.
|
||||
*/
|
||||
virtual void infer(const float* mfcc, unsigned int n_frames, std::vector<float>& logits_output) = 0;
|
||||
virtual void infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
std::vector<float>& logits_output,
|
||||
std::vector<float>& state_c_output,
|
||||
std::vector<float>& state_h_output) = 0;
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
|
@ -4,14 +4,13 @@ using namespace tflite;
|
||||
using std::vector;
|
||||
|
||||
int
|
||||
tflite_get_tensor_by_name(const Interpreter* interpreter,
|
||||
const vector<int>& list,
|
||||
TFLiteModelState::get_tensor_by_name(const vector<int>& list,
|
||||
const char* name)
|
||||
{
|
||||
int rv = -1;
|
||||
|
||||
for (int i = 0; i < list.size(); ++i) {
|
||||
const string& node_name = interpreter->tensor(list[i])->name;
|
||||
const string& node_name = interpreter_->tensor(list[i])->name;
|
||||
if (node_name.compare(string(name)) == 0) {
|
||||
rv = i;
|
||||
}
|
||||
@ -22,17 +21,17 @@ tflite_get_tensor_by_name(const Interpreter* interpreter,
|
||||
}
|
||||
|
||||
int
|
||||
tflite_get_input_tensor_by_name(const Interpreter* interpreter, const char* name)
|
||||
TFLiteModelState::get_input_tensor_by_name(const char* name)
|
||||
{
|
||||
int idx = tflite_get_tensor_by_name(interpreter, interpreter->inputs(), name);
|
||||
return interpreter->inputs()[idx];
|
||||
int idx = get_tensor_by_name(interpreter_->inputs(), name);
|
||||
return interpreter_->inputs()[idx];
|
||||
}
|
||||
|
||||
int
|
||||
tflite_get_output_tensor_by_name(const Interpreter* interpreter, const char* name)
|
||||
TFLiteModelState::get_output_tensor_by_name(const char* name)
|
||||
{
|
||||
int idx = tflite_get_tensor_by_name(interpreter, interpreter->outputs(), name);
|
||||
return interpreter->outputs()[idx];
|
||||
int idx = get_tensor_by_name(interpreter_->outputs(), name);
|
||||
return interpreter_->outputs()[idx];
|
||||
}
|
||||
|
||||
void
|
||||
@ -48,8 +47,8 @@ push_back_if_not_present(std::deque<int>& list, int value)
|
||||
// output, add it to the parent list, and add its input tensors to the frontier
|
||||
// list. Because we start from the final tensor and work backwards to the inputs,
|
||||
// the parents list is constructed in reverse, adding elements to its front.
|
||||
std::vector<int>
|
||||
tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
|
||||
vector<int>
|
||||
TFLiteModelState::find_parent_node_ids(int tensor_id)
|
||||
{
|
||||
std::deque<int> parents;
|
||||
std::deque<int> frontier;
|
||||
@ -58,8 +57,8 @@ tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
|
||||
int next_tensor_id = frontier.front();
|
||||
frontier.pop_front();
|
||||
// Find all nodes that have next_tensor_id as an output
|
||||
for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) {
|
||||
TfLiteNode node = interpreter->node_and_registration(node_id)->first;
|
||||
for (int node_id = 0; node_id < interpreter_->nodes_size(); ++node_id) {
|
||||
TfLiteNode node = interpreter_->node_and_registration(node_id)->first;
|
||||
// Search node outputs for the tensor we're looking for
|
||||
for (int i = 0; i < node.outputs->size; ++i) {
|
||||
if (node.outputs->data[i] == next_tensor_id) {
|
||||
@ -74,16 +73,13 @@ tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
|
||||
}
|
||||
}
|
||||
|
||||
return std::vector<int>(parents.begin(), parents.end());
|
||||
return vector<int>(parents.begin(), parents.end());
|
||||
}
|
||||
|
||||
TFLiteModelState::TFLiteModelState()
|
||||
: ModelState()
|
||||
, interpreter_(nullptr)
|
||||
, fbmodel_(nullptr)
|
||||
, previous_state_size_(0)
|
||||
, previous_state_c_(nullptr)
|
||||
, previous_state_h_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
@ -120,21 +116,21 @@ TFLiteModelState::init(const char* model_path,
|
||||
interpreter_->SetNumThreads(4);
|
||||
|
||||
// Query all the index once
|
||||
input_node_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_node");
|
||||
previous_state_c_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_c");
|
||||
previous_state_h_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_h");
|
||||
input_samples_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_samples");
|
||||
logits_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "logits");
|
||||
new_state_c_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_c");
|
||||
new_state_h_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_h");
|
||||
mfccs_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "mfccs");
|
||||
input_node_idx_ = get_input_tensor_by_name("input_node");
|
||||
previous_state_c_idx_ = get_input_tensor_by_name("previous_state_c");
|
||||
previous_state_h_idx_ = get_input_tensor_by_name("previous_state_h");
|
||||
input_samples_idx_ = get_input_tensor_by_name("input_samples");
|
||||
logits_idx_ = get_output_tensor_by_name("logits");
|
||||
new_state_c_idx_ = get_output_tensor_by_name("new_state_c");
|
||||
new_state_h_idx_ = get_output_tensor_by_name("new_state_h");
|
||||
mfccs_idx_ = get_output_tensor_by_name("mfccs");
|
||||
|
||||
// When we call Interpreter::Invoke, the whole graph is executed by default,
|
||||
// which means every time compute_mfcc is called the entire acoustic model is
|
||||
// also executed. To workaround that problem, we walk up the dependency DAG
|
||||
// from the mfccs output tensor to find all the relevant nodes required for
|
||||
// feature computation, building an execution plan that runs just those nodes.
|
||||
auto mfcc_plan = tflite_find_parent_node_ids(interpreter_.get(), mfccs_idx_);
|
||||
auto mfcc_plan = find_parent_node_ids(mfccs_idx_);
|
||||
auto orig_plan = interpreter_->execution_plan();
|
||||
|
||||
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
|
||||
@ -168,50 +164,57 @@ TFLiteModelState::init(const char* model_path,
|
||||
TfLiteIntArray* dims_c = interpreter_->tensor(previous_state_c_idx_)->dims;
|
||||
TfLiteIntArray* dims_h = interpreter_->tensor(previous_state_h_idx_)->dims;
|
||||
assert(dims_c->data[1] == dims_h->data[1]);
|
||||
|
||||
previous_state_size_ = dims_c->data[1];
|
||||
previous_state_c_.reset(new float[previous_state_size_]());
|
||||
previous_state_h_.reset(new float[previous_state_size_]());
|
||||
|
||||
// Set initial values for previous_state_c and previous_state_h
|
||||
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
|
||||
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
|
||||
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
int
|
||||
TFLiteModelState::initialize_state()
|
||||
{
|
||||
/* Ensure previous_state_{c,h} are not holding previous stream value */
|
||||
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
|
||||
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
|
||||
assert(state_size_ > 0);
|
||||
state_size_ = dims_c->data[1];
|
||||
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
void
|
||||
TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
|
||||
TFLiteModelState::copy_vector_to_tensor(const vector<float>& vec,
|
||||
int tensor_idx,
|
||||
int num_elements)
|
||||
{
|
||||
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
|
||||
int i;
|
||||
for (i = 0; i < vec.size(); ++i) {
|
||||
tensor[i] = vec[i];
|
||||
}
|
||||
for (; i < num_elements; ++i) {
|
||||
tensor[i] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TFLiteModelState::copy_tensor_to_vector(int tensor_idx,
|
||||
int num_elements,
|
||||
vector<float>& vec)
|
||||
{
|
||||
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
vec.push_back(tensor[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TFLiteModelState::infer(const vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const vector<float>& previous_state_c,
|
||||
const vector<float>& previous_state_h,
|
||||
vector<float>& logits_output,
|
||||
vector<float>& state_c_output,
|
||||
vector<float>& state_h_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||
|
||||
// Feeding input_node
|
||||
float* input_node = interpreter_->typed_tensor<float>(input_node_idx_);
|
||||
{
|
||||
int i;
|
||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
|
||||
input_node[i] = aMfcc[i];
|
||||
}
|
||||
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
|
||||
input_node[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
assert(previous_state_size_ > 0);
|
||||
copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_);
|
||||
|
||||
// Feeding previous_state_c, previous_state_h
|
||||
memcpy(interpreter_->typed_tensor<float>(previous_state_c_idx_), previous_state_c_.get(), sizeof(float) * previous_state_size_);
|
||||
memcpy(interpreter_->typed_tensor<float>(previous_state_h_idx_), previous_state_h_.get(), sizeof(float) * previous_state_size_);
|
||||
assert(previous_state_c.size() == state_size_);
|
||||
copy_vector_to_tensor(previous_state_c, previous_state_c_idx_, state_size_);
|
||||
assert(previous_state_h.size() == state_size_);
|
||||
copy_vector_to_tensor(previous_state_h, previous_state_h_idx_, state_size_);
|
||||
|
||||
interpreter_->SetExecutionPlan(acoustic_exec_plan_);
|
||||
TfLiteStatus status = interpreter_->Invoke();
|
||||
@ -220,25 +223,23 @@ TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>
|
||||
return;
|
||||
}
|
||||
|
||||
float* outputs = interpreter_->typed_tensor<float>(logits_idx_);
|
||||
copy_tensor_to_vector(logits_idx_, n_frames * BATCH_SIZE * num_classes, logits_output);
|
||||
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(outputs[t]);
|
||||
}
|
||||
state_c_output.clear();
|
||||
state_c_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(new_state_c_idx_, state_size_, state_c_output);
|
||||
|
||||
memcpy(previous_state_c_.get(), interpreter_->typed_tensor<float>(new_state_c_idx_), sizeof(float) * previous_state_size_);
|
||||
memcpy(previous_state_h_.get(), interpreter_->typed_tensor<float>(new_state_h_idx_), sizeof(float) * previous_state_size_);
|
||||
state_h_output.clear();
|
||||
state_h_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(new_state_h_idx_, state_size_, state_h_output);
|
||||
}
|
||||
|
||||
void
|
||||
TFLiteModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
TFLiteModelState::compute_mfcc(const vector<float>& samples,
|
||||
vector<float>& mfcc_output)
|
||||
{
|
||||
// Feeding input_node
|
||||
float* input_samples = interpreter_->typed_tensor<float>(input_samples_idx_);
|
||||
for (int i = 0; i < samples.size(); ++i) {
|
||||
input_samples[i] = samples[i];
|
||||
}
|
||||
copy_vector_to_tensor(samples, input_samples_idx_, samples.size());
|
||||
|
||||
TfLiteStatus status = interpreter_->SetExecutionPlan(mfcc_exec_plan_);
|
||||
if (status != kTfLiteOk) {
|
||||
@ -261,8 +262,5 @@ TFLiteModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc
|
||||
}
|
||||
assert(num_elements / n_features_ == n_windows);
|
||||
|
||||
float* outputs = interpreter_->typed_tensor<float>(mfccs_idx_);
|
||||
for (int i = 0; i < n_windows * n_features_; ++i) {
|
||||
mfcc_output.push_back(outputs[i]);
|
||||
}
|
||||
copy_tensor_to_vector(mfccs_idx_, n_windows * n_features_, mfcc_output);
|
||||
}
|
||||
|
@ -14,10 +14,6 @@ struct TFLiteModelState : public ModelState
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
std::unique_ptr<tflite::FlatBufferModel> fbmodel_;
|
||||
|
||||
size_t previous_state_size_;
|
||||
std::unique_ptr<float[]> previous_state_c_;
|
||||
std::unique_ptr<float[]> previous_state_h_;
|
||||
|
||||
int input_node_idx_;
|
||||
int previous_state_c_idx_;
|
||||
int previous_state_h_idx_;
|
||||
@ -40,13 +36,28 @@ struct TFLiteModelState : public ModelState
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width) override;
|
||||
|
||||
virtual int initialize_state() override;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||
std::vector<float>& mfcc_output) override;
|
||||
|
||||
virtual void infer(const float* mfcc, unsigned int n_frames,
|
||||
std::vector<float>& logits_output) override;
|
||||
virtual void infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
std::vector<float>& logits_output,
|
||||
std::vector<float>& state_c_output,
|
||||
std::vector<float>& state_h_output) override;
|
||||
|
||||
private:
|
||||
int get_tensor_by_name(const std::vector<int>& list, const char* name);
|
||||
int get_input_tensor_by_name(const char* name);
|
||||
int get_output_tensor_by_name(const char* name);
|
||||
std::vector<int> find_parent_node_ids(int tensor_id);
|
||||
void copy_vector_to_tensor(const std::vector<float>& vec,
|
||||
int tensor_idx,
|
||||
int num_elements);
|
||||
void copy_tensor_to_vector(int tensor_idx,
|
||||
int num_elements,
|
||||
std::vector<float>& vec);
|
||||
};
|
||||
|
||||
#endif // TFLITEMODELSTATE_H
|
||||
|
@ -98,6 +98,9 @@ TFModelState::init(const char* model_path,
|
||||
n_context_ = (shape.dim(2).size()-1)/2;
|
||||
n_features_ = shape.dim(3).size();
|
||||
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
|
||||
} else if (node.name() == "previous_state_c") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
state_size_ = shape.dim(1).size();
|
||||
} else if (node.name() == "logits_shape") {
|
||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||
@ -134,66 +137,83 @@ TFModelState::init(const char* model_path,
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
int
|
||||
TFModelState::initialize_state()
|
||||
Tensor
|
||||
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
|
||||
{
|
||||
Status status = session_->Run({}, {}, {"initialize_state"}, nullptr);
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << std::endl;
|
||||
return DS_ERR_FAIL_RUN_SESS;
|
||||
Tensor ret(DT_FLOAT, shape);
|
||||
auto ret_mapped = ret.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < vec.size(); ++i) {
|
||||
ret_mapped(i) = vec[i];
|
||||
}
|
||||
|
||||
return DS_ERR_OK;
|
||||
for (; i < shape.num_elements(); ++i) {
|
||||
ret_mapped(i) = 0.f;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
|
||||
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
|
||||
{
|
||||
auto tensor_mapped = tensor.flat<float>();
|
||||
if (num_elements == -1) {
|
||||
num_elements = tensor.shape().num_elements();
|
||||
}
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
vec.push_back(tensor_mapped(i));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
vector<float>& logits_output,
|
||||
vector<float>& state_c_output,
|
||||
vector<float>& state_h_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||
|
||||
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
||||
|
||||
auto input_mapped = input.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
|
||||
input_mapped(i) = aMfcc[i];
|
||||
}
|
||||
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
|
||||
input_mapped(i) = 0.;
|
||||
}
|
||||
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
||||
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
|
||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||
input_lengths.scalar<int>()() = n_frames;
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run(
|
||||
{{"input_node", input}, {"input_lengths", input_lengths}},
|
||||
{"logits"}, {}, &outputs);
|
||||
{
|
||||
{"input_node", input},
|
||||
{"input_lengths", input_lengths},
|
||||
{"previous_state_c", previous_state_c_t},
|
||||
{"previous_state_h", previous_state_h_t}
|
||||
},
|
||||
{"logits", "new_state_c", "new_state_h"},
|
||||
{},
|
||||
&outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
auto logits_mapped = outputs[0].flat<float>();
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(logits_mapped(t));
|
||||
}
|
||||
copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
|
||||
|
||||
state_c_output.clear();
|
||||
state_c_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(outputs[1], state_c_output);
|
||||
|
||||
state_h_output.clear();
|
||||
state_h_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(outputs[2], state_h_output);
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
{
|
||||
Tensor input(DT_FLOAT, TensorShape({audio_win_len_}));
|
||||
auto input_mapped = input.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < samples.size(); ++i) {
|
||||
input_mapped(i) = samples[i];
|
||||
}
|
||||
for (; i < audio_win_len_; ++i) {
|
||||
input_mapped(i) = 0.f;
|
||||
}
|
||||
Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
||||
@ -206,9 +226,5 @@ TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_out
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
const int n_windows = 1;
|
||||
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
|
||||
|
||||
auto mfcc_mapped = outputs[0].flat<float>();
|
||||
for (int i = 0; i < n_windows * n_features_; ++i) {
|
||||
mfcc_output.push_back(mfcc_mapped(i));
|
||||
}
|
||||
copy_tensor_to_vector(outputs[0], mfcc_output);
|
||||
}
|
||||
|
@ -24,11 +24,13 @@ struct TFModelState : public ModelState
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width) override;
|
||||
|
||||
virtual int initialize_state() override;
|
||||
|
||||
virtual void infer(const float* mfcc,
|
||||
virtual void infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
std::vector<float>& logits_output) override;
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
std::vector<float>& logits_output,
|
||||
std::vector<float>& state_c_output,
|
||||
std::vector<float>& state_h_output) override;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||
std::vector<float>& mfcc_output) override;
|
||||
|
Loading…
Reference in New Issue
Block a user