Make DeepSpeech.py importable without side-effects

This commit is contained in:
Reuben Morais 2017-02-02 21:26:27 -02:00
parent dc13d4be06
commit 0a1d3d49ca

View File

@ -1026,139 +1026,139 @@ def train():
return train_wer, dev_wer, hibernation_path
# As everything is prepared, we are now able to do the training.
# Define CPU as device on which the muti-gpu training is orchestrated
with tf.device('/cpu:0'):
# Take start time for time measurement
time_started = datetime.datetime.utcnow()
# Train the network
last_train_wer, last_dev_wer, hibernation_path = train()
# Take final time for time measurement
time_finished = datetime.datetime.utcnow()
# Calculate duration in seconds
duration = time_finished - time_started
duration = duration.days * 86400 + duration.seconds
# Finally the model is tested against some unbiased data-set
print "Testing model"
result = run_set('test', model_path=hibernation_path, query_report=True)
test_wer = print_report("Test", result)
# Finally, we restore the trained variables into a simpler graph that we can export for serving.
# Don't export a model if no export directory has been set
if export_dir:
if __name__ == "__main__":
# As everything is prepared, we are now able to do the training.
# Define CPU as device on which the muti-gpu training is orchestrated
with tf.device('/cpu:0'):
tf.reset_default_graph()
session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
# Take start time for time measurement
time_started = datetime.datetime.utcnow()
# Run inference
# Train the network
last_train_wer, last_dev_wer, hibernation_path = train()
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
input_tensor = tf.placeholder(tf.float32, [None, None, n_input + 2*n_input*n_context])
# Take final time for time measurement
time_finished = datetime.datetime.utcnow()
# Calculate input sequence length. This is done by tiling n_steps, batch_size times.
# If there are multiple sequences, it is assumed they are padded with zeros to be of
# the same length.
n_items = tf.slice(tf.shape(input_tensor), [0], [1])
n_steps = tf.slice(tf.shape(input_tensor), [1], [1])
seq_length = tf.tile(n_steps, n_items)
# Calculate duration in seconds
duration = time_finished - time_started
duration = duration.days * 86400 + duration.seconds
# Calculate the logits of the batch using BiRNN
logits = BiRNN(input_tensor, tf.to_int64(seq_length), 0)
# Beam search decode the batch
decoded, _ = ctc_ops.ctc_beam_search_decoder(logits, seq_length, merge_repeated=False)
decoded = tf.convert_to_tensor(
[tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in decoded])
# TODO: Transform the decoded output to a string
# Create a saver and exporter using variables from the above newly created graph
saver = tf.train.Saver(tf.all_variables())
model_exporter = exporter.Exporter(saver)
# Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
# over-fitting, we may want to restore an earlier checkpoint.
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
saver.restore(session, checkpoint.model_checkpoint_path)
print 'Restored checkpoint at training epoch %d' % (int(checkpoint.model_checkpoint_path.split('-')[-1]) + 1)
# Initialise the model exporter and export the model
model_exporter.init(session.graph.as_graph_def(),
named_graph_signatures = {
'inputs': exporter.generic_signature(
{ 'input': input_tensor }),
'outputs': exporter.generic_signature(
{ 'outputs': decoded})})
if remove_export:
actual_export_dir = os.path.join(export_dir, '%08d' % export_version)
if os.path.isdir(actual_export_dir):
print 'Removing old export'
shutil.rmtree(actual_export_dir)
try:
model_exporter.export(export_dir, tf.constant(export_version), session)
print 'Model exported at %s' % (export_dir)
except RuntimeError:
print sys.exc_info()[1]
# Finally the model is tested against some unbiased data-set
print "Testing model"
result = run_set('test', model_path=hibernation_path, query_report=True)
test_wer = print_report("Test", result)
# Logging Hyper Parameters and Results
# ====================================
# Finally, we restore the trained variables into a simpler graph that we can export for serving.
# Don't export a model if no export directory has been set
if export_dir:
with tf.device('/cpu:0'):
tf.reset_default_graph()
session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
# Now, as training and test are done, we persist the results alongside
# with the involved hyper parameters for further reporting.
data_sets = read_data_sets()
# Run inference
with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:
json.dump({
'context': {
'time_started': time_started.isoformat(),
'time_finished': time_finished.isoformat(),
'git_hash': get_git_revision_hash(),
'git_branch': get_git_branch()
},
'parameters': {
'learning_rate': learning_rate,
'beta1': beta1,
'beta2': beta2,
'epsilon': epsilon,
'epochs': epochs,
'train_batch_size': train_batch_size,
'dev_batch_size': dev_batch_size,
'test_batch_size': test_batch_size,
'validation_step': validation_step,
'dropout_rate': dropout_rate,
'relu_clip': relu_clip,
'n_input': n_input,
'n_context': n_context,
'n_hidden_1': n_hidden_1,
'n_hidden_2': n_hidden_2,
'n_hidden_3': n_hidden_3,
'n_hidden_5': n_hidden_5,
'n_hidden_6': n_hidden_6,
'n_cell_dim': n_cell_dim,
'n_character': n_character,
'total_batches_train': data_sets.train.total_batches,
'total_batches_validation': data_sets.dev.total_batches,
'total_batches_test': data_sets.test.total_batches,
'data_set': {
'name': ds_importer
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
input_tensor = tf.placeholder(tf.float32, [None, None, n_input + 2*n_input*n_context])
# Calculate input sequence length. This is done by tiling n_steps, batch_size times.
# If there are multiple sequences, it is assumed they are padded with zeros to be of
# the same length.
n_items = tf.slice(tf.shape(input_tensor), [0], [1])
n_steps = tf.slice(tf.shape(input_tensor), [1], [1])
seq_length = tf.tile(n_steps, n_items)
# Calculate the logits of the batch using BiRNN
logits = BiRNN(input_tensor, tf.to_int64(seq_length), 0)
# Beam search decode the batch
decoded, _ = ctc_ops.ctc_beam_search_decoder(logits, seq_length, merge_repeated=False)
decoded = tf.convert_to_tensor(
[tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in decoded])
# TODO: Transform the decoded output to a string
# Create a saver and exporter using variables from the above newly created graph
saver = tf.train.Saver(tf.all_variables())
model_exporter = exporter.Exporter(saver)
# Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
# over-fitting, we may want to restore an earlier checkpoint.
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
saver.restore(session, checkpoint.model_checkpoint_path)
print 'Restored checkpoint at training epoch %d' % (int(checkpoint.model_checkpoint_path.split('-')[-1]) + 1)
# Initialise the model exporter and export the model
model_exporter.init(session.graph.as_graph_def(),
named_graph_signatures = {
'inputs': exporter.generic_signature(
{ 'input': input_tensor }),
'outputs': exporter.generic_signature(
{ 'outputs': decoded})})
if remove_export:
actual_export_dir = os.path.join(export_dir, '%08d' % export_version)
if os.path.isdir(actual_export_dir):
print 'Removing old export'
shutil.rmtree(actual_export_dir)
try:
model_exporter.export(export_dir, tf.constant(export_version), session)
print 'Model exported at %s' % (export_dir)
except RuntimeError:
print sys.exc_info()[1]
# Logging Hyper Parameters and Results
# ====================================
# Now, as training and test are done, we persist the results alongside
# with the involved hyper parameters for further reporting.
data_sets = read_data_sets()
with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:
json.dump({
'context': {
'time_started': time_started.isoformat(),
'time_finished': time_finished.isoformat(),
'git_hash': get_git_revision_hash(),
'git_branch': get_git_branch()
},
'parameters': {
'learning_rate': learning_rate,
'beta1': beta1,
'beta2': beta2,
'epsilon': epsilon,
'epochs': epochs,
'train_batch_size': train_batch_size,
'dev_batch_size': dev_batch_size,
'test_batch_size': test_batch_size,
'validation_step': validation_step,
'dropout_rate': dropout_rate,
'relu_clip': relu_clip,
'n_input': n_input,
'n_context': n_context,
'n_hidden_1': n_hidden_1,
'n_hidden_2': n_hidden_2,
'n_hidden_3': n_hidden_3,
'n_hidden_5': n_hidden_5,
'n_hidden_6': n_hidden_6,
'n_cell_dim': n_cell_dim,
'n_character': n_character,
'total_batches_train': data_sets.train.total_batches,
'total_batches_validation': data_sets.dev.total_batches,
'total_batches_test': data_sets.test.total_batches,
'data_set': {
'name': ds_importer
}
},
'results': {
'duration': duration,
'last_train_wer': last_train_wer,
'last_validation_wer': last_dev_wer,
'test_wer': test_wer
}
},
'results': {
'duration': duration,
'last_train_wer': last_train_wer,
'last_validation_wer': last_dev_wer,
'test_wer': test_wer
}
}, dump_file, sort_keys=True, indent=4)
}, dump_file, sort_keys=True, indent=4)
# Let's also re-populate a central JS file, that contains all the dumps at once.
merge_logs(logs_dir)
maybe_publish()
# Let's also re-populate a central JS file, that contains all the dumps at once.
merge_logs(logs_dir)
maybe_publish()