Make DeepSpeech.py importable without side-effects
This commit is contained in:
parent
dc13d4be06
commit
0a1d3d49ca
250
DeepSpeech.py
250
DeepSpeech.py
@ -1026,139 +1026,139 @@ def train():
|
||||
|
||||
return train_wer, dev_wer, hibernation_path
|
||||
|
||||
|
||||
# As everything is prepared, we are now able to do the training.
|
||||
# Define CPU as device on which the muti-gpu training is orchestrated
|
||||
with tf.device('/cpu:0'):
|
||||
# Take start time for time measurement
|
||||
time_started = datetime.datetime.utcnow()
|
||||
|
||||
# Train the network
|
||||
last_train_wer, last_dev_wer, hibernation_path = train()
|
||||
|
||||
# Take final time for time measurement
|
||||
time_finished = datetime.datetime.utcnow()
|
||||
|
||||
# Calculate duration in seconds
|
||||
duration = time_finished - time_started
|
||||
duration = duration.days * 86400 + duration.seconds
|
||||
|
||||
# Finally the model is tested against some unbiased data-set
|
||||
print "Testing model"
|
||||
result = run_set('test', model_path=hibernation_path, query_report=True)
|
||||
test_wer = print_report("Test", result)
|
||||
|
||||
|
||||
# Finally, we restore the trained variables into a simpler graph that we can export for serving.
|
||||
# Don't export a model if no export directory has been set
|
||||
if export_dir:
|
||||
if __name__ == "__main__":
|
||||
# As everything is prepared, we are now able to do the training.
|
||||
# Define CPU as device on which the muti-gpu training is orchestrated
|
||||
with tf.device('/cpu:0'):
|
||||
tf.reset_default_graph()
|
||||
session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
|
||||
# Take start time for time measurement
|
||||
time_started = datetime.datetime.utcnow()
|
||||
|
||||
# Run inference
|
||||
# Train the network
|
||||
last_train_wer, last_dev_wer, hibernation_path = train()
|
||||
|
||||
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
input_tensor = tf.placeholder(tf.float32, [None, None, n_input + 2*n_input*n_context])
|
||||
# Take final time for time measurement
|
||||
time_finished = datetime.datetime.utcnow()
|
||||
|
||||
# Calculate input sequence length. This is done by tiling n_steps, batch_size times.
|
||||
# If there are multiple sequences, it is assumed they are padded with zeros to be of
|
||||
# the same length.
|
||||
n_items = tf.slice(tf.shape(input_tensor), [0], [1])
|
||||
n_steps = tf.slice(tf.shape(input_tensor), [1], [1])
|
||||
seq_length = tf.tile(n_steps, n_items)
|
||||
# Calculate duration in seconds
|
||||
duration = time_finished - time_started
|
||||
duration = duration.days * 86400 + duration.seconds
|
||||
|
||||
# Calculate the logits of the batch using BiRNN
|
||||
logits = BiRNN(input_tensor, tf.to_int64(seq_length), 0)
|
||||
|
||||
# Beam search decode the batch
|
||||
decoded, _ = ctc_ops.ctc_beam_search_decoder(logits, seq_length, merge_repeated=False)
|
||||
decoded = tf.convert_to_tensor(
|
||||
[tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in decoded])
|
||||
|
||||
# TODO: Transform the decoded output to a string
|
||||
|
||||
# Create a saver and exporter using variables from the above newly created graph
|
||||
saver = tf.train.Saver(tf.all_variables())
|
||||
model_exporter = exporter.Exporter(saver)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
|
||||
# over-fitting, we may want to restore an earlier checkpoint.
|
||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
|
||||
saver.restore(session, checkpoint.model_checkpoint_path)
|
||||
print 'Restored checkpoint at training epoch %d' % (int(checkpoint.model_checkpoint_path.split('-')[-1]) + 1)
|
||||
|
||||
# Initialise the model exporter and export the model
|
||||
model_exporter.init(session.graph.as_graph_def(),
|
||||
named_graph_signatures = {
|
||||
'inputs': exporter.generic_signature(
|
||||
{ 'input': input_tensor }),
|
||||
'outputs': exporter.generic_signature(
|
||||
{ 'outputs': decoded})})
|
||||
if remove_export:
|
||||
actual_export_dir = os.path.join(export_dir, '%08d' % export_version)
|
||||
if os.path.isdir(actual_export_dir):
|
||||
print 'Removing old export'
|
||||
shutil.rmtree(actual_export_dir)
|
||||
try:
|
||||
model_exporter.export(export_dir, tf.constant(export_version), session)
|
||||
print 'Model exported at %s' % (export_dir)
|
||||
except RuntimeError:
|
||||
print sys.exc_info()[1]
|
||||
# Finally the model is tested against some unbiased data-set
|
||||
print "Testing model"
|
||||
result = run_set('test', model_path=hibernation_path, query_report=True)
|
||||
test_wer = print_report("Test", result)
|
||||
|
||||
|
||||
# Logging Hyper Parameters and Results
|
||||
# ====================================
|
||||
# Finally, we restore the trained variables into a simpler graph that we can export for serving.
|
||||
# Don't export a model if no export directory has been set
|
||||
if export_dir:
|
||||
with tf.device('/cpu:0'):
|
||||
tf.reset_default_graph()
|
||||
session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
|
||||
|
||||
# Now, as training and test are done, we persist the results alongside
|
||||
# with the involved hyper parameters for further reporting.
|
||||
data_sets = read_data_sets()
|
||||
# Run inference
|
||||
|
||||
with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:
|
||||
json.dump({
|
||||
'context': {
|
||||
'time_started': time_started.isoformat(),
|
||||
'time_finished': time_finished.isoformat(),
|
||||
'git_hash': get_git_revision_hash(),
|
||||
'git_branch': get_git_branch()
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': learning_rate,
|
||||
'beta1': beta1,
|
||||
'beta2': beta2,
|
||||
'epsilon': epsilon,
|
||||
'epochs': epochs,
|
||||
'train_batch_size': train_batch_size,
|
||||
'dev_batch_size': dev_batch_size,
|
||||
'test_batch_size': test_batch_size,
|
||||
'validation_step': validation_step,
|
||||
'dropout_rate': dropout_rate,
|
||||
'relu_clip': relu_clip,
|
||||
'n_input': n_input,
|
||||
'n_context': n_context,
|
||||
'n_hidden_1': n_hidden_1,
|
||||
'n_hidden_2': n_hidden_2,
|
||||
'n_hidden_3': n_hidden_3,
|
||||
'n_hidden_5': n_hidden_5,
|
||||
'n_hidden_6': n_hidden_6,
|
||||
'n_cell_dim': n_cell_dim,
|
||||
'n_character': n_character,
|
||||
'total_batches_train': data_sets.train.total_batches,
|
||||
'total_batches_validation': data_sets.dev.total_batches,
|
||||
'total_batches_test': data_sets.test.total_batches,
|
||||
'data_set': {
|
||||
'name': ds_importer
|
||||
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
|
||||
input_tensor = tf.placeholder(tf.float32, [None, None, n_input + 2*n_input*n_context])
|
||||
|
||||
# Calculate input sequence length. This is done by tiling n_steps, batch_size times.
|
||||
# If there are multiple sequences, it is assumed they are padded with zeros to be of
|
||||
# the same length.
|
||||
n_items = tf.slice(tf.shape(input_tensor), [0], [1])
|
||||
n_steps = tf.slice(tf.shape(input_tensor), [1], [1])
|
||||
seq_length = tf.tile(n_steps, n_items)
|
||||
|
||||
# Calculate the logits of the batch using BiRNN
|
||||
logits = BiRNN(input_tensor, tf.to_int64(seq_length), 0)
|
||||
|
||||
# Beam search decode the batch
|
||||
decoded, _ = ctc_ops.ctc_beam_search_decoder(logits, seq_length, merge_repeated=False)
|
||||
decoded = tf.convert_to_tensor(
|
||||
[tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in decoded])
|
||||
|
||||
# TODO: Transform the decoded output to a string
|
||||
|
||||
# Create a saver and exporter using variables from the above newly created graph
|
||||
saver = tf.train.Saver(tf.all_variables())
|
||||
model_exporter = exporter.Exporter(saver)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
|
||||
# over-fitting, we may want to restore an earlier checkpoint.
|
||||
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
|
||||
saver.restore(session, checkpoint.model_checkpoint_path)
|
||||
print 'Restored checkpoint at training epoch %d' % (int(checkpoint.model_checkpoint_path.split('-')[-1]) + 1)
|
||||
|
||||
# Initialise the model exporter and export the model
|
||||
model_exporter.init(session.graph.as_graph_def(),
|
||||
named_graph_signatures = {
|
||||
'inputs': exporter.generic_signature(
|
||||
{ 'input': input_tensor }),
|
||||
'outputs': exporter.generic_signature(
|
||||
{ 'outputs': decoded})})
|
||||
if remove_export:
|
||||
actual_export_dir = os.path.join(export_dir, '%08d' % export_version)
|
||||
if os.path.isdir(actual_export_dir):
|
||||
print 'Removing old export'
|
||||
shutil.rmtree(actual_export_dir)
|
||||
try:
|
||||
model_exporter.export(export_dir, tf.constant(export_version), session)
|
||||
print 'Model exported at %s' % (export_dir)
|
||||
except RuntimeError:
|
||||
print sys.exc_info()[1]
|
||||
|
||||
|
||||
# Logging Hyper Parameters and Results
|
||||
# ====================================
|
||||
|
||||
# Now, as training and test are done, we persist the results alongside
|
||||
# with the involved hyper parameters for further reporting.
|
||||
data_sets = read_data_sets()
|
||||
|
||||
with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:
|
||||
json.dump({
|
||||
'context': {
|
||||
'time_started': time_started.isoformat(),
|
||||
'time_finished': time_finished.isoformat(),
|
||||
'git_hash': get_git_revision_hash(),
|
||||
'git_branch': get_git_branch()
|
||||
},
|
||||
'parameters': {
|
||||
'learning_rate': learning_rate,
|
||||
'beta1': beta1,
|
||||
'beta2': beta2,
|
||||
'epsilon': epsilon,
|
||||
'epochs': epochs,
|
||||
'train_batch_size': train_batch_size,
|
||||
'dev_batch_size': dev_batch_size,
|
||||
'test_batch_size': test_batch_size,
|
||||
'validation_step': validation_step,
|
||||
'dropout_rate': dropout_rate,
|
||||
'relu_clip': relu_clip,
|
||||
'n_input': n_input,
|
||||
'n_context': n_context,
|
||||
'n_hidden_1': n_hidden_1,
|
||||
'n_hidden_2': n_hidden_2,
|
||||
'n_hidden_3': n_hidden_3,
|
||||
'n_hidden_5': n_hidden_5,
|
||||
'n_hidden_6': n_hidden_6,
|
||||
'n_cell_dim': n_cell_dim,
|
||||
'n_character': n_character,
|
||||
'total_batches_train': data_sets.train.total_batches,
|
||||
'total_batches_validation': data_sets.dev.total_batches,
|
||||
'total_batches_test': data_sets.test.total_batches,
|
||||
'data_set': {
|
||||
'name': ds_importer
|
||||
}
|
||||
},
|
||||
'results': {
|
||||
'duration': duration,
|
||||
'last_train_wer': last_train_wer,
|
||||
'last_validation_wer': last_dev_wer,
|
||||
'test_wer': test_wer
|
||||
}
|
||||
},
|
||||
'results': {
|
||||
'duration': duration,
|
||||
'last_train_wer': last_train_wer,
|
||||
'last_validation_wer': last_dev_wer,
|
||||
'test_wer': test_wer
|
||||
}
|
||||
}, dump_file, sort_keys=True, indent=4)
|
||||
}, dump_file, sort_keys=True, indent=4)
|
||||
|
||||
# Let's also re-populate a central JS file, that contains all the dumps at once.
|
||||
merge_logs(logs_dir)
|
||||
maybe_publish()
|
||||
# Let's also re-populate a central JS file, that contains all the dumps at once.
|
||||
merge_logs(logs_dir)
|
||||
maybe_publish()
|
||||
|
Loading…
x
Reference in New Issue
Block a user