Merge branch 'more-metadata'
This commit is contained in:
commit
5779d298e1
@ -534,9 +534,9 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||
batch_size = batch_size if batch_size > 0 else None
|
||||
|
||||
# Create feature computation graph
|
||||
input_samples = tf.placeholder(tf.float32, [512], 'input_samples')
|
||||
input_samples = tf.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||
samples = tf.expand_dims(input_samples, -1)
|
||||
mfccs, _ = samples_to_mfccs(samples, 16000)
|
||||
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
|
||||
mfccs = tf.identity(mfccs, name='mfccs')
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||
@ -710,6 +710,17 @@ def export():
|
||||
if not FLAGS.export_tflite:
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
|
||||
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
|
||||
# Add a no-op node to the graph with metadata information to be loaded by the native client
|
||||
metadata = frozen_graph.node.add()
|
||||
metadata.name = 'model_metadata'
|
||||
metadata.op = 'NoOp'
|
||||
metadata.attr['sample_rate'].i = FLAGS.audio_sample_rate
|
||||
metadata.attr['feature_win_len'].i = FLAGS.feature_win_len
|
||||
metadata.attr['feature_win_step'].i = FLAGS.feature_win_step
|
||||
if FLAGS.export_model_language:
|
||||
metadata.attr['language'].s = FLAGS.export_model_language.encode('ascii')
|
||||
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
|
@ -39,25 +39,9 @@
|
||||
//TODO: infer batch size from model/use dynamic batch size
|
||||
constexpr unsigned int BATCH_SIZE = 1;
|
||||
|
||||
//TODO: use dynamic sample rate
|
||||
constexpr unsigned int SAMPLE_RATE = 16000;
|
||||
|
||||
constexpr float AUDIO_WIN_LEN = 0.032f;
|
||||
constexpr float AUDIO_WIN_STEP = 0.02f;
|
||||
constexpr unsigned int AUDIO_WIN_LEN_SAMPLES = (unsigned int)(AUDIO_WIN_LEN * SAMPLE_RATE);
|
||||
constexpr unsigned int AUDIO_WIN_STEP_SAMPLES = (unsigned int)(AUDIO_WIN_STEP * SAMPLE_RATE);
|
||||
|
||||
constexpr size_t WINDOW_SIZE = AUDIO_WIN_LEN * SAMPLE_RATE;
|
||||
|
||||
std::array<float, WINDOW_SIZE> calc_hamming_window() {
|
||||
std::array<float, WINDOW_SIZE> a{0};
|
||||
for (int i = 0; i < WINDOW_SIZE; ++i) {
|
||||
a[i] = 0.54 - 0.46 * std::cos(2*M_PI*i/(WINDOW_SIZE-1));
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
std::array<float, WINDOW_SIZE> hamming_window = calc_hamming_window();
|
||||
constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
|
||||
constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||
constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
using namespace tensorflow;
|
||||
@ -134,6 +118,9 @@ struct ModelState {
|
||||
unsigned int n_context;
|
||||
unsigned int n_features;
|
||||
unsigned int mfcc_feats_per_timestep;
|
||||
unsigned int sample_rate;
|
||||
unsigned int audio_win_len;
|
||||
unsigned int audio_win_step;
|
||||
|
||||
#ifdef USE_TFLITE
|
||||
size_t previous_state_size;
|
||||
@ -220,6 +207,9 @@ ModelState::ModelState()
|
||||
, n_context(-1)
|
||||
, n_features(-1)
|
||||
, mfcc_feats_per_timestep(-1)
|
||||
, sample_rate(DEFAULT_SAMPLE_RATE)
|
||||
, audio_win_len(DEFAULT_WINDOW_LENGTH)
|
||||
, audio_win_step(DEFAULT_WINDOW_STEP)
|
||||
#ifdef USE_TFLITE
|
||||
, previous_state_size(0)
|
||||
, previous_state_c_(nullptr)
|
||||
@ -258,7 +248,7 @@ StreamingState::feedAudioContent(const short* buffer,
|
||||
{
|
||||
// Consume all the data that was passed in, processing full buffers if needed
|
||||
while (buffer_size > 0) {
|
||||
while (buffer_size > 0 && audio_buffer.size() < AUDIO_WIN_LEN_SAMPLES) {
|
||||
while (buffer_size > 0 && audio_buffer.size() < model->audio_win_len) {
|
||||
// Convert i16 sample into f32
|
||||
float multiplier = 1.0f / (1 << 15);
|
||||
audio_buffer.push_back((float)(*buffer) * multiplier);
|
||||
@ -267,10 +257,10 @@ StreamingState::feedAudioContent(const short* buffer,
|
||||
}
|
||||
|
||||
// If the buffer is full, process and shift it
|
||||
if (audio_buffer.size() == AUDIO_WIN_LEN_SAMPLES) {
|
||||
if (audio_buffer.size() == model->audio_win_len) {
|
||||
processAudioWindow(audio_buffer);
|
||||
// Shift data by one step
|
||||
shift_buffer_left(audio_buffer, AUDIO_WIN_STEP_SAMPLES);
|
||||
shift_buffer_left(audio_buffer, model->audio_win_step);
|
||||
}
|
||||
|
||||
// Repeat until buffer empty
|
||||
@ -461,13 +451,13 @@ void
|
||||
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
{
|
||||
#ifndef USE_TFLITE
|
||||
Tensor input(DT_FLOAT, TensorShape({AUDIO_WIN_LEN_SAMPLES}));
|
||||
Tensor input(DT_FLOAT, TensorShape({audio_win_len}));
|
||||
auto input_mapped = input.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < samples.size(); ++i) {
|
||||
input_mapped(i) = samples[i];
|
||||
}
|
||||
for (; i < AUDIO_WIN_LEN_SAMPLES; ++i) {
|
||||
for (; i < audio_win_len; ++i) {
|
||||
input_mapped(i) = 0.f;
|
||||
}
|
||||
|
||||
@ -556,8 +546,8 @@ ModelState::decode_metadata(const vector<float>& logits)
|
||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||
metadata->items[i].character = (char*)alphabet->StringFromLabel(out[0].tokens[i]).c_str();
|
||||
metadata->items[i].timestep = out[0].timesteps[i];
|
||||
metadata->items[i].start_time = static_cast<float>(out[0].timesteps[i] * AUDIO_WIN_STEP);
|
||||
|
||||
metadata->items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step / sample_rate);
|
||||
|
||||
if (metadata->items[i].start_time < 0) {
|
||||
metadata->items[i].start_time = 0;
|
||||
}
|
||||
@ -700,6 +690,13 @@ DS_CreateModel(const char* aModelPath,
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
} else if (node.name() == "model_metadata") {
|
||||
int sample_rate = node.attr().at("sample_rate").i();
|
||||
model->sample_rate = sample_rate;
|
||||
int win_len_ms = node.attr().at("feature_win_len").i();
|
||||
int win_step_ms = node.attr().at("feature_win_step").i();
|
||||
model->audio_win_len = sample_rate * (win_len_ms / 1000.0);
|
||||
model->audio_win_step = sample_rate * (win_step_ms / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -833,7 +830,7 @@ DS_SetupStream(ModelState* aCtx,
|
||||
|
||||
ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes);
|
||||
|
||||
ctx->audio_buffer.reserve(AUDIO_WIN_LEN_SAMPLES);
|
||||
ctx->audio_buffer.reserve(aCtx->audio_win_len);
|
||||
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
||||
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
|
||||
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);
|
||||
|
@ -92,6 +92,12 @@ def initialize_globals():
|
||||
# Units in the sixth layer = number of characters in the target language plus one
|
||||
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
|
||||
|
||||
# Size of audio window in samples
|
||||
c.audio_window_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_len / 1000)
|
||||
|
||||
# Stride for feature computations in samples
|
||||
c.audio_step_samples = FLAGS.audio_sample_rate * (FLAGS.feature_win_step / 1000)
|
||||
|
||||
if FLAGS.one_shot_infer:
|
||||
if not os.path.exists(FLAGS.one_shot_infer):
|
||||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||
|
@ -27,7 +27,10 @@ def read_csvs(csv_files):
|
||||
|
||||
|
||||
def samples_to_mfccs(samples, sample_rate):
|
||||
spectrogram = contrib_audio.audio_spectrogram(samples, window_size=512, stride=320, magnitude_squared=True)
|
||||
spectrogram = contrib_audio.audio_spectrogram(samples,
|
||||
window_size=Config.audio_window_samples,
|
||||
stride=Config.audio_step_samples,
|
||||
magnitude_squared=True)
|
||||
mfccs = contrib_audio.mfcc(spectrogram, sample_rate, dct_coefficient_count=Config.n_input)
|
||||
mfccs = tf.reshape(mfccs, [-1, Config.n_input])
|
||||
|
||||
|
@ -19,6 +19,10 @@ def create_flags():
|
||||
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
|
||||
tf.app.flags.DEFINE_integer ('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
||||
tf.app.flags.DEFINE_integer ('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
||||
tf.app.flags.DEFINE_integer ('audio_sample_rate',16000, 'sample rate value expected by model')
|
||||
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
@ -73,6 +77,7 @@ def create_flags():
|
||||
tf.app.flags.DEFINE_boolean ('export_tflite', False, 'export a graph ready for TF Lite engine')
|
||||
tf.app.flags.DEFINE_boolean ('use_seq_length', True, 'have sequence_length in the exported graph (will make tfcompile unhappy)')
|
||||
tf.app.flags.DEFINE_integer ('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
||||
tf.app.flags.DEFINE_string ('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||
|
||||
# Reporting
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user