Make DeepSpeech.py importable without side-effects

This commit is contained in:
Reuben Morais 2017-02-02 21:26:27 -02:00
parent dc13d4be06
commit 0a1d3d49ca

View File

@ -1026,10 +1026,10 @@ def train():
return train_wer, dev_wer, hibernation_path return train_wer, dev_wer, hibernation_path
if __name__ == "__main__":
# As everything is prepared, we are now able to do the training. # As everything is prepared, we are now able to do the training.
# Define CPU as device on which the muti-gpu training is orchestrated # Define CPU as device on which the muti-gpu training is orchestrated
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
# Take start time for time measurement # Take start time for time measurement
time_started = datetime.datetime.utcnow() time_started = datetime.datetime.utcnow()
@ -1049,9 +1049,9 @@ with tf.device('/cpu:0'):
test_wer = print_report("Test", result) test_wer = print_report("Test", result)
# Finally, we restore the trained variables into a simpler graph that we can export for serving. # Finally, we restore the trained variables into a simpler graph that we can export for serving.
# Don't export a model if no export directory has been set # Don't export a model if no export directory has been set
if export_dir: if export_dir:
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
tf.reset_default_graph() tf.reset_default_graph()
session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
@ -1108,14 +1108,14 @@ if export_dir:
print sys.exc_info()[1] print sys.exc_info()[1]
# Logging Hyper Parameters and Results # Logging Hyper Parameters and Results
# ==================================== # ====================================
# Now, as training and test are done, we persist the results alongside # Now, as training and test are done, we persist the results alongside
# with the involved hyper parameters for further reporting. # with the involved hyper parameters for further reporting.
data_sets = read_data_sets() data_sets = read_data_sets()
with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file: with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:
json.dump({ json.dump({
'context': { 'context': {
'time_started': time_started.isoformat(), 'time_started': time_started.isoformat(),
@ -1159,6 +1159,6 @@ with open('%s/%s' % (log_dir, 'hyper.json'), 'w') as dump_file:
} }
}, dump_file, sort_keys=True, indent=4) }, dump_file, sort_keys=True, indent=4)
# Let's also re-populate a central JS file, that contains all the dumps at once. # Let's also re-populate a central JS file, that contains all the dumps at once.
merge_logs(logs_dir) merge_logs(logs_dir)
maybe_publish() maybe_publish()