Merge branch 'more-metadata'
This commit is contained in:
commit
5779d298e1
@ -534,9 +534,9 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
batch_size = batch_size if batch_size > 0 else None
|
batch_size = batch_size if batch_size > 0 else None
|
||||||
|
|
||||||
# Create feature computation graph
|
# Create feature computation graph
|
||||||
input_samples = tf.placeholder(tf.float32, [512], 'input_samples')
|
input_samples = tf.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||||
samples = tf.expand_dims(input_samples, -1)
|
samples = tf.expand_dims(input_samples, -1)
|
||||||
mfccs, _ = samples_to_mfccs(samples, 16000)
|
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||||
mfccs = tf.identity(mfccs, name='mfccs')
|
mfccs = tf.identity(mfccs, name='mfccs')
|
||||||
|
|
||||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||||
@ -710,6 +710,17 @@ def export():
|
|||||||
if not FLAGS.export_tflite:
|
if not FLAGS.export_tflite:
|
||||||
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
||||||
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
|
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||||
|
|
||||||
|
# Add a no-op node to the graph with metadata information to be loaded by the native client
|
||||||
|
metadata = frozen_graph.node.add()
|
||||||
|
metadata.name = 'model_metadata'
|
||||||
|
metadata.op = 'NoOp'
|
||||||
|
metadata.attr['sample_rate'].i = FLAGS.audio_sample_rate
|
||||||
|
metadata.attr['feature_win_len'].i = FLAGS.feature_win_len
|
||||||
|
metadata.attr['feature_win_step'].i = FLAGS.feature_win_step
|
||||||
|
if FLAGS.export_model_language:
|
||||||
|
metadata.attr['language'].s = FLAGS.export_model_language.encode('ascii')
|
||||||
|
|
||||||
with open(output_graph_path, 'wb') as fout:
|
with open(output_graph_path, 'wb') as fout:
|
||||||
fout.write(frozen_graph.SerializeToString())
|
fout.write(frozen_graph.SerializeToString())
|
||||||
else:
|
else:
|
||||||
|
@ -39,25 +39,9 @@
|
|||||||
//TODO: infer batch size from model/use dynamic batch size
|
//TODO: infer batch size from model/use dynamic batch size
|
||||||
constexpr unsigned int BATCH_SIZE = 1;
|
constexpr unsigned int BATCH_SIZE = 1;
|
||||||
|
|
||||||
//TODO: use dynamic sample rate
|
constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
|
||||||
constexpr unsigned int SAMPLE_RATE = 16000;
|
constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||||
|
constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||||
constexpr float AUDIO_WIN_LEN = 0.032f;
|
|
||||||
constexpr float AUDIO_WIN_STEP = 0.02f;
|
|
||||||
constexpr unsigned int AUDIO_WIN_LEN_SAMPLES = (unsigned int)(AUDIO_WIN_LEN * SAMPLE_RATE);
|
|
||||||
constexpr unsigned int AUDIO_WIN_STEP_SAMPLES = (unsigned int)(AUDIO_WIN_STEP * SAMPLE_RATE);
|
|
||||||
|
|
||||||
constexpr size_t WINDOW_SIZE = AUDIO_WIN_LEN * SAMPLE_RATE;
|
|
||||||
|
|
||||||
std::array<float, WINDOW_SIZE> calc_hamming_window() {
|
|
||||||
std::array<float, WINDOW_SIZE> a{0};
|
|
||||||
for (int i = 0; i < WINDOW_SIZE; ++i) {
|
|
||||||
a[i] = 0.54 - 0.46 * std::cos(2*M_PI*i/(WINDOW_SIZE-1));
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::array<float, WINDOW_SIZE> hamming_window = calc_hamming_window();
|
|
||||||
|
|
||||||
#ifndef USE_TFLITE
|
#ifndef USE_TFLITE
|
||||||
using namespace tensorflow;
|
using namespace tensorflow;
|
||||||
@ -134,6 +118,9 @@ struct ModelState {
|
|||||||
unsigned int n_context;
|
unsigned int n_context;
|
||||||
unsigned int n_features;
|
unsigned int n_features;
|
||||||
unsigned int mfcc_feats_per_timestep;
|
unsigned int mfcc_feats_per_timestep;
|
||||||
|
unsigned int sample_rate;
|
||||||
|
unsigned int audio_win_len;
|
||||||
|
unsigned int audio_win_step;
|
||||||
|
|
||||||
#ifdef USE_TFLITE
|
#ifdef USE_TFLITE
|
||||||
size_t previous_state_size;
|
size_t previous_state_size;
|
||||||
@ -220,6 +207,9 @@ ModelState::ModelState()
|
|||||||
, n_context(-1)
|
, n_context(-1)
|
||||||
, n_features(-1)
|
, n_features(-1)
|
||||||
, mfcc_feats_per_timestep(-1)
|
, mfcc_feats_per_timestep(-1)
|
||||||
|
, sample_rate(DEFAULT_SAMPLE_RATE)
|
||||||
|
, audio_win_len(DEFAULT_WINDOW_LENGTH)
|
||||||
|
, audio_win_step(DEFAULT_WINDOW_STEP)
|
||||||
#ifdef USE_TFLITE
|
#ifdef USE_TFLITE
|
||||||
, previous_state_size(0)
|
, previous_state_size(0)
|
||||||
, previous_state_c_(nullptr)
|
, previous_state_c_(nullptr)
|
||||||
@ -258,7 +248,7 @@ StreamingState::feedAudioContent(const short* buffer,
|
|||||||
{
|
{
|
||||||
// Consume all the data that was passed in, processing full buffers if needed
|
// Consume all the data that was passed in, processing full buffers if needed
|
||||||
while (buffer_size > 0) {
|
while (buffer_size > 0) {
|
||||||
while (buffer_size > 0 && audio_buffer.size() < AUDIO_WIN_LEN_SAMPLES) {
|
while (buffer_size > 0 && audio_buffer.size() < model->audio_win_len) {
|
||||||
// Convert i16 sample into f32
|
// Convert i16 sample into f32
|
||||||
float multiplier = 1.0f / (1 << 15);
|
float multiplier = 1.0f / (1 << 15);
|
||||||
audio_buffer.push_back((float)(*buffer) * multiplier);
|
audio_buffer.push_back((float)(*buffer) * multiplier);
|
||||||
@ -267,10 +257,10 @@ StreamingState::feedAudioContent(const short* buffer,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the buffer is full, process and shift it
|
// If the buffer is full, process and shift it
|
||||||
if (audio_buffer.size() == AUDIO_WIN_LEN_SAMPLES) {
|
if (audio_buffer.size() == model->audio_win_len) {
|
||||||
processAudioWindow(audio_buffer);
|
processAudioWindow(audio_buffer);
|
||||||
// Shift data by one step
|
// Shift data by one step
|
||||||
shift_buffer_left(audio_buffer, AUDIO_WIN_STEP_SAMPLES);
|
shift_buffer_left(audio_buffer, model->audio_win_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Repeat until buffer empty
|
// Repeat until buffer empty
|
||||||
@ -461,13 +451,13 @@ void
|
|||||||
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||||
{
|
{
|
||||||
#ifndef USE_TFLITE
|
#ifndef USE_TFLITE
|
||||||
Tensor input(DT_FLOAT, TensorShape({AUDIO_WIN_LEN_SAMPLES}));
|
Tensor input(DT_FLOAT, TensorShape({audio_win_len}));
|
||||||
auto input_mapped = input.flat<float>();
|
auto input_mapped = input.flat<float>();
|
||||||
int i;
|
int i;
|
||||||
for (i = 0; i < samples.size(); ++i) {
|
for (i = 0; i < samples.size(); ++i) {
|
||||||
input_mapped(i) = samples[i];
|
input_mapped(i) = samples[i];
|
||||||
}
|
}
|
||||||
for (; i < AUDIO_WIN_LEN_SAMPLES; ++i) {
|
for (; i < audio_win_len; ++i) {
|
||||||
input_mapped(i) = 0.f;
|
input_mapped(i) = 0.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -556,8 +546,8 @@ ModelState::decode_metadata(const vector<float>& logits)
|
|||||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||||
metadata->items[i].character = (char*)alphabet->StringFromLabel(out[0].tokens[i]).c_str();
|
metadata->items[i].character = (char*)alphabet->StringFromLabel(out[0].tokens[i]).c_str();
|
||||||
metadata->items[i].timestep = out[0].timesteps[i];
|
metadata->items[i].timestep = out[0].timesteps[i];
|
||||||
metadata->items[i].start_time = static_cast<float>(out[0].timesteps[i] * AUDIO_WIN_STEP);
|
metadata->items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step / sample_rate);
|
||||||
|
|
||||||
if (metadata->items[i].start_time < 0) {
|
if (metadata->items[i].start_time < 0) {
|
||||||
metadata->items[i].start_time = 0;
|
metadata->items[i].start_time = 0;
|
||||||
}
|
}
|
||||||
@ -700,6 +690,13 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
<< std::endl;
|
<< std::endl;
|
||||||
return DS_ERR_INVALID_ALPHABET;
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
}
|
}
|
||||||
|
} else if (node.name() == "model_metadata") {
|
||||||
|
int sample_rate = node.attr().at("sample_rate").i();
|
||||||
|
model->sample_rate = sample_rate;
|
||||||
|
int win_len_ms = node.attr().at("feature_win_len").i();
|
||||||
|
int win_step_ms = node.attr().at("feature_win_step").i();
|
||||||
|
model->audio_win_len = sample_rate * (win_len_ms / 1000.0);
|
||||||
|
model->audio_win_step = sample_rate * (win_step_ms / 1000.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -833,7 +830,7 @@ DS_SetupStream(ModelState* aCtx,
|
|||||||
|
|
||||||
ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);
|
ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);
|
||||||
|
|
||||||
ctx->audio_buffer.reserve(AUDIO_WIN_LEN_SAMPLES);
|
ctx->audio_buffer.reserve(aCtx->audio_win_len);
|
||||||
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
||||||
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
|
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
|
||||||
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);
|
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);
|
||||||
|
@ -92,6 +92,12 @@ def initialize_globals():
|
|||||||
# Units in the sixth layer = number of characters in the target language plus one
|
# Units in the sixth layer = number of characters in the target language plus one
|
||||||
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
|
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
|
||||||
|
|
||||||
|
# Size of audio window in samples
|
||||||
|
c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000)
|
||||||
|
|
||||||
|
# Stride for feature computations in samples
|
||||||
|
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)
|
||||||
|
|
||||||
if FLAGS.one_shot_infer:
|
if FLAGS.one_shot_infer:
|
||||||
if not os.path.exists(FLAGS.one_shot_infer):
|
if not os.path.exists(FLAGS.one_shot_infer):
|
||||||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||||
|
@ -27,7 +27,10 @@ def read_csvs(csv_files):
|
|||||||
|
|
||||||
|
|
||||||
def samples_to_mfccs(samples, sample_rate):
|
def samples_to_mfccs(samples, sample_rate):
|
||||||
spectrogram = contrib_audio.audio_spectrogram(samples, window_size=512, stride=320, magnitude_squared=True)
|
spectrogram = contrib_audio.audio_spectrogram(samples,
|
||||||
|
window_size=Config.audio_window_samples,
|
||||||
|
stride=Config.audio_step_samples,
|
||||||
|
magnitude_squared=True)
|
||||||
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
|
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
|
||||||
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
||||||
|
|
||||||
|
@ -19,6 +19,10 @@ def create_flags():
|
|||||||
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||||
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||||
|
|
||||||
|
tf.app.flags.DEFINE_integer ('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||||
|
tf.app.flags.DEFINE_integer ('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
||||||
|
tf.app.flags.DEFINE_integer ('audio_sample_rate',16000, 'sample rate value expected by model')
|
||||||
|
|
||||||
# Global Constants
|
# Global Constants
|
||||||
# ================
|
# ================
|
||||||
|
|
||||||
@ -73,6 +77,7 @@ def create_flags():
|
|||||||
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
|
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||||
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
|
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
|
||||||
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||||
|
tf.app.flags.DEFINE_string ('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||||
|
|
||||||
# Reporting
|
# Reporting
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user