Fix remote I/O handling in train
This commit is contained in:
parent
8f31072998
commit
a6322b384e
@ -35,7 +35,7 @@ from .util.feeding import create_dataset, audio_to_features, audiofile_to_featur
|
|||||||
from .util.flags import create_flags, FLAGS
|
from .util.flags import create_flags, FLAGS
|
||||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||||
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
|
from .util.logging import create_progressbar, log_debug, log_error, log_info, log_progress, log_warn
|
||||||
from .util.io import open_remote, remove_remote, listdir_remote
|
from .util.io import open_remote, remove_remote, listdir_remote, is_remote_path, isdir_remote
|
||||||
|
|
||||||
check_ctcdecoder_version()
|
check_ctcdecoder_version()
|
||||||
|
|
||||||
@ -513,7 +513,8 @@ def train():
|
|||||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||||
|
|
||||||
# Save flags next to checkpoints
|
# Save flags next to checkpoints
|
||||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
if not is_remote_path(FLAGS.save_checkpoint_dir):
|
||||||
|
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||||
with open_remote(flags_file, 'w') as fout:
|
with open_remote(flags_file, 'w') as fout:
|
||||||
fout.write(FLAGS.flags_into_string())
|
fout.write(FLAGS.flags_into_string())
|
||||||
@ -813,11 +814,11 @@ def export():
|
|||||||
if FLAGS.remove_remote_export:
|
if FLAGS.remove_remote_export:
|
||||||
if isdir_remote(FLAGS.export_dir):
|
if isdir_remote(FLAGS.export_dir):
|
||||||
log_info('Removing old export')
|
log_info('Removing old export')
|
||||||
shutil.rmtree(FLAGS.export_dir)
|
remove_remote(FLAGS.export_dir)
|
||||||
|
|
||||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||||
|
|
||||||
if not isdir_remote(FLAGS.export_dir):
|
if not is_remote_path(FLAGS.export_dir) and not os.path.isdir(FLAGS.export_dir):
|
||||||
os.makedirs(FLAGS.export_dir)
|
os.makedirs(FLAGS.export_dir)
|
||||||
|
|
||||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user