Merge pull request #2146 from mozilla/refactor-model-impls
Refactor TF and TFLite implementations into their own classes/files and fix concurrent/interleaved stream bugs by tracking LSTM state in StreamingState
This commit is contained in:
commit
a2306cf822
@ -574,12 +574,8 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
# no state management since n_step is expected to be dynamic too (see below)
|
||||
previous_state = previous_state_c = previous_state_h = None
|
||||
else:
|
||||
if tflite:
|
||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
else:
|
||||
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
|
||||
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||
|
||||
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||
|
||||
@ -592,7 +588,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
||||
logits, layers = create_model(batch_x=input_tensor,
|
||||
seq_length=seq_length if FLAGS.use_seq_length else None,
|
||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||
dropout=no_dropout,
|
||||
previous_state=previous_state,
|
||||
overlap=False,
|
||||
@ -605,7 +601,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
logits = tf.squeeze(logits, [1])
|
||||
|
||||
# Apply softmax for CTC decoder
|
||||
logits = tf.nn.softmax(logits)
|
||||
logits = tf.nn.softmax(logits, name='logits')
|
||||
|
||||
if batch_size <= 0:
|
||||
if tflite:
|
||||
@ -618,51 +614,31 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
'input_lengths': seq_length,
|
||||
},
|
||||
{
|
||||
'outputs': tf.identity(logits, name='logits'),
|
||||
'outputs': logits,
|
||||
},
|
||||
layers
|
||||
)
|
||||
|
||||
new_state_c, new_state_h = layers['rnn_output_state']
|
||||
if tflite:
|
||||
logits = tf.identity(logits, name='logits')
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'previous_state_c': previous_state_c,
|
||||
'previous_state_h': previous_state_h,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'previous_state_c': previous_state_c,
|
||||
'previous_state_h': previous_state_h,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
|
||||
if FLAGS.use_seq_length:
|
||||
inputs.update({'input_lengths': seq_length})
|
||||
if not FLAGS.export_tflite:
|
||||
inputs.update({'input_lengths': seq_length})
|
||||
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
else:
|
||||
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
|
||||
initialize_c = tf.assign(previous_state_c, zero_state)
|
||||
initialize_h = tf.assign(previous_state_h, zero_state)
|
||||
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
|
||||
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
|
||||
logits = tf.identity(logits, name='logits')
|
||||
|
||||
inputs = {
|
||||
'input': input_tensor,
|
||||
'input_lengths': seq_length,
|
||||
'input_samples': input_samples,
|
||||
}
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'initialize_state': initialize_state,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
outputs = {
|
||||
'outputs': logits,
|
||||
'new_state_c': new_state_c,
|
||||
'new_state_h': new_state_h,
|
||||
'mfccs': mfccs,
|
||||
}
|
||||
|
||||
return inputs, outputs, layers
|
||||
|
||||
@ -682,10 +658,12 @@ def export():
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
|
||||
else:
|
||||
mapping = None
|
||||
if FLAGS.export_tflite:
|
||||
# Create a saver using variables from the above newly created graph
|
||||
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses
|
||||
# a static RNN with a normal cell, so we need to rewrite the names to
|
||||
# match the training weights when restoring.
|
||||
def fixup(name):
|
||||
if name.startswith('rnn/lstm_cell/'):
|
||||
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
|
||||
@ -710,7 +688,7 @@ def export():
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
|
||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
|
||||
frozen = freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=tf.get_default_graph().as_graph_def(),
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
@ -731,7 +709,7 @@ def export():
|
||||
placeholder_type_enum=tf.float32.as_datatype_enum)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
||||
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
|
||||
# Add a no-op node to the graph with metadata information to be loaded by the native client
|
||||
@ -747,7 +725,7 @@ def export():
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
@ -771,8 +749,7 @@ def do_single_file_inference(input_file_path):
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Create a saver using variables from the above newly created graph
|
||||
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
|
||||
saver = tf.train.Saver(mapping)
|
||||
saver = tf.train.Saver()
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
# TODO: This restores the most recent checkpoint, but if we use validation to counteract
|
||||
@ -784,9 +761,10 @@ def do_single_file_inference(input_file_path):
|
||||
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
session.run(outputs['initialize_state'])
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||
|
||||
# Add batch dimension
|
||||
features = tf.expand_dims(features, 0)
|
||||
@ -799,6 +777,8 @@ def do_single_file_inference(input_file_path):
|
||||
logits = outputs['outputs'].eval(feed_dict={
|
||||
inputs['input']: features,
|
||||
inputs['input_lengths']: features_len,
|
||||
inputs['previous_state_c']: previous_state_c,
|
||||
inputs['previous_state_h']: previous_state_h,
|
||||
}, session=session)
|
||||
|
||||
logits = np.squeeze(logits)
|
||||
|
@ -1 +1 @@
|
||||
1
|
||||
2
|
@ -343,7 +343,7 @@ Refer to the corresponding [README.md](native_client/README.md) for information
|
||||
|
||||
### Exporting a model for TFLite
|
||||
|
||||
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--nouse_seq_length --export_tflite --export_dir /model/export/destination`.
|
||||
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--export_tflite --export_dir /model/export/destination`.
|
||||
|
||||
### Making a mmap-able model for inference
|
||||
|
||||
|
@ -20,4 +20,4 @@ python -u DeepSpeech.py --noshow_progressbar \
|
||||
--export_dir '/tmp/train_tflite' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--export_tflite --nouse_seq_length
|
||||
--export_tflite
|
||||
|
BIN
data/smoke_test/new-home-in-the-stars-16k.wav
Normal file
BIN
data/smoke_test/new-home-in-the-stars-16k.wav
Normal file
Binary file not shown.
@ -70,9 +70,20 @@ tf_cc_shared_object(
|
||||
srcs = ["deepspeech.cc",
|
||||
"deepspeech.h",
|
||||
"alphabet.h",
|
||||
"modelstate.h",
|
||||
"modelstate.cc",
|
||||
"ds_version.h",
|
||||
"ds_graph_version.h"] +
|
||||
DECODER_SOURCES,
|
||||
DECODER_SOURCES +
|
||||
select({
|
||||
"//native_client:tflite": [
|
||||
"tflitemodelstate.h",
|
||||
"tflitemodelstate.cc"
|
||||
],
|
||||
"//conditions:default": [
|
||||
"tfmodelstate.h",
|
||||
"tfmodelstate.cc"
|
||||
]}),
|
||||
copts = select({
|
||||
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
||||
"//tensorflow:windows": ["/w"],
|
||||
@ -103,34 +114,26 @@ tf_cc_shared_object(
|
||||
### => Trying to be more fine-grained
|
||||
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
|
||||
### CPU only build, libdeepspeech.so file size reduced by ~50%
|
||||
"//tensorflow/core/kernels:dense_update_ops", # Assign
|
||||
"//tensorflow/core/kernels:constant_op", # Const
|
||||
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst
|
||||
"//tensorflow/core/kernels:dense_update_ops", # Assign (remove once prod model no longer depends on it)
|
||||
"//tensorflow/core/kernels:constant_op", # Placeholder
|
||||
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
|
||||
"//tensorflow/core/kernels:identity_op", # Identity
|
||||
"//tensorflow/core/kernels:softmax_op", # Softmax
|
||||
"//tensorflow/core/kernels:transpose_op", # Transpose
|
||||
"//tensorflow/core/kernels:reshape_op", # Reshape
|
||||
"//tensorflow/core/kernels:shape_ops", # Shape
|
||||
"//tensorflow/core/kernels:concat_op", # ConcatV2
|
||||
"//tensorflow/core/kernels:split_op", # Split
|
||||
"//tensorflow/core/kernels:variable_ops", # VariableV2
|
||||
"//tensorflow/core/kernels:relu_op", # Relu
|
||||
"//tensorflow/core/kernels:bias_op", # BiasAdd
|
||||
"//tensorflow/core/kernels:math", # Range, MatMul
|
||||
"//tensorflow/core/kernels:control_flow_ops", # Enter
|
||||
"//tensorflow/core/kernels:tile_ops", # Tile
|
||||
"//tensorflow/core/kernels:gather_op", # Gather
|
||||
"//tensorflow/core/kernels:mfcc_op", # Mfcc
|
||||
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
|
||||
"//tensorflow/core/kernels:strided_slice_op", # StridedSlice
|
||||
"//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice
|
||||
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
|
||||
"//tensorflow/core/kernels:random_ops", # RandomGammaGrad
|
||||
"//tensorflow/core/kernels:pack_op", # Pack
|
||||
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
|
||||
#### Needed by production model produced without "--use_seq_length False"
|
||||
#"//tensorflow/core/kernels:logging_ops", # Assert
|
||||
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
|
||||
],
|
||||
}) + if_cuda([
|
||||
"//tensorflow/core:core",
|
||||
|
@ -11,17 +11,14 @@
|
||||
|
||||
#include "deepspeech.h"
|
||||
#include "alphabet.h"
|
||||
#include "modelstate.h"
|
||||
|
||||
#include "native_client/ds_version.h"
|
||||
#include "native_client/ds_graph_version.h"
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
#else // USE_TFLITE
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tfmodelstate.h"
|
||||
#else
|
||||
#include "tflitemodelstate.h"
|
||||
#endif // USE_TFLITE
|
||||
|
||||
#include "ctcdecode/ctc_beam_search_decoder.h"
|
||||
@ -36,23 +33,9 @@
|
||||
#define LOGE(...)
|
||||
#endif // __ANDROID__
|
||||
|
||||
//TODO: infer batch size from model/use dynamic batch size
|
||||
constexpr unsigned int BATCH_SIZE = 1;
|
||||
|
||||
constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
|
||||
constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||
constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
using namespace tensorflow;
|
||||
#else
|
||||
using namespace tflite;
|
||||
#endif
|
||||
|
||||
using std::vector;
|
||||
|
||||
/* This is the actual implementation of the streaming inference API, with the
|
||||
Model class just forwarding the calls to this class.
|
||||
/* This is the implementation of the streaming inference API.
|
||||
|
||||
The streaming process uses three buffers that are fed eagerly as audio data
|
||||
is fed in. The buffers only hold the minimum amount of data needed to do a
|
||||
@ -75,17 +58,23 @@ using std::vector;
|
||||
API. When audio_buffer is full, features are computed from it and pushed to
|
||||
mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer.
|
||||
When batch_buffer is full, we do a single step through the acoustic model
|
||||
and accumulate results in the DecoderState structure.
|
||||
and accumulate the intermediate decoding state in the DecoderState structure.
|
||||
|
||||
When finishStream() is called, we decode the accumulated logits and return
|
||||
the corresponding transcription.
|
||||
When finishStream() is called, we return the corresponding transcription from
|
||||
the current decoder state.
|
||||
*/
|
||||
struct StreamingState {
|
||||
vector<float> audio_buffer;
|
||||
vector<float> mfcc_buffer;
|
||||
vector<float> batch_buffer;
|
||||
ModelState* model;
|
||||
std::unique_ptr<DecoderState> decoder_state;
|
||||
vector<float> audio_buffer_;
|
||||
vector<float> mfcc_buffer_;
|
||||
vector<float> batch_buffer_;
|
||||
vector<float> previous_state_c_;
|
||||
vector<float> previous_state_h_;
|
||||
|
||||
ModelState* model_;
|
||||
std::unique_ptr<DecoderState> decoder_state_;
|
||||
|
||||
StreamingState();
|
||||
~StreamingState();
|
||||
|
||||
void feedAudioContent(const short* buffer, unsigned int buffer_size);
|
||||
char* intermediateDecode();
|
||||
@ -100,133 +89,12 @@ struct StreamingState {
|
||||
void processBatch(const vector<float>& buf, unsigned int n_steps);
|
||||
};
|
||||
|
||||
struct ModelState {
|
||||
#ifndef USE_TFLITE
|
||||
MemmappedEnv* mmap_env;
|
||||
Session* session;
|
||||
GraphDef graph_def;
|
||||
#else // USE_TFLITE
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
std::unique_ptr<FlatBufferModel> fbmodel;
|
||||
#endif // USE_TFLITE
|
||||
unsigned int ncep;
|
||||
unsigned int ncontext;
|
||||
Alphabet* alphabet;
|
||||
Scorer* scorer;
|
||||
unsigned int beam_width;
|
||||
unsigned int n_steps;
|
||||
unsigned int n_context;
|
||||
unsigned int n_features;
|
||||
unsigned int mfcc_feats_per_timestep;
|
||||
unsigned int sample_rate;
|
||||
unsigned int audio_win_len;
|
||||
unsigned int audio_win_step;
|
||||
|
||||
#ifdef USE_TFLITE
|
||||
size_t previous_state_size;
|
||||
std::unique_ptr<float[]> previous_state_c_;
|
||||
std::unique_ptr<float[]> previous_state_h_;
|
||||
|
||||
int input_node_idx;
|
||||
int previous_state_c_idx;
|
||||
int previous_state_h_idx;
|
||||
int input_samples_idx;
|
||||
|
||||
int logits_idx;
|
||||
int new_state_c_idx;
|
||||
int new_state_h_idx;
|
||||
int mfccs_idx;
|
||||
|
||||
std::vector<int> acoustic_exec_plan;
|
||||
std::vector<int> mfcc_exec_plan;
|
||||
#endif
|
||||
|
||||
ModelState();
|
||||
~ModelState();
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @return String representing the decoded text.
|
||||
*/
|
||||
char* decode(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
||||
*/
|
||||
vector<Output> decode_raw(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Return character-level metadata including letter timings.
|
||||
*
|
||||
*
|
||||
* @return Metadata struct containing MetadataItem structs for each character.
|
||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
||||
*/
|
||||
Metadata* decode_metadata(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Do a single inference step in the acoustic model, with:
|
||||
* input=mfcc
|
||||
* input_lengths=[n_frames]
|
||||
*
|
||||
* @param mfcc batch input data
|
||||
* @param n_frames number of timesteps in the data
|
||||
*
|
||||
* @param[out] output_logits Where to store computed logits.
|
||||
*/
|
||||
void infer(const float* mfcc, unsigned int n_frames, vector<float>& logits_output);
|
||||
|
||||
void compute_mfcc(const vector<float>& audio_buffer, vector<float>& mfcc_output);
|
||||
};
|
||||
|
||||
ModelState::ModelState()
|
||||
:
|
||||
#ifndef USE_TFLITE
|
||||
mmap_env(nullptr)
|
||||
, session(nullptr)
|
||||
#else // USE_TFLITE
|
||||
interpreter(nullptr)
|
||||
, fbmodel(nullptr)
|
||||
#endif // USE_TFLITE
|
||||
, ncep(0)
|
||||
, ncontext(0)
|
||||
, alphabet(nullptr)
|
||||
, scorer(nullptr)
|
||||
, beam_width(0)
|
||||
, n_steps(-1)
|
||||
, n_context(-1)
|
||||
, n_features(-1)
|
||||
, mfcc_feats_per_timestep(-1)
|
||||
, sample_rate(DEFAULT_SAMPLE_RATE)
|
||||
, audio_win_len(DEFAULT_WINDOW_LENGTH)
|
||||
, audio_win_step(DEFAULT_WINDOW_STEP)
|
||||
#ifdef USE_TFLITE
|
||||
, previous_state_size(0)
|
||||
, previous_state_c_(nullptr)
|
||||
, previous_state_h_(nullptr)
|
||||
#endif
|
||||
StreamingState::StreamingState()
|
||||
{
|
||||
}
|
||||
|
||||
ModelState::~ModelState()
|
||||
StreamingState::~StreamingState()
|
||||
{
|
||||
#ifndef USE_TFLITE
|
||||
if (session) {
|
||||
Status status = session->Close();
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
||||
}
|
||||
}
|
||||
delete mmap_env;
|
||||
#endif // USE_TFLITE
|
||||
|
||||
delete scorer;
|
||||
delete alphabet;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@ -243,19 +111,19 @@ StreamingState::feedAudioContent(const short* buffer,
|
||||
{
|
||||
// Consume all the data that was passed in, processing full buffers if needed
|
||||
while (buffer_size > 0) {
|
||||
while (buffer_size > 0 && audio_buffer.size() < model->audio_win_len) {
|
||||
while (buffer_size > 0 && audio_buffer_.size() < model_->audio_win_len_) {
|
||||
// Convert i16 sample into f32
|
||||
float multiplier = 1.0f / (1 << 15);
|
||||
audio_buffer.push_back((float)(*buffer) * multiplier);
|
||||
audio_buffer_.push_back((float)(*buffer) * multiplier);
|
||||
++buffer;
|
||||
--buffer_size;
|
||||
}
|
||||
|
||||
// If the buffer is full, process and shift it
|
||||
if (audio_buffer.size() == model->audio_win_len) {
|
||||
processAudioWindow(audio_buffer);
|
||||
if (audio_buffer_.size() == model_->audio_win_len_) {
|
||||
processAudioWindow(audio_buffer_);
|
||||
// Shift data by one step
|
||||
shift_buffer_left(audio_buffer, model->audio_win_step);
|
||||
shift_buffer_left(audio_buffer_, model_->audio_win_step_);
|
||||
}
|
||||
|
||||
// Repeat until buffer empty
|
||||
@ -265,21 +133,21 @@ StreamingState::feedAudioContent(const short* buffer,
|
||||
char*
|
||||
StreamingState::intermediateDecode()
|
||||
{
|
||||
return model->decode(decoder_state.get());
|
||||
return model_->decode(decoder_state_.get());
|
||||
}
|
||||
|
||||
char*
|
||||
StreamingState::finishStream()
|
||||
{
|
||||
finalizeStream();
|
||||
return model->decode(decoder_state.get());
|
||||
return model_->decode(decoder_state_.get());
|
||||
}
|
||||
|
||||
Metadata*
|
||||
StreamingState::finishStreamWithMetadata()
|
||||
{
|
||||
finalizeStream();
|
||||
return model->decode_metadata(decoder_state.get());
|
||||
return model_->decode_metadata(decoder_state_.get());
|
||||
}
|
||||
|
||||
void
|
||||
@ -287,8 +155,8 @@ StreamingState::processAudioWindow(const vector<float>& buf)
|
||||
{
|
||||
// Compute MFCC features
|
||||
vector<float> mfcc;
|
||||
mfcc.reserve(model->n_features);
|
||||
model->compute_mfcc(buf, mfcc);
|
||||
mfcc.reserve(model_->n_features_);
|
||||
model_->compute_mfcc(buf, mfcc);
|
||||
pushMfccBuffer(mfcc);
|
||||
}
|
||||
|
||||
@ -296,23 +164,23 @@ void
|
||||
StreamingState::finalizeStream()
|
||||
{
|
||||
// Flush audio buffer
|
||||
processAudioWindow(audio_buffer);
|
||||
processAudioWindow(audio_buffer_);
|
||||
|
||||
// Add empty mfcc vectors at end of sample
|
||||
for (int i = 0; i < model->n_context; ++i) {
|
||||
for (int i = 0; i < model_->n_context_; ++i) {
|
||||
addZeroMfccWindow();
|
||||
}
|
||||
|
||||
// Process final batch
|
||||
if (batch_buffer.size() > 0) {
|
||||
processBatch(batch_buffer, batch_buffer.size()/model->mfcc_feats_per_timestep);
|
||||
if (batch_buffer_.size() > 0) {
|
||||
processBatch(batch_buffer_, batch_buffer_.size()/model_->mfcc_feats_per_timestep_);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamingState::addZeroMfccWindow()
|
||||
{
|
||||
vector<float> zero_buffer(model->n_features, 0.f);
|
||||
vector<float> zero_buffer(model_->n_features_, 0.f);
|
||||
pushMfccBuffer(zero_buffer);
|
||||
}
|
||||
|
||||
@ -332,15 +200,15 @@ StreamingState::pushMfccBuffer(const vector<float>& buf)
|
||||
auto end = buf.end();
|
||||
while (start != end) {
|
||||
// Copy from input buffer to mfcc_buffer, stopping if we have a full context window
|
||||
start = copy_up_to_n(start, end, std::back_inserter(mfcc_buffer),
|
||||
model->mfcc_feats_per_timestep - mfcc_buffer.size());
|
||||
assert(mfcc_buffer.size() <= model->mfcc_feats_per_timestep);
|
||||
start = copy_up_to_n(start, end, std::back_inserter(mfcc_buffer_),
|
||||
model_->mfcc_feats_per_timestep_ - mfcc_buffer_.size());
|
||||
assert(mfcc_buffer_.size() <= model_->mfcc_feats_per_timestep_);
|
||||
|
||||
// If we have a full context window
|
||||
if (mfcc_buffer.size() == model->mfcc_feats_per_timestep) {
|
||||
processMfccWindow(mfcc_buffer);
|
||||
if (mfcc_buffer_.size() == model_->mfcc_feats_per_timestep_) {
|
||||
processMfccWindow(mfcc_buffer_);
|
||||
// Shift data by one step of one mfcc feature vector
|
||||
shift_buffer_left(mfcc_buffer, model->n_features);
|
||||
shift_buffer_left(mfcc_buffer_, model_->n_features_);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -352,14 +220,14 @@ StreamingState::processMfccWindow(const vector<float>& buf)
|
||||
auto end = buf.end();
|
||||
while (start != end) {
|
||||
// Copy from input buffer to batch_buffer, stopping if we have a full batch
|
||||
start = copy_up_to_n(start, end, std::back_inserter(batch_buffer),
|
||||
model->n_steps * model->mfcc_feats_per_timestep - batch_buffer.size());
|
||||
assert(batch_buffer.size() <= model->n_steps * model->mfcc_feats_per_timestep);
|
||||
start = copy_up_to_n(start, end, std::back_inserter(batch_buffer_),
|
||||
model_->n_steps_ * model_->mfcc_feats_per_timestep_ - batch_buffer_.size());
|
||||
assert(batch_buffer_.size() <= model_->n_steps_ * model_->mfcc_feats_per_timestep_);
|
||||
|
||||
// If we have a full batch
|
||||
if (batch_buffer.size() == model->n_steps * model->mfcc_feats_per_timestep) {
|
||||
processBatch(batch_buffer, model->n_steps);
|
||||
batch_buffer.resize(0);
|
||||
if (batch_buffer_.size() == model_->n_steps_ * model_->mfcc_feats_per_timestep_) {
|
||||
processBatch(batch_buffer_, model_->n_steps_);
|
||||
batch_buffer_.resize(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -368,272 +236,33 @@ void
|
||||
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
||||
{
|
||||
vector<float> logits;
|
||||
model->infer(buf.data(), n_steps, logits);
|
||||
|
||||
model_->infer(buf,
|
||||
n_steps,
|
||||
previous_state_c_,
|
||||
previous_state_h_,
|
||||
logits,
|
||||
previous_state_c_,
|
||||
previous_state_h_);
|
||||
|
||||
const int cutoff_top_n = 40;
|
||||
const double cutoff_prob = 1.0;
|
||||
const size_t num_classes = model->alphabet->GetSize() + 1; // +1 for blank
|
||||
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
|
||||
const size_t num_classes = model_->alphabet_->GetSize() + 1; // +1 for blank
|
||||
const int n_frames = logits.size() / (ModelState::BATCH_SIZE * num_classes);
|
||||
|
||||
// Convert logits to double
|
||||
vector<double> inputs(logits.begin(), logits.end());
|
||||
|
||||
decoder_next(inputs.data(),
|
||||
*model->alphabet,
|
||||
decoder_state.get(),
|
||||
*model_->alphabet_,
|
||||
decoder_state_.get(),
|
||||
n_frames,
|
||||
num_classes,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
model->beam_width,
|
||||
model->scorer);
|
||||
model_->beam_width_,
|
||||
model_->scorer_);
|
||||
}
|
||||
|
||||
void
|
||||
ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
|
||||
{
|
||||
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, n_features}));
|
||||
|
||||
auto input_mapped = input.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
||||
input_mapped(i) = aMfcc[i];
|
||||
}
|
||||
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
||||
input_mapped(i) = 0.;
|
||||
}
|
||||
|
||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||
input_lengths.scalar<int>()() = n_frames;
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session->Run(
|
||||
{{"input_node", input}, {"input_lengths", input_lengths}},
|
||||
{"logits"}, {}, &outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
auto logits_mapped = outputs[0].flat<float>();
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(logits_mapped(t));
|
||||
}
|
||||
#else // USE_TFLITE
|
||||
// Feeding input_node
|
||||
float* input_node = interpreter->typed_tensor<float>(input_node_idx);
|
||||
{
|
||||
int i;
|
||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
||||
input_node[i] = aMfcc[i];
|
||||
}
|
||||
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
||||
input_node[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
assert(previous_state_size > 0);
|
||||
|
||||
// Feeding previous_state_c, previous_state_h
|
||||
memcpy(interpreter->typed_tensor<float>(previous_state_c_idx), previous_state_c_.get(), sizeof(float) * previous_state_size);
|
||||
memcpy(interpreter->typed_tensor<float>(previous_state_h_idx), previous_state_h_.get(), sizeof(float) * previous_state_size);
|
||||
|
||||
interpreter->SetExecutionPlan(acoustic_exec_plan);
|
||||
TfLiteStatus status = interpreter->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
float* outputs = interpreter->typed_tensor<float>(logits_idx);
|
||||
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(outputs[t]);
|
||||
}
|
||||
|
||||
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(new_state_c_idx), sizeof(float) * previous_state_size);
|
||||
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(new_state_h_idx), sizeof(float) * previous_state_size);
|
||||
#endif // USE_TFLITE
|
||||
}
|
||||
|
||||
void
|
||||
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
{
|
||||
#ifndef USE_TFLITE
|
||||
Tensor input(DT_FLOAT, TensorShape({audio_win_len}));
|
||||
auto input_mapped = input.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < samples.size(); ++i) {
|
||||
input_mapped(i) = samples[i];
|
||||
}
|
||||
for (; i < audio_win_len; ++i) {
|
||||
input_mapped(i) = 0.f;
|
||||
}
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
const int n_windows = 1;
|
||||
assert(outputs[0].shape().num_elemements() / n_features == n_windows);
|
||||
|
||||
auto mfcc_mapped = outputs[0].flat<float>();
|
||||
for (int i = 0; i < n_windows * n_features; ++i) {
|
||||
mfcc_output.push_back(mfcc_mapped(i));
|
||||
}
|
||||
#else
|
||||
// Feeding input_node
|
||||
float* input_samples = interpreter->typed_tensor<float>(input_samples_idx);
|
||||
for (int i = 0; i < samples.size(); ++i) {
|
||||
input_samples[i] = samples[i];
|
||||
}
|
||||
|
||||
interpreter->SetExecutionPlan(mfcc_exec_plan);
|
||||
TfLiteStatus status = interpreter->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
int n_windows = 1;
|
||||
TfLiteIntArray* out_dims = interpreter->tensor(mfccs_idx)->dims;
|
||||
int num_elements = 1;
|
||||
for (int i = 0; i < out_dims->size; ++i) {
|
||||
num_elements *= out_dims->data[i];
|
||||
}
|
||||
assert(num_elements / n_features == n_windows);
|
||||
|
||||
float* outputs = interpreter->typed_tensor<float>(mfccs_idx);
|
||||
for (int i = 0; i < n_windows * n_features; ++i) {
|
||||
mfcc_output.push_back(outputs[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
char*
|
||||
ModelState::decode(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = ModelState::decode_raw(state);
|
||||
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
vector<Output>
|
||||
ModelState::decode_raw(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decoder_decode(state, *alphabet, beam_width, scorer);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
Metadata*
|
||||
ModelState::decode_metadata(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decode_raw(state);
|
||||
|
||||
std::unique_ptr<Metadata> metadata(new Metadata());
|
||||
metadata->num_items = out[0].tokens.size();
|
||||
metadata->probability = out[0].probability;
|
||||
|
||||
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
|
||||
|
||||
// Loop through each character
|
||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||
items[i].character = strdup(alphabet->StringFromLabel(out[0].tokens[i]).c_str());
|
||||
items[i].timestep = out[0].timesteps[i];
|
||||
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step / sample_rate);
|
||||
|
||||
if (items[i].start_time < 0) {
|
||||
items[i].start_time = 0;
|
||||
}
|
||||
}
|
||||
|
||||
metadata->items = items.release();
|
||||
return metadata.release();
|
||||
}
|
||||
|
||||
#ifdef USE_TFLITE
|
||||
int
|
||||
tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, const char* name)
|
||||
{
|
||||
int rv = -1;
|
||||
|
||||
for (int i = 0; i < list.size(); ++i) {
|
||||
const string& node_name = ctx->interpreter->tensor(list[i])->name;
|
||||
if (node_name.compare(string(name)) == 0) {
|
||||
rv = i;
|
||||
}
|
||||
}
|
||||
|
||||
assert(rv >= 0);
|
||||
return rv;
|
||||
}
|
||||
|
||||
int
|
||||
tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
|
||||
{
|
||||
return ctx->interpreter->inputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name)];
|
||||
}
|
||||
|
||||
int
|
||||
tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
|
||||
{
|
||||
return ctx->interpreter->outputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name)];
|
||||
}
|
||||
|
||||
void push_back_if_not_present(std::deque<int>& list, int value) {
|
||||
if (std::find(list.begin(), list.end(), value) == list.end()) {
|
||||
list.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Backwards BFS on the node DAG. At each iteration we get the next tensor id
|
||||
// from the frontier list, then for each node which has that tensor id as an
|
||||
// output, add it to the parent list, and add its input tensors to the frontier
|
||||
// list. Because we start from the final tensor and work backwards to the inputs,
|
||||
// the parents list is constructed in reverse, adding elements to its front.
|
||||
std::vector<int>
|
||||
tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
|
||||
{
|
||||
std::deque<int> parents;
|
||||
std::deque<int> frontier;
|
||||
frontier.push_back(tensor_id);
|
||||
while (!frontier.empty()) {
|
||||
int next_tensor_id = frontier.front();
|
||||
frontier.pop_front();
|
||||
// Find all nodes that have next_tensor_id as an output
|
||||
for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) {
|
||||
TfLiteNode node = interpreter->node_and_registration(node_id)->first;
|
||||
// Search node outputs for the tensor we're looking for
|
||||
for (int i = 0; i < node.outputs->size; ++i) {
|
||||
if (node.outputs->data[i] == next_tensor_id) {
|
||||
// This node is part of the parent tree, add it to the parent list and
|
||||
// add its input tensors to the frontier list
|
||||
parents.push_front(node_id);
|
||||
for (int j = 0; j < node.inputs->size; ++j) {
|
||||
push_back_if_not_present(frontier, node.inputs->data[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::vector<int>(parents.begin(), parents.end());
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
int
|
||||
DS_CreateModel(const char* aModelPath,
|
||||
unsigned int aNCep,
|
||||
@ -642,15 +271,6 @@ DS_CreateModel(const char* aModelPath,
|
||||
unsigned int aBeamWidth,
|
||||
ModelState** retval)
|
||||
{
|
||||
std::unique_ptr<ModelState> model(new ModelState());
|
||||
#ifndef USE_TFLITE
|
||||
model->mmap_env = new MemmappedEnv(Env::Default());
|
||||
#endif // USE_TFLITE
|
||||
model->ncep = aNCep;
|
||||
model->ncontext = aNContext;
|
||||
model->alphabet = new Alphabet(aAlphabetConfigPath);
|
||||
model->beam_width = aBeamWidth;
|
||||
|
||||
*retval = nullptr;
|
||||
|
||||
DS_PrintVersions();
|
||||
@ -660,183 +280,26 @@ DS_CreateModel(const char* aModelPath,
|
||||
return DS_ERR_NO_MODEL;
|
||||
}
|
||||
|
||||
std::unique_ptr<ModelState> model(
|
||||
#ifndef USE_TFLITE
|
||||
Status status;
|
||||
SessionOptions options;
|
||||
new TFModelState()
|
||||
#else
|
||||
new TFLiteModelState()
|
||||
#endif
|
||||
);
|
||||
|
||||
bool is_mmap = std::string(aModelPath).find(".pbmm") != std::string::npos;
|
||||
if (!is_mmap) {
|
||||
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
|
||||
} else {
|
||||
status = model->mmap_env->InitializeFromFile(aModelPath);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(::OptimizerOptions::L0);
|
||||
options.env = model->mmap_env;
|
||||
if (!model) {
|
||||
std::cerr << "Could not allocate model state." << std::endl;
|
||||
return DS_ERR_FAIL_CREATE_MODEL;
|
||||
}
|
||||
|
||||
status = NewSession(options, &model->session);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_INIT_SESS;
|
||||
}
|
||||
|
||||
if (is_mmap) {
|
||||
status = ReadBinaryProto(model->mmap_env,
|
||||
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&model->graph_def);
|
||||
} else {
|
||||
status = ReadBinaryProto(Env::Default(), aModelPath, &model->graph_def);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_READ_PROTOBUF;
|
||||
}
|
||||
|
||||
status = model->session->Create(model->graph_def);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_CREATE_SESS;
|
||||
}
|
||||
|
||||
int graph_version = model->graph_def.version();
|
||||
if (graph_version < DS_GRAPH_VERSION) {
|
||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
||||
<< "incompatible with minimum version supported by this client ("
|
||||
<< DS_GRAPH_VERSION << "). See "
|
||||
<< "https://github.com/mozilla/DeepSpeech/#model-compatibility "
|
||||
<< "for more information" << std::endl;
|
||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
for (int i = 0; i < model->graph_def.node_size(); ++i) {
|
||||
NodeDef node = model->graph_def.node(i);
|
||||
if (node.name() == "input_node") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
model->n_steps = shape.dim(1).size();
|
||||
model->n_context = (shape.dim(2).size()-1)/2;
|
||||
model->n_features = shape.dim(3).size();
|
||||
model->mfcc_feats_per_timestep = shape.dim(2).size() * shape.dim(3).size();
|
||||
} else if (node.name() == "logits_shape") {
|
||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
||||
if (final_dim_size != model->alphabet->GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << model->alphabet->GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
} else if (node.name() == "model_metadata") {
|
||||
int sample_rate = node.attr().at("sample_rate").i();
|
||||
model->sample_rate = sample_rate;
|
||||
int win_len_ms = node.attr().at("feature_win_len").i();
|
||||
int win_step_ms = node.attr().at("feature_win_step").i();
|
||||
model->audio_win_len = sample_rate * (win_len_ms / 1000.0);
|
||||
model->audio_win_step = sample_rate * (win_step_ms / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if (model->n_context == -1 || model->n_features == -1) {
|
||||
std::cerr << "Error: Could not infer input shape from model file. "
|
||||
<< "Make sure input_node is a 4D tensor with shape "
|
||||
<< "[batch_size=1, time, window_size, n_features]."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_SHAPE;
|
||||
int err = model->init(aModelPath, aNCep, aNContext, aAlphabetConfigPath, aBeamWidth);
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
*retval = model.release();
|
||||
return DS_ERR_OK;
|
||||
#else // USE_TFLITE
|
||||
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
|
||||
if (!model->fbmodel) {
|
||||
std::cerr << "Error at reading model file " << aModelPath << std::endl;
|
||||
return DS_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
|
||||
if (!model->interpreter) {
|
||||
std::cerr << "Error at InterpreterBuilder for model file " << aModelPath << std::endl;
|
||||
return DS_ERR_FAIL_INTERPRETER;
|
||||
}
|
||||
|
||||
model->interpreter->AllocateTensors();
|
||||
model->interpreter->SetNumThreads(4);
|
||||
|
||||
// Query all the index once
|
||||
model->input_node_idx = tflite_get_input_tensor_by_name(model.get(), "input_node");
|
||||
model->previous_state_c_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_c");
|
||||
model->previous_state_h_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_h");
|
||||
model->input_samples_idx = tflite_get_input_tensor_by_name(model.get(), "input_samples");
|
||||
model->logits_idx = tflite_get_output_tensor_by_name(model.get(), "logits");
|
||||
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
|
||||
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
|
||||
model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs");
|
||||
|
||||
// When we call Interpreter::Invoke, the whole graph is executed by default,
|
||||
// which means every time compute_mfcc is called the entire acoustic model is
|
||||
// also executed. To workaround that problem, we walk up the dependency DAG
|
||||
// from the mfccs output tensor to find all the relevant nodes required for
|
||||
// feature computation, building an execution plan that runs just those nodes.
|
||||
auto mfcc_plan = tflite_find_parent_node_ids(model->interpreter.get(), model->mfccs_idx);
|
||||
auto orig_plan = model->interpreter->execution_plan();
|
||||
|
||||
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
|
||||
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan](int elem) {
|
||||
return std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end();
|
||||
});
|
||||
orig_plan.erase(erase_begin, orig_plan.end());
|
||||
|
||||
model->acoustic_exec_plan = std::move(orig_plan);
|
||||
model->mfcc_exec_plan = std::move(mfcc_plan);
|
||||
|
||||
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
|
||||
|
||||
model->n_steps = dims_input_node->data[1];
|
||||
model->n_context = (dims_input_node->data[2] - 1 ) / 2;
|
||||
model->n_features = dims_input_node->data[3];
|
||||
model->mfcc_feats_per_timestep = dims_input_node->data[2] * dims_input_node->data[3];
|
||||
|
||||
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->logits_idx)->dims;
|
||||
const int final_dim_size = dims_logits->data[1] - 1;
|
||||
if (final_dim_size != model->alphabet->GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << model->alphabet->GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
|
||||
TfLiteIntArray* dims_c = model->interpreter->tensor(model->previous_state_c_idx)->dims;
|
||||
TfLiteIntArray* dims_h = model->interpreter->tensor(model->previous_state_h_idx)->dims;
|
||||
assert(dims_c->data[1] == dims_h->data[1]);
|
||||
|
||||
model->previous_state_size = dims_c->data[1];
|
||||
model->previous_state_c_.reset(new float[model->previous_state_size]());
|
||||
model->previous_state_h_.reset(new float[model->previous_state_size]());
|
||||
|
||||
// Set initial values for previous_state_c and previous_state_h
|
||||
memset(model->previous_state_c_.get(), 0, sizeof(float) * model->previous_state_size);
|
||||
memset(model->previous_state_h_.get(), 0, sizeof(float) * model->previous_state_size);
|
||||
|
||||
*retval = model.release();
|
||||
return DS_ERR_OK;
|
||||
#endif // USE_TFLITE
|
||||
}
|
||||
|
||||
void
|
||||
@ -854,10 +317,10 @@ DS_EnableDecoderWithLM(ModelState* aCtx,
|
||||
float aLMBeta)
|
||||
{
|
||||
try {
|
||||
aCtx->scorer = new Scorer(aLMAlpha, aLMBeta,
|
||||
aLMPath ? aLMPath : "",
|
||||
aTriePath ? aTriePath : "",
|
||||
*aCtx->alphabet);
|
||||
aCtx->scorer_ = new Scorer(aLMAlpha, aLMBeta,
|
||||
aLMPath ? aLMPath : "",
|
||||
aTriePath ? aTriePath : "",
|
||||
*aCtx->alphabet_);
|
||||
return DS_ERR_OK;
|
||||
} catch (...) {
|
||||
return DS_ERR_INVALID_LM;
|
||||
@ -872,41 +335,28 @@ DS_SetupStream(ModelState* aCtx,
|
||||
{
|
||||
*retval = nullptr;
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
Status status = aCtx->session->Run({}, {}, {"initialize_state"}, nullptr);
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << std::endl;
|
||||
return DS_ERR_FAIL_RUN_SESS;
|
||||
}
|
||||
#endif // USE_TFLITE
|
||||
|
||||
std::unique_ptr<StreamingState> ctx(new StreamingState());
|
||||
if (!ctx) {
|
||||
std::cerr << "Could not allocate streaming state." << std::endl;
|
||||
return DS_ERR_FAIL_CREATE_STREAM;
|
||||
}
|
||||
|
||||
const size_t num_classes = aCtx->alphabet->GetSize() + 1; // +1 for blank
|
||||
const size_t num_classes = aCtx->alphabet_->GetSize() + 1; // +1 for blank
|
||||
|
||||
// Default initial allocation = 3 seconds.
|
||||
if (aPreAllocFrames == 0) {
|
||||
aPreAllocFrames = 150;
|
||||
}
|
||||
|
||||
ctx->audio_buffer.reserve(aCtx->audio_win_len);
|
||||
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
||||
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
|
||||
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);
|
||||
ctx->audio_buffer_.reserve(aCtx->audio_win_len_);
|
||||
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
|
||||
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
|
||||
ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_);
|
||||
ctx->previous_state_c_.resize(aCtx->state_size_, 0.f);
|
||||
ctx->previous_state_h_.resize(aCtx->state_size_, 0.f);
|
||||
ctx->model_ = aCtx;
|
||||
|
||||
ctx->model = aCtx;
|
||||
|
||||
#ifdef USE_TFLITE
|
||||
/* Ensure previous_state_{c,h} are not holding previous stream value */
|
||||
memset(ctx->model->previous_state_c_.get(), 0, sizeof(float) * ctx->model->previous_state_size);
|
||||
memset(ctx->model->previous_state_h_.get(), 0, sizeof(float) * ctx->model->previous_state_size);
|
||||
#endif // USE_TFLITE
|
||||
|
||||
ctx->decoder_state.reset(decoder_init(*aCtx->alphabet, num_classes, aCtx->scorer));
|
||||
ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_));
|
||||
|
||||
*retval = ctx.release();
|
||||
return DS_ERR_OK;
|
||||
@ -1012,4 +462,3 @@ DS_PrintVersions() {
|
||||
LOGD("DeepSpeech: %s", ds_git_version());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -52,6 +52,7 @@ enum DeepSpeech_Error_Codes
|
||||
DS_ERR_FAIL_CREATE_STREAM = 0x3004,
|
||||
DS_ERR_FAIL_READ_PROTOBUF = 0x3005,
|
||||
DS_ERR_FAIL_CREATE_SESS = 0x3006,
|
||||
DS_ERR_FAIL_CREATE_MODEL = 0x3007,
|
||||
};
|
||||
|
||||
/**
|
||||
|
82
native_client/modelstate.cc
Normal file
82
native_client/modelstate.cc
Normal file
@ -0,0 +1,82 @@
|
||||
#include <vector>
|
||||
|
||||
#include "ctcdecode/ctc_beam_search_decoder.h"
|
||||
|
||||
#include "modelstate.h"
|
||||
|
||||
using std::vector;
|
||||
|
||||
ModelState::ModelState()
|
||||
: alphabet_(nullptr)
|
||||
, scorer_(nullptr)
|
||||
, beam_width_(-1)
|
||||
, n_steps_(-1)
|
||||
, n_context_(-1)
|
||||
, n_features_(-1)
|
||||
, mfcc_feats_per_timestep_(-1)
|
||||
, sample_rate_(DEFAULT_SAMPLE_RATE)
|
||||
, audio_win_len_(DEFAULT_WINDOW_LENGTH)
|
||||
, audio_win_step_(DEFAULT_WINDOW_STEP)
|
||||
, state_size_(-1)
|
||||
{
|
||||
}
|
||||
|
||||
ModelState::~ModelState()
|
||||
{
|
||||
delete scorer_;
|
||||
delete alphabet_;
|
||||
}
|
||||
|
||||
int
|
||||
ModelState::init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width)
|
||||
{
|
||||
n_features_ = n_features;
|
||||
n_context_ = n_context;
|
||||
alphabet_ = new Alphabet(alphabet_path);
|
||||
beam_width_ = beam_width;
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
vector<Output>
|
||||
ModelState::decode_raw(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decoder_decode(state, *alphabet_, beam_width_, scorer_);
|
||||
return out;
|
||||
}
|
||||
|
||||
char*
|
||||
ModelState::decode(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decode_raw(state);
|
||||
return strdup(alphabet_->LabelsToString(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
Metadata*
|
||||
ModelState::decode_metadata(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decode_raw(state);
|
||||
|
||||
std::unique_ptr<Metadata> metadata(new Metadata());
|
||||
metadata->num_items = out[0].tokens.size();
|
||||
metadata->probability = out[0].probability;
|
||||
|
||||
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
|
||||
|
||||
// Loop through each character
|
||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||
items[i].character = strdup(alphabet_->StringFromLabel(out[0].tokens[i]).c_str());
|
||||
items[i].timestep = out[0].timesteps[i];
|
||||
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
|
||||
|
||||
if (items[i].start_time < 0) {
|
||||
items[i].start_time = 0;
|
||||
}
|
||||
}
|
||||
|
||||
metadata->items = items.release();
|
||||
return metadata.release();
|
||||
}
|
93
native_client/modelstate.h
Normal file
93
native_client/modelstate.h
Normal file
@ -0,0 +1,93 @@
|
||||
#ifndef MODELSTATE_H
|
||||
#define MODELSTATE_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "deepspeech.h"
|
||||
#include "alphabet.h"
|
||||
|
||||
#include "ctcdecode/scorer.h"
|
||||
#include "ctcdecode/output.h"
|
||||
#include "ctcdecode/decoderstate.h"
|
||||
|
||||
struct ModelState {
|
||||
//TODO: infer batch size from model/use dynamic batch size
|
||||
static constexpr unsigned int BATCH_SIZE = 1;
|
||||
|
||||
static constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
|
||||
static constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||
static constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||
|
||||
Alphabet* alphabet_;
|
||||
Scorer* scorer_;
|
||||
unsigned int beam_width_;
|
||||
unsigned int n_steps_;
|
||||
unsigned int n_context_;
|
||||
unsigned int n_features_;
|
||||
unsigned int mfcc_feats_per_timestep_;
|
||||
unsigned int sample_rate_;
|
||||
unsigned int audio_win_len_;
|
||||
unsigned int audio_win_step_;
|
||||
unsigned int state_size_;
|
||||
|
||||
ModelState();
|
||||
virtual ~ModelState();
|
||||
|
||||
virtual int init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width);
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
||||
|
||||
/**
|
||||
* @brief Do a single inference step in the acoustic model, with:
|
||||
* input=mfcc
|
||||
* input_lengths=[n_frames]
|
||||
*
|
||||
* @param mfcc batch input data
|
||||
* @param n_frames number of timesteps in the data
|
||||
*
|
||||
* @param[out] output_logits Where to store computed logits.
|
||||
*/
|
||||
virtual void infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
std::vector<float>& logits_output,
|
||||
std::vector<float>& state_c_output,
|
||||
std::vector<float>& state_h_output) = 0;
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @param state Decoder state to use when decoding.
|
||||
*
|
||||
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
||||
*/
|
||||
virtual std::vector<Output> decode_raw(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @param state Decoder state to use when decoding.
|
||||
*
|
||||
* @return String representing the decoded text.
|
||||
*/
|
||||
virtual char* decode(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Return character-level metadata including letter timings.
|
||||
*
|
||||
* @param state Decoder state to use when decoding.
|
||||
*
|
||||
* @return Metadata struct containing MetadataItem structs for each character.
|
||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
||||
*/
|
||||
virtual Metadata* decode_metadata(DecoderState* state);
|
||||
};
|
||||
|
||||
#endif // MODELSTATE_H
|
78
native_client/test/concurrent_streams.py
Normal file
78
native_client/test/concurrent_streams.py
Normal file
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import wave
|
||||
|
||||
from deepspeech import Model
|
||||
|
||||
|
||||
# These constants control the beam search decoder
|
||||
|
||||
# Beam width used in the CTC decoder when building candidate transcriptions
|
||||
BEAM_WIDTH = 500
|
||||
|
||||
# The alpha hyperparameter of the CTC decoder. Language Model weight
|
||||
LM_ALPHA = 0.75
|
||||
|
||||
# The beta hyperparameter of the CTC decoder. Word insertion bonus.
|
||||
LM_BETA = 1.85
|
||||
|
||||
|
||||
# These constants are tied to the shape of the graph used (changing them changes
|
||||
# the geometry of the first layer), so make sure you use the same constants that
|
||||
# were used during training
|
||||
|
||||
# Number of MFCC features to use
|
||||
N_FEATURES = 26
|
||||
|
||||
# Size of the context window used for producing timesteps in the input vector
|
||||
N_CONTEXT = 9
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to the model (protocol buffer binary file)')
|
||||
parser.add_argument('--alphabet', required=True,
|
||||
help='Path to the configuration file specifying the alphabet used by the network')
|
||||
parser.add_argument('--lm', nargs='?',
|
||||
help='Path to the language model binary file')
|
||||
parser.add_argument('--trie', nargs='?',
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--audio1', required=True,
|
||||
help='First audio file to use in interleaved streams')
|
||||
parser.add_argument('--audio2', required=True,
|
||||
help='Second audio file to use in interleaved streams')
|
||||
args = parser.parse_args()
|
||||
|
||||
ds = Model(args.model, N_FEATURES, N_CONTEXT, args.alphabet, BEAM_WIDTH)
|
||||
|
||||
if args.lm and args.trie:
|
||||
ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_ALPHA, LM_BETA)
|
||||
|
||||
with wave.open(args.audio1, 'rb') as fin:
|
||||
fs1 = fin.getframerate()
|
||||
audio1 = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||
|
||||
with wave.open(args.audio2, 'rb') as fin:
|
||||
fs2 = fin.getframerate()
|
||||
audio2 = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
|
||||
|
||||
stream1 = ds.setupStream(sample_rate=fs1)
|
||||
stream2 = ds.setupStream(sample_rate=fs2)
|
||||
|
||||
splits1 = np.array_split(audio1, 10)
|
||||
splits2 = np.array_split(audio2, 10)
|
||||
|
||||
for part1, part2 in zip(splits1, splits2):
|
||||
ds.feedAudioContent(stream1, part1)
|
||||
ds.feedAudioContent(stream2, part2)
|
||||
|
||||
print(ds.finishStream(stream1))
|
||||
print(ds.finishStream(stream2))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
269
native_client/tflitemodelstate.cc
Normal file
269
native_client/tflitemodelstate.cc
Normal file
@ -0,0 +1,269 @@
|
||||
#include "tflitemodelstate.h"
|
||||
|
||||
using namespace tflite;
|
||||
using std::vector;
|
||||
|
||||
int
|
||||
TFLiteModelState::get_tensor_by_name(const vector<int>& list,
|
||||
const char* name)
|
||||
{
|
||||
int rv = -1;
|
||||
|
||||
for (int i = 0; i < list.size(); ++i) {
|
||||
const string& node_name = interpreter_->tensor(list[i])->name;
|
||||
if (node_name.compare(string(name)) == 0) {
|
||||
rv = i;
|
||||
}
|
||||
}
|
||||
|
||||
assert(rv >= 0);
|
||||
return rv;
|
||||
}
|
||||
|
||||
int
|
||||
TFLiteModelState::get_input_tensor_by_name(const char* name)
|
||||
{
|
||||
int idx = get_tensor_by_name(interpreter_->inputs(), name);
|
||||
return interpreter_->inputs()[idx];
|
||||
}
|
||||
|
||||
int
|
||||
TFLiteModelState::get_output_tensor_by_name(const char* name)
|
||||
{
|
||||
int idx = get_tensor_by_name(interpreter_->outputs(), name);
|
||||
return interpreter_->outputs()[idx];
|
||||
}
|
||||
|
||||
void
|
||||
push_back_if_not_present(std::deque<int>& list, int value)
|
||||
{
|
||||
if (std::find(list.begin(), list.end(), value) == list.end()) {
|
||||
list.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Backwards BFS on the node DAG. At each iteration we get the next tensor id
|
||||
// from the frontier list, then for each node which has that tensor id as an
|
||||
// output, add it to the parent list, and add its input tensors to the frontier
|
||||
// list. Because we start from the final tensor and work backwards to the inputs,
|
||||
// the parents list is constructed in reverse, adding elements to its front.
|
||||
vector<int>
|
||||
TFLiteModelState::find_parent_node_ids(int tensor_id)
|
||||
{
|
||||
std::deque<int> parents;
|
||||
std::deque<int> frontier;
|
||||
frontier.push_back(tensor_id);
|
||||
while (!frontier.empty()) {
|
||||
int next_tensor_id = frontier.front();
|
||||
frontier.pop_front();
|
||||
// Find all nodes that have next_tensor_id as an output
|
||||
for (int node_id = 0; node_id < interpreter_->nodes_size(); ++node_id) {
|
||||
TfLiteNode node = interpreter_->node_and_registration(node_id)->first;
|
||||
// Search node outputs for the tensor we're looking for
|
||||
for (int i = 0; i < node.outputs->size; ++i) {
|
||||
if (node.outputs->data[i] == next_tensor_id) {
|
||||
// This node is part of the parent tree, add it to the parent list and
|
||||
// add its input tensors to the frontier list
|
||||
parents.push_front(node_id);
|
||||
for (int j = 0; j < node.inputs->size; ++j) {
|
||||
push_back_if_not_present(frontier, node.inputs->data[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vector<int>(parents.begin(), parents.end());
|
||||
}
|
||||
|
||||
TFLiteModelState::TFLiteModelState()
|
||||
: ModelState()
|
||||
, interpreter_(nullptr)
|
||||
, fbmodel_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
TFLiteModelState::~TFLiteModelState()
|
||||
{
|
||||
}
|
||||
|
||||
int
|
||||
TFLiteModelState::init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width)
|
||||
{
|
||||
int err = ModelState::init(model_path, n_features, n_context, alphabet_path, beam_width);
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_path);
|
||||
if (!fbmodel_) {
|
||||
std::cerr << "Error at reading model file " << model_path << std::endl;
|
||||
return DS_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
tflite::InterpreterBuilder(*fbmodel_, resolver)(&interpreter_);
|
||||
if (!interpreter_) {
|
||||
std::cerr << "Error at InterpreterBuilder for model file " << model_path << std::endl;
|
||||
return DS_ERR_FAIL_INTERPRETER;
|
||||
}
|
||||
|
||||
interpreter_->AllocateTensors();
|
||||
interpreter_->SetNumThreads(4);
|
||||
|
||||
// Query all the index once
|
||||
input_node_idx_ = get_input_tensor_by_name("input_node");
|
||||
previous_state_c_idx_ = get_input_tensor_by_name("previous_state_c");
|
||||
previous_state_h_idx_ = get_input_tensor_by_name("previous_state_h");
|
||||
input_samples_idx_ = get_input_tensor_by_name("input_samples");
|
||||
logits_idx_ = get_output_tensor_by_name("logits");
|
||||
new_state_c_idx_ = get_output_tensor_by_name("new_state_c");
|
||||
new_state_h_idx_ = get_output_tensor_by_name("new_state_h");
|
||||
mfccs_idx_ = get_output_tensor_by_name("mfccs");
|
||||
|
||||
// When we call Interpreter::Invoke, the whole graph is executed by default,
|
||||
// which means every time compute_mfcc is called the entire acoustic model is
|
||||
// also executed. To workaround that problem, we walk up the dependency DAG
|
||||
// from the mfccs output tensor to find all the relevant nodes required for
|
||||
// feature computation, building an execution plan that runs just those nodes.
|
||||
auto mfcc_plan = find_parent_node_ids(mfccs_idx_);
|
||||
auto orig_plan = interpreter_->execution_plan();
|
||||
|
||||
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
|
||||
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan](int elem) {
|
||||
return std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end();
|
||||
});
|
||||
orig_plan.erase(erase_begin, orig_plan.end());
|
||||
|
||||
acoustic_exec_plan_ = std::move(orig_plan);
|
||||
mfcc_exec_plan_ = std::move(mfcc_plan);
|
||||
|
||||
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;
|
||||
|
||||
n_steps_ = dims_input_node->data[1];
|
||||
n_context_ = (dims_input_node->data[2] - 1) / 2;
|
||||
n_features_ = dims_input_node->data[3];
|
||||
mfcc_feats_per_timestep_ = dims_input_node->data[2] * dims_input_node->data[3];
|
||||
|
||||
TfLiteIntArray* dims_logits = interpreter_->tensor(logits_idx_)->dims;
|
||||
const int final_dim_size = dims_logits->data[1] - 1;
|
||||
if (final_dim_size != alphabet_->GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << alphabet_->GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
|
||||
TfLiteIntArray* dims_c = interpreter_->tensor(previous_state_c_idx_)->dims;
|
||||
TfLiteIntArray* dims_h = interpreter_->tensor(previous_state_h_idx_)->dims;
|
||||
assert(dims_c->data[1] == dims_h->data[1]);
|
||||
assert(state_size_ > 0);
|
||||
state_size_ = dims_c->data[1];
|
||||
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
// Copy contents of vec into the tensor with index tensor_idx.
|
||||
// If vec.size() < num_elements, set the remainder of the tensor values to zero.
|
||||
void
|
||||
TFLiteModelState::copy_vector_to_tensor(const vector<float>& vec,
|
||||
int tensor_idx,
|
||||
int num_elements)
|
||||
{
|
||||
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
|
||||
int i;
|
||||
for (i = 0; i < vec.size(); ++i) {
|
||||
tensor[i] = vec[i];
|
||||
}
|
||||
for (; i < num_elements; ++i) {
|
||||
tensor[i] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy num_elements elements from the tensor with index tensor_idx into vec
|
||||
void
|
||||
TFLiteModelState::copy_tensor_to_vector(int tensor_idx,
|
||||
int num_elements,
|
||||
vector<float>& vec)
|
||||
{
|
||||
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
vec.push_back(tensor[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TFLiteModelState::infer(const vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const vector<float>& previous_state_c,
|
||||
const vector<float>& previous_state_h,
|
||||
vector<float>& logits_output,
|
||||
vector<float>& state_c_output,
|
||||
vector<float>& state_h_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||
|
||||
// Feeding input_node
|
||||
copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_);
|
||||
|
||||
// Feeding previous_state_c, previous_state_h
|
||||
assert(previous_state_c.size() == state_size_);
|
||||
copy_vector_to_tensor(previous_state_c, previous_state_c_idx_, state_size_);
|
||||
assert(previous_state_h.size() == state_size_);
|
||||
copy_vector_to_tensor(previous_state_h, previous_state_h_idx_, state_size_);
|
||||
|
||||
interpreter_->SetExecutionPlan(acoustic_exec_plan_);
|
||||
TfLiteStatus status = interpreter_->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
copy_tensor_to_vector(logits_idx_, n_frames * BATCH_SIZE * num_classes, logits_output);
|
||||
|
||||
state_c_output.clear();
|
||||
state_c_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(new_state_c_idx_, state_size_, state_c_output);
|
||||
|
||||
state_h_output.clear();
|
||||
state_h_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(new_state_h_idx_, state_size_, state_h_output);
|
||||
}
|
||||
|
||||
void
|
||||
TFLiteModelState::compute_mfcc(const vector<float>& samples,
|
||||
vector<float>& mfcc_output)
|
||||
{
|
||||
// Feeding input_node
|
||||
copy_vector_to_tensor(samples, input_samples_idx_, samples.size());
|
||||
|
||||
TfLiteStatus status = interpreter_->SetExecutionPlan(mfcc_exec_plan_);
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error setting execution plan: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
status = interpreter_->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
int n_windows = 1;
|
||||
TfLiteIntArray* out_dims = interpreter_->tensor(mfccs_idx_)->dims;
|
||||
int num_elements = 1;
|
||||
for (int i = 0; i < out_dims->size; ++i) {
|
||||
num_elements *= out_dims->data[i];
|
||||
}
|
||||
assert(num_elements / n_features_ == n_windows);
|
||||
|
||||
copy_tensor_to_vector(mfccs_idx_, n_windows * n_features_, mfcc_output);
|
||||
}
|
63
native_client/tflitemodelstate.h
Normal file
63
native_client/tflitemodelstate.h
Normal file
@ -0,0 +1,63 @@
|
||||
#ifndef TFLITEMODELSTATE_H
|
||||
#define TFLITEMODELSTATE_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
#include "modelstate.h"
|
||||
|
||||
struct TFLiteModelState : public ModelState
|
||||
{
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
std::unique_ptr<tflite::FlatBufferModel> fbmodel_;
|
||||
|
||||
int input_node_idx_;
|
||||
int previous_state_c_idx_;
|
||||
int previous_state_h_idx_;
|
||||
int input_samples_idx_;
|
||||
|
||||
int logits_idx_;
|
||||
int new_state_c_idx_;
|
||||
int new_state_h_idx_;
|
||||
int mfccs_idx_;
|
||||
|
||||
std::vector<int> acoustic_exec_plan_;
|
||||
std::vector<int> mfcc_exec_plan_;
|
||||
|
||||
TFLiteModelState();
|
||||
virtual ~TFLiteModelState();
|
||||
|
||||
virtual int init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width) override;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||
std::vector<float>& mfcc_output) override;
|
||||
|
||||
virtual void infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
std::vector<float>& logits_output,
|
||||
std::vector<float>& state_c_output,
|
||||
std::vector<float>& state_h_output) override;
|
||||
|
||||
private:
|
||||
int get_tensor_by_name(const std::vector<int>& list, const char* name);
|
||||
int get_input_tensor_by_name(const char* name);
|
||||
int get_output_tensor_by_name(const char* name);
|
||||
std::vector<int> find_parent_node_ids(int tensor_id);
|
||||
void copy_vector_to_tensor(const std::vector<float>& vec,
|
||||
int tensor_idx,
|
||||
int num_elements);
|
||||
void copy_tensor_to_vector(int tensor_idx,
|
||||
int num_elements,
|
||||
std::vector<float>& vec);
|
||||
};
|
||||
|
||||
#endif // TFLITEMODELSTATE_H
|
230
native_client/tfmodelstate.cc
Normal file
230
native_client/tfmodelstate.cc
Normal file
@ -0,0 +1,230 @@
|
||||
#include "tfmodelstate.h"
|
||||
|
||||
#include "ds_graph_version.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using std::vector;
|
||||
|
||||
TFModelState::TFModelState()
|
||||
: ModelState()
|
||||
, mmap_env_(nullptr)
|
||||
, session_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
TFModelState::~TFModelState()
|
||||
{
|
||||
if (session_) {
|
||||
Status status = session_->Close();
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
||||
}
|
||||
}
|
||||
delete mmap_env_;
|
||||
}
|
||||
|
||||
int
|
||||
TFModelState::init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width)
|
||||
{
|
||||
int err = ModelState::init(model_path, n_features, n_context, alphabet_path, beam_width);
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
Status status;
|
||||
SessionOptions options;
|
||||
|
||||
mmap_env_ = new MemmappedEnv(Env::Default());
|
||||
|
||||
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
|
||||
if (!is_mmap) {
|
||||
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
|
||||
} else {
|
||||
status = mmap_env_->InitializeFromFile(model_path);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(::OptimizerOptions::L0);
|
||||
options.env = mmap_env_;
|
||||
}
|
||||
|
||||
status = NewSession(options, &session_);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_INIT_SESS;
|
||||
}
|
||||
|
||||
if (is_mmap) {
|
||||
status = ReadBinaryProto(mmap_env_,
|
||||
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&graph_def_);
|
||||
} else {
|
||||
status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_READ_PROTOBUF;
|
||||
}
|
||||
|
||||
status = session_->Create(graph_def_);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_CREATE_SESS;
|
||||
}
|
||||
|
||||
int graph_version = graph_def_.version();
|
||||
if (graph_version < DS_GRAPH_VERSION) {
|
||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
||||
<< "incompatible with minimum version supported by this client ("
|
||||
<< DS_GRAPH_VERSION << "). See "
|
||||
<< "https://github.com/mozilla/DeepSpeech/#model-compatibility "
|
||||
<< "for more information" << std::endl;
|
||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
||||
NodeDef node = graph_def_.node(i);
|
||||
if (node.name() == "input_node") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
n_steps_ = shape.dim(1).size();
|
||||
n_context_ = (shape.dim(2).size()-1)/2;
|
||||
n_features_ = shape.dim(3).size();
|
||||
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
|
||||
} else if (node.name() == "previous_state_c") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
state_size_ = shape.dim(1).size();
|
||||
} else if (node.name() == "logits_shape") {
|
||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
||||
if (final_dim_size != alphabet_->GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << alphabet_->GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
} else if (node.name() == "model_metadata") {
|
||||
sample_rate_ = node.attr().at("sample_rate").i();
|
||||
int win_len_ms = node.attr().at("feature_win_len").i();
|
||||
int win_step_ms = node.attr().at("feature_win_step").i();
|
||||
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
||||
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if (n_context_ == -1 || n_features_ == -1) {
|
||||
std::cerr << "Error: Could not infer input shape from model file. "
|
||||
<< "Make sure input_node is a 4D tensor with shape "
|
||||
<< "[batch_size=1, time, window_size, n_features]."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_SHAPE;
|
||||
}
|
||||
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
Tensor
|
||||
tensor_from_vector(const std::vector<float>& vec, const TensorShape& shape)
|
||||
{
|
||||
Tensor ret(DT_FLOAT, shape);
|
||||
auto ret_mapped = ret.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < vec.size(); ++i) {
|
||||
ret_mapped(i) = vec[i];
|
||||
}
|
||||
for (; i < shape.num_elements(); ++i) {
|
||||
ret_mapped(i) = 0.f;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
copy_tensor_to_vector(const Tensor& tensor, vector<float>& vec, int num_elements = -1)
|
||||
{
|
||||
auto tensor_mapped = tensor.flat<float>();
|
||||
if (num_elements == -1) {
|
||||
num_elements = tensor.shape().num_elements();
|
||||
}
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
vec.push_back(tensor_mapped(i));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
vector<float>& logits_output,
|
||||
vector<float>& state_c_output,
|
||||
vector<float>& state_h_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||
|
||||
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
||||
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
Tensor previous_state_h_t = tensor_from_vector(previous_state_h, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
|
||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||
input_lengths.scalar<int>()() = n_frames;
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run(
|
||||
{
|
||||
{"input_node", input},
|
||||
{"input_lengths", input_lengths},
|
||||
{"previous_state_c", previous_state_c_t},
|
||||
{"previous_state_h", previous_state_h_t}
|
||||
},
|
||||
{"logits", "new_state_c", "new_state_h"},
|
||||
{},
|
||||
&outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
copy_tensor_to_vector(outputs[0], logits_output, n_frames * BATCH_SIZE * num_classes);
|
||||
|
||||
state_c_output.clear();
|
||||
state_c_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(outputs[1], state_c_output);
|
||||
|
||||
state_h_output.clear();
|
||||
state_h_output.reserve(state_size_);
|
||||
copy_tensor_to_vector(outputs[2], state_h_output);
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
{
|
||||
Tensor input = tensor_from_vector(samples, TensorShape({audio_win_len_}));
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
const int n_windows = 1;
|
||||
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
|
||||
copy_tensor_to_vector(outputs[0], mfcc_output);
|
||||
}
|
39
native_client/tfmodelstate.h
Normal file
39
native_client/tfmodelstate.h
Normal file
@ -0,0 +1,39 @@
|
||||
#ifndef TFMODELSTATE_H
|
||||
#define TFMODELSTATE_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
|
||||
#include "modelstate.h"
|
||||
|
||||
struct TFModelState : public ModelState
|
||||
{
|
||||
tensorflow::MemmappedEnv* mmap_env_;
|
||||
tensorflow::Session* session_;
|
||||
tensorflow::GraphDef graph_def_;
|
||||
|
||||
TFModelState();
|
||||
virtual ~TFModelState();
|
||||
|
||||
virtual int init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width) override;
|
||||
|
||||
virtual void infer(const std::vector<float>& mfcc,
|
||||
unsigned int n_frames,
|
||||
const std::vector<float>& previous_state_c,
|
||||
const std::vector<float>& previous_state_h,
|
||||
std::vector<float>& logits_output,
|
||||
std::vector<float>& state_c_output,
|
||||
std::vector<float>& state_h_output) override;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||
std::vector<float>& mfcc_output) override;
|
||||
};
|
||||
|
||||
#endif // TFMODELSTATE_H
|
@ -39,4 +39,6 @@ LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-bin
|
||||
|
||||
run_prod_inference_tests
|
||||
|
||||
run_prod_concurrent_stream_tests
|
||||
|
||||
virtualenv_deactivate "${pyver}" "${PYENV_NAME}"
|
||||
|
@ -419,6 +419,26 @@ run_all_inference_tests()
|
||||
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
|
||||
}
|
||||
|
||||
run_prod_concurrent_stream_tests()
|
||||
{
|
||||
set +e
|
||||
output=$(python ${TASKCLUSTER_TMP_DIR}/test_sources/concurrent_streams.py \
|
||||
--model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} \
|
||||
--alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt \
|
||||
--lm ${TASKCLUSTER_TMP_DIR}/lm.binary \
|
||||
--trie ${TASKCLUSTER_TMP_DIR}/trie \
|
||||
--audio1 ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav \
|
||||
--audio2 ${TASKCLUSTER_TMP_DIR}/new-home-in-the-stars-16k.wav 2>/dev/null)
|
||||
status=$?
|
||||
set -e
|
||||
|
||||
output1=$(echo ${output} | head -n 1)
|
||||
output2=$(echo ${output} | tail -n 1)
|
||||
|
||||
assert_correct_ldc93s1_prodmodel "${output1}" "${status}"
|
||||
assert_correct_inference "${output2}" "i must find a new home in the stars" "${status}"
|
||||
}
|
||||
|
||||
run_prod_inference_tests()
|
||||
{
|
||||
set +e
|
||||
@ -549,6 +569,7 @@ download_data()
|
||||
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/alphabet.txt ${TASKCLUSTER_TMP_DIR}/alphabet.txt
|
||||
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/vocab.pruned.lm ${TASKCLUSTER_TMP_DIR}/lm.binary
|
||||
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/vocab.trie ${TASKCLUSTER_TMP_DIR}/trie
|
||||
cp -R ${DS_ROOT_TASK}/DeepSpeech/ds/native_client/test ${TASKCLUSTER_TMP_DIR}/test_sources
|
||||
}
|
||||
|
||||
download_material()
|
||||
|
@ -73,7 +73,6 @@ def create_flags():
|
||||
f.DEFINE_string('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
|
||||
f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models')
|
||||
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||
f.DEFINE_boolean('use_seq_length', True, 'have sequence_length in the exported graph(will make tfcompile unhappy)')
|
||||
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user