Make DeepSpeech.py importable without side-effects
This commit is contained in:
parent
dc13d4be06
commit
0a1d3d49ca
250
DeepSpeech.py
250
DeepSpeech.py
@ -1026,139 +1026,139 @@ def train():
|
|||||||
|
|
||||||
return train_wer, dev_wer, hibernation_path
|
return train_wer, dev_wer, hibernation_path
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
# As everything is prepared, we are now able to do the training.
|
# As everything is prepared, we are now able to do the training.
|
||||||
# Define CPU as device on which the muti-gpu training is orchestrated
|
# Define CPU as device on which the muti-gpu training is orchestrated
|
||||||
with tf.device('/cpu:0'):
|
|
||||||
# Take start time for time measurement
|
|
||||||
time_started = datetime.datetime.utcnow()
|
|
||||||
|
|
||||||
# Train the network
|
|
||||||
last_train_wer, last_dev_wer, hibernation_path = train()
|
|
||||||
|
|
||||||
# Take final time for time measurement
|
|
||||||
time_finished = datetime.datetime.utcnow()
|
|
||||||
|
|
||||||
# Calculate duration in seconds
|
|
||||||
duration = time_finished - time_started
|
|
||||||
duration = duration.days * 86400 + duration.seconds
|
|
||||||
|
|
||||||
# Finally the model is tested against some unbiased data-set
|
|
||||||
print "Testing model"
|
|
||||||
result = run_set('test', model_path=hibernation_path, query_report=True)
|
|
||||||
test_wer = print_report("Test", result)
|
|
||||||
|
|
||||||
|
|
||||||
# Finally, we restore the trained variables into a simpler graph that we can export for serving.
|
|
||||||
# Don't export a model if no export directory has been set
|
|
||||||
if export_dir:
|
|
||||||
with tf.device('/cpu:0'):
|
with tf.device('/cpu:0'):
|
||||||
tf.reset_default_graph()
|
# Take start time for time measurement
|
||||||
session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
|
time_started = datetime.datetime.utcnow()
|
||||||
|
|
||||||
# Run inference
|
# Train the network
|
||||||
|
last_train_wer, last_dev_wer, hibernation_path = train()
|
||||||
|
|
||||||
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
|
# Take final time for time measurement
|
||||||
input_tensor = tf.placeholder(tf.float32, [None, None, n_input + 2*n_input*n_context])
|
time_finished = datetime.datetime.utcnow()
|
||||||
|
|
||||||
# Calculate input sequence length. This is done by tiling n_steps, batch_size times.
|
# Calculate duration in seconds
|
||||||
# If there are multiple sequences, it is assumed they are padded with zeros to be of
|
duration = time_finished - time_started
|
||||||
# the same length.
|
duration = duration.days * 86400 + duration.seconds
|
||||||
n_items = tf.slice(tf.shape(input_tensor), [0], [1])
|
|
||||||
n_steps = tf.slice(tf.shape(input_tensor), [1], [1])
|
|
||||||
seq_length = tf.tile(n_steps, n_items)
|
|
||||||
|
|
||||||
# Calculate the logits of the batch using BiRNN
|
# Finally the model is tested against some unbiased data-set
|
||||||
logits = BiRNN(input_tensor, tf.to_int64(seq_length), 0)
|
print "Testing model"
|
||||||
|
result = run_set('test', model_path=hibernation_path, query_report=True)
|
||||||
# Beam search decode the batch
|
test_wer = print_report("Test", result)
|
||||||
decoded, _ = ctc_ops.ctc_beam_search_decoder(logits, seq_length, merge_repeated=False)
|
|
||||||
decoded = tf.convert_to_tensor(
|
|
||||||
[tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in decoded])
|
|
||||||
|
|
||||||
# TODO: Transform the decoded output to a string
|
|
||||||
|
|
||||||
# Create a saver and exporter using variables from the above newly created graph
|
|
||||||
saver = tf.train.Saver(tf.all_variables())
|
|
||||||
model_exporter = exporter.Exporter(saver)
|
|
||||||
|
|
||||||
# Restore variables from training checkpoint
|
|
||||||
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
|
|
||||||
# over-fitting, we may want to restore an earlier checkpoint.
|
|
||||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
|
|
||||||
saver.restore(session, checkpoint.model_checkpoint_path)
|
|
||||||
print 'Restored checkpoint at training epoch %d' % (int(checkpoint.model_checkpoint_path.split('-')[-1]) + 1)
|
|
||||||
|
|
||||||
# Initialise the model exporter and export the model
|
|
||||||
model_exporter.init(session.graph.as_graph_def(),
|
|
||||||
named_graph_signatures = {
|
|
||||||
'inputs': exporter.generic_signature(
|
|
||||||
{ 'input': input_tensor }),
|
|
||||||
'outputs': exporter.generic_signature(
|
|
||||||
{ 'outputs': decoded})})
|
|
||||||
if remove_export:
|
|
||||||
actual_export_dir = os.path.join(export_dir, '%08d' % export_version)
|
|
||||||
if os.path.isdir(actual_export_dir):
|
|
||||||
print 'Removing old export'
|
|
||||||
shutil.rmtree(actual_export_dir)
|
|
||||||
try:
|
|
||||||
model_exporter.export(export_dir, tf.constant(export_version), session)
|
|
||||||
print 'Model exported at %s' % (export_dir)
|
|
||||||
except RuntimeError:
|
|
||||||
print sys.exc_info()[1]
|
|
||||||
|
|
||||||
|
|
||||||
# Logging Hyper Parameters and Results
|
# Finally, we restore the trained variables into a simpler graph that we can export for serving.
|
||||||
# ====================================
|
# Don't export a model if no export directory has been set
|
||||||
|
if export_dir:
|
||||||
|
with tf.device('/cpu:0'):
|
||||||
|
tf.reset_default_graph()
|
||||||
|
session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
|
||||||
|
|
||||||
# Now, as training and test are done, we persist the results alongside
|
# Run inference
|
||||||
# with the involved hyper parameters for further reporting.
|
|
||||||
data_sets = read_data_sets()
|
|
||||||
|
|
||||||
with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:
|
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||||
json.dump({
|
input_tensor = tf.placeholder(tf.float32, [None, None, n_input + 2*n_input*n_context])
|
||||||
'context': {
|
|
||||||
'time_started': time_started.isoformat(),
|
# Calculate input sequence length. This is done by tiling n_steps, batch_size times.
|
||||||
'time_finished': time_finished.isoformat(),
|
# If there are multiple sequences, it is assumed they are padded with zeros to be of
|
||||||
'git_hash': get_git_revision_hash(),
|
# the same length.
|
||||||
'git_branch': get_git_branch()
|
n_items = tf.slice(tf.shape(input_tensor), [0], [1])
|
||||||
},
|
n_steps = tf.slice(tf.shape(input_tensor), [1], [1])
|
||||||
'parameters': {
|
seq_length = tf.tile(n_steps, n_items)
|
||||||
'learning_rate': learning_rate,
|
|
||||||
'beta1': beta1,
|
# Calculate the logits of the batch using BiRNN
|
||||||
'beta2': beta2,
|
logits = BiRNN(input_tensor, tf.to_int64(seq_length), 0)
|
||||||
'epsilon': epsilon,
|
|
||||||
'epochs': epochs,
|
# Beam search decode the batch
|
||||||
'train_batch_size': train_batch_size,
|
decoded, _ = ctc_ops.ctc_beam_search_decoder(logits, seq_length, merge_repeated=False)
|
||||||
'dev_batch_size': dev_batch_size,
|
decoded = tf.convert_to_tensor(
|
||||||
'test_batch_size': test_batch_size,
|
[tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in decoded])
|
||||||
'validation_step': validation_step,
|
|
||||||
'dropout_rate': dropout_rate,
|
# TODO: Transform the decoded output to a string
|
||||||
'relu_clip': relu_clip,
|
|
||||||
'n_input': n_input,
|
# Create a saver and exporter using variables from the above newly created graph
|
||||||
'n_context': n_context,
|
saver = tf.train.Saver(tf.all_variables())
|
||||||
'n_hidden_1': n_hidden_1,
|
model_exporter = exporter.Exporter(saver)
|
||||||
'n_hidden_2': n_hidden_2,
|
|
||||||
'n_hidden_3': n_hidden_3,
|
# Restore variables from training checkpoint
|
||||||
'n_hidden_5': n_hidden_5,
|
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
|
||||||
'n_hidden_6': n_hidden_6,
|
# over-fitting, we may want to restore an earlier checkpoint.
|
||||||
'n_cell_dim': n_cell_dim,
|
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
|
||||||
'n_character': n_character,
|
saver.restore(session, checkpoint.model_checkpoint_path)
|
||||||
'total_batches_train': data_sets.train.total_batches,
|
print 'Restored checkpoint at training epoch %d' % (int(checkpoint.model_checkpoint_path.split('-')[-1]) + 1)
|
||||||
'total_batches_validation': data_sets.dev.total_batches,
|
|
||||||
'total_batches_test': data_sets.test.total_batches,
|
# Initialise the model exporter and export the model
|
||||||
'data_set': {
|
model_exporter.init(session.graph.as_graph_def(),
|
||||||
'name': ds_importer
|
named_graph_signatures = {
|
||||||
|
'inputs': exporter.generic_signature(
|
||||||
|
{ 'input': input_tensor }),
|
||||||
|
'outputs': exporter.generic_signature(
|
||||||
|
{ 'outputs': decoded})})
|
||||||
|
if remove_export:
|
||||||
|
actual_export_dir = os.path.join(export_dir, '%08d' % export_version)
|
||||||
|
if os.path.isdir(actual_export_dir):
|
||||||
|
print 'Removing old export'
|
||||||
|
shutil.rmtree(actual_export_dir)
|
||||||
|
try:
|
||||||
|
model_exporter.export(export_dir, tf.constant(export_version), session)
|
||||||
|
print 'Model exported at %s' % (export_dir)
|
||||||
|
except RuntimeError:
|
||||||
|
print sys.exc_info()[1]
|
||||||
|
|
||||||
|
|
||||||
|
# Logging Hyper Parameters and Results
|
||||||
|
# ====================================
|
||||||
|
|
||||||
|
# Now, as training and test are done, we persist the results alongside
|
||||||
|
# with the involved hyper parameters for further reporting.
|
||||||
|
data_sets = read_data_sets()
|
||||||
|
|
||||||
|
with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:
|
||||||
|
json.dump({
|
||||||
|
'context': {
|
||||||
|
'time_started': time_started.isoformat(),
|
||||||
|
'time_finished': time_finished.isoformat(),
|
||||||
|
'git_hash': get_git_revision_hash(),
|
||||||
|
'git_branch': get_git_branch()
|
||||||
|
},
|
||||||
|
'parameters': {
|
||||||
|
'learning_rate': learning_rate,
|
||||||
|
'beta1': beta1,
|
||||||
|
'beta2': beta2,
|
||||||
|
'epsilon': epsilon,
|
||||||
|
'epochs': epochs,
|
||||||
|
'train_batch_size': train_batch_size,
|
||||||
|
'dev_batch_size': dev_batch_size,
|
||||||
|
'test_batch_size': test_batch_size,
|
||||||
|
'validation_step': validation_step,
|
||||||
|
'dropout_rate': dropout_rate,
|
||||||
|
'relu_clip': relu_clip,
|
||||||
|
'n_input': n_input,
|
||||||
|
'n_context': n_context,
|
||||||
|
'n_hidden_1': n_hidden_1,
|
||||||
|
'n_hidden_2': n_hidden_2,
|
||||||
|
'n_hidden_3': n_hidden_3,
|
||||||
|
'n_hidden_5': n_hidden_5,
|
||||||
|
'n_hidden_6': n_hidden_6,
|
||||||
|
'n_cell_dim': n_cell_dim,
|
||||||
|
'n_character': n_character,
|
||||||
|
'total_batches_train': data_sets.train.total_batches,
|
||||||
|
'total_batches_validation': data_sets.dev.total_batches,
|
||||||
|
'total_batches_test': data_sets.test.total_batches,
|
||||||
|
'data_set': {
|
||||||
|
'name': ds_importer
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'results': {
|
||||||
|
'duration': duration,
|
||||||
|
'last_train_wer': last_train_wer,
|
||||||
|
'last_validation_wer': last_dev_wer,
|
||||||
|
'test_wer': test_wer
|
||||||
}
|
}
|
||||||
},
|
}, dump_file, sort_keys=True, indent=4)
|
||||||
'results': {
|
|
||||||
'duration': duration,
|
|
||||||
'last_train_wer': last_train_wer,
|
|
||||||
'last_validation_wer': last_dev_wer,
|
|
||||||
'test_wer': test_wer
|
|
||||||
}
|
|
||||||
}, dump_file, sort_keys=True, indent=4)
|
|
||||||
|
|
||||||
# Let's also re-populate a central JS file, that contains all the dumps at once.
|
# Let's also re-populate a central JS file, that contains all the dumps at once.
|
||||||
merge_logs(logs_dir)
|
merge_logs(logs_dir)
|
||||||
maybe_publish()
|
maybe_publish()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user