Refactor train.py into separate scripts
Currently train.py is overloaded with many independent features. Understanding the code and what will be the result of a training call requires untangling the entire script. It's also an error prone UX. This is a first step at separating independent parts into their own scripts.
This commit is contained in:
parent
fcbd92d0d7
commit
b85ad3ea74
287
training/deepspeech_training/deepspeech_model.py
Normal file
287
training/deepspeech_training/deepspeech_model.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||||
|
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
|
tfv1.logging.set_verbosity({
|
||||||
|
'0': tfv1.logging.DEBUG,
|
||||||
|
'1': tfv1.logging.INFO,
|
||||||
|
'2': tfv1.logging.WARN,
|
||||||
|
'3': tfv1.logging.ERROR
|
||||||
|
}.get(DESIRED_LOG_LEVEL))
|
||||||
|
|
||||||
|
from .util.config import Config
|
||||||
|
from .util.feeding import audio_to_features
|
||||||
|
from .util.flags import FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
def variable_on_cpu(name, shape, initializer):
|
||||||
|
r"""
|
||||||
|
Next we concern ourselves with graph creation.
|
||||||
|
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
||||||
|
used to create a variable in CPU memory.
|
||||||
|
"""
|
||||||
|
# Use the /cpu:0 device for scoped operations
|
||||||
|
with tf.device(Config.cpu_device):
|
||||||
|
# Create or get apropos variable
|
||||||
|
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
||||||
|
return var
|
||||||
|
|
||||||
|
|
||||||
|
def create_overlapping_windows(batch_x):
|
||||||
|
batch_size = tf.shape(input=batch_x)[0]
|
||||||
|
window_width = 2 * Config.n_context + 1
|
||||||
|
num_channels = Config.n_input
|
||||||
|
|
||||||
|
# Create a constant convolution filter using an identity matrix, so that the
|
||||||
|
# convolution returns patches of the input tensor as is, and we can create
|
||||||
|
# overlapping windows over the MFCCs.
|
||||||
|
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
||||||
|
.reshape(window_width, num_channels, window_width * num_channels), tf.float32)
|
||||||
|
|
||||||
|
# Create overlapping windows
|
||||||
|
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
||||||
|
|
||||||
|
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
||||||
|
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
||||||
|
|
||||||
|
return batch_x
|
||||||
|
|
||||||
|
|
||||||
|
def dense(name, x, units, dropout_rate=None, relu=True, layer_norm=False):
|
||||||
|
with tfv1.variable_scope(name):
|
||||||
|
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
||||||
|
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
||||||
|
|
||||||
|
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
||||||
|
|
||||||
|
if relu:
|
||||||
|
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
||||||
|
|
||||||
|
if layer_norm:
|
||||||
|
with tfv1.variable_scope(name):
|
||||||
|
output = tf.contrib.layers.layer_norm(output)
|
||||||
|
|
||||||
|
if dropout_rate is not None:
|
||||||
|
output = tf.nn.dropout(output, rate=dropout_rate)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
||||||
|
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
||||||
|
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
||||||
|
forget_bias=0,
|
||||||
|
reuse=reuse,
|
||||||
|
name='cudnn_compatible_lstm_cell')
|
||||||
|
|
||||||
|
output, output_state = fw_cell(inputs=x,
|
||||||
|
dtype=tf.float32,
|
||||||
|
sequence_length=seq_length,
|
||||||
|
initial_state=previous_state)
|
||||||
|
|
||||||
|
return output, output_state
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
||||||
|
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
||||||
|
|
||||||
|
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
||||||
|
# the object it creates the variables, and then you just call it several times
|
||||||
|
# to enable variable re-use. Because all of our code is structure in an old
|
||||||
|
# school TensorFlow structure where you can just call tf.get_variable again with
|
||||||
|
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
||||||
|
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
||||||
|
# emulating a static function variable.
|
||||||
|
if not rnn_impl_cudnn_rnn.cell:
|
||||||
|
# Forward direction cell:
|
||||||
|
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
||||||
|
num_units=Config.n_cell_dim,
|
||||||
|
input_mode='linear_input',
|
||||||
|
direction='unidirectional',
|
||||||
|
dtype=tf.float32)
|
||||||
|
rnn_impl_cudnn_rnn.cell = fw_cell
|
||||||
|
|
||||||
|
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
||||||
|
sequence_lengths=seq_length)
|
||||||
|
|
||||||
|
return output, output_state
|
||||||
|
|
||||||
|
rnn_impl_cudnn_rnn.cell = None
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
||||||
|
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
||||||
|
# Forward direction cell:
|
||||||
|
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
||||||
|
forget_bias=0,
|
||||||
|
reuse=reuse,
|
||||||
|
name='cudnn_compatible_lstm_cell')
|
||||||
|
|
||||||
|
# Split rank N tensor into list of rank N-1 tensors
|
||||||
|
x = [x[l] for l in range(x.shape[0])]
|
||||||
|
|
||||||
|
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
||||||
|
inputs=x,
|
||||||
|
sequence_length=seq_length,
|
||||||
|
initial_state=previous_state,
|
||||||
|
dtype=tf.float32,
|
||||||
|
scope='cell_0')
|
||||||
|
|
||||||
|
output = tf.concat(output, 0)
|
||||||
|
|
||||||
|
return output, output_state
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
||||||
|
layers = {}
|
||||||
|
|
||||||
|
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||||
|
if not batch_size:
|
||||||
|
batch_size = tf.shape(input=batch_x)[0]
|
||||||
|
|
||||||
|
# Create overlapping feature windows if needed
|
||||||
|
if overlap:
|
||||||
|
batch_x = create_overlapping_windows(batch_x)
|
||||||
|
|
||||||
|
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
||||||
|
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||||
|
|
||||||
|
# Permute n_steps and batch_size
|
||||||
|
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
||||||
|
# Reshape to prepare input for first layer
|
||||||
|
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||||
|
layers['input_reshaped'] = batch_x
|
||||||
|
|
||||||
|
# The next three blocks will pass `batch_x` through three hidden layers with
|
||||||
|
# clipped RELU activation and dropout.
|
||||||
|
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0], layer_norm=FLAGS.layer_norm)
|
||||||
|
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1], layer_norm=FLAGS.layer_norm)
|
||||||
|
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2], layer_norm=FLAGS.layer_norm)
|
||||||
|
|
||||||
|
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
||||||
|
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
||||||
|
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
||||||
|
|
||||||
|
# Run through parametrized RNN implementation, as we use different RNNs
|
||||||
|
# for training and inference
|
||||||
|
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
||||||
|
|
||||||
|
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
||||||
|
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
||||||
|
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
||||||
|
layers['rnn_output'] = output
|
||||||
|
layers['rnn_output_state'] = output_state
|
||||||
|
|
||||||
|
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
||||||
|
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5], layer_norm=FLAGS.layer_norm)
|
||||||
|
|
||||||
|
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
||||||
|
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
||||||
|
|
||||||
|
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
||||||
|
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
||||||
|
# Note, that this differs from the input in that it is time-major.
|
||||||
|
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
||||||
|
layers['raw_logits'] = layer_6
|
||||||
|
|
||||||
|
# Output shape: [n_steps, batch_size, n_hidden_6]
|
||||||
|
return layer_6, layers
|
||||||
|
|
||||||
|
|
||||||
|
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
||||||
|
batch_size = batch_size if batch_size > 0 else None
|
||||||
|
|
||||||
|
# Create feature computation graph
|
||||||
|
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
||||||
|
samples = tf.expand_dims(input_samples, -1)
|
||||||
|
mfccs, _ = audio_to_features(samples, FLAGS.audio_sample_rate)
|
||||||
|
mfccs = tf.identity(mfccs, name='mfccs')
|
||||||
|
|
||||||
|
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||||
|
# This shape is read by the native_client in DS_CreateModel to know the
|
||||||
|
# value of n_steps, n_context and n_input. Make sure you update the code
|
||||||
|
# there if this shape is changed.
|
||||||
|
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
||||||
|
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||||
|
|
||||||
|
if batch_size <= 0:
|
||||||
|
# no state management since n_step is expected to be dynamic too (see below)
|
||||||
|
previous_state = None
|
||||||
|
else:
|
||||||
|
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
||||||
|
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
||||||
|
|
||||||
|
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
||||||
|
|
||||||
|
# One rate per layer
|
||||||
|
no_dropout = [None] * 6
|
||||||
|
|
||||||
|
if tflite:
|
||||||
|
rnn_impl = rnn_impl_static_rnn
|
||||||
|
else:
|
||||||
|
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||||
|
|
||||||
|
logits, layers = create_model(batch_x=input_tensor,
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_length=seq_length if not FLAGS.export_tflite else None,
|
||||||
|
dropout=no_dropout,
|
||||||
|
previous_state=previous_state,
|
||||||
|
overlap=False,
|
||||||
|
rnn_impl=rnn_impl)
|
||||||
|
|
||||||
|
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
||||||
|
# by default we get 3, the middle one being batch_size which is forced to
|
||||||
|
# one on inference graph, so remove that dimension
|
||||||
|
if tflite:
|
||||||
|
logits = tf.squeeze(logits, [1])
|
||||||
|
|
||||||
|
# Apply softmax for CTC decoder
|
||||||
|
probs = tf.nn.softmax(logits, name='logits')
|
||||||
|
|
||||||
|
if batch_size <= 0:
|
||||||
|
if tflite:
|
||||||
|
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
||||||
|
if n_steps > 0:
|
||||||
|
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
'input': input_tensor,
|
||||||
|
'input_lengths': seq_length,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'outputs': probs,
|
||||||
|
},
|
||||||
|
layers
|
||||||
|
)
|
||||||
|
|
||||||
|
new_state_c, new_state_h = layers['rnn_output_state']
|
||||||
|
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
||||||
|
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
'input': input_tensor,
|
||||||
|
'previous_state_c': previous_state_c,
|
||||||
|
'previous_state_h': previous_state_h,
|
||||||
|
'input_samples': input_samples,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not FLAGS.export_tflite:
|
||||||
|
inputs['input_lengths'] = seq_length
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
'outputs': probs,
|
||||||
|
'new_state_c': new_state_c,
|
||||||
|
'new_state_h': new_state_h,
|
||||||
|
'mfccs': mfccs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs, outputs, layers
|
@ -133,6 +133,13 @@ def evaluate(test_csvs, create_model):
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||||
|
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
|
if FLAGS.test_output_file:
|
||||||
|
save_samples_json(samples, FLAGS.test_output_file)
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
initialize_globals()
|
initialize_globals()
|
||||||
|
|
||||||
@ -141,16 +148,13 @@ def main(_):
|
|||||||
'the --test_files flag.')
|
'the --test_files flag.')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from .train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
test()
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
|
||||||
|
|
||||||
if FLAGS.test_output_file:
|
|
||||||
save_samples_json(samples, FLAGS.test_output_file)
|
|
||||||
|
|
||||||
|
|
||||||
def run_script():
|
def run_script():
|
||||||
create_flags()
|
create_flags()
|
||||||
absl.app.run(main)
|
absl.app.run(main)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_script()
|
run_script()
|
||||||
|
162
training/deepspeech_training/export.py
Normal file
162
training/deepspeech_training/export.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||||
|
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||||
|
|
||||||
|
import absl.app
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from .deepspeech_model import create_inference_graph
|
||||||
|
from .util.checkpoints import load_graph_for_evaluation
|
||||||
|
from .util.config import Config, initialize_globals
|
||||||
|
from .util.flags import create_flags, FLAGS
|
||||||
|
from .util.io import open_remote, rmtree_remote, listdir_remote, is_remote_path, isdir_remote
|
||||||
|
from .util.logging import log_error, log_info
|
||||||
|
|
||||||
|
|
||||||
|
def file_relative_read(fname):
|
||||||
|
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||||
|
|
||||||
|
|
||||||
|
def export():
|
||||||
|
r'''
|
||||||
|
Restores the trained variables into a simpler graph that will be exported for serving.
|
||||||
|
'''
|
||||||
|
log_info('Exporting the model...')
|
||||||
|
|
||||||
|
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||||
|
|
||||||
|
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||||
|
assert graph_version > 0
|
||||||
|
|
||||||
|
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||||
|
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||||
|
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||||
|
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||||
|
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||||
|
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.Serialize()], name='metadata_alphabet')
|
||||||
|
|
||||||
|
if FLAGS.export_language:
|
||||||
|
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||||
|
|
||||||
|
# Prevent further graph changes
|
||||||
|
tfv1.get_default_graph().finalize()
|
||||||
|
|
||||||
|
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
|
||||||
|
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
|
||||||
|
output_names = output_names_tensors + output_names_ops
|
||||||
|
|
||||||
|
with tf.Session() as session:
|
||||||
|
# Restore variables from checkpoint
|
||||||
|
load_graph_for_evaluation(session)
|
||||||
|
|
||||||
|
output_filename = FLAGS.export_file_name + '.pb'
|
||||||
|
if FLAGS.remove_export:
|
||||||
|
if isdir_remote(FLAGS.export_dir):
|
||||||
|
log_info('Removing old export')
|
||||||
|
rmtree_remote(FLAGS.export_dir)
|
||||||
|
|
||||||
|
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||||
|
|
||||||
|
if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
|
||||||
|
os.makedirs(FLAGS.export_dir)
|
||||||
|
|
||||||
|
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||||
|
sess=session,
|
||||||
|
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||||
|
output_node_names=output_names)
|
||||||
|
|
||||||
|
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||||
|
graph_def=frozen_graph,
|
||||||
|
dest_nodes=output_names)
|
||||||
|
|
||||||
|
if not FLAGS.export_tflite:
|
||||||
|
with open_remote(output_graph_path, 'wb') as fout:
|
||||||
|
fout.write(frozen_graph.SerializeToString())
|
||||||
|
else:
|
||||||
|
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||||
|
|
||||||
|
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||||
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
|
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||||
|
converter.allow_custom_ops = True
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
with open_remote(output_tflite_path, 'wb') as fout:
|
||||||
|
fout.write(tflite_model)
|
||||||
|
|
||||||
|
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||||
|
|
||||||
|
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
||||||
|
FLAGS.export_author_id,
|
||||||
|
FLAGS.export_model_name,
|
||||||
|
FLAGS.export_model_version))
|
||||||
|
|
||||||
|
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
||||||
|
with open_remote(metadata_fname, 'w') as f:
|
||||||
|
f.write('---\n')
|
||||||
|
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
||||||
|
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
||||||
|
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
||||||
|
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
||||||
|
f.write('license: {}\n'.format(FLAGS.export_license))
|
||||||
|
f.write('language: {}\n'.format(FLAGS.export_language))
|
||||||
|
f.write('runtime: {}\n'.format(model_runtime))
|
||||||
|
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
||||||
|
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
||||||
|
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
||||||
|
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
||||||
|
f.write('---\n')
|
||||||
|
f.write('{}\n'.format(FLAGS.export_description))
|
||||||
|
|
||||||
|
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
||||||
|
|
||||||
|
|
||||||
|
def package_zip():
|
||||||
|
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||||
|
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
||||||
|
if is_remote_path(export_dir):
|
||||||
|
log_error("Cannot package remote path zip %s. Please do this manually." % export_dir)
|
||||||
|
return
|
||||||
|
|
||||||
|
zip_filename = os.path.dirname(export_dir)
|
||||||
|
|
||||||
|
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||||
|
|
||||||
|
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||||
|
log_info('Exported packaged model {}'.format(archive))
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
initialize_globals()
|
||||||
|
|
||||||
|
if FLAGS.export_dir:
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
|
||||||
|
if not FLAGS.export_zip:
|
||||||
|
# Export to folder
|
||||||
|
export()
|
||||||
|
else:
|
||||||
|
# Export and zip, TFLite only, creates package readable by Java example app
|
||||||
|
FLAGS.export_tflite = True
|
||||||
|
|
||||||
|
if listdir_remote(FLAGS.export_dir):
|
||||||
|
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
export()
|
||||||
|
package_zip()
|
||||||
|
else:
|
||||||
|
log_error('Calling export script directly but no --export_dir specified')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
create_flags()
|
||||||
|
absl.app.run(main)
|
@ -1,7 +1,5 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@ -10,13 +8,13 @@ DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.a
|
|||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||||
|
|
||||||
import absl.app
|
import absl.app
|
||||||
import numpy as np
|
|
||||||
import progressbar
|
import progressbar
|
||||||
import shutil
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.compat.v1 as tfv1
|
import tensorflow.compat.v1 as tfv1
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
tfv1.logging.set_verbosity({
|
tfv1.logging.set_verbosity({
|
||||||
'0': tfv1.logging.DEBUG,
|
'0': tfv1.logging.DEBUG,
|
||||||
'1': tfv1.logging.INFO,
|
'1': tfv1.logging.INFO,
|
||||||
@ -24,197 +22,23 @@ tfv1.logging.set_verbosity({
|
|||||||
'3': tfv1.logging.ERROR
|
'3': tfv1.logging.ERROR
|
||||||
}.get(DESIRED_LOG_LEVEL))
|
}.get(DESIRED_LOG_LEVEL))
|
||||||
|
|
||||||
from datetime import datetime
|
from ds_ctcdecoder import Scorer
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
|
||||||
from .evaluate import evaluate
|
from . import export
|
||||||
from six.moves import zip, range
|
from . import evaluate
|
||||||
|
from . import training_graph_inference
|
||||||
|
from .deepspeech_model import create_model, rnn_impl_lstmblockfusedcell, rnn_impl_cudnn_rnn
|
||||||
from .util.config import Config, initialize_globals
|
from .util.config import Config, initialize_globals
|
||||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
|
from .util.checkpoints import load_or_init_graph_for_training, reload_best_checkpoint
|
||||||
from .util.evaluate_tools import save_samples_json
|
from .util.feeding import create_dataset
|
||||||
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
|
|
||||||
from .util.flags import create_flags, FLAGS
|
from .util.flags import create_flags, FLAGS
|
||||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||||
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
|
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
|
||||||
from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote
|
from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path
|
||||||
|
|
||||||
|
|
||||||
check_ctcdecoder_version()
|
check_ctcdecoder_version()
|
||||||
|
|
||||||
# Graph Creation
|
|
||||||
# ==============
|
|
||||||
|
|
||||||
def variable_on_cpu(name, shape, initializer):
|
|
||||||
r"""
|
|
||||||
Next we concern ourselves with graph creation.
|
|
||||||
However, before we do so we must introduce a utility function ``variable_on_cpu()``
|
|
||||||
used to create a variable in CPU memory.
|
|
||||||
"""
|
|
||||||
# Use the /cpu:0 device for scoped operations
|
|
||||||
with tf.device(Config.cpu_device):
|
|
||||||
# Create or get apropos variable
|
|
||||||
var = tfv1.get_variable(name=name, shape=shape, initializer=initializer)
|
|
||||||
return var
|
|
||||||
|
|
||||||
|
|
||||||
def create_overlapping_windows(batch_x):
|
|
||||||
batch_size = tf.shape(input=batch_x)[0]
|
|
||||||
window_width = 2 * Config.n_context + 1
|
|
||||||
num_channels = Config.n_input
|
|
||||||
|
|
||||||
# Create a constant convolution filter using an identity matrix, so that the
|
|
||||||
# convolution returns patches of the input tensor as is, and we can create
|
|
||||||
# overlapping windows over the MFCCs.
|
|
||||||
eye_filter = tf.constant(np.eye(window_width * num_channels)
|
|
||||||
.reshape(window_width, num_channels, window_width * num_channels), tf.float32) # pylint: disable=bad-continuation
|
|
||||||
|
|
||||||
# Create overlapping windows
|
|
||||||
batch_x = tf.nn.conv1d(input=batch_x, filters=eye_filter, stride=1, padding='SAME')
|
|
||||||
|
|
||||||
# Remove dummy depth dimension and reshape into [batch_size, n_windows, window_width, n_input]
|
|
||||||
batch_x = tf.reshape(batch_x, [batch_size, -1, window_width, num_channels])
|
|
||||||
|
|
||||||
return batch_x
|
|
||||||
|
|
||||||
|
|
||||||
def dense(name, x, units, dropout_rate=None, relu=True, layer_norm=False):
|
|
||||||
with tfv1.variable_scope(name):
|
|
||||||
bias = variable_on_cpu('bias', [units], tfv1.zeros_initializer())
|
|
||||||
weights = variable_on_cpu('weights', [x.shape[-1], units], tfv1.keras.initializers.VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform"))
|
|
||||||
|
|
||||||
output = tf.nn.bias_add(tf.matmul(x, weights), bias)
|
|
||||||
|
|
||||||
if relu:
|
|
||||||
output = tf.minimum(tf.nn.relu(output), FLAGS.relu_clip)
|
|
||||||
|
|
||||||
if layer_norm:
|
|
||||||
with tfv1.variable_scope(name):
|
|
||||||
output = tf.contrib.layers.layer_norm(output)
|
|
||||||
|
|
||||||
if dropout_rate is not None:
|
|
||||||
output = tf.nn.dropout(output, rate=dropout_rate)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def rnn_impl_lstmblockfusedcell(x, seq_length, previous_state, reuse):
|
|
||||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell/cell_0'):
|
|
||||||
fw_cell = tf.contrib.rnn.LSTMBlockFusedCell(Config.n_cell_dim,
|
|
||||||
forget_bias=0,
|
|
||||||
reuse=reuse,
|
|
||||||
name='cudnn_compatible_lstm_cell')
|
|
||||||
|
|
||||||
output, output_state = fw_cell(inputs=x,
|
|
||||||
dtype=tf.float32,
|
|
||||||
sequence_length=seq_length,
|
|
||||||
initial_state=previous_state)
|
|
||||||
|
|
||||||
return output, output_state
|
|
||||||
|
|
||||||
|
|
||||||
def rnn_impl_cudnn_rnn(x, seq_length, previous_state, _):
|
|
||||||
assert previous_state is None # 'Passing previous state not supported with CuDNN backend'
|
|
||||||
|
|
||||||
# Hack: CudnnLSTM works similarly to Keras layers in that when you instantiate
|
|
||||||
# the object it creates the variables, and then you just call it several times
|
|
||||||
# to enable variable re-use. Because all of our code is structure in an old
|
|
||||||
# school TensorFlow structure where you can just call tf.get_variable again with
|
|
||||||
# reuse=True to reuse variables, we can't easily make use of the object oriented
|
|
||||||
# way CudnnLSTM is implemented, so we save a singleton instance in the function,
|
|
||||||
# emulating a static function variable.
|
|
||||||
if not rnn_impl_cudnn_rnn.cell:
|
|
||||||
# Forward direction cell:
|
|
||||||
fw_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers=1,
|
|
||||||
num_units=Config.n_cell_dim,
|
|
||||||
input_mode='linear_input',
|
|
||||||
direction='unidirectional',
|
|
||||||
dtype=tf.float32)
|
|
||||||
rnn_impl_cudnn_rnn.cell = fw_cell
|
|
||||||
|
|
||||||
output, output_state = rnn_impl_cudnn_rnn.cell(inputs=x,
|
|
||||||
sequence_lengths=seq_length)
|
|
||||||
|
|
||||||
return output, output_state
|
|
||||||
|
|
||||||
rnn_impl_cudnn_rnn.cell = None
|
|
||||||
|
|
||||||
|
|
||||||
def rnn_impl_static_rnn(x, seq_length, previous_state, reuse):
|
|
||||||
with tfv1.variable_scope('cudnn_lstm/rnn/multi_rnn_cell'):
|
|
||||||
# Forward direction cell:
|
|
||||||
fw_cell = tfv1.nn.rnn_cell.LSTMCell(Config.n_cell_dim,
|
|
||||||
forget_bias=0,
|
|
||||||
reuse=reuse,
|
|
||||||
name='cudnn_compatible_lstm_cell')
|
|
||||||
|
|
||||||
# Split rank N tensor into list of rank N-1 tensors
|
|
||||||
x = [x[l] for l in range(x.shape[0])]
|
|
||||||
|
|
||||||
output, output_state = tfv1.nn.static_rnn(cell=fw_cell,
|
|
||||||
inputs=x,
|
|
||||||
sequence_length=seq_length,
|
|
||||||
initial_state=previous_state,
|
|
||||||
dtype=tf.float32,
|
|
||||||
scope='cell_0')
|
|
||||||
|
|
||||||
output = tf.concat(output, 0)
|
|
||||||
|
|
||||||
return output, output_state
|
|
||||||
|
|
||||||
|
|
||||||
def create_model(batch_x, seq_length, dropout, reuse=False, batch_size=None, previous_state=None, overlap=True, rnn_impl=rnn_impl_lstmblockfusedcell):
|
|
||||||
layers = {}
|
|
||||||
|
|
||||||
# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
|
|
||||||
if not batch_size:
|
|
||||||
batch_size = tf.shape(input=batch_x)[0]
|
|
||||||
|
|
||||||
# Create overlapping feature windows if needed
|
|
||||||
if overlap:
|
|
||||||
batch_x = create_overlapping_windows(batch_x)
|
|
||||||
|
|
||||||
# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
|
|
||||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
|
||||||
|
|
||||||
# Permute n_steps and batch_size
|
|
||||||
batch_x = tf.transpose(a=batch_x, perm=[1, 0, 2, 3])
|
|
||||||
# Reshape to prepare input for first layer
|
|
||||||
batch_x = tf.reshape(batch_x, [-1, Config.n_input + 2*Config.n_input*Config.n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
|
||||||
layers['input_reshaped'] = batch_x
|
|
||||||
|
|
||||||
# The next three blocks will pass `batch_x` through three hidden layers with
|
|
||||||
# clipped RELU activation and dropout.
|
|
||||||
layers['layer_1'] = layer_1 = dense('layer_1', batch_x, Config.n_hidden_1, dropout_rate=dropout[0], layer_norm=FLAGS.layer_norm)
|
|
||||||
layers['layer_2'] = layer_2 = dense('layer_2', layer_1, Config.n_hidden_2, dropout_rate=dropout[1], layer_norm=FLAGS.layer_norm)
|
|
||||||
layers['layer_3'] = layer_3 = dense('layer_3', layer_2, Config.n_hidden_3, dropout_rate=dropout[2], layer_norm=FLAGS.layer_norm)
|
|
||||||
|
|
||||||
# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
|
|
||||||
# as the LSTM RNN expects its input to be of shape `[max_time, batch_size, input_size]`.
|
|
||||||
layer_3 = tf.reshape(layer_3, [-1, batch_size, Config.n_hidden_3])
|
|
||||||
|
|
||||||
# Run through parametrized RNN implementation, as we use different RNNs
|
|
||||||
# for training and inference
|
|
||||||
output, output_state = rnn_impl(layer_3, seq_length, previous_state, reuse)
|
|
||||||
|
|
||||||
# Reshape output from a tensor of shape [n_steps, batch_size, n_cell_dim]
|
|
||||||
# to a tensor of shape [n_steps*batch_size, n_cell_dim]
|
|
||||||
output = tf.reshape(output, [-1, Config.n_cell_dim])
|
|
||||||
layers['rnn_output'] = output
|
|
||||||
layers['rnn_output_state'] = output_state
|
|
||||||
|
|
||||||
# Now we feed `output` to the fifth hidden layer with clipped RELU activation
|
|
||||||
layers['layer_5'] = layer_5 = dense('layer_5', output, Config.n_hidden_5, dropout_rate=dropout[5], layer_norm=FLAGS.layer_norm)
|
|
||||||
|
|
||||||
# Now we apply a final linear layer creating `n_classes` dimensional vectors, the logits.
|
|
||||||
layers['layer_6'] = layer_6 = dense('layer_6', layer_5, Config.n_hidden_6, relu=False)
|
|
||||||
|
|
||||||
# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
|
|
||||||
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
|
|
||||||
# Note, that this differs from the input in that it is time-major.
|
|
||||||
layer_6 = tf.reshape(layer_6, [-1, batch_size, Config.n_hidden_6], name='raw_logits')
|
|
||||||
layers['raw_logits'] = layer_6
|
|
||||||
|
|
||||||
# Output shape: [n_steps, batch_size, n_hidden_6]
|
|
||||||
return layer_6, layers
|
|
||||||
|
|
||||||
|
|
||||||
# Accuracy and Loss
|
# Accuracy and Loss
|
||||||
# =================
|
# =================
|
||||||
@ -678,255 +502,6 @@ def train():
|
|||||||
log_debug('Session closed.')
|
log_debug('Session closed.')
|
||||||
|
|
||||||
|
|
||||||
def test():
|
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
|
||||||
if FLAGS.test_output_file:
|
|
||||||
save_samples_json(samples, FLAGS.test_output_file)
|
|
||||||
|
|
||||||
|
|
||||||
def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|
||||||
batch_size = batch_size if batch_size > 0 else None
|
|
||||||
|
|
||||||
# Create feature computation graph
|
|
||||||
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
|
|
||||||
samples = tf.expand_dims(input_samples, -1)
|
|
||||||
mfccs, _ = audio_to_features(samples, FLAGS.audio_sample_rate)
|
|
||||||
mfccs = tf.identity(mfccs, name='mfccs')
|
|
||||||
|
|
||||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
|
||||||
# This shape is read by the native_client in DS_CreateModel to know the
|
|
||||||
# value of n_steps, n_context and n_input. Make sure you update the code
|
|
||||||
# there if this shape is changed.
|
|
||||||
input_tensor = tfv1.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2 * Config.n_context + 1, Config.n_input], name='input_node')
|
|
||||||
seq_length = tfv1.placeholder(tf.int32, [batch_size], name='input_lengths')
|
|
||||||
|
|
||||||
if batch_size <= 0:
|
|
||||||
# no state management since n_step is expected to be dynamic too (see below)
|
|
||||||
previous_state = None
|
|
||||||
else:
|
|
||||||
previous_state_c = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
|
|
||||||
previous_state_h = tfv1.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
|
|
||||||
|
|
||||||
previous_state = tf.nn.rnn_cell.LSTMStateTuple(previous_state_c, previous_state_h)
|
|
||||||
|
|
||||||
# One rate per layer
|
|
||||||
no_dropout = [None] * 6
|
|
||||||
|
|
||||||
if tflite:
|
|
||||||
rnn_impl = rnn_impl_static_rnn
|
|
||||||
else:
|
|
||||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
|
||||||
|
|
||||||
logits, layers = create_model(batch_x=input_tensor,
|
|
||||||
batch_size=batch_size,
|
|
||||||
seq_length=seq_length if not FLAGS.export_tflite else None,
|
|
||||||
dropout=no_dropout,
|
|
||||||
previous_state=previous_state,
|
|
||||||
overlap=False,
|
|
||||||
rnn_impl=rnn_impl)
|
|
||||||
|
|
||||||
# TF Lite runtime will check that input dimensions are 1, 2 or 4
|
|
||||||
# by default we get 3, the middle one being batch_size which is forced to
|
|
||||||
# one on inference graph, so remove that dimension
|
|
||||||
if tflite:
|
|
||||||
logits = tf.squeeze(logits, [1])
|
|
||||||
|
|
||||||
# Apply softmax for CTC decoder
|
|
||||||
probs = tf.nn.softmax(logits, name='logits')
|
|
||||||
|
|
||||||
if batch_size <= 0:
|
|
||||||
if tflite:
|
|
||||||
raise NotImplementedError('dynamic batch_size does not support tflite nor streaming')
|
|
||||||
if n_steps > 0:
|
|
||||||
raise NotImplementedError('dynamic batch_size expect n_steps to be dynamic too')
|
|
||||||
return (
|
|
||||||
{
|
|
||||||
'input': input_tensor,
|
|
||||||
'input_lengths': seq_length,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'outputs': probs,
|
|
||||||
},
|
|
||||||
layers
|
|
||||||
)
|
|
||||||
|
|
||||||
new_state_c, new_state_h = layers['rnn_output_state']
|
|
||||||
new_state_c = tf.identity(new_state_c, name='new_state_c')
|
|
||||||
new_state_h = tf.identity(new_state_h, name='new_state_h')
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
'input': input_tensor,
|
|
||||||
'previous_state_c': previous_state_c,
|
|
||||||
'previous_state_h': previous_state_h,
|
|
||||||
'input_samples': input_samples,
|
|
||||||
}
|
|
||||||
|
|
||||||
if not FLAGS.export_tflite:
|
|
||||||
inputs['input_lengths'] = seq_length
|
|
||||||
|
|
||||||
outputs = {
|
|
||||||
'outputs': probs,
|
|
||||||
'new_state_c': new_state_c,
|
|
||||||
'new_state_h': new_state_h,
|
|
||||||
'mfccs': mfccs,
|
|
||||||
}
|
|
||||||
|
|
||||||
return inputs, outputs, layers
|
|
||||||
|
|
||||||
|
|
||||||
def file_relative_read(fname):
|
|
||||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
|
||||||
|
|
||||||
|
|
||||||
def export():
|
|
||||||
r'''
|
|
||||||
Restores the trained variables into a simpler graph that will be exported for serving.
|
|
||||||
'''
|
|
||||||
log_info('Exporting the model...')
|
|
||||||
|
|
||||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
|
||||||
|
|
||||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
|
||||||
assert graph_version > 0
|
|
||||||
|
|
||||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
|
||||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
|
||||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
|
||||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
|
||||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
|
||||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.Serialize()], name='metadata_alphabet')
|
|
||||||
|
|
||||||
if FLAGS.export_language:
|
|
||||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
|
||||||
|
|
||||||
# Prevent further graph changes
|
|
||||||
tfv1.get_default_graph().finalize()
|
|
||||||
|
|
||||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, tf.Tensor)]
|
|
||||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, tf.Operation)]
|
|
||||||
output_names = output_names_tensors + output_names_ops
|
|
||||||
|
|
||||||
with tf.Session() as session:
|
|
||||||
# Restore variables from checkpoint
|
|
||||||
load_graph_for_evaluation(session)
|
|
||||||
|
|
||||||
output_filename = FLAGS.export_file_name + '.pb'
|
|
||||||
if FLAGS.remove_export:
|
|
||||||
if isdir_remote(FLAGS.export_dir):
|
|
||||||
log_info('Removing old export')
|
|
||||||
remove_remote(FLAGS.export_dir)
|
|
||||||
|
|
||||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
|
||||||
|
|
||||||
if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
|
|
||||||
os.makedirs(FLAGS.export_dir)
|
|
||||||
|
|
||||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
|
||||||
sess=session,
|
|
||||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
|
||||||
output_node_names=output_names)
|
|
||||||
|
|
||||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
|
||||||
graph_def=frozen_graph,
|
|
||||||
dest_nodes=output_names)
|
|
||||||
|
|
||||||
if not FLAGS.export_tflite:
|
|
||||||
with open_remote(output_graph_path, 'wb') as fout:
|
|
||||||
fout.write(frozen_graph.SerializeToString())
|
|
||||||
else:
|
|
||||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
|
||||||
|
|
||||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
|
||||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
||||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
|
||||||
converter.allow_custom_ops = True
|
|
||||||
tflite_model = converter.convert()
|
|
||||||
|
|
||||||
with open_remote(output_tflite_path, 'wb') as fout:
|
|
||||||
fout.write(tflite_model)
|
|
||||||
|
|
||||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
|
||||||
|
|
||||||
metadata_fname = os.path.join(FLAGS.export_dir, '{}_{}_{}.md'.format(
|
|
||||||
FLAGS.export_author_id,
|
|
||||||
FLAGS.export_model_name,
|
|
||||||
FLAGS.export_model_version))
|
|
||||||
|
|
||||||
model_runtime = 'tflite' if FLAGS.export_tflite else 'tensorflow'
|
|
||||||
with open_remote(metadata_fname, 'w') as f:
|
|
||||||
f.write('---\n')
|
|
||||||
f.write('author: {}\n'.format(FLAGS.export_author_id))
|
|
||||||
f.write('model_name: {}\n'.format(FLAGS.export_model_name))
|
|
||||||
f.write('model_version: {}\n'.format(FLAGS.export_model_version))
|
|
||||||
f.write('contact_info: {}\n'.format(FLAGS.export_contact_info))
|
|
||||||
f.write('license: {}\n'.format(FLAGS.export_license))
|
|
||||||
f.write('language: {}\n'.format(FLAGS.export_language))
|
|
||||||
f.write('runtime: {}\n'.format(model_runtime))
|
|
||||||
f.write('min_ds_version: {}\n'.format(FLAGS.export_min_ds_version))
|
|
||||||
f.write('max_ds_version: {}\n'.format(FLAGS.export_max_ds_version))
|
|
||||||
f.write('acoustic_model_url: <replace this with a publicly available URL of the acoustic model>\n')
|
|
||||||
f.write('scorer_url: <replace this with a publicly available URL of the scorer, if present>\n')
|
|
||||||
f.write('---\n')
|
|
||||||
f.write('{}\n'.format(FLAGS.export_description))
|
|
||||||
|
|
||||||
log_info('Model metadata file saved to {}. Before submitting the exported model for publishing make sure all information in the metadata file is correct, and complete the URL fields.'.format(metadata_fname))
|
|
||||||
|
|
||||||
|
|
||||||
def package_zip():
|
|
||||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
|
||||||
export_dir = os.path.join(os.path.abspath(FLAGS.export_dir), '') # Force ending '/'
|
|
||||||
if is_remote_path(export_dir):
|
|
||||||
log_error("Cannot package remote path zip %s. Please do this manually." % export_dir)
|
|
||||||
return
|
|
||||||
|
|
||||||
zip_filename = os.path.dirname(export_dir)
|
|
||||||
|
|
||||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
|
||||||
|
|
||||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
|
||||||
log_info('Exported packaged model {}'.format(archive))
|
|
||||||
|
|
||||||
|
|
||||||
def do_single_file_inference(input_file_path):
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
|
||||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
|
||||||
|
|
||||||
# Restore variables from training checkpoint
|
|
||||||
load_graph_for_evaluation(session)
|
|
||||||
|
|
||||||
features, features_len = audiofile_to_features(input_file_path)
|
|
||||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
|
||||||
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
|
||||||
|
|
||||||
# Add batch dimension
|
|
||||||
features = tf.expand_dims(features, 0)
|
|
||||||
features_len = tf.expand_dims(features_len, 0)
|
|
||||||
|
|
||||||
# Evaluate
|
|
||||||
features = create_overlapping_windows(features).eval(session=session)
|
|
||||||
features_len = features_len.eval(session=session)
|
|
||||||
|
|
||||||
probs = outputs['outputs'].eval(feed_dict={
|
|
||||||
inputs['input']: features,
|
|
||||||
inputs['input_lengths']: features_len,
|
|
||||||
inputs['previous_state_c']: previous_state_c,
|
|
||||||
inputs['previous_state_h']: previous_state_h,
|
|
||||||
}, session=session)
|
|
||||||
|
|
||||||
probs = np.squeeze(probs)
|
|
||||||
|
|
||||||
if FLAGS.scorer_path:
|
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
|
||||||
FLAGS.scorer_path, Config.alphabet)
|
|
||||||
else:
|
|
||||||
scorer = None
|
|
||||||
decoded = ctc_beam_search_decoder(probs, Config.alphabet, FLAGS.beam_width,
|
|
||||||
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
|
||||||
cutoff_top_n=FLAGS.cutoff_top_n)
|
|
||||||
# Print highest probability result
|
|
||||||
print(decoded[0][1])
|
|
||||||
|
|
||||||
|
|
||||||
def early_training_checks():
|
def early_training_checks():
|
||||||
# Check for proper scorer early
|
# Check for proper scorer early
|
||||||
if FLAGS.scorer_path:
|
if FLAGS.scorer_path:
|
||||||
@ -948,38 +523,51 @@ def main(_):
|
|||||||
initialize_globals()
|
initialize_globals()
|
||||||
early_training_checks()
|
early_training_checks()
|
||||||
|
|
||||||
|
def deprecated_msg(prefix):
|
||||||
|
return ('{} Using the training script as a generic driver for all training '
|
||||||
|
'related functionality is deprecated and will be removed soon. Use '
|
||||||
|
'the specific scripts: train/evaluate/export/training_graph_inference.'.format(prefix))
|
||||||
|
|
||||||
if FLAGS.train_files:
|
if FLAGS.train_files:
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
tfv1.set_random_seed(FLAGS.random_seed)
|
tfv1.set_random_seed(FLAGS.random_seed)
|
||||||
train()
|
train()
|
||||||
|
else:
|
||||||
|
log_warn(deprecated_msg('Calling training script without --train_files.'))
|
||||||
|
|
||||||
if FLAGS.test_files:
|
if FLAGS.test_files:
|
||||||
|
log_warn(deprecated_msg('Specifying --test_files when calling train script.'))
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
test()
|
evaluate.test()
|
||||||
|
|
||||||
if FLAGS.export_dir and not FLAGS.export_zip:
|
if FLAGS.export_dir:
|
||||||
|
log_warn(deprecated_msg('Specifying --export_dir when calling train script.'))
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
export()
|
|
||||||
|
|
||||||
if FLAGS.export_zip:
|
if not FLAGS.export_zip:
|
||||||
tfv1.reset_default_graph()
|
# Export to folder
|
||||||
FLAGS.export_tflite = True
|
export.export()
|
||||||
|
else:
|
||||||
|
# Export and zip, TFLite only, creates package readable by Java example app
|
||||||
|
FLAGS.export_tflite = True
|
||||||
|
|
||||||
if listdir_remote(FLAGS.export_dir):
|
if listdir_remote(FLAGS.export_dir):
|
||||||
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
export()
|
export.export()
|
||||||
package_zip()
|
export.package_zip()
|
||||||
|
|
||||||
if FLAGS.one_shot_infer:
|
if FLAGS.one_shot_infer:
|
||||||
|
log_warn(deprecated_msg('Specifying --one_shot_infer when calling train script.'))
|
||||||
tfv1.reset_default_graph()
|
tfv1.reset_default_graph()
|
||||||
do_single_file_inference(FLAGS.one_shot_infer)
|
training_graph_inference.do_single_file_inference(FLAGS.one_shot_infer)
|
||||||
|
|
||||||
|
|
||||||
def run_script():
|
def run_script():
|
||||||
create_flags()
|
create_flags()
|
||||||
absl.app.run(main)
|
absl.app.run(main)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_script()
|
run_script()
|
||||||
|
77
training/deepspeech_training/training_graph_inference.py
Normal file
77
training/deepspeech_training/training_graph_inference.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||||
|
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||||
|
|
||||||
|
import absl.app
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
|
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||||
|
from .deepspeech_model import create_inference_graph, create_overlapping_windows
|
||||||
|
from .util.checkpoints import load_graph_for_evaluation
|
||||||
|
from .util.config import Config, initialize_globals
|
||||||
|
from .util.feeding import audiofile_to_features
|
||||||
|
from .util.flags import create_flags, FLAGS
|
||||||
|
from .util.logging import log_error
|
||||||
|
|
||||||
|
|
||||||
|
def do_single_file_inference(input_file_path):
|
||||||
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
|
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||||
|
|
||||||
|
# Restore variables from training checkpoint
|
||||||
|
load_graph_for_evaluation(session)
|
||||||
|
|
||||||
|
features, features_len = audiofile_to_features(input_file_path)
|
||||||
|
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||||
|
previous_state_h = np.zeros([1, Config.n_cell_dim])
|
||||||
|
|
||||||
|
# Add batch dimension
|
||||||
|
features = tf.expand_dims(features, 0)
|
||||||
|
features_len = tf.expand_dims(features_len, 0)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
features = create_overlapping_windows(features).eval(session=session)
|
||||||
|
features_len = features_len.eval(session=session)
|
||||||
|
|
||||||
|
probs = outputs['outputs'].eval(feed_dict={
|
||||||
|
inputs['input']: features,
|
||||||
|
inputs['input_lengths']: features_len,
|
||||||
|
inputs['previous_state_c']: previous_state_c,
|
||||||
|
inputs['previous_state_h']: previous_state_h,
|
||||||
|
}, session=session)
|
||||||
|
|
||||||
|
probs = np.squeeze(probs)
|
||||||
|
|
||||||
|
if FLAGS.scorer_path:
|
||||||
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||||
|
FLAGS.scorer_path, Config.alphabet)
|
||||||
|
else:
|
||||||
|
scorer = None
|
||||||
|
decoded = ctc_beam_search_decoder(probs, Config.alphabet, FLAGS.beam_width,
|
||||||
|
scorer=scorer, cutoff_prob=FLAGS.cutoff_prob,
|
||||||
|
cutoff_top_n=FLAGS.cutoff_top_n)
|
||||||
|
# Print highest probability result
|
||||||
|
print(decoded[0][1])
|
||||||
|
|
||||||
|
|
||||||
|
def main(_):
|
||||||
|
initialize_globals()
|
||||||
|
|
||||||
|
if FLAGS.one_shot_infer:
|
||||||
|
tfv1.reset_default_graph()
|
||||||
|
do_single_file_inference(FLAGS.one_shot_infer)
|
||||||
|
else:
|
||||||
|
log_error('Calling training_graph_inference script directly but no --one_shot_infer input audio file specified')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
create_flags()
|
||||||
|
absl.app.run(main)
|
@ -15,6 +15,7 @@ from .helpers import parse_file_size
|
|||||||
from .augmentations import parse_augmentations
|
from .augmentations import parse_augmentations
|
||||||
from .io import path_exists_remote
|
from .io import path_exists_remote
|
||||||
|
|
||||||
|
|
||||||
class ConfigSingleton:
|
class ConfigSingleton:
|
||||||
_config = None
|
_config = None
|
||||||
|
|
||||||
|
@ -77,5 +77,11 @@ def remove_remote(filename):
|
|||||||
"""
|
"""
|
||||||
Wrapper that can remove local and remote files like `gs://...`
|
Wrapper that can remove local and remote files like `gs://...`
|
||||||
"""
|
"""
|
||||||
# Conditional import
|
return gfile.remove(filename)
|
||||||
return gfile.remove_remote(filename)
|
|
||||||
|
|
||||||
|
def rmtree_remote(foldername):
|
||||||
|
"""
|
||||||
|
Wrapper that can remove local and remote directories like `gs://...`
|
||||||
|
"""
|
||||||
|
return gfile.rmtree(foldername)
|
||||||
|
Loading…
Reference in New Issue
Block a user